risingwave_frontend/optimizer/plan_node/
logical_expand.rs1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17
18use super::generic::GenericPlanRef;
19use super::utils::impl_distill_by_unit;
20use super::{
21 BatchExpand, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
22 PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamExpand, StreamPlanRef, ToBatch, ToStream,
23 gen_filter_and_pushdown, generic,
24};
25use crate::error::Result;
26use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
27use crate::optimizer::plan_node::{
28 ColumnPruningContext, LogicalProject, PredicatePushdownContext, RewriteStreamContext,
29 ToStreamContext,
30};
31use crate::utils::{ColIndexMapping, Condition};
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct LogicalExpand {
43 pub base: PlanBase<Logical>,
44 core: generic::Expand<PlanRef>,
45}
46
47impl LogicalExpand {
48 pub fn new(input: PlanRef, column_subsets: Vec<Vec<usize>>) -> Self {
49 for key in column_subsets.iter().flatten() {
50 assert!(*key < input.schema().len());
51 }
52
53 let core = generic::Expand {
54 column_subsets,
55 input,
56 };
57 let base = PlanBase::new_logical_with_core(&core);
58
59 LogicalExpand { base, core }
60 }
61
62 pub fn create(input: PlanRef, column_subsets: Vec<Vec<usize>>) -> PlanRef {
63 Self::new(input, column_subsets).into()
64 }
65
66 pub fn column_subsets(&self) -> &Vec<Vec<usize>> {
67 &self.core.column_subsets
68 }
69
70 pub fn decompose(self) -> (PlanRef, Vec<Vec<usize>>) {
71 self.core.decompose()
72 }
73}
74
75impl PlanTreeNodeUnary<Logical> for LogicalExpand {
76 fn input(&self) -> PlanRef {
77 self.core.input.clone()
78 }
79
80 fn clone_with_input(&self, input: PlanRef) -> Self {
81 Self::new(input, self.column_subsets().clone())
82 }
83
84 fn rewrite_with_input(
85 &self,
86 input: PlanRef,
87 input_col_change: ColIndexMapping,
88 ) -> (Self, ColIndexMapping) {
89 let column_subsets = self
90 .column_subsets()
91 .iter()
92 .map(|subset| {
93 subset
94 .iter()
95 .filter_map(|i| input_col_change.try_map(*i))
96 .collect_vec()
97 })
98 .collect_vec();
99
100 let old_out_len = self.schema().len();
101 let old_in_len = self.input().schema().len();
102 let new_in_len = input.schema().len();
103 assert_eq!(
104 old_out_len,
105 old_in_len * 2 + 1 );
107
108 let mut mapping = Vec::with_capacity(old_out_len);
109 for i in 0..old_in_len {
111 mapping.push(input_col_change.try_map(i));
112 }
113 for i in 0..old_in_len {
115 mapping.push(
116 input_col_change
117 .try_map(i)
118 .map(|x| x + new_in_len ),
119 );
120 }
121 mapping.push(Some(2 * new_in_len));
123
124 let expand = Self::new(input, column_subsets);
125 let output_col_num = expand.schema().len();
126 (expand, ColIndexMapping::new(mapping, output_col_num))
127 }
128}
129
130impl_plan_tree_node_for_unary! { Logical, LogicalExpand}
131impl_distill_by_unit!(LogicalExpand, core, "LogicalExpand");
132
133impl ColPrunable for LogicalExpand {
134 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
135 let input_len = self.input().schema().len();
136 let mut input_required_cols = FixedBitSet::with_capacity(input_len);
137 for &idx in required_cols {
138 if idx < input_len {
139 input_required_cols.insert(idx);
140 } else if idx < input_len * 2 {
141 input_required_cols.insert(idx - input_len);
142 } else {
143 assert_eq!(idx, input_len * 2);
144 }
145 }
146 let input_required_cols = input_required_cols.ones().collect_vec();
147 let input_col_change =
148 ColIndexMapping::with_remaining_columns(&input_required_cols, input_len);
149 let input = self.input().prune_col(&input_required_cols, ctx);
150 let (expand, out_col_change) = self.rewrite_with_input(input, input_col_change);
151 let output_required_cols = required_cols
152 .iter()
153 .map(|idx| {
154 out_col_change
155 .try_map(*idx)
156 .expect("required column should be kept")
157 })
158 .collect_vec();
159 LogicalProject::with_out_col_idx(expand.into(), output_required_cols.into_iter()).into()
160 }
161}
162
163impl ExprRewritable<Logical> for LogicalExpand {}
164
165impl ExprVisitable for LogicalExpand {}
166
167impl PredicatePushdown for LogicalExpand {
168 fn predicate_pushdown(
169 &self,
170 predicate: Condition,
171 ctx: &mut PredicatePushdownContext,
172 ) -> PlanRef {
173 gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
175 }
176}
177
178impl ToBatch for LogicalExpand {
179 fn to_batch(&self) -> Result<BatchPlanRef> {
180 let new_input = self.input().to_batch()?;
181 let expand = self.core.clone_with_input(new_input);
182 Ok(BatchExpand::new(expand).into())
183 }
184}
185
186impl ToStream for LogicalExpand {
187 fn logical_rewrite_for_stream(
188 &self,
189 ctx: &mut RewriteStreamContext,
190 ) -> Result<(PlanRef, ColIndexMapping)> {
191 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
192 let (expand, out_col_change) = self.rewrite_with_input(input, input_col_change);
193 Ok((expand.into(), out_col_change))
194 }
195
196 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
197 let new_input = self.input().to_stream(ctx)?;
198 let expand = self.core.clone_with_input(new_input);
199 Ok(StreamExpand::new(expand).into())
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use itertools::Itertools;
206 use risingwave_common::catalog::{Field, Schema};
207 use risingwave_common::types::DataType;
208
209 use crate::optimizer::optimizer_context::OptimizerContext;
210 use crate::optimizer::plan_node::{LogicalExpand, LogicalValues};
211
212 #[tokio::test]
213 async fn fd_derivation_expand() {
214 let ctx = OptimizerContext::mock();
220 let fields: Vec<Field> = vec![
221 Field::with_name(DataType::Int32, "v1"),
222 Field::with_name(DataType::Int32, "v2"),
223 Field::with_name(DataType::Int32, "v3"),
224 ];
225 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
226 values
227 .base
228 .functional_dependency_mut()
229 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
230
231 let column_subsets = vec![vec![0, 1], vec![2]];
232 let expand = LogicalExpand::create(values.into(), column_subsets);
233 let fd = expand.functional_dependency().as_dependencies();
234 assert_eq!(fd.len(), 1);
235 assert_eq!(fd[0].from().ones().collect_vec(), &[3]);
236 assert_eq!(fd[0].to().ones().collect_vec(), &[4, 5]);
237 }
238}