risingwave_stream/executor/exchange/
input.rs1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use anyhow::anyhow;
19use either::Either;
20use local_input::LocalInputStreamInner;
21use pin_project::pin_project;
22use risingwave_common::util::addr::{HostAddr, is_local_address};
23use tokio::sync::mpsc;
24
25use super::permit::Receiver;
26use crate::executor::prelude::*;
27use crate::executor::{
28 BarrierInner, DispatcherBarrier, DispatcherMessage, DispatcherMessageBatch,
29 DispatcherMessageStreamItem,
30};
31use crate::task::{FragmentId, LocalBarrierManager, UpDownActorIds, UpDownFragmentIds};
32
33pub trait Input: Stream + Send {
36 type InputId;
37 fn id(&self) -> Self::InputId;
39
40 fn boxed_input(self) -> BoxedInput<Self::InputId, Self::Item>
41 where
42 Self: Sized + 'static,
43 {
44 Box::pin(self)
45 }
46}
47
48pub type BoxedInput<InputId, Item> = Pin<Box<dyn Input<InputId = InputId, Item = Item>>>;
49
50pub trait ActorInput = Input<Item = DispatcherMessageStreamItem, InputId = ActorId>;
54
55pub type BoxedActorInput = Pin<Box<dyn ActorInput>>;
56
57impl std::fmt::Debug for dyn ActorInput {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("Input")
60 .field("actor_id", &self.id())
61 .finish_non_exhaustive()
62 }
63}
64
65#[pin_project]
67pub struct LocalInput {
68 #[pin]
69 inner: LocalInputStreamInner,
70
71 actor_id: ActorId,
72}
73
74pub(crate) fn assert_equal_dispatcher_barrier<M1, M2>(
75 first: &BarrierInner<M1>,
76 second: &BarrierInner<M2>,
77) {
78 assert_eq!(first.epoch, second.epoch);
79 assert_eq!(first.kind, second.kind);
80}
81
82pub(crate) fn apply_dispatcher_barrier(
83 recv_barrier: &mut Barrier,
84 dispatcher_barrier: DispatcherBarrier,
85) {
86 assert_equal_dispatcher_barrier(recv_barrier, &dispatcher_barrier);
87 recv_barrier
88 .passed_actors
89 .extend(dispatcher_barrier.passed_actors);
90}
91
92pub(crate) async fn process_dispatcher_msg(
93 dispatcher_msg: DispatcherMessage,
94 barrier_rx: &mut mpsc::UnboundedReceiver<Barrier>,
95) -> StreamExecutorResult<Message> {
96 let msg = match dispatcher_msg {
97 DispatcherMessage::Chunk(chunk) => Message::Chunk(chunk),
98 DispatcherMessage::Barrier(barrier) => {
99 let mut recv_barrier = barrier_rx
100 .recv()
101 .await
102 .ok_or_else(|| anyhow!("end of barrier recv"))?;
103 apply_dispatcher_barrier(&mut recv_barrier, barrier);
104 Message::Barrier(recv_barrier)
105 }
106 DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
107 };
108 Ok(msg)
109}
110
111impl LocalInput {
112 pub fn new(channel: Receiver, upstream_actor_id: ActorId) -> Self {
113 Self {
114 inner: local_input::run(channel, upstream_actor_id),
115 actor_id: upstream_actor_id,
116 }
117 }
118}
119
120mod local_input {
121 use await_tree::InstrumentAwait;
122 use either::Either;
123
124 use crate::executor::exchange::error::ExchangeChannelClosed;
125 use crate::executor::exchange::permit::Receiver;
126 use crate::executor::prelude::try_stream;
127 use crate::executor::{DispatcherMessage, StreamExecutorError};
128 use crate::task::ActorId;
129
130 pub(super) type LocalInputStreamInner = impl crate::executor::DispatcherMessageStream;
131
132 #[define_opaque(LocalInputStreamInner)]
133 pub(super) fn run(channel: Receiver, upstream_actor_id: ActorId) -> LocalInputStreamInner {
134 run_inner(channel, upstream_actor_id)
135 }
136
137 #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
138 async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) {
139 let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose();
140 while let Some(msg) = channel.recv().instrument_await(span.clone()).await {
141 match msg.into_messages() {
142 Either::Left(barriers) => {
143 for b in barriers {
144 yield b;
145 }
146 }
147 Either::Right(m) => {
148 yield m;
149 }
150 }
151 }
152 Err(ExchangeChannelClosed::local_input(upstream_actor_id))?
155 }
156}
157
158impl Stream for LocalInput {
159 type Item = DispatcherMessageStreamItem;
160
161 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
162 self.project().inner.poll_next(cx)
164 }
165}
166
167impl Input for LocalInput {
168 type InputId = ActorId;
169
170 fn id(&self) -> Self::InputId {
171 self.actor_id
172 }
173}
174
175#[pin_project]
177pub struct RemoteInput {
178 #[pin]
179 inner: RemoteInputStreamInner,
180
181 actor_id: ActorId,
182}
183
184use remote_input::RemoteInputStreamInner;
185use risingwave_pb::common::ActorInfo;
186
187impl RemoteInput {
188 pub async fn new(
191 local_barrier_manager: &LocalBarrierManager,
192 upstream_addr: HostAddr,
193 up_down_ids: UpDownActorIds,
194 up_down_frag: UpDownFragmentIds,
195 metrics: Arc<StreamingMetrics>,
196 ) -> StreamExecutorResult<Self> {
197 let actor_id = up_down_ids.0;
198
199 let client = local_barrier_manager
200 .env
201 .client_pool()
202 .get_by_addr(upstream_addr)
203 .await?;
204 let (stream, permits_tx) = client
205 .get_stream(
206 up_down_ids.0,
207 up_down_ids.1,
208 up_down_frag.0,
209 up_down_frag.1,
210 local_barrier_manager.database_id,
211 local_barrier_manager.term_id.clone(),
212 )
213 .await?;
214
215 Ok(Self {
216 actor_id,
217 inner: remote_input::run(
218 stream,
219 permits_tx,
220 up_down_ids,
221 up_down_frag,
222 metrics,
223 local_barrier_manager
224 .env
225 .config()
226 .developer
227 .exchange_batched_permits,
228 ),
229 })
230 }
231}
232
233mod remote_input {
234 use std::sync::Arc;
235
236 use anyhow::Context;
237 use await_tree::InstrumentAwait;
238 use either::Either;
239 use risingwave_pb::task_service::{GetStreamResponse, permits};
240 use tokio::sync::mpsc;
241 use tonic::Streaming;
242
243 use crate::executor::exchange::error::ExchangeChannelClosed;
244 use crate::executor::monitor::StreamingMetrics;
245 use crate::executor::prelude::{StreamExt, pin_mut, try_stream};
246 use crate::executor::{DispatcherMessage, StreamExecutorError};
247 use crate::task::{UpDownActorIds, UpDownFragmentIds};
248
249 pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream;
250
251 #[define_opaque(RemoteInputStreamInner)]
252 pub(super) fn run(
253 stream: Streaming<GetStreamResponse>,
254 permits_tx: mpsc::UnboundedSender<permits::Value>,
255 up_down_ids: UpDownActorIds,
256 up_down_frag: UpDownFragmentIds,
257 metrics: Arc<StreamingMetrics>,
258 batched_permits_limit: usize,
259 ) -> RemoteInputStreamInner {
260 run_inner(
261 stream,
262 permits_tx,
263 up_down_ids,
264 up_down_frag,
265 metrics,
266 batched_permits_limit,
267 )
268 }
269
270 #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
271 async fn run_inner(
272 stream: Streaming<GetStreamResponse>,
273 permits_tx: mpsc::UnboundedSender<permits::Value>,
274 up_down_ids: UpDownActorIds,
275 up_down_frag: UpDownFragmentIds,
276 metrics: Arc<StreamingMetrics>,
277 batched_permits_limit: usize,
278 ) {
279 let up_actor_id = up_down_ids.0.to_string();
280 let up_fragment_id = up_down_frag.0.to_string();
281 let down_fragment_id = up_down_frag.1.to_string();
282 let exchange_frag_recv_size_metrics = metrics
283 .exchange_frag_recv_size
284 .with_guarded_label_values(&[&up_fragment_id, &down_fragment_id]);
285
286 let span = await_tree::span!("RemoteInput (actor {up_actor_id})").verbose();
287
288 let mut batched_permits_accumulated = 0;
289
290 pin_mut!(stream);
291 while let Some(data_res) = stream.next().instrument_await(span.clone()).await {
292 match data_res {
293 Ok(GetStreamResponse { message, permits }) => {
294 use crate::executor::DispatcherMessageBatch;
295 let msg = message.unwrap();
296 let bytes = DispatcherMessageBatch::get_encoded_len(&msg);
297
298 exchange_frag_recv_size_metrics.inc_by(bytes as u64);
299
300 let msg_res = DispatcherMessageBatch::from_protobuf(&msg);
301 if let Some(add_back_permits) = match permits.unwrap().value {
302 Some(permits::Value::Record(p)) => {
305 batched_permits_accumulated += p;
306 if batched_permits_accumulated >= batched_permits_limit as u32 {
307 let permits = std::mem::take(&mut batched_permits_accumulated);
308 Some(permits::Value::Record(permits))
309 } else {
310 None
311 }
312 }
313 Some(permits::Value::Barrier(p)) => Some(permits::Value::Barrier(p)),
315 None => None,
316 } {
317 permits_tx
318 .send(add_back_permits)
319 .context("RemoteInput backward permits channel closed.")?;
320 }
321
322 let msg = msg_res.context("RemoteInput decode message error")?;
323 match msg.into_messages() {
324 Either::Left(barriers) => {
325 for b in barriers {
326 yield b;
327 }
328 }
329 Either::Right(m) => {
330 yield m;
331 }
332 }
333 }
334
335 Err(e) => Err(ExchangeChannelClosed::remote_input(up_down_ids.0, Some(e)))?,
336 }
337 }
338
339 Err(ExchangeChannelClosed::remote_input(up_down_ids.0, None))?
342 }
343}
344
345impl Stream for RemoteInput {
346 type Item = DispatcherMessageStreamItem;
347
348 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349 self.project().inner.poll_next(cx)
350 }
351}
352
353impl Input for RemoteInput {
354 type InputId = ActorId;
355
356 fn id(&self) -> Self::InputId {
357 self.actor_id
358 }
359}
360
361pub(crate) async fn new_input(
364 local_barrier_manager: &LocalBarrierManager,
365 metrics: Arc<StreamingMetrics>,
366 actor_id: ActorId,
367 fragment_id: FragmentId,
368 upstream_actor_info: &ActorInfo,
369 upstream_fragment_id: FragmentId,
370) -> StreamExecutorResult<BoxedActorInput> {
371 let upstream_actor_id = upstream_actor_info.actor_id;
372 let upstream_addr = upstream_actor_info.get_host()?.into();
373
374 let input = if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) {
375 LocalInput::new(
376 local_barrier_manager.register_local_upstream_output(actor_id, upstream_actor_id),
377 upstream_actor_id,
378 )
379 .boxed_input()
380 } else {
381 RemoteInput::new(
382 local_barrier_manager,
383 upstream_addr,
384 (upstream_actor_id, actor_id),
385 (upstream_fragment_id, fragment_id),
386 metrics,
387 )
388 .await?
389 .boxed_input()
390 };
391
392 Ok(input)
393}
394
395impl DispatcherMessageBatch {
396 fn into_messages(self) -> Either<impl Iterator<Item = DispatcherMessage>, DispatcherMessage> {
397 match self {
398 DispatcherMessageBatch::BarrierBatch(barriers) => {
399 Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier))
400 }
401 DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)),
402 DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)),
403 }
404 }
405}