risingwave_stream/executor/
chain.rs1use crate::executor::prelude::*;
16use crate::task::CreateMviewProgressReporter;
17
18pub struct ChainExecutor {
23 snapshot: Executor,
24
25 upstream: Executor,
26
27 progress: CreateMviewProgressReporter,
28
29 actor_id: ActorId,
30
31 upstream_only: bool,
33}
34
35impl ChainExecutor {
36 pub fn new(
37 snapshot: Executor,
38 upstream: Executor,
39 progress: CreateMviewProgressReporter,
40 upstream_only: bool,
41 ) -> Self {
42 Self {
43 snapshot,
44 upstream,
45 actor_id: progress.actor_id(),
46 progress,
47 upstream_only,
48 }
49 }
50
51 #[try_stream(ok = Message, error = StreamExecutorError)]
52 async fn execute_inner(mut self) {
53 let mut upstream = self.upstream.execute();
54
55 let barrier = expect_first_barrier(&mut upstream).await?;
57 let prev_epoch = barrier.epoch.prev;
58
59 let to_consume_snapshot = barrier.is_newly_added(self.actor_id) && !self.upstream_only;
63
64 if barrier.is_newly_added(self.actor_id) && self.upstream_only {
67 self.progress.finish(barrier.epoch, 0);
68 }
69
70 yield Message::Barrier(barrier);
72
73 if to_consume_snapshot {
76 let snapshot = self.snapshot.execute_with_epoch(prev_epoch);
78
79 #[for_await]
80 for msg in snapshot {
81 yield msg?;
82 }
83 }
84
85 #[for_await]
88 for msg in upstream {
89 let msg = msg?;
90 if to_consume_snapshot && let Message::Barrier(barrier) = &msg {
91 self.progress.finish(barrier.epoch, 0);
92 }
93 yield msg;
94 }
95 }
96}
97
98impl Execute for ChainExecutor {
99 fn execute(self: Box<Self>) -> super::BoxedMessageStream {
100 self.execute_inner().boxed()
101 }
102}
103
104#[cfg(test)]
105mod test {
106
107 use futures::StreamExt;
108 use risingwave_common::array::StreamChunk;
109 use risingwave_common::array::stream_chunk::StreamChunkTestExt;
110 use risingwave_common::catalog::{Field, Schema};
111 use risingwave_common::types::DataType;
112 use risingwave_common::util::epoch::test_epoch;
113 use risingwave_pb::stream_plan::Dispatcher;
114
115 use super::ChainExecutor;
116 use crate::executor::test_utils::MockSource;
117 use crate::executor::{AddMutation, Barrier, Execute, Message, Mutation, PkIndices};
118 use crate::task::{CreateMviewProgressReporter, LocalBarrierManager};
119
120 #[tokio::test]
121 async fn test_basic() {
122 let barrier_manager = LocalBarrierManager::for_test();
123 let progress = CreateMviewProgressReporter::for_test(barrier_manager);
124 let actor_id = progress.actor_id();
125
126 let schema = Schema::new(vec![Field::unnamed(DataType::Int64)]);
127 let first = MockSource::with_chunks(vec![
128 StreamChunk::from_pretty("I\n + 1"),
129 StreamChunk::from_pretty("I\n + 2"),
130 ])
131 .stop_on_finish(false)
132 .into_executor(schema.clone(), PkIndices::new());
133
134 let second = MockSource::with_messages(vec![
135 Message::Barrier(Barrier::new_test_barrier(test_epoch(1)).with_mutation(
136 Mutation::Add(AddMutation {
137 adds: maplit::hashmap! {
138 0 => vec![Dispatcher {
139 downstream_actor_id: vec![actor_id],
140 ..Default::default()
141 }],
142 },
143 added_actors: maplit::hashset! { actor_id },
144 splits: Default::default(),
145 pause: false,
146 subscriptions_to_add: vec![],
147 }),
148 )),
149 Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
150 Message::Chunk(StreamChunk::from_pretty("I\n + 4")),
151 ])
152 .into_executor(schema.clone(), PkIndices::new());
153
154 let chain = ChainExecutor::new(first, second, progress, false);
155
156 let mut chain = chain.boxed().execute();
157 chain.next().await;
158
159 let mut count = 0;
160 while let Some(Message::Chunk(ck)) = chain.next().await.transpose().unwrap() {
161 count += 1;
162 assert_eq!(ck, StreamChunk::from_pretty(&format!("I\n + {count}")));
163 }
164 assert_eq!(count, 4);
165 }
166}