risingwave_frontend/optimizer/plan_node/
logical_union.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::max;
16
17use itertools::Itertools;
18use risingwave_common::catalog::Schema;
19use risingwave_common::types::{DataType, Scalar};
20
21use super::utils::impl_distill_by_unit;
22use super::{
23    ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PredicatePushdown, ToBatch, ToStream,
24};
25use crate::Explain;
26use crate::error::Result;
27use crate::expr::{ExprImpl, InputRef, Literal};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::generic::GenericPlanRef;
30use crate::optimizer::plan_node::stream_union::StreamUnion;
31use crate::optimizer::plan_node::{
32    BatchHashAgg, BatchUnion, ColumnPruningContext, LogicalProject, PlanTreeNode,
33    PredicatePushdownContext, RewriteStreamContext, ToStreamContext, generic,
34};
35use crate::optimizer::property::RequiredDist;
36use crate::utils::{ColIndexMapping, Condition};
37
38/// `LogicalUnion` returns the union of the rows of its inputs.
39/// If `all` is false, it needs to eliminate duplicates.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct LogicalUnion {
42    pub base: PlanBase<Logical>,
43    core: generic::Union<PlanRef>,
44}
45
46impl LogicalUnion {
47    pub fn new(all: bool, inputs: Vec<PlanRef>) -> Self {
48        assert!(Schema::all_type_eq(inputs.iter().map(|x| x.schema())));
49        Self::new_with_source_col(all, inputs, None)
50    }
51
52    /// It is used by streaming processing. We need to use `source_col` to identify the record came
53    /// from which source input.
54    pub fn new_with_source_col(all: bool, inputs: Vec<PlanRef>, source_col: Option<usize>) -> Self {
55        let core = generic::Union {
56            all,
57            inputs,
58            source_col,
59        };
60        let base = PlanBase::new_logical_with_core(&core);
61        LogicalUnion { base, core }
62    }
63
64    pub fn create(all: bool, inputs: Vec<PlanRef>) -> PlanRef {
65        LogicalUnion::new(all, inputs).into()
66    }
67
68    pub fn all(&self) -> bool {
69        self.core.all
70    }
71
72    pub fn source_col(&self) -> Option<usize> {
73        self.core.source_col
74    }
75}
76
77impl PlanTreeNode for LogicalUnion {
78    fn inputs(&self) -> smallvec::SmallVec<[crate::optimizer::PlanRef; 2]> {
79        self.core.inputs.clone().into_iter().collect()
80    }
81
82    fn clone_with_inputs(&self, inputs: &[crate::optimizer::PlanRef]) -> PlanRef {
83        Self::new_with_source_col(self.all(), inputs.to_vec(), self.core.source_col).into()
84    }
85}
86
87impl_distill_by_unit!(LogicalUnion, core, "LogicalUnion");
88
89impl ColPrunable for LogicalUnion {
90    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
91        let new_inputs = self
92            .inputs()
93            .iter()
94            .map(|input| input.prune_col(required_cols, ctx))
95            .collect_vec();
96        self.clone_with_inputs(&new_inputs)
97    }
98}
99
100impl ExprRewritable for LogicalUnion {}
101
102impl ExprVisitable for LogicalUnion {}
103
104impl PredicatePushdown for LogicalUnion {
105    fn predicate_pushdown(
106        &self,
107        predicate: Condition,
108        ctx: &mut PredicatePushdownContext,
109    ) -> PlanRef {
110        let new_inputs = self
111            .inputs()
112            .iter()
113            .map(|input| input.predicate_pushdown(predicate.clone(), ctx))
114            .collect_vec();
115        self.clone_with_inputs(&new_inputs)
116    }
117}
118
119impl ToBatch for LogicalUnion {
120    fn to_batch(&self) -> Result<PlanRef> {
121        let new_inputs = self
122            .inputs()
123            .iter()
124            .map(|input| input.to_batch())
125            .try_collect()?;
126        let new_logical = generic::Union {
127            all: true,
128            inputs: new_inputs,
129            source_col: None,
130        };
131        // We still need to handle !all even if we already have `UnionToDistinctRule`, because it
132        // can be generated by index selection which is an optimization during the `to_batch`.
133        // Convert union to union all + agg
134        if !self.all() {
135            let batch_union = BatchUnion::new(new_logical).into();
136            Ok(BatchHashAgg::new(
137                generic::Agg::new(vec![], (0..self.base.schema().len()).collect(), batch_union)
138                    .with_enable_two_phase(false),
139            )
140            .into())
141        } else {
142            Ok(BatchUnion::new(new_logical).into())
143        }
144    }
145}
146
147impl ToStream for LogicalUnion {
148    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
149        // TODO: use round robin distribution instead of using hash distribution of all inputs.
150        let dist = RequiredDist::hash_shard(self.base.stream_key().unwrap_or_else(|| {
151            panic!(
152                "should always have a stream key in the stream plan but not, sub plan: {}",
153                PlanRef::from(self.clone()).explain_to_string()
154            )
155        }));
156        let new_inputs: Result<Vec<_>> = self
157            .inputs()
158            .iter()
159            .map(|input| input.to_stream_with_dist_required(&dist, ctx))
160            .collect();
161        let new_logical = generic::Union {
162            all: true,
163            inputs: new_inputs?,
164            ..self.core
165        };
166        assert!(
167            self.all(),
168            "After UnionToDistinctRule, union should become union all"
169        );
170        Ok(StreamUnion::new(new_logical).into())
171    }
172
173    fn logical_rewrite_for_stream(
174        &self,
175        ctx: &mut RewriteStreamContext,
176    ) -> Result<(PlanRef, ColIndexMapping)> {
177        type FixedState = std::hash::BuildHasherDefault<std::hash::DefaultHasher>;
178        type TypeMap<T> = std::collections::HashMap<DataType, T, FixedState>;
179
180        let original_schema = self.base.schema().clone();
181        let original_schema_len = original_schema.len();
182        let mut rewrites = vec![];
183        for input in &self.core.inputs {
184            rewrites.push(input.logical_rewrite_for_stream(ctx)?);
185        }
186
187        let original_schema_contain_all_input_stream_keys =
188            rewrites.iter().all(|(new_input, col_index_mapping)| {
189                let original_schema_new_pos = (0..original_schema_len)
190                    .map(|x| col_index_mapping.map(x))
191                    .collect_vec();
192                new_input
193                    .expect_stream_key()
194                    .iter()
195                    .all(|x| original_schema_new_pos.contains(x))
196            });
197
198        if original_schema_contain_all_input_stream_keys {
199            // Add one more column at the end of the original schema to identify the record came
200            // from which input. [original_schema + source_col]
201            let new_inputs = rewrites
202                .into_iter()
203                .enumerate()
204                .map(|(i, (new_input, col_index_mapping))| {
205                    // original_schema
206                    let mut exprs = (0..original_schema_len)
207                        .map(|x| {
208                            ExprImpl::InputRef(
209                                InputRef::new(
210                                    col_index_mapping.map(x),
211                                    original_schema.fields[x].data_type.clone(),
212                                )
213                                .into(),
214                            )
215                        })
216                        .collect_vec();
217                    // source_col
218                    exprs.push(ExprImpl::Literal(
219                        Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
220                    ));
221                    LogicalProject::create(new_input, exprs)
222                })
223                .collect_vec();
224            let new_union = LogicalUnion::new_with_source_col(
225                self.all(),
226                new_inputs,
227                Some(original_schema_len),
228            );
229            // We have already used project to map rewrite input to the origin schema, so we can use
230            // identity with the new schema len.
231            let out_col_change =
232                ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
233            Ok((new_union.into(), out_col_change))
234        } else {
235            // In order to ensure all inputs have the same schema for new union, we construct new
236            // schema like that: [original_schema + merged_stream_key + source_col]
237            // where merged_stream_key is merged by the types of each input stream key.
238            // If all inputs have the same stream key column types, we have a small merged_stream_key. Otherwise, we will have a large merged_stream_key.
239
240            let (merged_stream_key_types, types_offset) = {
241                let mut max_types_counter = TypeMap::default();
242                for (new_input, _) in &rewrites {
243                    let mut types_counter = TypeMap::default();
244                    for x in new_input.expect_stream_key() {
245                        types_counter
246                            .entry(new_input.schema().fields[*x].data_type())
247                            .and_modify(|x| *x += 1)
248                            .or_insert(1);
249                    }
250                    for (key, val) in types_counter {
251                        max_types_counter
252                            .entry(key)
253                            .and_modify(|x| *x = max(*x, val))
254                            .or_insert(val);
255                    }
256                }
257
258                let mut merged_stream_key_types = vec![];
259                let mut types_offset = TypeMap::default();
260                let mut offset = 0;
261                for (key, val) in max_types_counter {
262                    let _ = types_offset.insert(key.clone(), offset);
263                    offset += val;
264                    merged_stream_key_types.extend(std::iter::repeat_n(key.clone(), val));
265                }
266
267                (merged_stream_key_types, types_offset)
268            };
269
270            let input_stream_key_nulls = merged_stream_key_types
271                .iter()
272                .map(|t| ExprImpl::Literal(Literal::new(None, t.clone()).into()))
273                .collect_vec();
274
275            let new_inputs = rewrites
276                .into_iter()
277                .enumerate()
278                .map(|(i, (new_input, col_index_mapping))| {
279                    // original_schema
280                    let mut exprs = (0..original_schema_len)
281                        .map(|x| {
282                            ExprImpl::InputRef(
283                                InputRef::new(
284                                    col_index_mapping.map(x),
285                                    original_schema.fields[x].data_type.clone(),
286                                )
287                                .into(),
288                            )
289                        })
290                        .collect_vec();
291                    // merged_stream_key
292                    let mut input_stream_keys = input_stream_key_nulls.clone();
293                    let mut types_counter = TypeMap::default();
294                    for stream_key_idx in new_input.expect_stream_key() {
295                        let data_type =
296                            new_input.schema().fields[*stream_key_idx].data_type.clone();
297                        let count = *types_counter
298                            .entry(data_type.clone())
299                            .and_modify(|x| *x += 1)
300                            .or_insert(1);
301                        let type_start_offset = *types_offset.get(&data_type).unwrap();
302
303                        input_stream_keys[type_start_offset + count - 1] =
304                            ExprImpl::InputRef(InputRef::new(*stream_key_idx, data_type).into());
305                    }
306                    exprs.extend(input_stream_keys);
307                    // source_col
308                    exprs.push(ExprImpl::Literal(
309                        Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
310                    ));
311                    LogicalProject::create(new_input, exprs)
312                })
313                .collect_vec();
314
315            let new_union = LogicalUnion::new_with_source_col(
316                self.all(),
317                new_inputs,
318                Some(original_schema_len + merged_stream_key_types.len()),
319            );
320            // We have already used project to map rewrite input to the origin schema, so we can use
321            // identity with the new schema len.
322            let out_col_change =
323                ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
324            Ok((new_union.into(), out_col_change))
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331
332    use risingwave_common::catalog::Field;
333
334    use super::*;
335    use crate::optimizer::optimizer_context::OptimizerContext;
336    use crate::optimizer::plan_node::{LogicalValues, PlanTreeNodeUnary};
337
338    #[tokio::test]
339    async fn test_prune_union() {
340        let ty = DataType::Int32;
341        let ctx = OptimizerContext::mock().await;
342        let fields: Vec<Field> = vec![
343            Field::with_name(ty.clone(), "v1"),
344            Field::with_name(ty.clone(), "v2"),
345            Field::with_name(ty.clone(), "v3"),
346        ];
347        let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
348
349        let values2 = values1.clone();
350
351        let union: PlanRef = LogicalUnion::new(false, vec![values1.into(), values2.into()]).into();
352
353        // Perform the prune
354        let required_cols = vec![1, 2];
355        let plan = union.prune_col(
356            &required_cols,
357            &mut ColumnPruningContext::new(union.clone()),
358        );
359
360        // Check the result
361        let union = plan.as_logical_union().unwrap();
362        assert_eq!(union.base.schema().len(), 2);
363    }
364
365    #[tokio::test]
366    async fn test_union_to_batch() {
367        let ty = DataType::Int32;
368        let ctx = OptimizerContext::mock().await;
369        let fields: Vec<Field> = vec![
370            Field::with_name(ty.clone(), "v1"),
371            Field::with_name(ty.clone(), "v2"),
372            Field::with_name(ty.clone(), "v3"),
373        ];
374        let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
375
376        let values2 = values1.clone();
377
378        let union = LogicalUnion::new(false, vec![values1.into(), values2.into()]);
379
380        let plan = union.to_batch().unwrap();
381        let agg: &BatchHashAgg = plan.as_batch_hash_agg().unwrap();
382        let agg_input = agg.input();
383        let union = agg_input.as_batch_union().unwrap();
384
385        assert_eq!(union.inputs().len(), 2);
386    }
387}