risingwave_stream/executor/
lookup_union.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use futures::channel::mpsc;
17use futures::future::{Either, join_all, select};
18use futures::{FutureExt, SinkExt};
19use itertools::Itertools;
20
21use crate::executor::prelude::*;
22
23/// Merges data from multiple inputs with order. If `order = [2, 1, 0]`, then
24/// it will first pipe data from the third input; after the third input gets a barrier, it will then
25/// pipe the second, and finally the first. In the future we could have more efficient
26/// implementation.
27pub struct LookupUnionExecutor {
28    inputs: Vec<Executor>,
29    order: Vec<usize>,
30}
31
32impl std::fmt::Debug for LookupUnionExecutor {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("LookupUnionExecutor").finish()
35    }
36}
37
38impl LookupUnionExecutor {
39    pub fn new(inputs: Vec<Executor>, order: Vec<u32>) -> Self {
40        Self {
41            inputs,
42            order: order.iter().map(|x| *x as _).collect(),
43        }
44    }
45}
46
47#[async_trait]
48impl Execute for LookupUnionExecutor {
49    fn execute(self: Box<Self>) -> BoxedMessageStream {
50        self.execute_inner().boxed()
51    }
52}
53
54impl LookupUnionExecutor {
55    #[try_stream(ok = Message, error = StreamExecutorError)]
56    async fn execute_inner(self) {
57        let mut inputs = self.inputs.into_iter().map(Some).collect_vec();
58        let mut futures = vec![];
59        let mut rxs = vec![];
60        for idx in self.order {
61            let mut stream = inputs[idx].take().unwrap().execute();
62            let (mut tx, rx) = mpsc::channel(1024); // set buffer size to control back pressure
63            rxs.push(rx);
64            futures.push(
65                // construct a future that drives input stream until it is exhausted.
66                // the input elements are sent over bounded channel.
67                async move {
68                    while let Some(ret) = stream.next().await {
69                        tx.send(ret).await.unwrap();
70                    }
71                }
72                .boxed(),
73            );
74        }
75        // This future is used to drive all inputs.
76        let mut drive_inputs = join_all(futures).fuse();
77        let mut end = false;
78        while !end {
79            end = true; // no message on this turn?
80            let mut this_barrier: Option<Barrier> = None;
81            for rx in &mut rxs {
82                loop {
83                    let msg = match select(rx.next(), &mut drive_inputs).await {
84                        Either::Left((Some(msg), _)) => msg?,
85                        Either::Left((None, _)) => break, // input end
86                        Either::Right(_) => continue,
87                    };
88                    end = false;
89                    match msg {
90                        Message::Watermark(_) => {
91                            // TODO: https://github.com/risingwavelabs/risingwave/issues/6042
92                        }
93
94                        msg @ Message::Chunk(_) => yield msg,
95                        Message::Barrier(barrier) => {
96                            if let Some(this_barrier) = &this_barrier {
97                                if this_barrier.epoch != barrier.epoch {
98                                    return Err(StreamExecutorError::align_barrier(
99                                        this_barrier.clone(),
100                                        barrier,
101                                    ));
102                                }
103                            } else {
104                                this_barrier = Some(barrier);
105                            }
106                            break; // move to the next input
107                        }
108                    }
109                }
110            }
111            if end {
112                break;
113            } else {
114                yield Message::Barrier(this_barrier.take().unwrap());
115            }
116        }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use futures::TryStreamExt;
123    use risingwave_common::catalog::Field;
124    use risingwave_common::test_prelude::StreamChunkTestExt;
125    use risingwave_common::util::epoch::test_epoch;
126
127    use super::*;
128    use crate::executor::test_utils::MockSource;
129
130    #[tokio::test]
131    async fn lookup_union() {
132        let schema = Schema {
133            fields: vec![Field::unnamed(DataType::Int64)],
134        };
135        let source0 = MockSource::with_messages(vec![
136            Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
137            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
138            Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
139            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
140            Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
141            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
142        ])
143        .stop_on_finish(false)
144        .into_executor(schema.clone(), vec![0]);
145        let source1 = MockSource::with_messages(vec![
146            Message::Chunk(StreamChunk::from_pretty("I\n + 11")),
147            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
148            Message::Chunk(StreamChunk::from_pretty("I\n + 12")),
149            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
150        ])
151        .stop_on_finish(false)
152        .into_executor(schema.clone(), vec![0]);
153        let source2 = MockSource::with_messages(vec![
154            Message::Chunk(StreamChunk::from_pretty("I\n + 21")),
155            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
156            Message::Chunk(StreamChunk::from_pretty("I\n + 22")),
157            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
158        ])
159        .stop_on_finish(false)
160        .into_executor(schema, vec![0]);
161
162        let executor = LookupUnionExecutor::new(vec![source0, source1, source2], vec![2, 1, 0])
163            .boxed()
164            .execute();
165
166        let outputs: Vec<_> = executor.try_collect().await.unwrap();
167        assert_eq!(
168            outputs,
169            vec![
170                Message::Chunk(StreamChunk::from_pretty("I\n + 21")),
171                Message::Chunk(StreamChunk::from_pretty("I\n + 11")),
172                Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
173                Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
174                Message::Chunk(StreamChunk::from_pretty("I\n + 22")),
175                Message::Chunk(StreamChunk::from_pretty("I\n + 12")),
176                Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
177                Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
178                Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
179                Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
180            ]
181        );
182    }
183}