risingwave_stream/executor/exchange/
input.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use anyhow::anyhow;
19use either::Either;
20use local_input::LocalInputStreamInner;
21use pin_project::pin_project;
22use risingwave_common::util::addr::{HostAddr, is_local_address};
23use tokio::sync::mpsc;
24
25use super::permit::Receiver;
26use crate::executor::prelude::*;
27use crate::executor::{
28    BarrierInner, DispatcherBarrier, DispatcherMessage, DispatcherMessageBatch,
29    DispatcherMessageStream, DispatcherMessageStreamItem,
30};
31use crate::task::{FragmentId, LocalBarrierManager, UpDownActorIds, UpDownFragmentIds};
32
33/// `Input` provides an interface for [`MergeExecutor`](crate::executor::MergeExecutor) and
34/// [`ReceiverExecutor`](crate::executor::ReceiverExecutor) to receive data from upstream actors.
35pub trait Input: DispatcherMessageStream {
36    /// The upstream actor id.
37    fn actor_id(&self) -> ActorId;
38
39    fn boxed_input(self) -> BoxedInput
40    where
41        Self: Sized + 'static,
42    {
43        Box::pin(self)
44    }
45}
46
47pub type BoxedInput = Pin<Box<dyn Input>>;
48
49impl std::fmt::Debug for dyn Input {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("Input")
52            .field("actor_id", &self.actor_id())
53            .finish_non_exhaustive()
54    }
55}
56
57/// `LocalInput` receives data from a local channel.
58#[pin_project]
59pub struct LocalInput {
60    #[pin]
61    inner: LocalInputStreamInner,
62
63    actor_id: ActorId,
64}
65
66pub(crate) fn assert_equal_dispatcher_barrier<M1, M2>(
67    first: &BarrierInner<M1>,
68    second: &BarrierInner<M2>,
69) {
70    assert_eq!(first.epoch, second.epoch);
71    assert_eq!(first.kind, second.kind);
72}
73
74pub(crate) fn apply_dispatcher_barrier(
75    recv_barrier: &mut Barrier,
76    dispatcher_barrier: DispatcherBarrier,
77) {
78    assert_equal_dispatcher_barrier(recv_barrier, &dispatcher_barrier);
79    recv_barrier
80        .passed_actors
81        .extend(dispatcher_barrier.passed_actors);
82}
83
84pub(crate) async fn process_dispatcher_msg(
85    dispatcher_msg: DispatcherMessage,
86    barrier_rx: &mut mpsc::UnboundedReceiver<Barrier>,
87) -> StreamExecutorResult<Message> {
88    let msg = match dispatcher_msg {
89        DispatcherMessage::Chunk(chunk) => Message::Chunk(chunk),
90        DispatcherMessage::Barrier(barrier) => {
91            let mut recv_barrier = barrier_rx
92                .recv()
93                .await
94                .ok_or_else(|| anyhow!("end of barrier recv"))?;
95            apply_dispatcher_barrier(&mut recv_barrier, barrier);
96            Message::Barrier(recv_barrier)
97        }
98        DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
99    };
100    Ok(msg)
101}
102
103impl LocalInput {
104    pub fn new(channel: Receiver, upstream_actor_id: ActorId) -> Self {
105        Self {
106            inner: local_input::run(channel, upstream_actor_id),
107            actor_id: upstream_actor_id,
108        }
109    }
110}
111
112mod local_input {
113    use await_tree::InstrumentAwait;
114    use either::Either;
115
116    use crate::executor::exchange::error::ExchangeChannelClosed;
117    use crate::executor::exchange::permit::Receiver;
118    use crate::executor::prelude::try_stream;
119    use crate::executor::{DispatcherMessage, StreamExecutorError};
120    use crate::task::ActorId;
121
122    pub(super) type LocalInputStreamInner = impl crate::executor::DispatcherMessageStream;
123
124    pub(super) fn run(channel: Receiver, upstream_actor_id: ActorId) -> LocalInputStreamInner {
125        run_inner(channel, upstream_actor_id)
126    }
127
128    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
129    async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) {
130        let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose();
131        while let Some(msg) = channel.recv().instrument_await(span.clone()).await {
132            match msg.into_messages() {
133                Either::Left(barriers) => {
134                    for b in barriers {
135                        yield b;
136                    }
137                }
138                Either::Right(m) => {
139                    yield m;
140                }
141            }
142        }
143        // Always emit an error outside the loop. This is because we use barrier as the control
144        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
145        Err(ExchangeChannelClosed::local_input(upstream_actor_id))?
146    }
147}
148
149impl Stream for LocalInput {
150    type Item = DispatcherMessageStreamItem;
151
152    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
153        // TODO: shall we pass the error with local exchange?
154        self.project().inner.poll_next(cx)
155    }
156}
157
158impl Input for LocalInput {
159    fn actor_id(&self) -> ActorId {
160        self.actor_id
161    }
162}
163
164/// `RemoteInput` connects to the upstream exchange server and receives data with `gRPC`.
165#[pin_project]
166pub struct RemoteInput {
167    #[pin]
168    inner: RemoteInputStreamInner,
169
170    actor_id: ActorId,
171}
172
173use remote_input::RemoteInputStreamInner;
174use risingwave_pb::common::ActorInfo;
175
176impl RemoteInput {
177    /// Create a remote input from compute client and related info. Should provide the corresponding
178    /// compute client of where the actor is placed.
179    pub async fn new(
180        local_barrier_manager: &LocalBarrierManager,
181        upstream_addr: HostAddr,
182        up_down_ids: UpDownActorIds,
183        up_down_frag: UpDownFragmentIds,
184        metrics: Arc<StreamingMetrics>,
185    ) -> StreamExecutorResult<Self> {
186        let actor_id = up_down_ids.0;
187
188        let client = local_barrier_manager
189            .env
190            .client_pool()
191            .get_by_addr(upstream_addr)
192            .await?;
193        let (stream, permits_tx) = client
194            .get_stream(
195                up_down_ids.0,
196                up_down_ids.1,
197                up_down_frag.0,
198                up_down_frag.1,
199                local_barrier_manager.database_id,
200                local_barrier_manager.term_id.clone(),
201            )
202            .await?;
203
204        Ok(Self {
205            actor_id,
206            inner: remote_input::run(
207                stream,
208                permits_tx,
209                up_down_ids,
210                up_down_frag,
211                metrics,
212                local_barrier_manager
213                    .env
214                    .config()
215                    .developer
216                    .exchange_batched_permits,
217            ),
218        })
219    }
220}
221
222mod remote_input {
223    use std::sync::Arc;
224
225    use anyhow::Context;
226    use await_tree::InstrumentAwait;
227    use either::Either;
228    use risingwave_pb::task_service::{GetStreamResponse, permits};
229    use tokio::sync::mpsc;
230    use tonic::Streaming;
231
232    use crate::executor::exchange::error::ExchangeChannelClosed;
233    use crate::executor::monitor::StreamingMetrics;
234    use crate::executor::prelude::{StreamExt, pin_mut, try_stream};
235    use crate::executor::{DispatcherMessage, StreamExecutorError};
236    use crate::task::{UpDownActorIds, UpDownFragmentIds};
237
238    pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream;
239
240    pub(super) fn run(
241        stream: Streaming<GetStreamResponse>,
242        permits_tx: mpsc::UnboundedSender<permits::Value>,
243        up_down_ids: UpDownActorIds,
244        up_down_frag: UpDownFragmentIds,
245        metrics: Arc<StreamingMetrics>,
246        batched_permits_limit: usize,
247    ) -> RemoteInputStreamInner {
248        run_inner(
249            stream,
250            permits_tx,
251            up_down_ids,
252            up_down_frag,
253            metrics,
254            batched_permits_limit,
255        )
256    }
257
258    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
259    async fn run_inner(
260        stream: Streaming<GetStreamResponse>,
261        permits_tx: mpsc::UnboundedSender<permits::Value>,
262        up_down_ids: UpDownActorIds,
263        up_down_frag: UpDownFragmentIds,
264        metrics: Arc<StreamingMetrics>,
265        batched_permits_limit: usize,
266    ) {
267        let up_actor_id = up_down_ids.0.to_string();
268        let up_fragment_id = up_down_frag.0.to_string();
269        let down_fragment_id = up_down_frag.1.to_string();
270        let exchange_frag_recv_size_metrics = metrics
271            .exchange_frag_recv_size
272            .with_guarded_label_values(&[&up_fragment_id, &down_fragment_id]);
273
274        let span = await_tree::span!("RemoteInput (actor {up_actor_id})").verbose();
275
276        let mut batched_permits_accumulated = 0;
277
278        pin_mut!(stream);
279        while let Some(data_res) = stream.next().instrument_await(span.clone()).await {
280            match data_res {
281                Ok(GetStreamResponse { message, permits }) => {
282                    use crate::executor::DispatcherMessageBatch;
283                    let msg = message.unwrap();
284                    let bytes = DispatcherMessageBatch::get_encoded_len(&msg);
285
286                    exchange_frag_recv_size_metrics.inc_by(bytes as u64);
287
288                    let msg_res = DispatcherMessageBatch::from_protobuf(&msg);
289                    if let Some(add_back_permits) = match permits.unwrap().value {
290                        // For records, batch the permits we received to reduce the backward
291                        // `AddPermits` messages.
292                        Some(permits::Value::Record(p)) => {
293                            batched_permits_accumulated += p;
294                            if batched_permits_accumulated >= batched_permits_limit as u32 {
295                                let permits = std::mem::take(&mut batched_permits_accumulated);
296                                Some(permits::Value::Record(permits))
297                            } else {
298                                None
299                            }
300                        }
301                        // For barriers, always send it back immediately.
302                        Some(permits::Value::Barrier(p)) => Some(permits::Value::Barrier(p)),
303                        None => None,
304                    } {
305                        permits_tx
306                            .send(add_back_permits)
307                            .context("RemoteInput backward permits channel closed.")?;
308                    }
309
310                    let msg = msg_res.context("RemoteInput decode message error")?;
311                    match msg.into_messages() {
312                        Either::Left(barriers) => {
313                            for b in barriers {
314                                yield b;
315                            }
316                        }
317                        Either::Right(m) => {
318                            yield m;
319                        }
320                    }
321                }
322
323                Err(e) => Err(ExchangeChannelClosed::remote_input(up_down_ids.0, Some(e)))?,
324            }
325        }
326
327        // Always emit an error outside the loop. This is because we use barrier as the control
328        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
329        Err(ExchangeChannelClosed::remote_input(up_down_ids.0, None))?
330    }
331}
332
333impl Stream for RemoteInput {
334    type Item = DispatcherMessageStreamItem;
335
336    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
337        self.project().inner.poll_next(cx)
338    }
339}
340
341impl Input for RemoteInput {
342    fn actor_id(&self) -> ActorId {
343        self.actor_id
344    }
345}
346
347/// Create a [`LocalInput`] or [`RemoteInput`] instance with given info. Used by merge executors and
348/// receiver executors.
349pub(crate) async fn new_input(
350    local_barrier_manager: &LocalBarrierManager,
351    metrics: Arc<StreamingMetrics>,
352    actor_id: ActorId,
353    fragment_id: FragmentId,
354    upstream_actor_info: &ActorInfo,
355    upstream_fragment_id: FragmentId,
356) -> StreamExecutorResult<BoxedInput> {
357    let upstream_actor_id = upstream_actor_info.actor_id;
358    let upstream_addr = upstream_actor_info.get_host()?.into();
359
360    let input = if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) {
361        LocalInput::new(
362            local_barrier_manager.register_local_upstream_output(actor_id, upstream_actor_id),
363            upstream_actor_id,
364        )
365        .boxed_input()
366    } else {
367        RemoteInput::new(
368            local_barrier_manager,
369            upstream_addr,
370            (upstream_actor_id, actor_id),
371            (upstream_fragment_id, fragment_id),
372            metrics,
373        )
374        .await?
375        .boxed_input()
376    };
377
378    Ok(input)
379}
380
381impl DispatcherMessageBatch {
382    fn into_messages(self) -> Either<impl Iterator<Item = DispatcherMessage>, DispatcherMessage> {
383        match self {
384            DispatcherMessageBatch::BarrierBatch(barriers) => {
385                Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier))
386            }
387            DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)),
388            DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)),
389        }
390    }
391}