risingwave_frontend/optimizer/plan_node/
logical_expand.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17
18use super::generic::GenericPlanRef;
19use super::utils::impl_distill_by_unit;
20use super::{
21    BatchExpand, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
22    PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamExpand, StreamPlanRef, ToBatch, ToStream,
23    gen_filter_and_pushdown, generic,
24};
25use crate::error::Result;
26use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
27use crate::optimizer::plan_node::{
28    ColumnPruningContext, LogicalProject, PredicatePushdownContext, RewriteStreamContext,
29    ToStreamContext,
30};
31use crate::utils::{ColIndexMapping, Condition};
32
33/// [`LogicalExpand`] expands one row multiple times according to `column_subsets` and also keeps
34/// original columns of input. It can be used to implement distinct aggregation and group set.
35///
36/// This is the schema of `LogicalExpand`:
37/// | expanded columns(i.e. some columns are set to null) | original columns of input | flag |.
38///
39/// Aggregates use expanded columns as their arguments and original columns for their filter. `flag`
40/// is used to distinguish between different `subset`s in `column_subsets`.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct LogicalExpand {
43    pub base: PlanBase<Logical>,
44    core: generic::Expand<PlanRef>,
45}
46
47impl LogicalExpand {
48    pub fn new(input: PlanRef, column_subsets: Vec<Vec<usize>>) -> Self {
49        for key in column_subsets.iter().flatten() {
50            assert!(*key < input.schema().len());
51        }
52
53        let core = generic::Expand {
54            column_subsets,
55            input,
56        };
57        let base = PlanBase::new_logical_with_core(&core);
58
59        LogicalExpand { base, core }
60    }
61
62    pub fn create(input: PlanRef, column_subsets: Vec<Vec<usize>>) -> PlanRef {
63        Self::new(input, column_subsets).into()
64    }
65
66    pub fn column_subsets(&self) -> &Vec<Vec<usize>> {
67        &self.core.column_subsets
68    }
69
70    pub fn decompose(self) -> (PlanRef, Vec<Vec<usize>>) {
71        self.core.decompose()
72    }
73}
74
75impl PlanTreeNodeUnary<Logical> for LogicalExpand {
76    fn input(&self) -> PlanRef {
77        self.core.input.clone()
78    }
79
80    fn clone_with_input(&self, input: PlanRef) -> Self {
81        Self::new(input, self.column_subsets().clone())
82    }
83
84    fn rewrite_with_input(
85        &self,
86        input: PlanRef,
87        input_col_change: ColIndexMapping,
88    ) -> (Self, ColIndexMapping) {
89        let column_subsets = self
90            .column_subsets()
91            .iter()
92            .map(|subset| {
93                subset
94                    .iter()
95                    .filter_map(|i| input_col_change.try_map(*i))
96                    .collect_vec()
97            })
98            .collect_vec();
99
100        let old_out_len = self.schema().len();
101        let old_in_len = self.input().schema().len();
102        let new_in_len = input.schema().len();
103        assert_eq!(
104            old_out_len,
105            old_in_len * 2 + 1 // expanded input cols + real input cols + flag
106        );
107
108        let mut mapping = Vec::with_capacity(old_out_len);
109        // map the expanded input columns
110        for i in 0..old_in_len {
111            mapping.push(input_col_change.try_map(i));
112        }
113        // map the real input columns
114        for i in 0..old_in_len {
115            mapping.push(
116                input_col_change
117                    .try_map(i)
118                    .map(|x| x + new_in_len /* # of new expanded input cols */),
119            );
120        }
121        // map the flag column
122        mapping.push(Some(2 * new_in_len));
123
124        let expand = Self::new(input, column_subsets);
125        let output_col_num = expand.schema().len();
126        (expand, ColIndexMapping::new(mapping, output_col_num))
127    }
128}
129
130impl_plan_tree_node_for_unary! { Logical, LogicalExpand}
131impl_distill_by_unit!(LogicalExpand, core, "LogicalExpand");
132
133impl ColPrunable for LogicalExpand {
134    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
135        let input_len = self.input().schema().len();
136        let mut input_required_cols = FixedBitSet::with_capacity(input_len);
137        for &idx in required_cols {
138            if idx < input_len {
139                input_required_cols.insert(idx);
140            } else if idx < input_len * 2 {
141                input_required_cols.insert(idx - input_len);
142            } else {
143                assert_eq!(idx, input_len * 2);
144            }
145        }
146        let input_required_cols = input_required_cols.ones().collect_vec();
147        let input_col_change =
148            ColIndexMapping::with_remaining_columns(&input_required_cols, input_len);
149        let input = self.input().prune_col(&input_required_cols, ctx);
150        let (expand, out_col_change) = self.rewrite_with_input(input, input_col_change);
151        let output_required_cols = required_cols
152            .iter()
153            .map(|idx| {
154                out_col_change
155                    .try_map(*idx)
156                    .expect("required column should be kept")
157            })
158            .collect_vec();
159        LogicalProject::with_out_col_idx(expand.into(), output_required_cols.into_iter()).into()
160    }
161}
162
163impl ExprRewritable<Logical> for LogicalExpand {}
164
165impl ExprVisitable for LogicalExpand {}
166
167impl PredicatePushdown for LogicalExpand {
168    fn predicate_pushdown(
169        &self,
170        predicate: Condition,
171        ctx: &mut PredicatePushdownContext,
172    ) -> PlanRef {
173        // No pushdown.
174        gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
175    }
176}
177
178impl ToBatch for LogicalExpand {
179    fn to_batch(&self) -> Result<BatchPlanRef> {
180        let new_input = self.input().to_batch()?;
181        let expand = self.core.clone_with_input(new_input);
182        Ok(BatchExpand::new(expand).into())
183    }
184}
185
186impl ToStream for LogicalExpand {
187    fn logical_rewrite_for_stream(
188        &self,
189        ctx: &mut RewriteStreamContext,
190    ) -> Result<(PlanRef, ColIndexMapping)> {
191        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
192        let (expand, out_col_change) = self.rewrite_with_input(input, input_col_change);
193        Ok((expand.into(), out_col_change))
194    }
195
196    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
197        let new_input = self.input().to_stream(ctx)?;
198        let expand = self.core.clone_with_input(new_input);
199        Ok(StreamExpand::new(expand).into())
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use itertools::Itertools;
206    use risingwave_common::catalog::{Field, Schema};
207    use risingwave_common::types::DataType;
208
209    use crate::optimizer::optimizer_context::OptimizerContext;
210    use crate::optimizer::plan_node::{LogicalExpand, LogicalValues};
211
212    #[tokio::test]
213    async fn fd_derivation_expand() {
214        // input: [v1, v2, v3]
215        // FD: v1 --> { v2, v3 }
216        // output: [v1_expanded, v2_expanded, v3_expanded, v1, v2, v3, flag].
217        // Input FDs only hold on the preserved original columns. They do not hold on expanded
218        // columns, because each expand lane can set columns outside its subset to NULL.
219        let ctx = OptimizerContext::mock();
220        let fields: Vec<Field> = vec![
221            Field::with_name(DataType::Int32, "v1"),
222            Field::with_name(DataType::Int32, "v2"),
223            Field::with_name(DataType::Int32, "v3"),
224        ];
225        let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
226        values
227            .base
228            .functional_dependency_mut()
229            .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
230
231        let column_subsets = vec![vec![0, 1], vec![2]];
232        let expand = LogicalExpand::create(values.into(), column_subsets);
233        let fd = expand.functional_dependency().as_dependencies();
234        assert_eq!(fd.len(), 1);
235        assert_eq!(fd[0].from().ones().collect_vec(), &[3]);
236        assert_eq!(fd[0].to().ones().collect_vec(), &[4, 5]);
237    }
238}