risingwave_stream/executor/
chain.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::executor::prelude::*;
16use crate::task::CreateMviewProgressReporter;
17
18/// [`ChainExecutor`] is an executor that enables synchronization between the existing stream and
19/// newly appended executors. Currently, [`ChainExecutor`] is mainly used to implement MV on MV
20/// feature. It pipes new data of existing MVs to newly created MV only all of the old data in the
21/// existing MVs are dispatched.
22pub struct ChainExecutor {
23    snapshot: Executor,
24
25    upstream: Executor,
26
27    progress: CreateMviewProgressReporter,
28
29    actor_id: ActorId,
30
31    /// Only consume upstream messages.
32    upstream_only: bool,
33}
34
35impl ChainExecutor {
36    pub fn new(
37        snapshot: Executor,
38        upstream: Executor,
39        progress: CreateMviewProgressReporter,
40        upstream_only: bool,
41    ) -> Self {
42        Self {
43            snapshot,
44            upstream,
45            actor_id: progress.actor_id(),
46            progress,
47            upstream_only,
48        }
49    }
50
51    #[try_stream(ok = Message, error = StreamExecutorError)]
52    async fn execute_inner(mut self) {
53        let mut upstream = self.upstream.execute();
54
55        // 1. Poll the upstream to get the first barrier.
56        let barrier = expect_first_barrier(&mut upstream).await?;
57        let prev_epoch = barrier.epoch.prev;
58
59        // If the barrier is a conf change of creating this mview, init snapshot from its epoch
60        // and begin to consume the snapshot.
61        // Otherwise, it means we've recovered and the snapshot is already consumed.
62        let to_consume_snapshot = barrier.is_newly_added(self.actor_id) && !self.upstream_only;
63
64        // If the barrier is a conf change of creating this mview, and the snapshot is not to be
65        // consumed, we can finish the progress immediately.
66        if barrier.is_newly_added(self.actor_id) && self.upstream_only {
67            self.progress.finish(barrier.epoch, 0);
68        }
69
70        // The first barrier message should be propagated.
71        yield Message::Barrier(barrier);
72
73        // 2. Consume the snapshot if needed. Note that the snapshot is already projected, so
74        // there's no mapping required.
75        if to_consume_snapshot {
76            // Init the snapshot with reading epoch.
77            let snapshot = self.snapshot.execute_with_epoch(prev_epoch);
78
79            #[for_await]
80            for msg in snapshot {
81                yield msg?;
82            }
83        }
84
85        // 3. Continuously consume the upstream. Report that we've finished the creation on the
86        // first barrier.
87        #[for_await]
88        for msg in upstream {
89            let msg = msg?;
90            if to_consume_snapshot && let Message::Barrier(barrier) = &msg {
91                self.progress.finish(barrier.epoch, 0);
92            }
93            yield msg;
94        }
95    }
96}
97
98impl Execute for ChainExecutor {
99    fn execute(self: Box<Self>) -> super::BoxedMessageStream {
100        self.execute_inner().boxed()
101    }
102}
103
104#[cfg(test)]
105mod test {
106
107    use futures::StreamExt;
108    use risingwave_common::array::StreamChunk;
109    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
110    use risingwave_common::catalog::{Field, Schema};
111    use risingwave_common::types::DataType;
112    use risingwave_common::util::epoch::test_epoch;
113    use risingwave_pb::stream_plan::Dispatcher;
114
115    use super::ChainExecutor;
116    use crate::executor::test_utils::MockSource;
117    use crate::executor::{AddMutation, Barrier, Execute, Message, Mutation, PkIndices};
118    use crate::task::CreateMviewProgressReporter;
119    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
120
121    #[tokio::test]
122    async fn test_basic() {
123        let test_env = LocalBarrierTestEnv::for_test().await;
124        let barrier_manager = test_env.local_barrier_manager.clone();
125        let progress = CreateMviewProgressReporter::for_test(barrier_manager);
126        let actor_id = progress.actor_id();
127
128        let schema = Schema::new(vec![Field::unnamed(DataType::Int64)]);
129        let first = MockSource::with_chunks(vec![
130            StreamChunk::from_pretty("I\n + 1"),
131            StreamChunk::from_pretty("I\n + 2"),
132        ])
133        .stop_on_finish(false)
134        .into_executor(schema.clone(), PkIndices::new());
135
136        let second = MockSource::with_messages(vec![
137            Message::Barrier(Barrier::new_test_barrier(test_epoch(1)).with_mutation(
138                Mutation::Add(AddMutation {
139                    adds: maplit::hashmap! {
140                        0 => vec![Dispatcher {
141                            downstream_actor_id: vec![actor_id],
142                            ..Default::default()
143                        }],
144                    },
145                    added_actors: maplit::hashset! { actor_id },
146                    splits: Default::default(),
147                    pause: false,
148                    subscriptions_to_add: vec![],
149                    backfill_nodes_to_pause: Default::default(),
150                }),
151            )),
152            Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
153            Message::Chunk(StreamChunk::from_pretty("I\n + 4")),
154        ])
155        .into_executor(schema.clone(), PkIndices::new());
156
157        let chain = ChainExecutor::new(first, second, progress, false);
158
159        let mut chain = chain.boxed().execute();
160        chain.next().await;
161
162        let mut count = 0;
163        while let Some(Message::Chunk(ck)) = chain.next().await.transpose().unwrap() {
164            count += 1;
165            assert_eq!(ck, StreamChunk::from_pretty(&format!("I\n + {count}")));
166        }
167        assert_eq!(count, 4);
168    }
169}