risingwave_stream/executor/exchange/
input.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use anyhow::anyhow;
19use either::Either;
20use local_input::LocalInputStreamInner;
21use pin_project::pin_project;
22use risingwave_common::util::addr::{HostAddr, is_local_address};
23use risingwave_rpc_client::ComputeClientPool;
24use tokio::sync::mpsc;
25
26use super::permit::Receiver;
27use crate::executor::prelude::*;
28use crate::executor::{
29    BarrierInner, DispatcherBarrier, DispatcherMessage, DispatcherMessageBatch,
30    DispatcherMessageStream, DispatcherMessageStreamItem,
31};
32use crate::task::{FragmentId, SharedContext, UpDownActorIds, UpDownFragmentIds};
33
34/// `Input` provides an interface for [`MergeExecutor`](crate::executor::MergeExecutor) and
35/// [`ReceiverExecutor`](crate::executor::ReceiverExecutor) to receive data from upstream actors.
36pub trait Input: DispatcherMessageStream {
37    /// The upstream actor id.
38    fn actor_id(&self) -> ActorId;
39
40    fn boxed_input(self) -> BoxedInput
41    where
42        Self: Sized + 'static,
43    {
44        Box::pin(self)
45    }
46}
47
48pub type BoxedInput = Pin<Box<dyn Input>>;
49
50impl std::fmt::Debug for dyn Input {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("Input")
53            .field("actor_id", &self.actor_id())
54            .finish_non_exhaustive()
55    }
56}
57
58/// `LocalInput` receives data from a local channel.
59#[pin_project]
60pub struct LocalInput {
61    #[pin]
62    inner: LocalInputStreamInner,
63
64    actor_id: ActorId,
65}
66
67pub(crate) fn assert_equal_dispatcher_barrier<M1, M2>(
68    first: &BarrierInner<M1>,
69    second: &BarrierInner<M2>,
70) {
71    assert_eq!(first.epoch, second.epoch);
72    assert_eq!(first.kind, second.kind);
73}
74
75pub(crate) fn apply_dispatcher_barrier(
76    recv_barrier: &mut Barrier,
77    dispatcher_barrier: DispatcherBarrier,
78) {
79    assert_equal_dispatcher_barrier(recv_barrier, &dispatcher_barrier);
80    recv_barrier
81        .passed_actors
82        .extend(dispatcher_barrier.passed_actors);
83}
84
85pub(crate) async fn process_dispatcher_msg(
86    dispatcher_msg: DispatcherMessage,
87    barrier_rx: &mut mpsc::UnboundedReceiver<Barrier>,
88) -> StreamExecutorResult<Message> {
89    let msg = match dispatcher_msg {
90        DispatcherMessage::Chunk(chunk) => Message::Chunk(chunk),
91        DispatcherMessage::Barrier(barrier) => {
92            let mut recv_barrier = barrier_rx
93                .recv()
94                .await
95                .ok_or_else(|| anyhow!("end of barrier recv"))?;
96            apply_dispatcher_barrier(&mut recv_barrier, barrier);
97            Message::Barrier(recv_barrier)
98        }
99        DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
100    };
101    Ok(msg)
102}
103
104impl LocalInput {
105    pub fn new(channel: Receiver, upstream_actor_id: ActorId) -> Self {
106        Self {
107            inner: local_input::run(channel, upstream_actor_id),
108            actor_id: upstream_actor_id,
109        }
110    }
111}
112
113mod local_input {
114    use await_tree::InstrumentAwait;
115    use either::Either;
116
117    use crate::executor::exchange::error::ExchangeChannelClosed;
118    use crate::executor::exchange::permit::Receiver;
119    use crate::executor::prelude::try_stream;
120    use crate::executor::{DispatcherMessage, StreamExecutorError};
121    use crate::task::ActorId;
122
123    pub(super) type LocalInputStreamInner = impl crate::executor::DispatcherMessageStream;
124
125    pub(super) fn run(channel: Receiver, upstream_actor_id: ActorId) -> LocalInputStreamInner {
126        run_inner(channel, upstream_actor_id)
127    }
128
129    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
130    async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) {
131        let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose();
132        while let Some(msg) = channel.recv().instrument_await(span.clone()).await {
133            match msg.into_messages() {
134                Either::Left(barriers) => {
135                    for b in barriers {
136                        yield b;
137                    }
138                }
139                Either::Right(m) => {
140                    yield m;
141                }
142            }
143        }
144        // Always emit an error outside the loop. This is because we use barrier as the control
145        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
146        Err(ExchangeChannelClosed::local_input(upstream_actor_id))?
147    }
148}
149
150impl Stream for LocalInput {
151    type Item = DispatcherMessageStreamItem;
152
153    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154        // TODO: shall we pass the error with local exchange?
155        self.project().inner.poll_next(cx)
156    }
157}
158
159impl Input for LocalInput {
160    fn actor_id(&self) -> ActorId {
161        self.actor_id
162    }
163}
164
165/// `RemoteInput` connects to the upstream exchange server and receives data with `gRPC`.
166#[pin_project]
167pub struct RemoteInput {
168    #[pin]
169    inner: RemoteInputStreamInner,
170
171    actor_id: ActorId,
172}
173
174use remote_input::RemoteInputStreamInner;
175use risingwave_common::catalog::DatabaseId;
176
177impl RemoteInput {
178    /// Create a remote input from compute client and related info. Should provide the corresponding
179    /// compute client of where the actor is placed.
180    #[expect(clippy::too_many_arguments)]
181    pub fn new(
182        client_pool: ComputeClientPool,
183        upstream_addr: HostAddr,
184        up_down_ids: UpDownActorIds,
185        up_down_frag: UpDownFragmentIds,
186        database_id: DatabaseId,
187        metrics: Arc<StreamingMetrics>,
188        batched_permits: usize,
189        term_id: String,
190    ) -> Self {
191        let actor_id = up_down_ids.0;
192
193        Self {
194            actor_id,
195            inner: remote_input::run(
196                client_pool,
197                upstream_addr,
198                up_down_ids,
199                up_down_frag,
200                database_id,
201                metrics,
202                batched_permits,
203                term_id,
204            ),
205        }
206    }
207}
208
209mod remote_input {
210    use std::sync::Arc;
211
212    use anyhow::Context;
213    use await_tree::InstrumentAwait;
214    use either::Either;
215    use risingwave_common::catalog::DatabaseId;
216    use risingwave_common::util::addr::HostAddr;
217    use risingwave_pb::task_service::{GetStreamResponse, permits};
218    use risingwave_rpc_client::ComputeClientPool;
219
220    use crate::executor::exchange::error::ExchangeChannelClosed;
221    use crate::executor::monitor::StreamingMetrics;
222    use crate::executor::prelude::{StreamExt, pin_mut, try_stream};
223    use crate::executor::{DispatcherMessage, StreamExecutorError};
224    use crate::task::{UpDownActorIds, UpDownFragmentIds};
225
226    pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream;
227
228    #[expect(clippy::too_many_arguments)]
229    pub(super) fn run(
230        client_pool: ComputeClientPool,
231        upstream_addr: HostAddr,
232        up_down_ids: UpDownActorIds,
233        up_down_frag: UpDownFragmentIds,
234        database_id: DatabaseId,
235        metrics: Arc<StreamingMetrics>,
236        batched_permits_limit: usize,
237        term_id: String,
238    ) -> RemoteInputStreamInner {
239        run_inner(
240            client_pool,
241            upstream_addr,
242            up_down_ids,
243            up_down_frag,
244            database_id,
245            metrics,
246            batched_permits_limit,
247            term_id,
248        )
249    }
250
251    #[expect(clippy::too_many_arguments)]
252    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
253    async fn run_inner(
254        client_pool: ComputeClientPool,
255        upstream_addr: HostAddr,
256        up_down_ids: UpDownActorIds,
257        up_down_frag: UpDownFragmentIds,
258        database_id: DatabaseId,
259        metrics: Arc<StreamingMetrics>,
260        batched_permits_limit: usize,
261        term_id: String,
262    ) {
263        let client = client_pool.get_by_addr(upstream_addr).await?;
264        let (stream, permits_tx) = client
265            .get_stream(
266                up_down_ids.0,
267                up_down_ids.1,
268                up_down_frag.0,
269                up_down_frag.1,
270                database_id,
271                term_id,
272            )
273            .await?;
274
275        let up_actor_id = up_down_ids.0.to_string();
276        let up_fragment_id = up_down_frag.0.to_string();
277        let down_fragment_id = up_down_frag.1.to_string();
278        let exchange_frag_recv_size_metrics = metrics
279            .exchange_frag_recv_size
280            .with_guarded_label_values(&[&up_fragment_id, &down_fragment_id]);
281
282        let span = await_tree::span!("RemoteInput (actor {up_actor_id})").verbose();
283
284        let mut batched_permits_accumulated = 0;
285
286        pin_mut!(stream);
287        while let Some(data_res) = stream.next().instrument_await(span.clone()).await {
288            match data_res {
289                Ok(GetStreamResponse { message, permits }) => {
290                    use crate::executor::DispatcherMessageBatch;
291                    let msg = message.unwrap();
292                    let bytes = DispatcherMessageBatch::get_encoded_len(&msg);
293
294                    exchange_frag_recv_size_metrics.inc_by(bytes as u64);
295
296                    let msg_res = DispatcherMessageBatch::from_protobuf(&msg);
297                    if let Some(add_back_permits) = match permits.unwrap().value {
298                        // For records, batch the permits we received to reduce the backward
299                        // `AddPermits` messages.
300                        Some(permits::Value::Record(p)) => {
301                            batched_permits_accumulated += p;
302                            if batched_permits_accumulated >= batched_permits_limit as u32 {
303                                let permits = std::mem::take(&mut batched_permits_accumulated);
304                                Some(permits::Value::Record(permits))
305                            } else {
306                                None
307                            }
308                        }
309                        // For barriers, always send it back immediately.
310                        Some(permits::Value::Barrier(p)) => Some(permits::Value::Barrier(p)),
311                        None => None,
312                    } {
313                        permits_tx
314                            .send(add_back_permits)
315                            .context("RemoteInput backward permits channel closed.")?;
316                    }
317
318                    let msg = msg_res.context("RemoteInput decode message error")?;
319                    match msg.into_messages() {
320                        Either::Left(barriers) => {
321                            for b in barriers {
322                                yield b;
323                            }
324                        }
325                        Either::Right(m) => {
326                            yield m;
327                        }
328                    }
329                }
330
331                Err(e) => Err(ExchangeChannelClosed::remote_input(up_down_ids.0, Some(e)))?,
332            }
333        }
334
335        // Always emit an error outside the loop. This is because we use barrier as the control
336        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
337        Err(ExchangeChannelClosed::remote_input(up_down_ids.0, None))?
338    }
339}
340
341impl Stream for RemoteInput {
342    type Item = DispatcherMessageStreamItem;
343
344    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
345        self.project().inner.poll_next(cx)
346    }
347}
348
349impl Input for RemoteInput {
350    fn actor_id(&self) -> ActorId {
351        self.actor_id
352    }
353}
354
355/// Create a [`LocalInput`] or [`RemoteInput`] instance with given info. Used by merge executors and
356/// receiver executors.
357pub(crate) fn new_input(
358    context: &SharedContext,
359    metrics: Arc<StreamingMetrics>,
360    actor_id: ActorId,
361    fragment_id: FragmentId,
362    upstream_actor_id: ActorId,
363    upstream_fragment_id: FragmentId,
364) -> StreamResult<BoxedInput> {
365    let upstream_addr = context
366        .get_actor_info(&upstream_actor_id)?
367        .get_host()?
368        .into();
369
370    let input = if is_local_address(&context.addr, &upstream_addr) {
371        LocalInput::new(
372            context.take_receiver((upstream_actor_id, actor_id))?,
373            upstream_actor_id,
374        )
375        .boxed_input()
376    } else {
377        RemoteInput::new(
378            context.compute_client_pool.as_ref().to_owned(),
379            upstream_addr,
380            (upstream_actor_id, actor_id),
381            (upstream_fragment_id, fragment_id),
382            context.database_id,
383            metrics,
384            context.config.developer.exchange_batched_permits,
385            context.term_id(),
386        )
387        .boxed_input()
388    };
389
390    Ok(input)
391}
392
393impl DispatcherMessageBatch {
394    fn into_messages(self) -> Either<impl Iterator<Item = DispatcherMessage>, DispatcherMessage> {
395        match self {
396            DispatcherMessageBatch::BarrierBatch(barriers) => {
397                Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier))
398            }
399            DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)),
400            DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)),
401        }
402    }
403}