risingwave_stream/executor/
chain.rs1use crate::executor::prelude::*;
16use crate::task::CreateMviewProgressReporter;
17
18pub struct ChainExecutor {
23 snapshot: Executor,
24
25 upstream: Executor,
26
27 progress: CreateMviewProgressReporter,
28
29 actor_id: ActorId,
30
31 upstream_only: bool,
33}
34
35impl ChainExecutor {
36 pub fn new(
37 snapshot: Executor,
38 upstream: Executor,
39 progress: CreateMviewProgressReporter,
40 upstream_only: bool,
41 ) -> Self {
42 Self {
43 snapshot,
44 upstream,
45 actor_id: progress.actor_id(),
46 progress,
47 upstream_only,
48 }
49 }
50
51 #[try_stream(ok = Message, error = StreamExecutorError)]
52 async fn execute_inner(mut self) {
53 let mut upstream = self.upstream.execute();
54
55 let barrier = expect_first_barrier(&mut upstream).await?;
57 let prev_epoch = barrier.epoch.prev;
58
59 let to_consume_snapshot = barrier.is_newly_added(self.actor_id) && !self.upstream_only;
63
64 if barrier.is_newly_added(self.actor_id) && self.upstream_only {
67 self.progress.finish(barrier.epoch, 0);
68 }
69
70 yield Message::Barrier(barrier);
72
73 if to_consume_snapshot {
76 let snapshot = self.snapshot.execute_with_epoch(prev_epoch);
78
79 #[for_await]
80 for msg in snapshot {
81 yield msg?;
82 }
83 }
84
85 #[for_await]
88 for msg in upstream {
89 let msg = msg?;
90 if to_consume_snapshot && let Message::Barrier(barrier) = &msg {
91 self.progress.finish(barrier.epoch, 0);
92 }
93 yield msg;
94 }
95 }
96}
97
98impl Execute for ChainExecutor {
99 fn execute(self: Box<Self>) -> super::BoxedMessageStream {
100 self.execute_inner().boxed()
101 }
102}
103
104#[cfg(test)]
105mod test {
106
107 use futures::StreamExt;
108 use risingwave_common::array::StreamChunk;
109 use risingwave_common::array::stream_chunk::StreamChunkTestExt;
110 use risingwave_common::catalog::{Field, Schema};
111 use risingwave_common::types::DataType;
112 use risingwave_common::util::epoch::test_epoch;
113 use risingwave_pb::stream_plan::Dispatcher;
114
115 use super::ChainExecutor;
116 use crate::executor::test_utils::MockSource;
117 use crate::executor::{AddMutation, Barrier, Execute, Message, Mutation, PkIndices};
118 use crate::task::CreateMviewProgressReporter;
119 use crate::task::barrier_test_utils::LocalBarrierTestEnv;
120
121 #[tokio::test]
122 async fn test_basic() {
123 let test_env = LocalBarrierTestEnv::for_test().await;
124 let barrier_manager = test_env.local_barrier_manager.clone();
125 let progress = CreateMviewProgressReporter::for_test(barrier_manager);
126 let actor_id = progress.actor_id();
127
128 let schema = Schema::new(vec![Field::unnamed(DataType::Int64)]);
129 let first = MockSource::with_chunks(vec![
130 StreamChunk::from_pretty("I\n + 1"),
131 StreamChunk::from_pretty("I\n + 2"),
132 ])
133 .stop_on_finish(false)
134 .into_executor(schema.clone(), PkIndices::new());
135
136 let second = MockSource::with_messages(vec![
137 Message::Barrier(Barrier::new_test_barrier(test_epoch(1)).with_mutation(
138 Mutation::Add(AddMutation {
139 adds: maplit::hashmap! {
140 0 => vec![Dispatcher {
141 downstream_actor_id: vec![actor_id],
142 ..Default::default()
143 }],
144 },
145 added_actors: maplit::hashset! { actor_id },
146 splits: Default::default(),
147 pause: false,
148 subscriptions_to_add: vec![],
149 backfill_nodes_to_pause: Default::default(),
150 }),
151 )),
152 Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
153 Message::Chunk(StreamChunk::from_pretty("I\n + 4")),
154 ])
155 .into_executor(schema.clone(), PkIndices::new());
156
157 let chain = ChainExecutor::new(first, second, progress, false);
158
159 let mut chain = chain.boxed().execute();
160 chain.next().await;
161
162 let mut count = 0;
163 while let Some(Message::Chunk(ck)) = chain.next().await.transpose().unwrap() {
164 count += 1;
165 assert_eq!(ck, StreamChunk::from_pretty(&format!("I\n + {count}")));
166 }
167 assert_eq!(count, 4);
168 }
169}