risingwave_stream/executor/exchange/
input.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use either::Either;
19use local_input::LocalInputStreamInner;
20use pin_project::pin_project;
21use risingwave_common::util::addr::{HostAddr, is_local_address};
22
23use super::permit::Receiver;
24use crate::executor::prelude::*;
25use crate::executor::{
26    BarrierInner, DispatcherBarrier, DispatcherMessage, DispatcherMessageBatch,
27    DispatcherMessageStreamItem,
28};
29use crate::task::{FragmentId, LocalBarrierManager, UpDownActorIds, UpDownFragmentIds};
30
31/// `Input` is a more abstract upstream input type, used for `DynamicReceivers` type
32/// handling of multiple upstream inputs
33pub trait Input: Stream + Send {
34    type InputId;
35    /// The upstream input id.
36    fn id(&self) -> Self::InputId;
37
38    fn boxed_input(self) -> BoxedInput<Self::InputId, Self::Item>
39    where
40        Self: Sized + 'static,
41    {
42        Box::pin(self)
43    }
44}
45
46pub type BoxedInput<InputId, Item> = Pin<Box<dyn Input<InputId = InputId, Item = Item>>>;
47
48/// `ActorInput` provides an interface for [`MergeExecutor`](crate::executor::MergeExecutor) and
49/// [`ReceiverExecutor`](crate::executor::ReceiverExecutor) to receive data from upstream actors.
50/// Only used for actor inputs.
51pub trait ActorInput = Input<Item = DispatcherMessageStreamItem, InputId = ActorId>;
52
53pub type BoxedActorInput = Pin<Box<dyn ActorInput>>;
54
55impl std::fmt::Debug for dyn ActorInput {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("Input")
58            .field("actor_id", &self.id())
59            .finish_non_exhaustive()
60    }
61}
62
63/// `LocalInput` receives data from a local channel.
64#[pin_project]
65pub struct LocalInput {
66    #[pin]
67    inner: LocalInputStreamInner,
68
69    actor_id: ActorId,
70}
71
72pub(crate) fn assert_equal_dispatcher_barrier<M1, M2>(
73    first: &BarrierInner<M1>,
74    second: &BarrierInner<M2>,
75) {
76    assert_eq!(first.epoch, second.epoch);
77    assert_eq!(first.kind, second.kind);
78}
79
80pub(crate) fn apply_dispatcher_barrier(
81    recv_barrier: &mut Barrier,
82    dispatcher_barrier: DispatcherBarrier,
83) {
84    assert_equal_dispatcher_barrier(recv_barrier, &dispatcher_barrier);
85    recv_barrier
86        .passed_actors
87        .extend(dispatcher_barrier.passed_actors);
88}
89
90impl LocalInput {
91    pub fn new(channel: Receiver, upstream_actor_id: ActorId) -> Self {
92        Self {
93            inner: local_input::run(channel, upstream_actor_id),
94            actor_id: upstream_actor_id,
95        }
96    }
97}
98
99mod local_input {
100    use await_tree::InstrumentAwait;
101    use either::Either;
102
103    use crate::executor::exchange::error::ExchangeChannelClosed;
104    use crate::executor::exchange::permit::Receiver;
105    use crate::executor::prelude::try_stream;
106    use crate::executor::{DispatcherMessage, StreamExecutorError};
107    use crate::task::ActorId;
108
109    pub(super) type LocalInputStreamInner = impl crate::executor::DispatcherMessageStream;
110
111    #[define_opaque(LocalInputStreamInner)]
112    pub(super) fn run(channel: Receiver, upstream_actor_id: ActorId) -> LocalInputStreamInner {
113        run_inner(channel, upstream_actor_id)
114    }
115
116    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
117    async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) {
118        let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose();
119        while let Some(msg) = channel.recv().instrument_await(span.clone()).await {
120            match msg.into_messages() {
121                Either::Left(barriers) => {
122                    for b in barriers {
123                        yield b;
124                    }
125                }
126                Either::Right(m) => {
127                    yield m;
128                }
129            }
130        }
131        // Always emit an error outside the loop. This is because we use barrier as the control
132        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
133        Err(ExchangeChannelClosed::local_input(upstream_actor_id))?
134    }
135}
136
137impl Stream for LocalInput {
138    type Item = DispatcherMessageStreamItem;
139
140    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141        // TODO: shall we pass the error with local exchange?
142        self.project().inner.poll_next(cx)
143    }
144}
145
146impl Input for LocalInput {
147    type InputId = ActorId;
148
149    fn id(&self) -> Self::InputId {
150        self.actor_id
151    }
152}
153
154/// `RemoteInput` connects to the upstream exchange server and receives data with `gRPC`.
155#[pin_project]
156pub struct RemoteInput {
157    #[pin]
158    inner: RemoteInputStreamInner,
159
160    actor_id: ActorId,
161}
162
163use remote_input::RemoteInputStreamInner;
164use risingwave_pb::common::ActorInfo;
165
166impl RemoteInput {
167    /// Create a remote input from compute client and related info. Should provide the corresponding
168    /// compute client of where the actor is placed.
169    pub async fn new(
170        local_barrier_manager: &LocalBarrierManager,
171        upstream_addr: HostAddr,
172        up_down_ids: UpDownActorIds,
173        up_down_frag: UpDownFragmentIds,
174        metrics: Arc<StreamingMetrics>,
175    ) -> StreamExecutorResult<Self> {
176        let actor_id = up_down_ids.0;
177
178        let client = local_barrier_manager
179            .env
180            .client_pool()
181            .get_by_addr(upstream_addr)
182            .await?;
183        let (stream, permits_tx) = client
184            .get_stream(
185                up_down_ids.0,
186                up_down_ids.1,
187                up_down_frag.0,
188                up_down_frag.1,
189                local_barrier_manager.database_id,
190                local_barrier_manager.term_id.clone(),
191            )
192            .await?;
193
194        Ok(Self {
195            actor_id,
196            inner: remote_input::run(
197                stream,
198                permits_tx,
199                up_down_ids,
200                up_down_frag,
201                metrics,
202                local_barrier_manager
203                    .env
204                    .config()
205                    .developer
206                    .exchange_batched_permits,
207            ),
208        })
209    }
210}
211
212mod remote_input {
213    use std::sync::Arc;
214
215    use anyhow::Context;
216    use await_tree::InstrumentAwait;
217    use either::Either;
218    use risingwave_pb::task_service::{GetStreamResponse, permits};
219    use tokio::sync::mpsc;
220    use tonic::Streaming;
221
222    use crate::executor::exchange::error::ExchangeChannelClosed;
223    use crate::executor::monitor::StreamingMetrics;
224    use crate::executor::prelude::{StreamExt, pin_mut, try_stream};
225    use crate::executor::{DispatcherMessage, StreamExecutorError};
226    use crate::task::{UpDownActorIds, UpDownFragmentIds};
227
228    pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream;
229
230    #[define_opaque(RemoteInputStreamInner)]
231    pub(super) fn run(
232        stream: Streaming<GetStreamResponse>,
233        permits_tx: mpsc::UnboundedSender<permits::Value>,
234        up_down_ids: UpDownActorIds,
235        up_down_frag: UpDownFragmentIds,
236        metrics: Arc<StreamingMetrics>,
237        batched_permits_limit: usize,
238    ) -> RemoteInputStreamInner {
239        run_inner(
240            stream,
241            permits_tx,
242            up_down_ids,
243            up_down_frag,
244            metrics,
245            batched_permits_limit,
246        )
247    }
248
249    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
250    async fn run_inner(
251        stream: Streaming<GetStreamResponse>,
252        permits_tx: mpsc::UnboundedSender<permits::Value>,
253        up_down_ids: UpDownActorIds,
254        up_down_frag: UpDownFragmentIds,
255        metrics: Arc<StreamingMetrics>,
256        batched_permits_limit: usize,
257    ) {
258        let up_actor_id = up_down_ids.0.to_string();
259        let up_fragment_id = up_down_frag.0.to_string();
260        let down_fragment_id = up_down_frag.1.to_string();
261        let exchange_frag_recv_size_metrics = metrics
262            .exchange_frag_recv_size
263            .with_guarded_label_values(&[&up_fragment_id, &down_fragment_id]);
264
265        let span = await_tree::span!("RemoteInput (actor {up_actor_id})").verbose();
266
267        let mut batched_permits_accumulated = 0;
268
269        pin_mut!(stream);
270        while let Some(data_res) = stream.next().instrument_await(span.clone()).await {
271            match data_res {
272                Ok(GetStreamResponse { message, permits }) => {
273                    use crate::executor::DispatcherMessageBatch;
274                    let msg = message.unwrap();
275                    let bytes = DispatcherMessageBatch::get_encoded_len(&msg);
276
277                    exchange_frag_recv_size_metrics.inc_by(bytes as u64);
278
279                    let msg_res = DispatcherMessageBatch::from_protobuf(&msg);
280                    if let Some(add_back_permits) = match permits.unwrap().value {
281                        // For records, batch the permits we received to reduce the backward
282                        // `AddPermits` messages.
283                        Some(permits::Value::Record(p)) => {
284                            batched_permits_accumulated += p;
285                            if batched_permits_accumulated >= batched_permits_limit as u32 {
286                                let permits = std::mem::take(&mut batched_permits_accumulated);
287                                Some(permits::Value::Record(permits))
288                            } else {
289                                None
290                            }
291                        }
292                        // For barriers, always send it back immediately.
293                        Some(permits::Value::Barrier(p)) => Some(permits::Value::Barrier(p)),
294                        None => None,
295                    } {
296                        permits_tx
297                            .send(add_back_permits)
298                            .context("RemoteInput backward permits channel closed.")?;
299                    }
300
301                    let msg = msg_res.context("RemoteInput decode message error")?;
302                    match msg.into_messages() {
303                        Either::Left(barriers) => {
304                            for b in barriers {
305                                yield b;
306                            }
307                        }
308                        Either::Right(m) => {
309                            yield m;
310                        }
311                    }
312                }
313
314                Err(e) => Err(ExchangeChannelClosed::remote_input(up_down_ids.0, Some(e)))?,
315            }
316        }
317
318        // Always emit an error outside the loop. This is because we use barrier as the control
319        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
320        Err(ExchangeChannelClosed::remote_input(up_down_ids.0, None))?
321    }
322}
323
324impl Stream for RemoteInput {
325    type Item = DispatcherMessageStreamItem;
326
327    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
328        self.project().inner.poll_next(cx)
329    }
330}
331
332impl Input for RemoteInput {
333    type InputId = ActorId;
334
335    fn id(&self) -> Self::InputId {
336        self.actor_id
337    }
338}
339
340/// Create a [`LocalInput`] or [`RemoteInput`] instance with given info. Used by merge executors and
341/// receiver executors.
342pub(crate) async fn new_input(
343    local_barrier_manager: &LocalBarrierManager,
344    metrics: Arc<StreamingMetrics>,
345    actor_id: ActorId,
346    fragment_id: FragmentId,
347    upstream_actor_info: &ActorInfo,
348    upstream_fragment_id: FragmentId,
349) -> StreamExecutorResult<BoxedActorInput> {
350    let upstream_actor_id = upstream_actor_info.actor_id;
351    let upstream_addr = upstream_actor_info.get_host()?.into();
352
353    let input = if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) {
354        LocalInput::new(
355            local_barrier_manager.register_local_upstream_output(actor_id, upstream_actor_id),
356            upstream_actor_id,
357        )
358        .boxed_input()
359    } else {
360        RemoteInput::new(
361            local_barrier_manager,
362            upstream_addr,
363            (upstream_actor_id, actor_id),
364            (upstream_fragment_id, fragment_id),
365            metrics,
366        )
367        .await?
368        .boxed_input()
369    };
370
371    Ok(input)
372}
373
374impl DispatcherMessageBatch {
375    fn into_messages(self) -> Either<impl Iterator<Item = DispatcherMessage>, DispatcherMessage> {
376        match self {
377            DispatcherMessageBatch::BarrierBatch(barriers) => {
378                Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier))
379            }
380            DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)),
381            DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)),
382        }
383    }
384}