risingwave_frontend/optimizer/plan_node/
logical_union.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::max;
16
17use itertools::Itertools;
18use risingwave_common::catalog::Schema;
19use risingwave_common::types::{DataType, Scalar};
20
21use super::utils::impl_distill_by_unit;
22use super::{
23    ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase, PredicatePushdown,
24    ToBatch, ToStream,
25};
26use crate::Explain;
27use crate::error::Result;
28use crate::expr::{ExprImpl, InputRef, Literal};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::GenericPlanRef;
31use crate::optimizer::plan_node::stream_union::StreamUnion;
32use crate::optimizer::plan_node::{
33    BatchHashAgg, BatchUnion, ColumnPruningContext, LogicalProject, PlanTreeNode,
34    PredicatePushdownContext, RewriteStreamContext, ToStreamContext, generic,
35};
36use crate::optimizer::property::RequiredDist;
37use crate::utils::{ColIndexMapping, Condition};
38
39/// `LogicalUnion` returns the union of the rows of its inputs.
40/// If `all` is false, it needs to eliminate duplicates.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct LogicalUnion {
43    pub base: PlanBase<Logical>,
44    core: generic::Union<PlanRef>,
45}
46
47impl LogicalUnion {
48    pub fn new(all: bool, inputs: Vec<PlanRef>) -> Self {
49        assert!(Schema::all_type_eq(inputs.iter().map(|x| x.schema())));
50        Self::new_with_source_col(all, inputs, None)
51    }
52
53    /// It is used by streaming processing. We need to use `source_col` to identify the record came
54    /// from which source input.
55    pub fn new_with_source_col(all: bool, inputs: Vec<PlanRef>, source_col: Option<usize>) -> Self {
56        let core = generic::Union {
57            all,
58            inputs,
59            source_col,
60        };
61        let base = PlanBase::new_logical_with_core(&core);
62        LogicalUnion { base, core }
63    }
64
65    pub fn create(all: bool, inputs: Vec<PlanRef>) -> PlanRef {
66        LogicalUnion::new(all, inputs).into()
67    }
68
69    pub fn all(&self) -> bool {
70        self.core.all
71    }
72
73    pub fn source_col(&self) -> Option<usize> {
74        self.core.source_col
75    }
76}
77
78impl PlanTreeNode<Logical> for LogicalUnion {
79    fn inputs(&self) -> smallvec::SmallVec<[PlanRef; 2]> {
80        self.core.inputs.clone().into_iter().collect()
81    }
82
83    fn clone_with_inputs(&self, inputs: &[PlanRef]) -> PlanRef {
84        Self::new_with_source_col(self.all(), inputs.to_vec(), self.core.source_col).into()
85    }
86}
87
88impl_distill_by_unit!(LogicalUnion, core, "LogicalUnion");
89
90impl ColPrunable for LogicalUnion {
91    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
92        let new_inputs = self
93            .inputs()
94            .iter()
95            .map(|input| input.prune_col(required_cols, ctx))
96            .collect_vec();
97        self.clone_with_inputs(&new_inputs)
98    }
99}
100
101impl ExprRewritable<Logical> for LogicalUnion {}
102
103impl ExprVisitable for LogicalUnion {}
104
105impl PredicatePushdown for LogicalUnion {
106    fn predicate_pushdown(
107        &self,
108        predicate: Condition,
109        ctx: &mut PredicatePushdownContext,
110    ) -> PlanRef {
111        let new_inputs = self
112            .inputs()
113            .iter()
114            .map(|input| input.predicate_pushdown(predicate.clone(), ctx))
115            .collect_vec();
116        self.clone_with_inputs(&new_inputs)
117    }
118}
119
120impl ToBatch for LogicalUnion {
121    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
122        let new_inputs = self
123            .inputs()
124            .iter()
125            .map(|input| input.to_batch())
126            .try_collect()?;
127        let new_logical = generic::Union {
128            all: true,
129            inputs: new_inputs,
130            source_col: None,
131        };
132        // We still need to handle !all even if we already have `UnionToDistinctRule`, because it
133        // can be generated by index selection which is an optimization during the `to_batch`.
134        // Convert union to union all + agg
135        if !self.all() {
136            let batch_union = BatchUnion::new(new_logical).into();
137            Ok(BatchHashAgg::new(
138                generic::Agg::new(vec![], (0..self.base.schema().len()).collect(), batch_union)
139                    .with_enable_two_phase(false),
140            )
141            .into())
142        } else {
143            Ok(BatchUnion::new(new_logical).into())
144        }
145    }
146}
147
148impl ToStream for LogicalUnion {
149    fn to_stream(
150        &self,
151        ctx: &mut ToStreamContext,
152    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
153        // TODO: use round robin distribution instead of using hash distribution of all inputs.
154        let dist = RequiredDist::hash_shard(self.base.stream_key().unwrap_or_else(|| {
155            panic!(
156                "should always have a stream key in the stream plan but not, sub plan: {}",
157                PlanRef::from(self.clone()).explain_to_string()
158            )
159        }));
160        let new_inputs: Result<Vec<_>> = self
161            .inputs()
162            .iter()
163            .map(|input| input.to_stream_with_dist_required(&dist, ctx))
164            .collect();
165        let core = self.core.clone_with_inputs(new_inputs?);
166        assert!(
167            self.all(),
168            "After UnionToDistinctRule, union should become union all"
169        );
170        Ok(StreamUnion::new(core).into())
171    }
172
173    fn logical_rewrite_for_stream(
174        &self,
175        ctx: &mut RewriteStreamContext,
176    ) -> Result<(PlanRef, ColIndexMapping)> {
177        type FixedState = std::hash::BuildHasherDefault<std::hash::DefaultHasher>;
178        type TypeMap<T> = std::collections::HashMap<DataType, T, FixedState>;
179
180        let original_schema = self.base.schema().clone();
181        let original_schema_len = original_schema.len();
182        let mut rewrites = vec![];
183        for input in &self.core.inputs {
184            rewrites.push(input.logical_rewrite_for_stream(ctx)?);
185        }
186
187        let original_schema_contain_all_input_stream_keys =
188            rewrites.iter().all(|(new_input, col_index_mapping)| {
189                let original_schema_new_pos = (0..original_schema_len)
190                    .map(|x| col_index_mapping.map(x))
191                    .collect_vec();
192                new_input
193                    .expect_stream_key()
194                    .iter()
195                    .all(|x| original_schema_new_pos.contains(x))
196            });
197
198        if original_schema_contain_all_input_stream_keys {
199            // Add one more column at the end of the original schema to identify the record came
200            // from which input. [original_schema + source_col]
201            let new_inputs = rewrites
202                .into_iter()
203                .enumerate()
204                .map(|(i, (new_input, col_index_mapping))| {
205                    // original_schema
206                    let mut exprs = (0..original_schema_len)
207                        .map(|x| {
208                            ExprImpl::InputRef(
209                                InputRef::new(
210                                    col_index_mapping.map(x),
211                                    original_schema.fields[x].data_type.clone(),
212                                )
213                                .into(),
214                            )
215                        })
216                        .collect_vec();
217                    // source_col
218                    exprs.push(ExprImpl::Literal(
219                        Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
220                    ));
221                    LogicalProject::create(new_input, exprs)
222                })
223                .collect_vec();
224            let new_union = LogicalUnion::new_with_source_col(
225                self.all(),
226                new_inputs,
227                Some(original_schema_len),
228            );
229            // We have already used project to map rewrite input to the origin schema, so we can use
230            // identity with the new schema len.
231            let out_col_change =
232                ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
233            Ok((new_union.into(), out_col_change))
234        } else {
235            // In order to ensure all inputs have the same schema for new union, we construct new
236            // schema like that: [original_schema + merged_stream_key + source_col]
237            // where merged_stream_key is merged by the types of each input stream key.
238            // If all inputs have the same stream key column types, we have a small merged_stream_key. Otherwise, we will have a large merged_stream_key.
239
240            let (merged_stream_key_types, types_offset) = {
241                let mut max_types_counter = TypeMap::default();
242                for (new_input, _) in &rewrites {
243                    let mut types_counter = TypeMap::default();
244                    for x in new_input.expect_stream_key() {
245                        types_counter
246                            .entry(new_input.schema().fields[*x].data_type())
247                            .and_modify(|x| *x += 1)
248                            .or_insert(1);
249                    }
250                    for (key, val) in types_counter {
251                        max_types_counter
252                            .entry(key)
253                            .and_modify(|x| *x = max(*x, val))
254                            .or_insert(val);
255                    }
256                }
257
258                let mut merged_stream_key_types = vec![];
259                let mut types_offset = TypeMap::default();
260                let mut offset = 0;
261                for (key, val) in max_types_counter {
262                    let _ = types_offset.insert(key.clone(), offset);
263                    offset += val;
264                    merged_stream_key_types.extend(std::iter::repeat_n(key.clone(), val));
265                }
266
267                (merged_stream_key_types, types_offset)
268            };
269
270            let input_stream_key_nulls = merged_stream_key_types
271                .iter()
272                .map(|t| ExprImpl::Literal(Literal::new(None, t.clone()).into()))
273                .collect_vec();
274
275            let new_inputs = rewrites
276                .into_iter()
277                .enumerate()
278                .map(|(i, (new_input, col_index_mapping))| {
279                    // original_schema
280                    let mut exprs = (0..original_schema_len)
281                        .map(|x| {
282                            ExprImpl::InputRef(
283                                InputRef::new(
284                                    col_index_mapping.map(x),
285                                    original_schema.fields[x].data_type.clone(),
286                                )
287                                .into(),
288                            )
289                        })
290                        .collect_vec();
291                    // merged_stream_key
292                    let mut input_stream_keys = input_stream_key_nulls.clone();
293                    let mut types_counter = TypeMap::default();
294                    for stream_key_idx in new_input.expect_stream_key() {
295                        let data_type =
296                            new_input.schema().fields[*stream_key_idx].data_type.clone();
297                        let count = *types_counter
298                            .entry(data_type.clone())
299                            .and_modify(|x| *x += 1)
300                            .or_insert(1);
301                        let type_start_offset = *types_offset.get(&data_type).unwrap();
302
303                        input_stream_keys[type_start_offset + count - 1] =
304                            ExprImpl::InputRef(InputRef::new(*stream_key_idx, data_type).into());
305                    }
306                    exprs.extend(input_stream_keys);
307                    // source_col
308                    exprs.push(ExprImpl::Literal(
309                        Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
310                    ));
311                    LogicalProject::create(new_input, exprs)
312                })
313                .collect_vec();
314
315            let new_union = LogicalUnion::new_with_source_col(
316                self.all(),
317                new_inputs,
318                Some(original_schema_len + merged_stream_key_types.len()),
319            );
320            // We have already used project to map rewrite input to the origin schema, so we can use
321            // identity with the new schema len.
322            let out_col_change =
323                ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
324            Ok((new_union.into(), out_col_change))
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331
332    use risingwave_common::catalog::Field;
333
334    use super::*;
335    use crate::optimizer::optimizer_context::OptimizerContext;
336    use crate::optimizer::plan_node::{LogicalValues, PlanTreeNodeUnary};
337
338    #[tokio::test]
339    async fn test_prune_union() {
340        let ty = DataType::Int32;
341        let ctx = OptimizerContext::mock().await;
342        let fields: Vec<Field> = vec![
343            Field::with_name(ty.clone(), "v1"),
344            Field::with_name(ty.clone(), "v2"),
345            Field::with_name(ty.clone(), "v3"),
346        ];
347        let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
348
349        let values2 = values1.clone();
350
351        let union: PlanRef = LogicalUnion::new(false, vec![values1.into(), values2.into()]).into();
352
353        // Perform the prune
354        let required_cols = vec![1, 2];
355        let plan = union.prune_col(
356            &required_cols,
357            &mut ColumnPruningContext::new(union.clone()),
358        );
359
360        // Check the result
361        let union = plan.as_logical_union().unwrap();
362        assert_eq!(union.base.schema().len(), 2);
363    }
364
365    #[tokio::test]
366    async fn test_union_to_batch() {
367        let ty = DataType::Int32;
368        let ctx = OptimizerContext::mock().await;
369        let fields: Vec<Field> = vec![
370            Field::with_name(ty.clone(), "v1"),
371            Field::with_name(ty.clone(), "v2"),
372            Field::with_name(ty.clone(), "v3"),
373        ];
374        let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
375
376        let values2 = values1.clone();
377
378        let union = LogicalUnion::new(false, vec![values1.into(), values2.into()]);
379
380        let plan = union.to_batch().unwrap();
381        let agg: &BatchHashAgg = plan.as_batch_hash_agg().unwrap();
382        let agg_input = agg.input();
383        let union = agg_input.as_batch_union().unwrap();
384
385        assert_eq!(union.inputs().len(), 2);
386    }
387}