risingwave_stream/executor/
merge.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, VecDeque};
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use anyhow::Context as _;
20use futures::stream::{FusedStream, FuturesUnordered, StreamFuture};
21use prometheus::Histogram;
22use risingwave_common::array::StreamChunkBuilder;
23use risingwave_common::config::MetricLevel;
24use risingwave_common::metrics::LabelGuardedMetric;
25use tokio::sync::mpsc;
26use tokio::time::Instant;
27
28use super::exchange::input::BoxedInput;
29use super::watermark::*;
30use super::*;
31use crate::executor::exchange::input::{
32    assert_equal_dispatcher_barrier, new_input, process_dispatcher_msg,
33};
34use crate::executor::prelude::*;
35use crate::task::SharedContext;
36
37pub(crate) enum MergeExecutorUpstream {
38    Singleton(BoxedInput),
39    Merge(SelectReceivers),
40}
41
42pub(crate) struct MergeExecutorInput {
43    upstream: MergeExecutorUpstream,
44    actor_context: ActorContextRef,
45    upstream_fragment_id: UpstreamFragmentId,
46    shared_context: Arc<SharedContext>,
47    executor_stats: Arc<StreamingMetrics>,
48    pub(crate) info: ExecutorInfo,
49    chunk_size: usize,
50}
51
52impl MergeExecutorInput {
53    pub(crate) fn new(
54        upstream: MergeExecutorUpstream,
55        actor_context: ActorContextRef,
56        upstream_fragment_id: UpstreamFragmentId,
57        shared_context: Arc<SharedContext>,
58        executor_stats: Arc<StreamingMetrics>,
59        info: ExecutorInfo,
60        chunk_size: usize,
61    ) -> Self {
62        Self {
63            upstream,
64            actor_context,
65            upstream_fragment_id,
66            shared_context,
67            executor_stats,
68            info,
69            chunk_size,
70        }
71    }
72
73    pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
74        let fragment_id = self.actor_context.fragment_id;
75        let executor = match self.upstream {
76            MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
77                self.actor_context,
78                fragment_id,
79                self.upstream_fragment_id,
80                input,
81                self.shared_context,
82                self.executor_stats,
83                barrier_rx,
84            )
85            .boxed(),
86            MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
87                self.actor_context,
88                fragment_id,
89                self.upstream_fragment_id,
90                inputs,
91                self.shared_context,
92                self.executor_stats,
93                barrier_rx,
94                self.chunk_size,
95                self.info.schema.clone(),
96            )
97            .boxed(),
98        };
99        (self.info, executor).into()
100    }
101}
102
103impl Stream for MergeExecutorInput {
104    type Item = DispatcherMessageStreamItem;
105
106    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
107        match &mut self.get_mut().upstream {
108            MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
109            MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
110        }
111    }
112}
113
114/// `MergeExecutor` merges data from multiple channels. Dataflow from one channel
115/// will be stopped on barrier.
116pub struct MergeExecutor {
117    /// The context of the actor.
118    actor_context: ActorContextRef,
119
120    /// Upstream channels.
121    upstreams: SelectReceivers,
122
123    /// Belonged fragment id.
124    fragment_id: FragmentId,
125
126    /// Upstream fragment id.
127    upstream_fragment_id: FragmentId,
128
129    /// Shared context of the stream manager.
130    context: Arc<SharedContext>,
131
132    /// Streaming metrics.
133    metrics: Arc<StreamingMetrics>,
134
135    barrier_rx: mpsc::UnboundedReceiver<Barrier>,
136
137    /// Chunk size for the `StreamChunkBuilder`
138    chunk_size: usize,
139
140    /// Data types for the `StreamChunkBuilder`
141    schema: Schema,
142}
143
144impl MergeExecutor {
145    #[allow(clippy::too_many_arguments)]
146    pub fn new(
147        ctx: ActorContextRef,
148        fragment_id: FragmentId,
149        upstream_fragment_id: FragmentId,
150        upstreams: SelectReceivers,
151        context: Arc<SharedContext>,
152        metrics: Arc<StreamingMetrics>,
153        barrier_rx: mpsc::UnboundedReceiver<Barrier>,
154        chunk_size: usize,
155        schema: Schema,
156    ) -> Self {
157        Self {
158            actor_context: ctx,
159            upstreams,
160            fragment_id,
161            upstream_fragment_id,
162            context,
163            metrics,
164            barrier_rx,
165            chunk_size,
166            schema,
167        }
168    }
169
170    #[cfg(test)]
171    pub fn for_test(
172        actor_id: ActorId,
173        inputs: Vec<super::exchange::permit::Receiver>,
174        shared_context: Arc<SharedContext>,
175        local_barrier_manager: crate::task::LocalBarrierManager,
176        schema: Schema,
177    ) -> Self {
178        use super::exchange::input::LocalInput;
179        use crate::executor::exchange::input::Input;
180
181        let barrier_rx = local_barrier_manager.subscribe_barrier(actor_id);
182
183        let metrics = StreamingMetrics::unused();
184        let actor_ctx = ActorContext::for_test(actor_id);
185        let upstream = Self::new_select_receiver(
186            inputs
187                .into_iter()
188                .enumerate()
189                .map(|(idx, input)| LocalInput::new(input, idx as ActorId).boxed_input())
190                .collect(),
191            &metrics,
192            &actor_ctx,
193        );
194
195        Self::new(
196            actor_ctx,
197            514,
198            1919,
199            upstream,
200            shared_context,
201            metrics.into(),
202            barrier_rx,
203            100,
204            schema,
205        )
206    }
207
208    pub(crate) fn new_select_receiver(
209        upstreams: Vec<BoxedInput>,
210        metrics: &StreamingMetrics,
211        actor_context: &ActorContext,
212    ) -> SelectReceivers {
213        let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
214            Some(
215                metrics
216                    .merge_barrier_align_duration
217                    .with_guarded_label_values(&[
218                        &actor_context.id.to_string(),
219                        &actor_context.fragment_id.to_string(),
220                    ]),
221            )
222        } else {
223            None
224        };
225
226        // Futures of all active upstreams.
227        SelectReceivers::new(
228            actor_context.id,
229            upstreams,
230            merge_barrier_align_duration.clone(),
231        )
232    }
233
234    #[try_stream(ok = Message, error = StreamExecutorError)]
235    async fn execute_inner(mut self: Box<Self>) {
236        let select_all = self.upstreams;
237        let select_all = BufferChunks::new(select_all, self.chunk_size, self.schema);
238        let actor_id = self.actor_context.id;
239
240        let mut metrics = self.metrics.new_actor_input_metrics(
241            actor_id,
242            self.fragment_id,
243            self.upstream_fragment_id,
244        );
245
246        // Channels that're blocked by the barrier to align.
247        let mut start_time = Instant::now();
248        pin_mut!(select_all);
249        while let Some(msg) = select_all.next().await {
250            metrics
251                .actor_input_buffer_blocking_duration_ns
252                .inc_by(start_time.elapsed().as_nanos() as u64);
253            let msg: DispatcherMessage = msg?;
254            let mut msg: Message = process_dispatcher_msg(msg, &mut self.barrier_rx).await?;
255
256            match &mut msg {
257                Message::Watermark(_) => {
258                    // Do nothing.
259                }
260                Message::Chunk(chunk) => {
261                    metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
262                }
263                Message::Barrier(barrier) => {
264                    tracing::debug!(
265                        target: "events::stream::barrier::path",
266                        actor_id = actor_id,
267                        "receiver receives barrier from path: {:?}",
268                        barrier.passed_actors
269                    );
270                    barrier.passed_actors.push(actor_id);
271
272                    if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
273                        barrier.mutation.as_deref()
274                    {
275                        if select_all
276                            .upstream_actor_ids()
277                            .iter()
278                            .any(|actor_id| dispatchers.contains_key(actor_id))
279                        {
280                            // `Watermark` of upstream may become stale after downstream scaling.
281                            select_all
282                                .buffered_watermarks
283                                .values_mut()
284                                .for_each(|buffers| buffers.clear());
285                        }
286                    }
287
288                    if let Some(update) =
289                        barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
290                    {
291                        let new_upstream_fragment_id = update
292                            .new_upstream_fragment_id
293                            .unwrap_or(self.upstream_fragment_id);
294                        let added_upstream_actor_id = update.added_upstream_actor_id.clone();
295                        let removed_upstream_actor_id: HashSet<_> =
296                            if update.new_upstream_fragment_id.is_some() {
297                                select_all.upstream_actor_ids().iter().copied().collect()
298                            } else {
299                                update.removed_upstream_actor_id.iter().copied().collect()
300                            };
301
302                        // `Watermark` of upstream may become stale after upstream scaling.
303                        select_all
304                            .buffered_watermarks
305                            .values_mut()
306                            .for_each(|buffers| buffers.clear());
307
308                        if !added_upstream_actor_id.is_empty() {
309                            // Create new upstreams receivers.
310                            let new_upstreams: Vec<_> = added_upstream_actor_id
311                                .iter()
312                                .map(|&upstream_actor_id| {
313                                    new_input(
314                                        &self.context,
315                                        self.metrics.clone(),
316                                        self.actor_context.id,
317                                        self.fragment_id,
318                                        upstream_actor_id,
319                                        new_upstream_fragment_id,
320                                    )
321                                })
322                                .try_collect()
323                                .context("failed to create upstream receivers")?;
324
325                            // Poll the first barrier from the new upstreams. It must be the same as
326                            // the one we polled from original upstreams.
327                            let mut select_new = SelectReceivers::new(
328                                self.actor_context.id,
329                                new_upstreams,
330                                select_all.merge_barrier_align_duration(),
331                            );
332                            let new_barrier = expect_first_barrier(&mut select_new).await?;
333                            assert_equal_dispatcher_barrier(barrier, &new_barrier);
334
335                            // Add the new upstreams to select.
336                            select_all.add_upstreams_from(select_new);
337
338                            // Add buffers to the buffered watermarks for all cols
339                            select_all
340                                .buffered_watermarks
341                                .values_mut()
342                                .for_each(|buffers| {
343                                    buffers.add_buffers(added_upstream_actor_id.clone())
344                                });
345                        }
346
347                        if !removed_upstream_actor_id.is_empty() {
348                            // Remove upstreams.
349                            select_all.remove_upstreams(&removed_upstream_actor_id);
350
351                            for buffers in select_all.buffered_watermarks.values_mut() {
352                                // Call `check_heap` in case the only upstream(s) that does not have
353                                // watermark in heap is removed
354                                buffers.remove_buffer(removed_upstream_actor_id.clone());
355                            }
356                        }
357
358                        self.upstream_fragment_id = new_upstream_fragment_id;
359                        metrics = self.metrics.new_actor_input_metrics(
360                            actor_id,
361                            self.fragment_id,
362                            self.upstream_fragment_id,
363                        );
364
365                        select_all.update_actor_ids();
366                    }
367
368                    if barrier.is_stop(actor_id) {
369                        yield msg;
370                        break;
371                    }
372                }
373            }
374
375            yield msg;
376            start_time = Instant::now();
377        }
378    }
379}
380
381impl Execute for MergeExecutor {
382    fn execute(self: Box<Self>) -> BoxedMessageStream {
383        self.execute_inner().boxed()
384    }
385}
386
387/// A stream for merging messages from multiple upstreams.
388pub struct SelectReceivers {
389    /// The barrier we're aligning to. If this is `None`, then `blocked_upstreams` is empty.
390    barrier: Option<DispatcherBarrier>,
391    /// The upstreams that're blocked by the `barrier`.
392    blocked: Vec<BoxedInput>,
393    /// The upstreams that're not blocked and can be polled.
394    active: FuturesUnordered<StreamFuture<BoxedInput>>,
395    /// All upstream actor ids.
396    upstream_actor_ids: Vec<ActorId>,
397
398    /// The actor id of this fragment.
399    actor_id: u32,
400    /// watermark column index -> `BufferedWatermarks`
401    buffered_watermarks: BTreeMap<usize, BufferedWatermarks<ActorId>>,
402    /// If None, then we don't take `Instant::now()` and `observe` during `poll_next`
403    merge_barrier_align_duration: Option<LabelGuardedMetric<Histogram, 2>>,
404}
405
406impl Stream for SelectReceivers {
407    type Item = DispatcherMessageStreamItem;
408
409    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410        if self.active.is_terminated() {
411            // This only happens if we've been asked to stop.
412            assert!(self.blocked.is_empty());
413            return Poll::Ready(None);
414        }
415
416        let mut start = None;
417        loop {
418            match futures::ready!(self.active.poll_next_unpin(cx)) {
419                // Directly forward the error.
420                Some((Some(Err(e)), _)) => {
421                    return Poll::Ready(Some(Err(e)));
422                }
423                // Handle the message from some upstream.
424                Some((Some(Ok(message)), remaining)) => {
425                    let actor_id = remaining.actor_id();
426                    match message {
427                        DispatcherMessage::Chunk(chunk) => {
428                            // Continue polling this upstream by pushing it back to `active`.
429                            self.active.push(remaining.into_future());
430                            return Poll::Ready(Some(Ok(DispatcherMessage::Chunk(chunk))));
431                        }
432                        DispatcherMessage::Watermark(watermark) => {
433                            // Continue polling this upstream by pushing it back to `active`.
434                            self.active.push(remaining.into_future());
435                            if let Some(watermark) = self.handle_watermark(actor_id, watermark) {
436                                return Poll::Ready(Some(Ok(DispatcherMessage::Watermark(
437                                    watermark,
438                                ))));
439                            }
440                        }
441                        DispatcherMessage::Barrier(barrier) => {
442                            // Block this upstream by pushing it to `blocked`.
443                            if self.blocked.is_empty()
444                                && self.merge_barrier_align_duration.is_some()
445                            {
446                                start = Some(Instant::now());
447                            }
448                            self.blocked.push(remaining);
449                            if let Some(current_barrier) = self.barrier.as_ref() {
450                                if current_barrier.epoch != barrier.epoch {
451                                    return Poll::Ready(Some(Err(
452                                        StreamExecutorError::align_barrier(
453                                            current_barrier.clone().map_mutation(|_| None),
454                                            barrier.map_mutation(|_| None),
455                                        ),
456                                    )));
457                                }
458                            } else {
459                                self.barrier = Some(barrier);
460                            }
461                        }
462                    }
463                }
464                // We use barrier as the control message of the stream. That is, we always stop the
465                // actors actively when we receive a `Stop` mutation, instead of relying on the stream
466                // termination.
467                //
468                // Besides, in abnormal cases when the other side of the `Input` closes unexpectedly,
469                // we also yield an `Err(ExchangeChannelClosed)`, which will hit the `Err` arm above.
470                // So this branch will never be reached in all cases.
471                Some((None, _)) => unreachable!(),
472                // There's no active upstreams. Process the barrier and resume the blocked ones.
473                None => {
474                    if let Some(start) = start
475                        && let Some(merge_barrier_align_duration) =
476                            &self.merge_barrier_align_duration
477                    {
478                        // Observe did a few atomic operation inside, we want to avoid the overhead.
479                        merge_barrier_align_duration.observe(start.elapsed().as_secs_f64())
480                    }
481                    break;
482                }
483            }
484        }
485
486        assert!(self.active.is_terminated());
487        let barrier = self.barrier.take().unwrap();
488
489        let upstreams = std::mem::take(&mut self.blocked);
490        self.extend_active(upstreams);
491        assert!(!self.active.is_terminated());
492
493        Poll::Ready(Some(Ok(DispatcherMessage::Barrier(barrier))))
494    }
495}
496
497impl SelectReceivers {
498    fn new(
499        actor_id: u32,
500        upstreams: Vec<BoxedInput>,
501        merge_barrier_align_duration: Option<LabelGuardedMetric<Histogram, 2>>,
502    ) -> Self {
503        assert!(!upstreams.is_empty());
504        let upstream_actor_ids = upstreams.iter().map(|input| input.actor_id()).collect();
505        let mut this = Self {
506            blocked: Vec::with_capacity(upstreams.len()),
507            active: Default::default(),
508            actor_id,
509            barrier: None,
510            upstream_actor_ids,
511            buffered_watermarks: Default::default(),
512            merge_barrier_align_duration,
513        };
514        this.extend_active(upstreams);
515        this
516    }
517
518    /// Extend the active upstreams with the given upstreams. The current stream must be at the
519    /// clean state right after a barrier.
520    fn extend_active(&mut self, upstreams: impl IntoIterator<Item = BoxedInput>) {
521        assert!(self.blocked.is_empty() && self.barrier.is_none());
522
523        self.active
524            .extend(upstreams.into_iter().map(|s| s.into_future()));
525    }
526
527    fn upstream_actor_ids(&self) -> &[ActorId] {
528        &self.upstream_actor_ids
529    }
530
531    fn update_actor_ids(&mut self) {
532        self.upstream_actor_ids = self
533            .blocked
534            .iter()
535            .map(|input| input.actor_id())
536            .chain(
537                self.active
538                    .iter()
539                    .map(|input| input.get_ref().unwrap().actor_id()),
540            )
541            .collect();
542    }
543
544    /// Handle a new watermark message. Optionally returns the watermark message to emit.
545    fn handle_watermark(&mut self, actor_id: ActorId, watermark: Watermark) -> Option<Watermark> {
546        let col_idx = watermark.col_idx;
547        // Insert a buffer watermarks when first received from a column.
548        let watermarks = self
549            .buffered_watermarks
550            .entry(col_idx)
551            .or_insert_with(|| BufferedWatermarks::with_ids(self.upstream_actor_ids.clone()));
552        watermarks.handle_watermark(actor_id, watermark)
553    }
554
555    /// Consume `other` and add its upstreams to `self`. The two streams must be at the clean state
556    /// right after a barrier.
557    fn add_upstreams_from(&mut self, other: Self) {
558        assert!(self.blocked.is_empty() && self.barrier.is_none());
559        assert!(other.blocked.is_empty() && other.barrier.is_none());
560        assert_eq!(self.actor_id, other.actor_id);
561
562        self.active.extend(other.active);
563    }
564
565    /// Remove upstreams from `self` in `upstream_actor_ids`. The current stream must be at the
566    /// clean state right after a barrier.
567    fn remove_upstreams(&mut self, upstream_actor_ids: &HashSet<ActorId>) {
568        assert!(self.blocked.is_empty() && self.barrier.is_none());
569
570        let new_upstreams = std::mem::take(&mut self.active)
571            .into_iter()
572            .map(|s| s.into_inner().unwrap())
573            .filter(|u| !upstream_actor_ids.contains(&u.actor_id()));
574        self.extend_active(new_upstreams);
575    }
576
577    fn merge_barrier_align_duration(&self) -> Option<LabelGuardedMetric<Histogram, 2>> {
578        self.merge_barrier_align_duration.clone()
579    }
580}
581
582/// A wrapper that buffers the `StreamChunk`s from upstream until no more ready items are available.
583/// Besides, any message other than `StreamChunk` will trigger the buffered `StreamChunk`s
584/// to be emitted immediately along with the message itself.
585struct BufferChunks<S: Stream> {
586    inner: S,
587    chunk_builder: StreamChunkBuilder,
588
589    /// The items to be emitted. Whenever there's something here, we should return a `Poll::Ready` immediately.
590    pending_items: VecDeque<S::Item>,
591}
592
593impl<S: Stream> BufferChunks<S> {
594    pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
595        assert!(chunk_size > 0);
596        let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
597        Self {
598            inner,
599            chunk_builder,
600            pending_items: VecDeque::new(),
601        }
602    }
603}
604
605impl<S: Stream> std::ops::Deref for BufferChunks<S> {
606    type Target = S;
607
608    fn deref(&self) -> &Self::Target {
609        &self.inner
610    }
611}
612
613impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
614    fn deref_mut(&mut self) -> &mut Self::Target {
615        &mut self.inner
616    }
617}
618
619impl<S: Stream> Stream for BufferChunks<S>
620where
621    S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
622{
623    type Item = S::Item;
624
625    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
626        loop {
627            if let Some(item) = self.pending_items.pop_front() {
628                return Poll::Ready(Some(item));
629            }
630
631            match self.inner.poll_next_unpin(cx) {
632                Poll::Pending => {
633                    return if let Some(chunk_out) = self.chunk_builder.take() {
634                        Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
635                    } else {
636                        Poll::Pending
637                    };
638                }
639
640                Poll::Ready(Some(result)) => {
641                    if let Ok(MessageInner::Chunk(chunk)) = result {
642                        for row in chunk.records() {
643                            if let Some(chunk_out) = self.chunk_builder.append_record(row) {
644                                self.pending_items
645                                    .push_back(Ok(MessageInner::Chunk(chunk_out)));
646                            }
647                        }
648                    } else {
649                        return if let Some(chunk_out) = self.chunk_builder.take() {
650                            self.pending_items.push_back(result);
651                            Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
652                        } else {
653                            Poll::Ready(Some(result))
654                        };
655                    }
656                }
657
658                Poll::Ready(None) => {
659                    // See also the comments in `SelectReceivers::poll_next`.
660                    unreachable!("SelectReceivers should never return None");
661                }
662            }
663        }
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use std::sync::atomic::{AtomicBool, Ordering};
670    use std::time::Duration;
671
672    use assert_matches::assert_matches;
673    use futures::FutureExt;
674    use risingwave_common::array::Op;
675    use risingwave_common::util::epoch::test_epoch;
676    use risingwave_pb::task_service::exchange_service_server::{
677        ExchangeService, ExchangeServiceServer,
678    };
679    use risingwave_pb::task_service::{
680        GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
681    };
682    use risingwave_rpc_client::ComputeClientPool;
683    use tokio::time::sleep;
684    use tokio_stream::wrappers::ReceiverStream;
685    use tonic::{Request, Response, Status, Streaming};
686
687    use super::*;
688    use crate::executor::exchange::input::{Input, LocalInput, RemoteInput};
689    use crate::executor::exchange::permit::channel_for_test;
690    use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
691    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
692    use crate::task::test_utils::helper_make_local_actor;
693
694    fn build_test_chunk(size: u64) -> StreamChunk {
695        let ops = vec![Op::Insert; size as usize];
696        StreamChunk::new(ops, vec![])
697    }
698
699    #[tokio::test]
700    async fn test_buffer_chunks() {
701        let test_env = LocalBarrierTestEnv::for_test().await;
702
703        let (tx, rx) = channel_for_test();
704        let input = LocalInput::new(rx, 1).boxed_input();
705        let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
706
707        // Send a chunk
708        tx.send(Message::Chunk(build_test_chunk(10)).into())
709            .await
710            .unwrap();
711        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
712            assert_eq!(chunk.ops().len() as u64, 10);
713        });
714
715        // Send 2 chunks and expect them to be merged.
716        tx.send(Message::Chunk(build_test_chunk(10)).into())
717            .await
718            .unwrap();
719        tx.send(Message::Chunk(build_test_chunk(10)).into())
720            .await
721            .unwrap();
722        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
723            assert_eq!(chunk.ops().len() as u64, 20);
724        });
725
726        // Send a watermark.
727        tx.send(
728            Message::Watermark(Watermark {
729                col_idx: 0,
730                data_type: DataType::Int64,
731                val: ScalarImpl::Int64(233),
732            })
733            .into(),
734        )
735        .await
736        .unwrap();
737        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
738            assert_eq!(watermark.val, ScalarImpl::Int64(233));
739        });
740
741        // Send 2 chunks before a watermark. Expect the 2 chunks to be merged and the watermark to be emitted.
742        tx.send(Message::Chunk(build_test_chunk(10)).into())
743            .await
744            .unwrap();
745        tx.send(Message::Chunk(build_test_chunk(10)).into())
746            .await
747            .unwrap();
748        tx.send(
749            Message::Watermark(Watermark {
750                col_idx: 0,
751                data_type: DataType::Int64,
752                val: ScalarImpl::Int64(233),
753            })
754            .into(),
755        )
756        .await
757        .unwrap();
758        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
759            assert_eq!(chunk.ops().len() as u64, 20);
760        });
761        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
762            assert_eq!(watermark.val, ScalarImpl::Int64(233));
763        });
764
765        // Send a barrier.
766        let barrier = Barrier::new_test_barrier(test_epoch(1));
767        test_env.inject_barrier(&barrier, [2]);
768        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
769            .await
770            .unwrap();
771        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
772            assert_eq!(barrier_epoch.curr, test_epoch(1));
773        });
774
775        // Send 2 chunks before a barrier. Expect the 2 chunks to be merged and the barrier to be emitted.
776        tx.send(Message::Chunk(build_test_chunk(10)).into())
777            .await
778            .unwrap();
779        tx.send(Message::Chunk(build_test_chunk(10)).into())
780            .await
781            .unwrap();
782        let barrier = Barrier::new_test_barrier(test_epoch(2));
783        test_env.inject_barrier(&barrier, [2]);
784        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
785            .await
786            .unwrap();
787        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
788            assert_eq!(chunk.ops().len() as u64, 20);
789        });
790        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
791            assert_eq!(barrier_epoch.curr, test_epoch(2));
792        });
793    }
794
795    #[tokio::test]
796    async fn test_merger() {
797        const CHANNEL_NUMBER: usize = 10;
798        let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
799        let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
800        for _i in 0..CHANNEL_NUMBER {
801            let (tx, rx) = channel_for_test();
802            txs.push(tx);
803            rxs.push(rx);
804        }
805        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
806        let actor_id = 233;
807        let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
808
809        let epochs = (10..1000u64)
810            .step_by(10)
811            .map(|idx| (idx, test_epoch(idx)))
812            .collect_vec();
813        let mut prev_epoch = 0;
814        let prev_epoch = &mut prev_epoch;
815        let barriers: HashMap<_, _> = epochs
816            .iter()
817            .map(|(_, epoch)| {
818                let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
819                *prev_epoch = *epoch;
820                barrier_test_env.inject_barrier(&barrier, [actor_id]);
821                (*epoch, barrier)
822            })
823            .collect();
824        let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
825            .with_mutation(Mutation::Stop(HashSet::default()));
826        barrier_test_env.inject_barrier(&b2, [actor_id]);
827        barrier_test_env.flush_all_events().await;
828
829        for (tx_id, tx) in txs.into_iter().enumerate() {
830            let epochs = epochs.clone();
831            let barriers = barriers.clone();
832            let b2 = b2.clone();
833            let handle = tokio::spawn(async move {
834                for (idx, epoch) in epochs {
835                    if idx % 20 == 0 {
836                        tx.send(Message::Chunk(build_test_chunk(10)).into())
837                            .await
838                            .unwrap();
839                    } else {
840                        tx.send(
841                            Message::Watermark(Watermark {
842                                col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
843                                data_type: DataType::Int64,
844                                val: ScalarImpl::Int64(idx as i64),
845                            })
846                            .into(),
847                        )
848                        .await
849                        .unwrap();
850                    }
851                    tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
852                        .await
853                        .unwrap();
854                    sleep(Duration::from_millis(1)).await;
855                }
856                tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
857                    .await
858                    .unwrap();
859            });
860            handles.push(handle);
861        }
862
863        let merger = MergeExecutor::for_test(
864            actor_id,
865            rxs,
866            barrier_test_env.shared_context.clone(),
867            barrier_test_env.local_barrier_manager.clone(),
868            Schema::new(vec![]),
869        );
870        let mut merger = merger.boxed().execute();
871        for (idx, epoch) in epochs {
872            if idx % 20 == 0 {
873                // expect 1 or more chunks with 100 rows in total
874                let mut count = 0usize;
875                while count < 100 {
876                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
877                        count += chunk.ops().len();
878                    });
879                }
880                assert_eq!(count, 100);
881            } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
882                // expect n watermarks
883                for _ in 0..CHANNEL_NUMBER {
884                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
885                        assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
886                    });
887                }
888            }
889            // expect a barrier
890            assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
891                assert_eq!(barrier_epoch.curr, epoch);
892            });
893        }
894        assert_matches!(
895            merger.next().await.unwrap().unwrap(),
896            Message::Barrier(Barrier {
897                mutation,
898                ..
899            }) if mutation.as_deref().unwrap().is_stop()
900        );
901
902        for handle in handles {
903            handle.await.unwrap();
904        }
905    }
906
907    #[tokio::test]
908    async fn test_configuration_change() {
909        let actor_id = 233;
910        let (untouched, old, new) = (234, 235, 238); // upstream actors
911        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
912        let ctx = barrier_test_env.shared_context.clone();
913        let metrics = Arc::new(StreamingMetrics::unused());
914
915        // 1. Register info in context.
916        ctx.add_actors(
917            [actor_id, untouched, old, new]
918                .into_iter()
919                .map(helper_make_local_actor),
920        );
921        // untouched -> actor_id
922        // old -> actor_id
923        // new -> actor_id
924
925        let (upstream_fragment_id, fragment_id) = (10, 18);
926
927        let inputs: Vec<_> = [untouched, old]
928            .into_iter()
929            .map(|upstream_actor_id| {
930                new_input(
931                    &ctx,
932                    metrics.clone(),
933                    actor_id,
934                    fragment_id,
935                    upstream_actor_id,
936                    upstream_fragment_id,
937                )
938            })
939            .try_collect()
940            .unwrap();
941
942        let merge_updates = maplit::hashmap! {
943            (actor_id, upstream_fragment_id) => MergeUpdate {
944                actor_id,
945                upstream_fragment_id,
946                new_upstream_fragment_id: None,
947                added_upstream_actor_id: vec![new],
948                removed_upstream_actor_id: vec![old],
949            }
950        };
951
952        let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
953            UpdateMutation {
954                dispatchers: Default::default(),
955                merges: merge_updates,
956                vnode_bitmaps: Default::default(),
957                dropped_actors: Default::default(),
958                actor_splits: Default::default(),
959                actor_new_dispatchers: Default::default(),
960            },
961        ));
962        barrier_test_env.inject_barrier(&b1, [actor_id]);
963        barrier_test_env.flush_all_events().await;
964
965        let barrier_rx = barrier_test_env
966            .local_barrier_manager
967            .subscribe_barrier(actor_id);
968        let actor_ctx = ActorContext::for_test(actor_id);
969        let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx);
970
971        let mut merge = MergeExecutor::new(
972            actor_ctx,
973            fragment_id,
974            upstream_fragment_id,
975            upstream,
976            ctx.clone(),
977            metrics.clone(),
978            barrier_rx,
979            100,
980            Schema::new(vec![]),
981        )
982        .boxed()
983        .execute();
984
985        // 2. Take downstream receivers.
986        let txs = [untouched, old, new]
987            .into_iter()
988            .map(|id| (id, ctx.take_sender(&(id, actor_id)).unwrap()))
989            .collect::<HashMap<_, _>>();
990        macro_rules! send {
991            ($actors:expr, $msg:expr) => {
992                for actor in $actors {
993                    txs.get(&actor).unwrap().send($msg).await.unwrap();
994                }
995            };
996        }
997
998        macro_rules! assert_recv_pending {
999            () => {
1000                assert!(
1001                    merge
1002                        .next()
1003                        .now_or_never()
1004                        .flatten()
1005                        .transpose()
1006                        .unwrap()
1007                        .is_none()
1008                );
1009            };
1010        }
1011        macro_rules! recv {
1012            () => {
1013                merge.next().await.transpose().unwrap()
1014            };
1015        }
1016
1017        // 3. Send a chunk.
1018        send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
1019        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
1020        assert_recv_pending!();
1021
1022        send!(
1023            [untouched, old],
1024            Message::Barrier(b1.clone().into_dispatcher()).into()
1025        );
1026        assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new.
1027
1028        send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
1029        recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier.
1030
1031        // 5. Send a chunk.
1032        send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
1033        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
1034        assert_recv_pending!();
1035    }
1036
1037    struct FakeExchangeService {
1038        rpc_called: Arc<AtomicBool>,
1039    }
1040
1041    fn exchange_client_test_barrier() -> crate::executor::Barrier {
1042        Barrier::new_test_barrier(test_epoch(1))
1043    }
1044
1045    #[async_trait::async_trait]
1046    impl ExchangeService for FakeExchangeService {
1047        type GetDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
1048        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
1049
1050        async fn get_data(
1051            &self,
1052            _: Request<GetDataRequest>,
1053        ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
1054            unimplemented!()
1055        }
1056
1057        async fn get_stream(
1058            &self,
1059            _request: Request<Streaming<GetStreamRequest>>,
1060        ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
1061            let (tx, rx) = tokio::sync::mpsc::channel(10);
1062            self.rpc_called.store(true, Ordering::SeqCst);
1063            // send stream_chunk
1064            let stream_chunk = StreamChunk::default().to_protobuf();
1065            tx.send(Ok(GetStreamResponse {
1066                message: Some(PbStreamMessageBatch {
1067                    stream_message_batch: Some(
1068                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
1069                            stream_chunk,
1070                        ),
1071                    ),
1072                }),
1073                permits: Some(PbPermits::default()),
1074            }))
1075            .await
1076            .unwrap();
1077            // send barrier
1078            let barrier = exchange_client_test_barrier();
1079            tx.send(Ok(GetStreamResponse {
1080                message: Some(PbStreamMessageBatch {
1081                    stream_message_batch: Some(
1082                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
1083                            BarrierBatch {
1084                                barriers: vec![barrier.to_protobuf()],
1085                            },
1086                        ),
1087                    ),
1088                }),
1089                permits: Some(PbPermits::default()),
1090            }))
1091            .await
1092            .unwrap();
1093            Ok(Response::new(ReceiverStream::new(rx)))
1094        }
1095    }
1096
1097    #[tokio::test]
1098    async fn test_stream_exchange_client() {
1099        const BATCHED_PERMITS: usize = 1024;
1100        let rpc_called = Arc::new(AtomicBool::new(false));
1101        let server_run = Arc::new(AtomicBool::new(false));
1102        let addr = "127.0.0.1:12348".parse().unwrap();
1103
1104        // Start a server.
1105        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
1106        let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
1107            rpc_called: rpc_called.clone(),
1108        });
1109        let cp_server_run = server_run.clone();
1110        let join_handle = tokio::spawn(async move {
1111            cp_server_run.store(true, Ordering::SeqCst);
1112            tonic::transport::Server::builder()
1113                .add_service(exchange_svc)
1114                .serve_with_shutdown(addr, async move {
1115                    shutdown_recv.await.unwrap();
1116                })
1117                .await
1118                .unwrap();
1119        });
1120
1121        sleep(Duration::from_secs(1)).await;
1122        assert!(server_run.load(Ordering::SeqCst));
1123
1124        let test_env = LocalBarrierTestEnv::for_test().await;
1125
1126        let remote_input = {
1127            let pool = ComputeClientPool::for_test();
1128            RemoteInput::new(
1129                pool,
1130                addr.into(),
1131                (0, 0),
1132                (0, 0),
1133                test_env.shared_context.database_id,
1134                Arc::new(StreamingMetrics::unused()),
1135                BATCHED_PERMITS,
1136                "for_test".into(),
1137            )
1138        };
1139
1140        test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.actor_id()]);
1141
1142        pin_mut!(remote_input);
1143
1144        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
1145            let (ops, columns, visibility) = chunk.into_inner();
1146            assert!(ops.is_empty());
1147            assert!(columns.is_empty());
1148            assert!(visibility.is_empty());
1149        });
1150        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
1151            assert_eq!(barrier_epoch.curr, test_epoch(1));
1152        });
1153        assert!(rpc_called.load(Ordering::SeqCst));
1154
1155        shutdown_send.send(()).unwrap();
1156        join_handle.await.unwrap();
1157    }
1158}