risingwave_stream/executor/exchange/
input.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use anyhow::anyhow;
19use either::Either;
20use local_input::LocalInputStreamInner;
21use pin_project::pin_project;
22use risingwave_common::util::addr::{HostAddr, is_local_address};
23use tokio::sync::mpsc;
24
25use super::permit::Receiver;
26use crate::executor::prelude::*;
27use crate::executor::{
28    BarrierInner, DispatcherBarrier, DispatcherMessage, DispatcherMessageBatch,
29    DispatcherMessageStreamItem,
30};
31use crate::task::{FragmentId, LocalBarrierManager, UpDownActorIds, UpDownFragmentIds};
32
33/// `Input` is a more abstract upstream input type, used for `DynamicReceivers` type
34/// handling of multiple upstream inputs
35pub trait Input: Stream + Send {
36    type InputId;
37    /// The upstream input id.
38    fn id(&self) -> Self::InputId;
39
40    fn boxed_input(self) -> BoxedInput<Self::InputId, Self::Item>
41    where
42        Self: Sized + 'static,
43    {
44        Box::pin(self)
45    }
46}
47
48pub type BoxedInput<InputId, Item> = Pin<Box<dyn Input<InputId = InputId, Item = Item>>>;
49
50/// `ActorInput` provides an interface for [`MergeExecutor`](crate::executor::MergeExecutor) and
51/// [`ReceiverExecutor`](crate::executor::ReceiverExecutor) to receive data from upstream actors.
52/// Only used for actor inputs.
53pub trait ActorInput = Input<Item = DispatcherMessageStreamItem, InputId = ActorId>;
54
55pub type BoxedActorInput = Pin<Box<dyn ActorInput>>;
56
57impl std::fmt::Debug for dyn ActorInput {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("Input")
60            .field("actor_id", &self.id())
61            .finish_non_exhaustive()
62    }
63}
64
65/// `LocalInput` receives data from a local channel.
66#[pin_project]
67pub struct LocalInput {
68    #[pin]
69    inner: LocalInputStreamInner,
70
71    actor_id: ActorId,
72}
73
74pub(crate) fn assert_equal_dispatcher_barrier<M1, M2>(
75    first: &BarrierInner<M1>,
76    second: &BarrierInner<M2>,
77) {
78    assert_eq!(first.epoch, second.epoch);
79    assert_eq!(first.kind, second.kind);
80}
81
82pub(crate) fn apply_dispatcher_barrier(
83    recv_barrier: &mut Barrier,
84    dispatcher_barrier: DispatcherBarrier,
85) {
86    assert_equal_dispatcher_barrier(recv_barrier, &dispatcher_barrier);
87    recv_barrier
88        .passed_actors
89        .extend(dispatcher_barrier.passed_actors);
90}
91
92pub(crate) async fn process_dispatcher_msg(
93    dispatcher_msg: DispatcherMessage,
94    barrier_rx: &mut mpsc::UnboundedReceiver<Barrier>,
95) -> StreamExecutorResult<Message> {
96    let msg = match dispatcher_msg {
97        DispatcherMessage::Chunk(chunk) => Message::Chunk(chunk),
98        DispatcherMessage::Barrier(barrier) => {
99            let mut recv_barrier = barrier_rx
100                .recv()
101                .await
102                .ok_or_else(|| anyhow!("end of barrier recv"))?;
103            apply_dispatcher_barrier(&mut recv_barrier, barrier);
104            Message::Barrier(recv_barrier)
105        }
106        DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
107    };
108    Ok(msg)
109}
110
111impl LocalInput {
112    pub fn new(channel: Receiver, upstream_actor_id: ActorId) -> Self {
113        Self {
114            inner: local_input::run(channel, upstream_actor_id),
115            actor_id: upstream_actor_id,
116        }
117    }
118}
119
120mod local_input {
121    use await_tree::InstrumentAwait;
122    use either::Either;
123
124    use crate::executor::exchange::error::ExchangeChannelClosed;
125    use crate::executor::exchange::permit::Receiver;
126    use crate::executor::prelude::try_stream;
127    use crate::executor::{DispatcherMessage, StreamExecutorError};
128    use crate::task::ActorId;
129
130    pub(super) type LocalInputStreamInner = impl crate::executor::DispatcherMessageStream;
131
132    #[define_opaque(LocalInputStreamInner)]
133    pub(super) fn run(channel: Receiver, upstream_actor_id: ActorId) -> LocalInputStreamInner {
134        run_inner(channel, upstream_actor_id)
135    }
136
137    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
138    async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) {
139        let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose();
140        while let Some(msg) = channel.recv().instrument_await(span.clone()).await {
141            match msg.into_messages() {
142                Either::Left(barriers) => {
143                    for b in barriers {
144                        yield b;
145                    }
146                }
147                Either::Right(m) => {
148                    yield m;
149                }
150            }
151        }
152        // Always emit an error outside the loop. This is because we use barrier as the control
153        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
154        Err(ExchangeChannelClosed::local_input(upstream_actor_id))?
155    }
156}
157
158impl Stream for LocalInput {
159    type Item = DispatcherMessageStreamItem;
160
161    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
162        // TODO: shall we pass the error with local exchange?
163        self.project().inner.poll_next(cx)
164    }
165}
166
167impl Input for LocalInput {
168    type InputId = ActorId;
169
170    fn id(&self) -> Self::InputId {
171        self.actor_id
172    }
173}
174
175/// `RemoteInput` connects to the upstream exchange server and receives data with `gRPC`.
176#[pin_project]
177pub struct RemoteInput {
178    #[pin]
179    inner: RemoteInputStreamInner,
180
181    actor_id: ActorId,
182}
183
184use remote_input::RemoteInputStreamInner;
185use risingwave_pb::common::ActorInfo;
186
187impl RemoteInput {
188    /// Create a remote input from compute client and related info. Should provide the corresponding
189    /// compute client of where the actor is placed.
190    pub async fn new(
191        local_barrier_manager: &LocalBarrierManager,
192        upstream_addr: HostAddr,
193        up_down_ids: UpDownActorIds,
194        up_down_frag: UpDownFragmentIds,
195        metrics: Arc<StreamingMetrics>,
196    ) -> StreamExecutorResult<Self> {
197        let actor_id = up_down_ids.0;
198
199        let client = local_barrier_manager
200            .env
201            .client_pool()
202            .get_by_addr(upstream_addr)
203            .await?;
204        let (stream, permits_tx) = client
205            .get_stream(
206                up_down_ids.0,
207                up_down_ids.1,
208                up_down_frag.0,
209                up_down_frag.1,
210                local_barrier_manager.database_id,
211                local_barrier_manager.term_id.clone(),
212            )
213            .await?;
214
215        Ok(Self {
216            actor_id,
217            inner: remote_input::run(
218                stream,
219                permits_tx,
220                up_down_ids,
221                up_down_frag,
222                metrics,
223                local_barrier_manager
224                    .env
225                    .config()
226                    .developer
227                    .exchange_batched_permits,
228            ),
229        })
230    }
231}
232
233mod remote_input {
234    use std::sync::Arc;
235
236    use anyhow::Context;
237    use await_tree::InstrumentAwait;
238    use either::Either;
239    use risingwave_pb::task_service::{GetStreamResponse, permits};
240    use tokio::sync::mpsc;
241    use tonic::Streaming;
242
243    use crate::executor::exchange::error::ExchangeChannelClosed;
244    use crate::executor::monitor::StreamingMetrics;
245    use crate::executor::prelude::{StreamExt, pin_mut, try_stream};
246    use crate::executor::{DispatcherMessage, StreamExecutorError};
247    use crate::task::{UpDownActorIds, UpDownFragmentIds};
248
249    pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream;
250
251    #[define_opaque(RemoteInputStreamInner)]
252    pub(super) fn run(
253        stream: Streaming<GetStreamResponse>,
254        permits_tx: mpsc::UnboundedSender<permits::Value>,
255        up_down_ids: UpDownActorIds,
256        up_down_frag: UpDownFragmentIds,
257        metrics: Arc<StreamingMetrics>,
258        batched_permits_limit: usize,
259    ) -> RemoteInputStreamInner {
260        run_inner(
261            stream,
262            permits_tx,
263            up_down_ids,
264            up_down_frag,
265            metrics,
266            batched_permits_limit,
267        )
268    }
269
270    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
271    async fn run_inner(
272        stream: Streaming<GetStreamResponse>,
273        permits_tx: mpsc::UnboundedSender<permits::Value>,
274        up_down_ids: UpDownActorIds,
275        up_down_frag: UpDownFragmentIds,
276        metrics: Arc<StreamingMetrics>,
277        batched_permits_limit: usize,
278    ) {
279        let up_actor_id = up_down_ids.0.to_string();
280        let up_fragment_id = up_down_frag.0.to_string();
281        let down_fragment_id = up_down_frag.1.to_string();
282        let exchange_frag_recv_size_metrics = metrics
283            .exchange_frag_recv_size
284            .with_guarded_label_values(&[&up_fragment_id, &down_fragment_id]);
285
286        let span = await_tree::span!("RemoteInput (actor {up_actor_id})").verbose();
287
288        let mut batched_permits_accumulated = 0;
289
290        pin_mut!(stream);
291        while let Some(data_res) = stream.next().instrument_await(span.clone()).await {
292            match data_res {
293                Ok(GetStreamResponse { message, permits }) => {
294                    use crate::executor::DispatcherMessageBatch;
295                    let msg = message.unwrap();
296                    let bytes = DispatcherMessageBatch::get_encoded_len(&msg);
297
298                    exchange_frag_recv_size_metrics.inc_by(bytes as u64);
299
300                    let msg_res = DispatcherMessageBatch::from_protobuf(&msg);
301                    if let Some(add_back_permits) = match permits.unwrap().value {
302                        // For records, batch the permits we received to reduce the backward
303                        // `AddPermits` messages.
304                        Some(permits::Value::Record(p)) => {
305                            batched_permits_accumulated += p;
306                            if batched_permits_accumulated >= batched_permits_limit as u32 {
307                                let permits = std::mem::take(&mut batched_permits_accumulated);
308                                Some(permits::Value::Record(permits))
309                            } else {
310                                None
311                            }
312                        }
313                        // For barriers, always send it back immediately.
314                        Some(permits::Value::Barrier(p)) => Some(permits::Value::Barrier(p)),
315                        None => None,
316                    } {
317                        permits_tx
318                            .send(add_back_permits)
319                            .context("RemoteInput backward permits channel closed.")?;
320                    }
321
322                    let msg = msg_res.context("RemoteInput decode message error")?;
323                    match msg.into_messages() {
324                        Either::Left(barriers) => {
325                            for b in barriers {
326                                yield b;
327                            }
328                        }
329                        Either::Right(m) => {
330                            yield m;
331                        }
332                    }
333                }
334
335                Err(e) => Err(ExchangeChannelClosed::remote_input(up_down_ids.0, Some(e)))?,
336            }
337        }
338
339        // Always emit an error outside the loop. This is because we use barrier as the control
340        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
341        Err(ExchangeChannelClosed::remote_input(up_down_ids.0, None))?
342    }
343}
344
345impl Stream for RemoteInput {
346    type Item = DispatcherMessageStreamItem;
347
348    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349        self.project().inner.poll_next(cx)
350    }
351}
352
353impl Input for RemoteInput {
354    type InputId = ActorId;
355
356    fn id(&self) -> Self::InputId {
357        self.actor_id
358    }
359}
360
361/// Create a [`LocalInput`] or [`RemoteInput`] instance with given info. Used by merge executors and
362/// receiver executors.
363pub(crate) async fn new_input(
364    local_barrier_manager: &LocalBarrierManager,
365    metrics: Arc<StreamingMetrics>,
366    actor_id: ActorId,
367    fragment_id: FragmentId,
368    upstream_actor_info: &ActorInfo,
369    upstream_fragment_id: FragmentId,
370) -> StreamExecutorResult<BoxedActorInput> {
371    let upstream_actor_id = upstream_actor_info.actor_id;
372    let upstream_addr = upstream_actor_info.get_host()?.into();
373
374    let input = if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) {
375        LocalInput::new(
376            local_barrier_manager.register_local_upstream_output(actor_id, upstream_actor_id),
377            upstream_actor_id,
378        )
379        .boxed_input()
380    } else {
381        RemoteInput::new(
382            local_barrier_manager,
383            upstream_addr,
384            (upstream_actor_id, actor_id),
385            (upstream_fragment_id, fragment_id),
386            metrics,
387        )
388        .await?
389        .boxed_input()
390    };
391
392    Ok(input)
393}
394
395impl DispatcherMessageBatch {
396    fn into_messages(self) -> Either<impl Iterator<Item = DispatcherMessage>, DispatcherMessage> {
397        match self {
398            DispatcherMessageBatch::BarrierBatch(barriers) => {
399                Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier))
400            }
401            DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)),
402            DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)),
403        }
404    }
405}