risingwave_stream/executor/exchange/
input.rs1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use anyhow::anyhow;
19use either::Either;
20use local_input::LocalInputStreamInner;
21use pin_project::pin_project;
22use risingwave_common::util::addr::{HostAddr, is_local_address};
23use risingwave_rpc_client::ComputeClientPool;
24use tokio::sync::mpsc;
25
26use super::permit::Receiver;
27use crate::executor::prelude::*;
28use crate::executor::{
29 BarrierInner, DispatcherBarrier, DispatcherMessage, DispatcherMessageBatch,
30 DispatcherMessageStream, DispatcherMessageStreamItem,
31};
32use crate::task::{FragmentId, SharedContext, UpDownActorIds, UpDownFragmentIds};
33
34pub trait Input: DispatcherMessageStream {
37 fn actor_id(&self) -> ActorId;
39
40 fn boxed_input(self) -> BoxedInput
41 where
42 Self: Sized + 'static,
43 {
44 Box::pin(self)
45 }
46}
47
48pub type BoxedInput = Pin<Box<dyn Input>>;
49
50impl std::fmt::Debug for dyn Input {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("Input")
53 .field("actor_id", &self.actor_id())
54 .finish_non_exhaustive()
55 }
56}
57
58#[pin_project]
60pub struct LocalInput {
61 #[pin]
62 inner: LocalInputStreamInner,
63
64 actor_id: ActorId,
65}
66
67pub(crate) fn assert_equal_dispatcher_barrier<M1, M2>(
68 first: &BarrierInner<M1>,
69 second: &BarrierInner<M2>,
70) {
71 assert_eq!(first.epoch, second.epoch);
72 assert_eq!(first.kind, second.kind);
73}
74
75pub(crate) fn apply_dispatcher_barrier(
76 recv_barrier: &mut Barrier,
77 dispatcher_barrier: DispatcherBarrier,
78) {
79 assert_equal_dispatcher_barrier(recv_barrier, &dispatcher_barrier);
80 recv_barrier
81 .passed_actors
82 .extend(dispatcher_barrier.passed_actors);
83}
84
85pub(crate) async fn process_dispatcher_msg(
86 dispatcher_msg: DispatcherMessage,
87 barrier_rx: &mut mpsc::UnboundedReceiver<Barrier>,
88) -> StreamExecutorResult<Message> {
89 let msg = match dispatcher_msg {
90 DispatcherMessage::Chunk(chunk) => Message::Chunk(chunk),
91 DispatcherMessage::Barrier(barrier) => {
92 let mut recv_barrier = barrier_rx
93 .recv()
94 .await
95 .ok_or_else(|| anyhow!("end of barrier recv"))?;
96 apply_dispatcher_barrier(&mut recv_barrier, barrier);
97 Message::Barrier(recv_barrier)
98 }
99 DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
100 };
101 Ok(msg)
102}
103
104impl LocalInput {
105 pub fn new(channel: Receiver, upstream_actor_id: ActorId) -> Self {
106 Self {
107 inner: local_input::run(channel, upstream_actor_id),
108 actor_id: upstream_actor_id,
109 }
110 }
111}
112
113mod local_input {
114 use await_tree::InstrumentAwait;
115 use either::Either;
116
117 use crate::executor::exchange::error::ExchangeChannelClosed;
118 use crate::executor::exchange::permit::Receiver;
119 use crate::executor::prelude::try_stream;
120 use crate::executor::{DispatcherMessage, StreamExecutorError};
121 use crate::task::ActorId;
122
123 pub(super) type LocalInputStreamInner = impl crate::executor::DispatcherMessageStream;
124
125 pub(super) fn run(channel: Receiver, upstream_actor_id: ActorId) -> LocalInputStreamInner {
126 run_inner(channel, upstream_actor_id)
127 }
128
129 #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
130 async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) {
131 let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose();
132 while let Some(msg) = channel.recv().instrument_await(span.clone()).await {
133 match msg.into_messages() {
134 Either::Left(barriers) => {
135 for b in barriers {
136 yield b;
137 }
138 }
139 Either::Right(m) => {
140 yield m;
141 }
142 }
143 }
144 Err(ExchangeChannelClosed::local_input(upstream_actor_id))?
147 }
148}
149
150impl Stream for LocalInput {
151 type Item = DispatcherMessageStreamItem;
152
153 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154 self.project().inner.poll_next(cx)
156 }
157}
158
159impl Input for LocalInput {
160 fn actor_id(&self) -> ActorId {
161 self.actor_id
162 }
163}
164
165#[pin_project]
167pub struct RemoteInput {
168 #[pin]
169 inner: RemoteInputStreamInner,
170
171 actor_id: ActorId,
172}
173
174use remote_input::RemoteInputStreamInner;
175use risingwave_common::catalog::DatabaseId;
176
177impl RemoteInput {
178 #[expect(clippy::too_many_arguments)]
181 pub fn new(
182 client_pool: ComputeClientPool,
183 upstream_addr: HostAddr,
184 up_down_ids: UpDownActorIds,
185 up_down_frag: UpDownFragmentIds,
186 database_id: DatabaseId,
187 metrics: Arc<StreamingMetrics>,
188 batched_permits: usize,
189 term_id: String,
190 ) -> Self {
191 let actor_id = up_down_ids.0;
192
193 Self {
194 actor_id,
195 inner: remote_input::run(
196 client_pool,
197 upstream_addr,
198 up_down_ids,
199 up_down_frag,
200 database_id,
201 metrics,
202 batched_permits,
203 term_id,
204 ),
205 }
206 }
207}
208
209mod remote_input {
210 use std::sync::Arc;
211
212 use anyhow::Context;
213 use await_tree::InstrumentAwait;
214 use either::Either;
215 use risingwave_common::catalog::DatabaseId;
216 use risingwave_common::util::addr::HostAddr;
217 use risingwave_pb::task_service::{GetStreamResponse, permits};
218 use risingwave_rpc_client::ComputeClientPool;
219
220 use crate::executor::exchange::error::ExchangeChannelClosed;
221 use crate::executor::monitor::StreamingMetrics;
222 use crate::executor::prelude::{StreamExt, pin_mut, try_stream};
223 use crate::executor::{DispatcherMessage, StreamExecutorError};
224 use crate::task::{UpDownActorIds, UpDownFragmentIds};
225
226 pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream;
227
228 #[expect(clippy::too_many_arguments)]
229 pub(super) fn run(
230 client_pool: ComputeClientPool,
231 upstream_addr: HostAddr,
232 up_down_ids: UpDownActorIds,
233 up_down_frag: UpDownFragmentIds,
234 database_id: DatabaseId,
235 metrics: Arc<StreamingMetrics>,
236 batched_permits_limit: usize,
237 term_id: String,
238 ) -> RemoteInputStreamInner {
239 run_inner(
240 client_pool,
241 upstream_addr,
242 up_down_ids,
243 up_down_frag,
244 database_id,
245 metrics,
246 batched_permits_limit,
247 term_id,
248 )
249 }
250
251 #[expect(clippy::too_many_arguments)]
252 #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
253 async fn run_inner(
254 client_pool: ComputeClientPool,
255 upstream_addr: HostAddr,
256 up_down_ids: UpDownActorIds,
257 up_down_frag: UpDownFragmentIds,
258 database_id: DatabaseId,
259 metrics: Arc<StreamingMetrics>,
260 batched_permits_limit: usize,
261 term_id: String,
262 ) {
263 let client = client_pool.get_by_addr(upstream_addr).await?;
264 let (stream, permits_tx) = client
265 .get_stream(
266 up_down_ids.0,
267 up_down_ids.1,
268 up_down_frag.0,
269 up_down_frag.1,
270 database_id,
271 term_id,
272 )
273 .await?;
274
275 let up_actor_id = up_down_ids.0.to_string();
276 let up_fragment_id = up_down_frag.0.to_string();
277 let down_fragment_id = up_down_frag.1.to_string();
278 let exchange_frag_recv_size_metrics = metrics
279 .exchange_frag_recv_size
280 .with_guarded_label_values(&[&up_fragment_id, &down_fragment_id]);
281
282 let span = await_tree::span!("RemoteInput (actor {up_actor_id})").verbose();
283
284 let mut batched_permits_accumulated = 0;
285
286 pin_mut!(stream);
287 while let Some(data_res) = stream.next().instrument_await(span.clone()).await {
288 match data_res {
289 Ok(GetStreamResponse { message, permits }) => {
290 use crate::executor::DispatcherMessageBatch;
291 let msg = message.unwrap();
292 let bytes = DispatcherMessageBatch::get_encoded_len(&msg);
293
294 exchange_frag_recv_size_metrics.inc_by(bytes as u64);
295
296 let msg_res = DispatcherMessageBatch::from_protobuf(&msg);
297 if let Some(add_back_permits) = match permits.unwrap().value {
298 Some(permits::Value::Record(p)) => {
301 batched_permits_accumulated += p;
302 if batched_permits_accumulated >= batched_permits_limit as u32 {
303 let permits = std::mem::take(&mut batched_permits_accumulated);
304 Some(permits::Value::Record(permits))
305 } else {
306 None
307 }
308 }
309 Some(permits::Value::Barrier(p)) => Some(permits::Value::Barrier(p)),
311 None => None,
312 } {
313 permits_tx
314 .send(add_back_permits)
315 .context("RemoteInput backward permits channel closed.")?;
316 }
317
318 let msg = msg_res.context("RemoteInput decode message error")?;
319 match msg.into_messages() {
320 Either::Left(barriers) => {
321 for b in barriers {
322 yield b;
323 }
324 }
325 Either::Right(m) => {
326 yield m;
327 }
328 }
329 }
330
331 Err(e) => Err(ExchangeChannelClosed::remote_input(up_down_ids.0, Some(e)))?,
332 }
333 }
334
335 Err(ExchangeChannelClosed::remote_input(up_down_ids.0, None))?
338 }
339}
340
341impl Stream for RemoteInput {
342 type Item = DispatcherMessageStreamItem;
343
344 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
345 self.project().inner.poll_next(cx)
346 }
347}
348
349impl Input for RemoteInput {
350 fn actor_id(&self) -> ActorId {
351 self.actor_id
352 }
353}
354
355pub(crate) fn new_input(
358 context: &SharedContext,
359 metrics: Arc<StreamingMetrics>,
360 actor_id: ActorId,
361 fragment_id: FragmentId,
362 upstream_actor_id: ActorId,
363 upstream_fragment_id: FragmentId,
364) -> StreamResult<BoxedInput> {
365 let upstream_addr = context
366 .get_actor_info(&upstream_actor_id)?
367 .get_host()?
368 .into();
369
370 let input = if is_local_address(&context.addr, &upstream_addr) {
371 LocalInput::new(
372 context.take_receiver((upstream_actor_id, actor_id))?,
373 upstream_actor_id,
374 )
375 .boxed_input()
376 } else {
377 RemoteInput::new(
378 context.compute_client_pool.as_ref().to_owned(),
379 upstream_addr,
380 (upstream_actor_id, actor_id),
381 (upstream_fragment_id, fragment_id),
382 context.database_id,
383 metrics,
384 context.config.developer.exchange_batched_permits,
385 context.term_id(),
386 )
387 .boxed_input()
388 };
389
390 Ok(input)
391}
392
393impl DispatcherMessageBatch {
394 fn into_messages(self) -> Either<impl Iterator<Item = DispatcherMessage>, DispatcherMessage> {
395 match self {
396 DispatcherMessageBatch::BarrierBatch(barriers) => {
397 Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier))
398 }
399 DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)),
400 DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)),
401 }
402 }
403}