risingwave_batch_executors/executor/
union.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use futures::StreamExt;
16use futures_async_stream::try_stream;
17use itertools::Itertools;
18use risingwave_common::array::DataChunk;
19use risingwave_common::catalog::Schema;
20use risingwave_pb::batch_plan::plan_node::NodeBody;
21use rw_futures_util::select_all;
22
23use crate::error::{BatchError, Result};
24use crate::executor::{
25    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
26};
27
28pub struct UnionExecutor {
29    inputs: Vec<BoxedExecutor>,
30    identity: String,
31}
32
33impl Executor for UnionExecutor {
34    fn schema(&self) -> &Schema {
35        self.inputs[0].schema()
36    }
37
38    fn identity(&self) -> &str {
39        &self.identity
40    }
41
42    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
43        self.do_execute()
44    }
45}
46
47impl UnionExecutor {
48    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
49    async fn do_execute(self: Box<Self>) {
50        let mut stream = select_all(
51            self.inputs
52                .into_iter()
53                .map(|input| input.execute())
54                .collect_vec(),
55        )
56        .boxed();
57
58        while let Some(data_chunk) = stream.next().await {
59            let data_chunk = data_chunk?;
60            yield data_chunk
61        }
62    }
63}
64
65impl BoxedExecutorBuilder for UnionExecutor {
66    async fn new_boxed_executor(
67        source: &ExecutorBuilder<'_>,
68        inputs: Vec<BoxedExecutor>,
69    ) -> Result<BoxedExecutor> {
70        let _union_node =
71            try_match_expand!(source.plan_node().get_node_body().unwrap(), NodeBody::Union)?;
72
73        Ok(Box::new(Self::new(
74            inputs,
75            source.plan_node().get_identity().clone(),
76        )))
77    }
78}
79
80impl UnionExecutor {
81    pub fn new(inputs: Vec<BoxedExecutor>, identity: String) -> Self {
82        Self { inputs, identity }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use assert_matches::assert_matches;
89    use futures::stream::StreamExt;
90    use risingwave_common::array::{Array, DataChunk};
91    use risingwave_common::catalog::{Field, Schema};
92    use risingwave_common::test_prelude::DataChunkTestExt;
93    use risingwave_common::types::DataType;
94
95    use crate::executor::test_utils::MockExecutor;
96    use crate::executor::{Executor, UnionExecutor};
97
98    #[tokio::test]
99    async fn test_union_executor() {
100        let schema = Schema {
101            fields: vec![
102                Field::unnamed(DataType::Int32),
103                Field::unnamed(DataType::Int32),
104            ],
105        };
106        let mut mock_executor1 = MockExecutor::new(schema.clone());
107        mock_executor1.add(DataChunk::from_pretty(
108            "i i
109             1 10
110             2 20
111             3 30
112             4 40",
113        ));
114
115        let mut mock_executor2 = MockExecutor::new(schema);
116        mock_executor2.add(DataChunk::from_pretty(
117            "i i
118             5 50
119             6 60
120             7 70
121             8 80",
122        ));
123
124        let union_executor = Box::new(UnionExecutor {
125            inputs: vec![Box::new(mock_executor1), Box::new(mock_executor2)],
126            identity: "UnionExecutor".to_owned(),
127        });
128        let fields = &union_executor.schema().fields;
129        assert_eq!(fields[0].data_type, DataType::Int32);
130        assert_eq!(fields[1].data_type, DataType::Int32);
131        let mut stream = union_executor.execute();
132        let res = stream.next().await.unwrap();
133        assert_matches!(res, Ok(_));
134        if let Ok(res) = res {
135            let col1 = res.column_at(0);
136            let array = col1;
137            let col1 = array.as_int32();
138            assert_eq!(col1.len(), 4);
139            assert_eq!(col1.value_at(0), Some(1));
140            assert_eq!(col1.value_at(1), Some(2));
141            assert_eq!(col1.value_at(2), Some(3));
142            assert_eq!(col1.value_at(3), Some(4));
143
144            let col2 = res.column_at(1);
145            let array = col2;
146            let col2 = array.as_int32();
147            assert_eq!(col2.len(), 4);
148            assert_eq!(col2.value_at(0), Some(10));
149            assert_eq!(col2.value_at(1), Some(20));
150            assert_eq!(col2.value_at(2), Some(30));
151            assert_eq!(col2.value_at(3), Some(40));
152        }
153
154        let res = stream.next().await.unwrap();
155        assert_matches!(res, Ok(_));
156        if let Ok(res) = res {
157            let col1 = res.column_at(0);
158            let array = col1;
159            let col1 = array.as_int32();
160            assert_eq!(col1.len(), 4);
161            assert_eq!(col1.value_at(0), Some(5));
162            assert_eq!(col1.value_at(1), Some(6));
163            assert_eq!(col1.value_at(2), Some(7));
164            assert_eq!(col1.value_at(3), Some(8));
165
166            let col2 = res.column_at(1);
167            let array = col2;
168            let col2 = array.as_int32();
169            assert_eq!(col2.len(), 4);
170            assert_eq!(col2.value_at(0), Some(50));
171            assert_eq!(col2.value_at(1), Some(60));
172            assert_eq!(col2.value_at(2), Some(70));
173            assert_eq!(col2.value_at(3), Some(80));
174        }
175
176        let res = stream.next().await;
177        assert_matches!(res, None);
178    }
179}