risingwave_frontend/optimizer/plan_node/
logical_project_set.rs1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::DataType;
18
19use super::utils::impl_distill_by_unit;
20use super::{
21 BatchProjectSet, ColPrunable, ExprRewritable, Logical, LogicalProject, PlanBase, PlanRef,
22 PlanTreeNodeUnary, PredicatePushdown, StreamProjectSet, ToBatch, ToStream,
23 gen_filter_and_pushdown, generic,
24};
25use crate::error::{ErrorCode, Result};
26use crate::expr::{
27 Expr, ExprImpl, ExprRewriter, ExprVisitor, FunctionCall, InputRef, TableFunction,
28 collect_input_refs,
29};
30use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
31use crate::optimizer::plan_node::generic::GenericPlanRef;
32use crate::optimizer::plan_node::{
33 ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
34};
35use crate::utils::{ColIndexMapping, Condition, Substitute};
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct LogicalProjectSet {
47 pub base: PlanBase<Logical>,
48 core: generic::ProjectSet<PlanRef>,
49}
50
51impl LogicalProjectSet {
52 pub fn new(input: PlanRef, select_list: Vec<ExprImpl>) -> Self {
53 assert!(
54 select_list.iter().any(|e| e.has_table_function()),
55 "ProjectSet should have at least one table function."
56 );
57
58 let core = generic::ProjectSet { select_list, input };
59 let base = PlanBase::new_logical_with_core(&core);
60
61 LogicalProjectSet { base, core }
62 }
63
64 pub fn create(input: PlanRef, select_list: Vec<ExprImpl>) -> PlanRef {
79 if select_list
80 .iter()
81 .all(|e: &ExprImpl| !e.has_table_function())
82 {
83 return LogicalProject::create(input, select_list);
84 }
85
86 struct Rewriter {
89 collected: Vec<TableFunction>,
90 level: usize,
94 input_schema_len: usize,
95 }
96
97 impl ExprRewriter for Rewriter {
98 fn rewrite_table_function(&mut self, table_func: TableFunction) -> ExprImpl {
99 if self.level == 0 {
100 self.level += 1;
102
103 let TableFunction {
104 args,
105 return_type,
106 function_type,
107 user_defined,
108 } = table_func;
109 let args = args
110 .into_iter()
111 .map(|expr| self.rewrite_expr(expr))
112 .collect();
113
114 self.level -= 1;
115 TableFunction {
116 args,
117 return_type,
118 function_type,
119 user_defined,
120 }
121 .into()
122 } else {
123 let input_ref = InputRef::new(
124 self.input_schema_len + self.collected.len(),
125 table_func.return_type(),
126 );
127 self.collected.push(table_func);
128 input_ref.into()
129 }
130 }
131
132 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
133 self.level += 1;
134 let (func_type, inputs, return_type) = func_call.decompose();
135 let inputs = inputs
136 .into_iter()
137 .map(|expr| self.rewrite_expr(expr))
138 .collect();
139 self.level -= 1;
140 FunctionCall::new_unchecked(func_type, inputs, return_type).into()
141 }
142 }
143
144 let mut rewriter = Rewriter {
145 collected: vec![],
146 level: 0,
147 input_schema_len: input.schema().len(),
148 };
149 let select_list: Vec<_> = select_list
150 .into_iter()
151 .map(|e| rewriter.rewrite_expr(e))
152 .collect();
153
154 if rewriter.collected.is_empty() {
155 LogicalProjectSet::new(input, select_list).into()
156 } else {
157 let mut inner_select_list: Vec<_> = input
158 .schema()
159 .data_types()
160 .into_iter()
161 .enumerate()
162 .map(|(i, ty)| InputRef::new(i, ty).into())
163 .collect();
164 inner_select_list.extend(rewriter.collected.into_iter().map(|tf| tf.into()));
165 let inner = LogicalProjectSet::create(input, inner_select_list);
166
167 struct IncInputRef {}
170 impl ExprRewriter for IncInputRef {
171 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
172 InputRef::new(input_ref.index + 1, input_ref.data_type).into()
173 }
174 }
175 let mut rewriter = IncInputRef {};
176 let select_list: Vec<_> = select_list
177 .into_iter()
178 .map(|e| rewriter.rewrite_expr(e))
179 .collect();
180
181 if select_list.iter().any(|e| e.has_table_function()) {
182 LogicalProjectSet::new(inner, select_list).into()
183 } else {
184 LogicalProject::new(inner, select_list).into()
185 }
186 }
187 }
188
189 pub fn select_list(&self) -> &Vec<ExprImpl> {
190 &self.core.select_list
191 }
192
193 pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
194 self.core.decompose()
195 }
196}
197
198impl PlanTreeNodeUnary for LogicalProjectSet {
199 fn input(&self) -> PlanRef {
200 self.core.input.clone()
201 }
202
203 fn clone_with_input(&self, input: PlanRef) -> Self {
204 Self::new(input, self.select_list().clone())
205 }
206
207 fn rewrite_with_input(
208 &self,
209 input: PlanRef,
210 mut input_col_change: ColIndexMapping,
211 ) -> (Self, ColIndexMapping) {
212 let select_list = self
213 .select_list()
214 .clone()
215 .into_iter()
216 .map(|item| input_col_change.rewrite_expr(item))
217 .collect();
218 let project_set = Self::new(input, select_list);
219 let out_col_change = ColIndexMapping::identity(self.schema().len());
221 (project_set, out_col_change)
222 }
223}
224
225impl_plan_tree_node_for_unary! {LogicalProjectSet}
226impl_distill_by_unit!(LogicalProjectSet, core, "LogicalProjectSet");
227impl ColPrunable for LogicalProjectSet {
230 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
231 let output_required_cols = required_cols;
232 let required_cols = {
233 let mut required_cols_set = FixedBitSet::from_iter(required_cols.iter().copied());
234 required_cols_set.grow(self.select_list().len() + 1);
235 let mut cols = required_cols.to_vec();
236 for (i, e) in self.select_list().iter().enumerate() {
239 if e.has_table_function() && !required_cols_set.contains(i + 1) {
240 cols.push(i + 1);
241 required_cols_set.set(i + 1, true);
242 }
243 }
244 cols
245 };
246
247 let input_col_num = self.input().schema().len();
248
249 let input_required_cols = collect_input_refs(
250 input_col_num,
251 required_cols
252 .iter()
253 .filter(|&&i| i > 0)
254 .map(|i| &self.select_list()[*i - 1]),
255 )
256 .ones()
257 .collect_vec();
258 let new_input = self.input().prune_col(&input_required_cols, ctx);
259 let mut mapping = ColIndexMapping::with_remaining_columns(
260 &input_required_cols,
261 self.input().schema().len(),
262 );
263 let select_list = required_cols
265 .iter()
266 .filter(|&&id| id > 0)
267 .map(|&id| mapping.rewrite_expr(self.select_list()[id - 1].clone()))
268 .collect();
269
270 let new_node: PlanRef = LogicalProjectSet::create(new_input, select_list);
272 if new_node.schema().len() == output_required_cols.len() {
273 new_node
275 } else {
276 let mut new_output_cols = required_cols.to_vec();
278 if !required_cols.contains(&0) {
279 new_output_cols.insert(0, 0);
280 }
281 let mapping =
282 &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
283 let output_required_cols = output_required_cols
284 .iter()
285 .map(|&idx| mapping.map(idx))
286 .collect_vec();
287 let src_size = new_node.schema().len();
288 LogicalProject::with_mapping(
289 new_node,
290 ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
291 )
292 .into()
293 }
294 }
295}
296
297impl ExprRewritable for LogicalProjectSet {
298 fn has_rewritable_expr(&self) -> bool {
299 true
300 }
301
302 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
303 let mut core = self.core.clone();
304 core.rewrite_exprs(r);
305 Self {
306 base: self.base.clone_with_new_plan_id(),
307 core,
308 }
309 .into()
310 }
311}
312
313impl ExprVisitable for LogicalProjectSet {
314 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
315 self.core.visit_exprs(v);
316 }
317}
318
319impl PredicatePushdown for LogicalProjectSet {
320 fn predicate_pushdown(
321 &self,
322 predicate: Condition,
323 ctx: &mut PredicatePushdownContext,
324 ) -> PlanRef {
325 let mut subst = Substitute {
327 mapping: {
328 let mut output_list = self.select_list().clone();
329 output_list.insert(
330 0,
331 ExprImpl::InputRef(Box::new(InputRef {
332 index: 0,
333 data_type: DataType::Int64,
334 })),
335 );
336 output_list
337 },
338 };
339
340 let remain_mask = {
341 let mut remain_mask = FixedBitSet::with_capacity(self.select_list().len() + 1);
342 remain_mask.set(0, true);
343 self.select_list()
344 .iter()
345 .enumerate()
346 .for_each(|(i, e)| remain_mask.set(i + 1, e.is_impure() || e.has_table_function()));
347 remain_mask
348 };
349 let (remained_cond, pushed_cond) = predicate.split_disjoint(&remain_mask);
350 let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
351
352 gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
353 }
354}
355
356impl ToBatch for LogicalProjectSet {
357 fn to_batch(&self) -> Result<PlanRef> {
358 let mut new_logical = self.core.clone();
359 new_logical.input = self.input().to_batch()?;
360 Ok(BatchProjectSet::new(new_logical).into())
361 }
362}
363
364impl ToStream for LogicalProjectSet {
365 fn logical_rewrite_for_stream(
366 &self,
367 ctx: &mut RewriteStreamContext,
368 ) -> Result<(PlanRef, ColIndexMapping)> {
369 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
370 let (project_set, out_col_change) =
371 self.rewrite_with_input(input.clone(), input_col_change);
372
373 let input_pk = input.expect_stream_key();
375 let i2o = self.core.i2o_col_mapping();
376 let col_need_to_add = input_pk
377 .iter()
378 .cloned()
379 .filter(|i| i2o.try_map(*i).is_none());
380 let input_schema = input.schema();
381 let select_list =
382 project_set
383 .select_list()
384 .iter()
385 .cloned()
386 .chain(col_need_to_add.map(|idx| {
387 InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
388 }))
389 .collect();
390 let project_set = Self::new(input, select_list);
391 let (map, _) = out_col_change.into_parts();
395 let out_col_change = ColIndexMapping::new(map, project_set.schema().len());
396 Ok((project_set.into(), out_col_change))
397 }
398
399 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
402 if self.select_list().iter().any(|item| item.has_now()) {
403 return Err(ErrorCode::NotSupported(
405 "General `now()` function in streaming queries".to_owned(),
406 "Streaming `now()` is currently only supported in GenerateSeries and TemporalFilter patterns.".to_owned(),
407 )
408 .into());
409 }
410
411 let new_input = self.input().to_stream(ctx)?;
412 let mut new_logical = self.core.clone();
413 new_logical.input = new_input;
414 Ok(StreamProjectSet::new(new_logical).into())
415 }
416}
417
418#[cfg(test)]
419mod test {
420 use std::collections::HashSet;
421
422 use risingwave_common::catalog::{Field, Schema};
423
424 use super::*;
425 use crate::optimizer::optimizer_context::OptimizerContext;
426 use crate::optimizer::plan_node::LogicalValues;
427 use crate::optimizer::property::FunctionalDependency;
428
429 #[tokio::test]
430 async fn fd_derivation_project_set() {
431 let ctx = OptimizerContext::mock().await;
437 let fields: Vec<Field> = vec![
438 Field::with_name(DataType::Int32, "v1"),
439 Field::with_name(DataType::Int32, "v2"),
440 Field::with_name(DataType::Int32, "v3"),
441 ];
442 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
443 values
444 .base
445 .functional_dependency_mut()
446 .add_functional_dependency_by_column_indices(&[1], &[2]);
447 let project_set = LogicalProjectSet::new(
448 values.into(),
449 vec![
450 ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
451 ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
452 ExprImpl::TableFunction(Box::new(
453 TableFunction::new(
454 crate::expr::TableFunctionType::GenerateSeries,
455 vec![
456 ExprImpl::InputRef(Box::new(InputRef::new(0, DataType::Int32))),
457 ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
458 ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
459 ],
460 )
461 .unwrap(),
462 )),
463 ],
464 );
465 let fd_set: HashSet<FunctionalDependency> = project_set
466 .base
467 .functional_dependency()
468 .as_dependencies()
469 .clone()
470 .into_iter()
471 .collect();
472 let expected_fd_set: HashSet<FunctionalDependency> =
473 [FunctionalDependency::with_indices(4, &[2], &[1])]
474 .into_iter()
475 .collect();
476 assert_eq!(fd_set, expected_fd_set);
477 }
478}