risingwave_frontend/optimizer/plan_node/
logical_expand.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16
17use super::generic::GenericPlanRef;
18use super::utils::impl_distill_by_unit;
19use super::{
20    BatchExpand, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary,
21    PredicatePushdown, StreamExpand, ToBatch, ToStream, gen_filter_and_pushdown, generic,
22};
23use crate::error::Result;
24use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
25use crate::optimizer::plan_node::{
26    ColumnPruningContext, LogicalProject, PredicatePushdownContext, RewriteStreamContext,
27    ToStreamContext,
28};
29use crate::utils::{ColIndexMapping, Condition};
30
31/// [`LogicalExpand`] expands one row multiple times according to `column_subsets` and also keeps
32/// original columns of input. It can be used to implement distinct aggregation and group set.
33///
34/// This is the schema of `LogicalExpand`:
35/// | expanded columns(i.e. some columns are set to null) | original columns of input | flag |.
36///
37/// Aggregates use expanded columns as their arguments and original columns for their filter. `flag`
38/// is used to distinguish between different `subset`s in `column_subsets`.
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct LogicalExpand {
41    pub base: PlanBase<Logical>,
42    core: generic::Expand<PlanRef>,
43}
44
45impl LogicalExpand {
46    pub fn new(input: PlanRef, column_subsets: Vec<Vec<usize>>) -> Self {
47        for key in column_subsets.iter().flatten() {
48            assert!(*key < input.schema().len());
49        }
50
51        let core = generic::Expand {
52            column_subsets,
53            input,
54        };
55        let base = PlanBase::new_logical_with_core(&core);
56
57        LogicalExpand { base, core }
58    }
59
60    pub fn create(input: PlanRef, column_subsets: Vec<Vec<usize>>) -> PlanRef {
61        Self::new(input, column_subsets).into()
62    }
63
64    pub fn column_subsets(&self) -> &Vec<Vec<usize>> {
65        &self.core.column_subsets
66    }
67
68    pub fn decompose(self) -> (PlanRef, Vec<Vec<usize>>) {
69        self.core.decompose()
70    }
71}
72
73impl PlanTreeNodeUnary for LogicalExpand {
74    fn input(&self) -> PlanRef {
75        self.core.input.clone()
76    }
77
78    fn clone_with_input(&self, input: PlanRef) -> Self {
79        Self::new(input, self.column_subsets().clone())
80    }
81
82    fn rewrite_with_input(
83        &self,
84        input: PlanRef,
85        input_col_change: ColIndexMapping,
86    ) -> (Self, ColIndexMapping) {
87        let column_subsets = self
88            .column_subsets()
89            .iter()
90            .map(|subset| {
91                subset
92                    .iter()
93                    .filter_map(|i| input_col_change.try_map(*i))
94                    .collect_vec()
95            })
96            .collect_vec();
97
98        let old_out_len = self.schema().len();
99        let old_in_len = self.input().schema().len();
100        let new_in_len = input.schema().len();
101        assert_eq!(
102            old_out_len,
103            old_in_len * 2 + 1 // expanded input cols + real input cols + flag
104        );
105
106        let mut mapping = Vec::with_capacity(old_out_len);
107        // map the expanded input columns
108        for i in 0..old_in_len {
109            mapping.push(input_col_change.try_map(i));
110        }
111        // map the real input columns
112        for i in 0..old_in_len {
113            mapping.push(
114                input_col_change
115                    .try_map(i)
116                    .map(|x| x + new_in_len /* # of new expanded input cols */),
117            );
118        }
119        // map the flag column
120        mapping.push(Some(2 * new_in_len));
121
122        let expand = Self::new(input, column_subsets);
123        let output_col_num = expand.schema().len();
124        (expand, ColIndexMapping::new(mapping, output_col_num))
125    }
126}
127
128impl_plan_tree_node_for_unary! {LogicalExpand}
129impl_distill_by_unit!(LogicalExpand, core, "LogicalExpand");
130
131impl ColPrunable for LogicalExpand {
132    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
133        // No pruning.
134        let input_required_cols = (0..self.input().schema().len()).collect_vec();
135        LogicalProject::with_out_col_idx(
136            self.clone_with_input(self.input().prune_col(&input_required_cols, ctx))
137                .into(),
138            required_cols.iter().cloned(),
139        )
140        .into()
141    }
142}
143
144impl ExprRewritable for LogicalExpand {}
145
146impl ExprVisitable for LogicalExpand {}
147
148impl PredicatePushdown for LogicalExpand {
149    fn predicate_pushdown(
150        &self,
151        predicate: Condition,
152        ctx: &mut PredicatePushdownContext,
153    ) -> PlanRef {
154        // No pushdown.
155        gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
156    }
157}
158
159impl ToBatch for LogicalExpand {
160    fn to_batch(&self) -> Result<PlanRef> {
161        let new_input = self.input().to_batch()?;
162        let mut new_logical = self.core.clone();
163        new_logical.input = new_input;
164        Ok(BatchExpand::new(new_logical).into())
165    }
166}
167
168impl ToStream for LogicalExpand {
169    fn logical_rewrite_for_stream(
170        &self,
171        ctx: &mut RewriteStreamContext,
172    ) -> Result<(PlanRef, ColIndexMapping)> {
173        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
174        let (expand, out_col_change) = self.rewrite_with_input(input, input_col_change);
175        Ok((expand.into(), out_col_change))
176    }
177
178    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
179        let new_input = self.input().to_stream(ctx)?;
180        let mut new_logical = self.core.clone();
181        new_logical.input = new_input;
182        Ok(StreamExpand::new(new_logical).into())
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use itertools::Itertools;
189    use risingwave_common::catalog::{Field, Schema};
190    use risingwave_common::types::DataType;
191
192    use crate::optimizer::optimizer_context::OptimizerContext;
193    use crate::optimizer::plan_node::{LogicalExpand, LogicalValues};
194
195    // TODO(Wenzhuo): change this test according to expand's new definition.
196    #[tokio::test]
197    async fn fd_derivation_expand() {
198        // input: [v1, v2, v3]
199        // FD: v1 --> { v2, v3 }
200        // output: [v1, v2, v3, flag],
201        // FD: { v1, flag } --> { v2, v3 }
202        let ctx = OptimizerContext::mock().await;
203        let fields: Vec<Field> = vec![
204            Field::with_name(DataType::Int32, "v1"),
205            Field::with_name(DataType::Int32, "v2"),
206            Field::with_name(DataType::Int32, "v3"),
207        ];
208        let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
209        values
210            .base
211            .functional_dependency_mut()
212            .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
213
214        let column_subsets = vec![vec![0, 1], vec![2]];
215        let expand = LogicalExpand::create(values.into(), column_subsets);
216        let fd = expand.functional_dependency().as_dependencies();
217        assert_eq!(fd.len(), 1);
218        assert_eq!(fd[0].from().ones().collect_vec(), &[0, 6]);
219        assert_eq!(fd[0].to().ones().collect_vec(), &[1, 2]);
220    }
221}