risingwave_frontend/optimizer/plan_node/
logical_union.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::max;
16use std::collections::BTreeMap;
17
18use itertools::Itertools;
19use risingwave_common::catalog::Schema;
20use risingwave_common::types::{DataType, Scalar};
21
22use super::utils::impl_distill_by_unit;
23use super::{
24    ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PredicatePushdown, ToBatch, ToStream,
25};
26use crate::Explain;
27use crate::error::Result;
28use crate::expr::{ExprImpl, InputRef, Literal};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::GenericPlanRef;
31use crate::optimizer::plan_node::stream_union::StreamUnion;
32use crate::optimizer::plan_node::{
33    BatchHashAgg, BatchUnion, ColumnPruningContext, LogicalProject, PlanTreeNode,
34    PredicatePushdownContext, RewriteStreamContext, ToStreamContext, generic,
35};
36use crate::optimizer::property::RequiredDist;
37use crate::utils::{ColIndexMapping, Condition};
38
39/// `LogicalUnion` returns the union of the rows of its inputs.
40/// If `all` is false, it needs to eliminate duplicates.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct LogicalUnion {
43    pub base: PlanBase<Logical>,
44    core: generic::Union<PlanRef>,
45}
46
47impl LogicalUnion {
48    pub fn new(all: bool, inputs: Vec<PlanRef>) -> Self {
49        assert!(Schema::all_type_eq(inputs.iter().map(|x| x.schema())));
50        Self::new_with_source_col(all, inputs, None)
51    }
52
53    /// It is used by streaming processing. We need to use `source_col` to identify the record came
54    /// from which source input.
55    pub fn new_with_source_col(all: bool, inputs: Vec<PlanRef>, source_col: Option<usize>) -> Self {
56        let core = generic::Union {
57            all,
58            inputs,
59            source_col,
60        };
61        let base = PlanBase::new_logical_with_core(&core);
62        LogicalUnion { base, core }
63    }
64
65    pub fn create(all: bool, inputs: Vec<PlanRef>) -> PlanRef {
66        LogicalUnion::new(all, inputs).into()
67    }
68
69    pub fn all(&self) -> bool {
70        self.core.all
71    }
72
73    pub fn source_col(&self) -> Option<usize> {
74        self.core.source_col
75    }
76}
77
78impl PlanTreeNode for LogicalUnion {
79    fn inputs(&self) -> smallvec::SmallVec<[crate::optimizer::PlanRef; 2]> {
80        self.core.inputs.clone().into_iter().collect()
81    }
82
83    fn clone_with_inputs(&self, inputs: &[crate::optimizer::PlanRef]) -> PlanRef {
84        Self::new_with_source_col(self.all(), inputs.to_vec(), self.core.source_col).into()
85    }
86}
87
88impl_distill_by_unit!(LogicalUnion, core, "LogicalUnion");
89
90impl ColPrunable for LogicalUnion {
91    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
92        let new_inputs = self
93            .inputs()
94            .iter()
95            .map(|input| input.prune_col(required_cols, ctx))
96            .collect_vec();
97        self.clone_with_inputs(&new_inputs)
98    }
99}
100
101impl ExprRewritable for LogicalUnion {}
102
103impl ExprVisitable for LogicalUnion {}
104
105impl PredicatePushdown for LogicalUnion {
106    fn predicate_pushdown(
107        &self,
108        predicate: Condition,
109        ctx: &mut PredicatePushdownContext,
110    ) -> PlanRef {
111        let new_inputs = self
112            .inputs()
113            .iter()
114            .map(|input| input.predicate_pushdown(predicate.clone(), ctx))
115            .collect_vec();
116        self.clone_with_inputs(&new_inputs)
117    }
118}
119
120impl ToBatch for LogicalUnion {
121    fn to_batch(&self) -> Result<PlanRef> {
122        let new_inputs = self
123            .inputs()
124            .iter()
125            .map(|input| input.to_batch())
126            .try_collect()?;
127        let new_logical = generic::Union {
128            all: true,
129            inputs: new_inputs,
130            source_col: None,
131        };
132        // We still need to handle !all even if we already have `UnionToDistinctRule`, because it
133        // can be generated by index selection which is an optimization during the `to_batch`.
134        // Convert union to union all + agg
135        if !self.all() {
136            let batch_union = BatchUnion::new(new_logical).into();
137            Ok(BatchHashAgg::new(
138                generic::Agg::new(vec![], (0..self.base.schema().len()).collect(), batch_union)
139                    .with_enable_two_phase(false),
140            )
141            .into())
142        } else {
143            Ok(BatchUnion::new(new_logical).into())
144        }
145    }
146}
147
148impl ToStream for LogicalUnion {
149    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
150        // TODO: use round robin distribution instead of using hash distribution of all inputs.
151        let dist = RequiredDist::hash_shard(self.base.stream_key().unwrap_or_else(|| {
152            panic!(
153                "should always have a stream key in the stream plan but not, sub plan: {}",
154                PlanRef::from(self.clone()).explain_to_string()
155            )
156        }));
157        let new_inputs: Result<Vec<_>> = self
158            .inputs()
159            .iter()
160            .map(|input| input.to_stream_with_dist_required(&dist, ctx))
161            .collect();
162        let new_logical = generic::Union {
163            all: true,
164            inputs: new_inputs?,
165            ..self.core
166        };
167        assert!(
168            self.all(),
169            "After UnionToDistinctRule, union should become union all"
170        );
171        Ok(StreamUnion::new(new_logical).into())
172    }
173
174    fn logical_rewrite_for_stream(
175        &self,
176        ctx: &mut RewriteStreamContext,
177    ) -> Result<(PlanRef, ColIndexMapping)> {
178        let original_schema = self.base.schema().clone();
179        let original_schema_len = original_schema.len();
180        let mut rewrites = vec![];
181        for input in &self.core.inputs {
182            rewrites.push(input.logical_rewrite_for_stream(ctx)?);
183        }
184
185        let original_schema_contain_all_input_stream_keys =
186            rewrites.iter().all(|(new_input, col_index_mapping)| {
187                let original_schema_new_pos = (0..original_schema_len)
188                    .map(|x| col_index_mapping.map(x))
189                    .collect_vec();
190                new_input
191                    .expect_stream_key()
192                    .iter()
193                    .all(|x| original_schema_new_pos.contains(x))
194            });
195
196        if original_schema_contain_all_input_stream_keys {
197            // Add one more column at the end of the original schema to identify the record came
198            // from which input. [original_schema + source_col]
199            let new_inputs = rewrites
200                .into_iter()
201                .enumerate()
202                .map(|(i, (new_input, col_index_mapping))| {
203                    // original_schema
204                    let mut exprs = (0..original_schema_len)
205                        .map(|x| {
206                            ExprImpl::InputRef(
207                                InputRef::new(
208                                    col_index_mapping.map(x),
209                                    original_schema.fields[x].data_type.clone(),
210                                )
211                                .into(),
212                            )
213                        })
214                        .collect_vec();
215                    // source_col
216                    exprs.push(ExprImpl::Literal(
217                        Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
218                    ));
219                    LogicalProject::create(new_input, exprs)
220                })
221                .collect_vec();
222            let new_union = LogicalUnion::new_with_source_col(
223                self.all(),
224                new_inputs,
225                Some(original_schema_len),
226            );
227            // We have already used project to map rewrite input to the origin schema, so we can use
228            // identity with the new schema len.
229            let out_col_change =
230                ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
231            Ok((new_union.into(), out_col_change))
232        } else {
233            // In order to ensure all inputs have the same schema for new union, we construct new
234            // schema like that: [original_schema + merged_stream_key + source_col]
235            // where merged_stream_key is merged by the types of each input stream key.
236            // If all inputs have the same stream key column types, we have a small merged_stream_key. Otherwise, we will have a large merged_stream_key.
237
238            let (merged_stream_key_types, types_offset) = {
239                let mut max_types_counter = BTreeMap::default();
240                for (new_input, _) in &rewrites {
241                    let mut types_counter = BTreeMap::default();
242                    for x in new_input.expect_stream_key() {
243                        types_counter
244                            .entry(new_input.schema().fields[*x].data_type())
245                            .and_modify(|x| *x += 1)
246                            .or_insert(1);
247                    }
248                    for (key, val) in types_counter {
249                        max_types_counter
250                            .entry(key)
251                            .and_modify(|x| *x = max(*x, val))
252                            .or_insert(val);
253                    }
254                }
255
256                let mut merged_stream_key_types = vec![];
257                let mut types_offset = BTreeMap::default();
258                let mut offset = 0;
259                for (key, val) in max_types_counter {
260                    let _ = types_offset.insert(key.clone(), offset);
261                    offset += val;
262                    merged_stream_key_types.extend(std::iter::repeat_n(key.clone(), val));
263                }
264
265                (merged_stream_key_types, types_offset)
266            };
267
268            let input_stream_key_nulls = merged_stream_key_types
269                .iter()
270                .map(|t| ExprImpl::Literal(Literal::new(None, t.clone()).into()))
271                .collect_vec();
272
273            let new_inputs = rewrites
274                .into_iter()
275                .enumerate()
276                .map(|(i, (new_input, col_index_mapping))| {
277                    // original_schema
278                    let mut exprs = (0..original_schema_len)
279                        .map(|x| {
280                            ExprImpl::InputRef(
281                                InputRef::new(
282                                    col_index_mapping.map(x),
283                                    original_schema.fields[x].data_type.clone(),
284                                )
285                                .into(),
286                            )
287                        })
288                        .collect_vec();
289                    // merged_stream_key
290                    let mut input_stream_keys = input_stream_key_nulls.clone();
291                    let mut types_counter = BTreeMap::default();
292                    for stream_key_idx in new_input.expect_stream_key() {
293                        let data_type =
294                            new_input.schema().fields[*stream_key_idx].data_type.clone();
295                        let count = *types_counter
296                            .entry(data_type.clone())
297                            .and_modify(|x| *x += 1)
298                            .or_insert(1);
299                        let type_start_offset = *types_offset.get(&data_type).unwrap();
300
301                        input_stream_keys[type_start_offset + count - 1] =
302                            ExprImpl::InputRef(InputRef::new(*stream_key_idx, data_type).into());
303                    }
304                    exprs.extend(input_stream_keys);
305                    // source_col
306                    exprs.push(ExprImpl::Literal(
307                        Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
308                    ));
309                    LogicalProject::create(new_input, exprs)
310                })
311                .collect_vec();
312
313            let new_union = LogicalUnion::new_with_source_col(
314                self.all(),
315                new_inputs,
316                Some(original_schema_len + merged_stream_key_types.len()),
317            );
318            // We have already used project to map rewrite input to the origin schema, so we can use
319            // identity with the new schema len.
320            let out_col_change =
321                ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
322            Ok((new_union.into(), out_col_change))
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329
330    use risingwave_common::catalog::Field;
331
332    use super::*;
333    use crate::optimizer::optimizer_context::OptimizerContext;
334    use crate::optimizer::plan_node::{LogicalValues, PlanTreeNodeUnary};
335
336    #[tokio::test]
337    async fn test_prune_union() {
338        let ty = DataType::Int32;
339        let ctx = OptimizerContext::mock().await;
340        let fields: Vec<Field> = vec![
341            Field::with_name(ty.clone(), "v1"),
342            Field::with_name(ty.clone(), "v2"),
343            Field::with_name(ty.clone(), "v3"),
344        ];
345        let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
346
347        let values2 = values1.clone();
348
349        let union: PlanRef = LogicalUnion::new(false, vec![values1.into(), values2.into()]).into();
350
351        // Perform the prune
352        let required_cols = vec![1, 2];
353        let plan = union.prune_col(
354            &required_cols,
355            &mut ColumnPruningContext::new(union.clone()),
356        );
357
358        // Check the result
359        let union = plan.as_logical_union().unwrap();
360        assert_eq!(union.base.schema().len(), 2);
361    }
362
363    #[tokio::test]
364    async fn test_union_to_batch() {
365        let ty = DataType::Int32;
366        let ctx = OptimizerContext::mock().await;
367        let fields: Vec<Field> = vec![
368            Field::with_name(ty.clone(), "v1"),
369            Field::with_name(ty.clone(), "v2"),
370            Field::with_name(ty.clone(), "v3"),
371        ];
372        let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
373
374        let values2 = values1.clone();
375
376        let union = LogicalUnion::new(false, vec![values1.into(), values2.into()]);
377
378        let plan = union.to_batch().unwrap();
379        let agg: &BatchHashAgg = plan.as_batch_hash_agg().unwrap();
380        let agg_input = agg.input();
381        let union = agg_input.as_batch_union().unwrap();
382
383        assert_eq!(union.inputs().len(), 2);
384    }
385}