risingwave_frontend/optimizer/plan_node/
logical_expand.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16
17use super::generic::GenericPlanRef;
18use super::utils::impl_distill_by_unit;
19use super::{
20    BatchExpand, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
21    PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamExpand, StreamPlanRef, ToBatch, ToStream,
22    gen_filter_and_pushdown, generic,
23};
24use crate::error::Result;
25use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
26use crate::optimizer::plan_node::{
27    ColumnPruningContext, LogicalProject, PredicatePushdownContext, RewriteStreamContext,
28    ToStreamContext,
29};
30use crate::utils::{ColIndexMapping, Condition};
31
32/// [`LogicalExpand`] expands one row multiple times according to `column_subsets` and also keeps
33/// original columns of input. It can be used to implement distinct aggregation and group set.
34///
35/// This is the schema of `LogicalExpand`:
36/// | expanded columns(i.e. some columns are set to null) | original columns of input | flag |.
37///
38/// Aggregates use expanded columns as their arguments and original columns for their filter. `flag`
39/// is used to distinguish between different `subset`s in `column_subsets`.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct LogicalExpand {
42    pub base: PlanBase<Logical>,
43    core: generic::Expand<PlanRef>,
44}
45
46impl LogicalExpand {
47    pub fn new(input: PlanRef, column_subsets: Vec<Vec<usize>>) -> Self {
48        for key in column_subsets.iter().flatten() {
49            assert!(*key < input.schema().len());
50        }
51
52        let core = generic::Expand {
53            column_subsets,
54            input,
55        };
56        let base = PlanBase::new_logical_with_core(&core);
57
58        LogicalExpand { base, core }
59    }
60
61    pub fn create(input: PlanRef, column_subsets: Vec<Vec<usize>>) -> PlanRef {
62        Self::new(input, column_subsets).into()
63    }
64
65    pub fn column_subsets(&self) -> &Vec<Vec<usize>> {
66        &self.core.column_subsets
67    }
68
69    pub fn decompose(self) -> (PlanRef, Vec<Vec<usize>>) {
70        self.core.decompose()
71    }
72}
73
74impl PlanTreeNodeUnary<Logical> for LogicalExpand {
75    fn input(&self) -> PlanRef {
76        self.core.input.clone()
77    }
78
79    fn clone_with_input(&self, input: PlanRef) -> Self {
80        Self::new(input, self.column_subsets().clone())
81    }
82
83    fn rewrite_with_input(
84        &self,
85        input: PlanRef,
86        input_col_change: ColIndexMapping,
87    ) -> (Self, ColIndexMapping) {
88        let column_subsets = self
89            .column_subsets()
90            .iter()
91            .map(|subset| {
92                subset
93                    .iter()
94                    .filter_map(|i| input_col_change.try_map(*i))
95                    .collect_vec()
96            })
97            .collect_vec();
98
99        let old_out_len = self.schema().len();
100        let old_in_len = self.input().schema().len();
101        let new_in_len = input.schema().len();
102        assert_eq!(
103            old_out_len,
104            old_in_len * 2 + 1 // expanded input cols + real input cols + flag
105        );
106
107        let mut mapping = Vec::with_capacity(old_out_len);
108        // map the expanded input columns
109        for i in 0..old_in_len {
110            mapping.push(input_col_change.try_map(i));
111        }
112        // map the real input columns
113        for i in 0..old_in_len {
114            mapping.push(
115                input_col_change
116                    .try_map(i)
117                    .map(|x| x + new_in_len /* # of new expanded input cols */),
118            );
119        }
120        // map the flag column
121        mapping.push(Some(2 * new_in_len));
122
123        let expand = Self::new(input, column_subsets);
124        let output_col_num = expand.schema().len();
125        (expand, ColIndexMapping::new(mapping, output_col_num))
126    }
127}
128
129impl_plan_tree_node_for_unary! { Logical, LogicalExpand}
130impl_distill_by_unit!(LogicalExpand, core, "LogicalExpand");
131
132impl ColPrunable for LogicalExpand {
133    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
134        // No pruning.
135        let input_required_cols = (0..self.input().schema().len()).collect_vec();
136        LogicalProject::with_out_col_idx(
137            self.clone_with_input(self.input().prune_col(&input_required_cols, ctx))
138                .into(),
139            required_cols.iter().cloned(),
140        )
141        .into()
142    }
143}
144
145impl ExprRewritable<Logical> for LogicalExpand {}
146
147impl ExprVisitable for LogicalExpand {}
148
149impl PredicatePushdown for LogicalExpand {
150    fn predicate_pushdown(
151        &self,
152        predicate: Condition,
153        ctx: &mut PredicatePushdownContext,
154    ) -> PlanRef {
155        // No pushdown.
156        gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
157    }
158}
159
160impl ToBatch for LogicalExpand {
161    fn to_batch(&self) -> Result<BatchPlanRef> {
162        let new_input = self.input().to_batch()?;
163        let expand = self.core.clone_with_input(new_input);
164        Ok(BatchExpand::new(expand).into())
165    }
166}
167
168impl ToStream for LogicalExpand {
169    fn logical_rewrite_for_stream(
170        &self,
171        ctx: &mut RewriteStreamContext,
172    ) -> Result<(PlanRef, ColIndexMapping)> {
173        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
174        let (expand, out_col_change) = self.rewrite_with_input(input, input_col_change);
175        Ok((expand.into(), out_col_change))
176    }
177
178    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
179        let new_input = self.input().to_stream(ctx)?;
180        let expand = self.core.clone_with_input(new_input);
181        Ok(StreamExpand::new(expand).into())
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use itertools::Itertools;
188    use risingwave_common::catalog::{Field, Schema};
189    use risingwave_common::types::DataType;
190
191    use crate::optimizer::optimizer_context::OptimizerContext;
192    use crate::optimizer::plan_node::{LogicalExpand, LogicalValues};
193
194    // TODO(Wenzhuo): change this test according to expand's new definition.
195    #[tokio::test]
196    async fn fd_derivation_expand() {
197        // input: [v1, v2, v3]
198        // FD: v1 --> { v2, v3 }
199        // output: [v1, v2, v3, flag],
200        // FD: { v1, flag } --> { v2, v3 }
201        let ctx = OptimizerContext::mock().await;
202        let fields: Vec<Field> = vec![
203            Field::with_name(DataType::Int32, "v1"),
204            Field::with_name(DataType::Int32, "v2"),
205            Field::with_name(DataType::Int32, "v3"),
206        ];
207        let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
208        values
209            .base
210            .functional_dependency_mut()
211            .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
212
213        let column_subsets = vec![vec![0, 1], vec![2]];
214        let expand = LogicalExpand::create(values.into(), column_subsets);
215        let fd = expand.functional_dependency().as_dependencies();
216        assert_eq!(fd.len(), 1);
217        assert_eq!(fd[0].from().ones().collect_vec(), &[0, 6]);
218        assert_eq!(fd[0].to().ones().collect_vec(), &[1, 2]);
219    }
220}