risingwave_frontend/optimizer/plan_node/
logical_project.rs1use std::collections::{BTreeMap, HashSet};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use pretty_xmlish::XmlNode;
20
21use super::utils::{Distill, childless_record};
22use super::{
23 BatchPlanRef, BatchProject, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
24 LogicalPlanRef, PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamMaterializedExprs,
25 StreamPlanRef, StreamProject, ToBatch, ToStream, gen_filter_and_pushdown, generic,
26};
27use crate::error::Result;
28use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef, collect_input_refs};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::GenericPlanRef;
31use crate::optimizer::plan_node::stream::StreamPlanNodeMetadata;
32use crate::optimizer::plan_node::{
33 ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
34};
35use crate::optimizer::property::{Distribution, Order, RequiredDist, StreamKind};
36use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, Substitute};
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct LogicalProject {
41 pub base: PlanBase<Logical>,
42 core: generic::Project<PlanRef>,
43}
44
45impl LogicalProject {
46 pub fn create(input: PlanRef, exprs: Vec<ExprImpl>) -> PlanRef {
47 Self::new(input, exprs).into()
48 }
49
50 pub fn new(input: PlanRef, exprs: Vec<ExprImpl>) -> Self {
52 let core = generic::Project::new(exprs, input);
53 Self::with_core(core)
54 }
55
56 pub fn with_core(core: generic::Project<PlanRef>) -> Self {
57 let base = PlanBase::new_logical_with_core(&core);
58 LogicalProject { base, core }
59 }
60
61 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
62 self.core.o2i_col_mapping()
63 }
64
65 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
66 self.core.i2o_col_mapping()
67 }
68
69 pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
76 Self::with_core(generic::Project::with_mapping(input, mapping))
77 }
78
79 pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
81 Self::with_core(generic::Project::with_out_fields(input, out_fields))
82 }
83
84 pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
86 Self::with_core(generic::Project::with_out_col_idx(input, out_fields))
87 }
88
89 pub fn exprs(&self) -> &Vec<ExprImpl> {
90 &self.core.exprs
91 }
92
93 pub fn is_identity(&self) -> bool {
94 self.core.is_identity()
95 }
96
97 pub fn try_as_projection(&self) -> Option<Vec<usize>> {
98 self.core.try_as_projection()
99 }
100
101 pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
102 self.core.decompose()
103 }
104
105 pub fn is_all_inputref(&self) -> bool {
106 self.core.is_all_inputref()
107 }
108}
109
110impl PlanTreeNodeUnary<Logical> for LogicalProject {
111 fn input(&self) -> LogicalPlanRef {
112 self.core.input.clone()
113 }
114
115 fn clone_with_input(&self, input: LogicalPlanRef) -> Self {
116 Self::new(input, self.exprs().clone())
117 }
118
119 fn rewrite_with_input(
120 &self,
121 input: PlanRef,
122 mut input_col_change: ColIndexMapping,
123 ) -> (Self, ColIndexMapping) {
124 let exprs = self
125 .exprs()
126 .clone()
127 .into_iter()
128 .map(|expr| input_col_change.rewrite_expr(expr))
129 .collect();
130 let proj = Self::new(input, exprs);
131 let out_col_change = ColIndexMapping::identity(self.schema().len());
133 (proj, out_col_change)
134 }
135}
136
137impl_plan_tree_node_for_unary! { Logical, LogicalProject}
138
139impl Distill for LogicalProject {
140 fn distill<'a>(&self) -> XmlNode<'a> {
141 childless_record(
142 "LogicalProject",
143 self.core.fields_pretty(self.base.schema()),
144 )
145 }
146}
147
148impl ColPrunable for LogicalProject {
149 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
150 let input_col_num: usize = self.input().schema().len();
151 let input_required_cols = collect_input_refs(
152 input_col_num,
153 required_cols.iter().map(|i| &self.exprs()[*i]),
154 )
155 .ones()
156 .collect_vec();
157 let new_input = self.input().prune_col(&input_required_cols, ctx);
158 let mut mapping = ColIndexMapping::with_remaining_columns(
159 &input_required_cols,
160 self.input().schema().len(),
161 );
162 let exprs = required_cols
164 .iter()
165 .map(|&id| mapping.rewrite_expr(self.exprs()[id].clone()))
166 .collect();
167
168 LogicalProject::new(new_input, exprs).into()
170 }
171}
172
173impl ExprRewritable<Logical> for LogicalProject {
174 fn has_rewritable_expr(&self) -> bool {
175 true
176 }
177
178 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
179 let mut core = self.core.clone();
180 core.rewrite_exprs(r);
181 Self {
182 base: self.base.clone_with_new_plan_id(),
183 core,
184 }
185 .into()
186 }
187}
188
189impl ExprVisitable for LogicalProject {
190 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
191 self.core.visit_exprs(v);
192 }
193}
194
195impl PredicatePushdown for LogicalProject {
196 fn predicate_pushdown(
197 &self,
198 predicate: Condition,
199 ctx: &mut PredicatePushdownContext,
200 ) -> PlanRef {
201 let mut subst = Substitute {
203 mapping: self.exprs().clone(),
204 };
205
206 let impure_mask = {
207 let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
208 for (i, e) in self.exprs().iter().enumerate() {
209 impure_mask.set(i, e.is_impure())
210 }
211 impure_mask
212 };
213 let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
215 let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
216
217 gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
218 }
219}
220
221impl ToBatch for LogicalProject {
222 fn to_batch(&self) -> Result<BatchPlanRef> {
223 self.to_batch_with_order_required(&Order::any())
224 }
225
226 fn to_batch_with_order_required(&self, required_order: &Order) -> Result<BatchPlanRef> {
227 let input_order = self
228 .o2i_col_mapping()
229 .rewrite_provided_order(required_order);
230 let new_input = self.input().to_batch_with_order_required(&input_order)?;
231 let project = self.core.clone_with_input(new_input);
232 let batch_project = BatchProject::new(project);
233 required_order.enforce_if_not_satisfies(batch_project.into())
234 }
235}
236
237impl ToStream for LogicalProject {
238 fn to_stream_with_dist_required(
239 &self,
240 required_dist: &RequiredDist,
241 ctx: &mut ToStreamContext,
242 ) -> Result<StreamPlanRef> {
243 let input_required = if required_dist.satisfies(&RequiredDist::AnyShard) {
244 RequiredDist::Any
245 } else {
246 let input_required = self
247 .o2i_col_mapping()
248 .rewrite_required_distribution(required_dist);
249 match input_required {
250 RequiredDist::PhysicalDist(dist) => match dist {
251 Distribution::Single => RequiredDist::Any,
252 _ => RequiredDist::PhysicalDist(dist),
253 },
254 _ => input_required,
255 }
256 };
257 let new_input = self
258 .input()
259 .to_stream_with_dist_required(&input_required, ctx)?;
260
261 let should_materialize_expr = match new_input.stream_kind() {
262 StreamKind::AppendOnly => None,
263 kind @ (StreamKind::Retract | StreamKind::Upsert) => {
264 let mut impure_field_names = BTreeMap::new();
266 let mut impure_expr_indices = HashSet::new();
267 let impure_exprs: Vec<_> = self
268 .exprs()
269 .iter()
270 .enumerate()
271 .filter_map(|(idx, expr)| {
272 if expr.is_impure() {
274 impure_expr_indices.insert(idx);
275 if let Some(name) = self.core.field_names.get(&idx) {
276 impure_field_names.insert(idx, name.clone());
277 }
278 Some(expr.clone())
279 } else {
280 None
281 }
282 })
283 .collect();
284 if impure_exprs.is_empty() {
285 None
286 } else if kind == StreamKind::Upsert
287 && new_input
288 .stream_key()
289 .into_iter()
290 .flatten()
291 .all(|stream_key_idx| !impure_expr_indices.contains(stream_key_idx))
292 {
293 None
295 } else {
296 Some((impure_field_names, impure_expr_indices, impure_exprs))
297 }
298 }
299 };
300
301 let stream_plan = if let Some((impure_field_names, impure_expr_indices, impure_exprs)) =
302 should_materialize_expr
303 {
304 {
305 let new_input = new_input.enforce_concrete_distribution();
306
307 let mat_exprs_plan: StreamPlanRef = StreamMaterializedExprs::new(
309 new_input.clone(),
310 impure_exprs,
311 impure_field_names,
312 )?
313 .into();
314
315 let input_len = new_input.schema().len();
316 let mut materialized_pos = 0;
317
318 let final_exprs = self
320 .exprs()
321 .iter()
322 .enumerate()
323 .map(|(idx, expr)| {
324 if impure_expr_indices.contains(&idx) {
325 let output_idx = input_len + materialized_pos;
326 materialized_pos += 1;
327 InputRef::new(output_idx, expr.return_type()).into()
328 } else {
329 expr.clone()
330 }
331 })
332 .collect();
333
334 let core = generic::Project::new(final_exprs, mat_exprs_plan);
335 StreamProject::new(core).into()
336 }
337 } else {
338 let core = generic::Project::new(self.exprs().clone(), new_input);
340 StreamProject::new(core).into()
341 };
342
343 required_dist.streaming_enforce_if_not_satisfies(stream_plan)
344 }
345
346 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
347 self.to_stream_with_dist_required(&RequiredDist::Any, ctx)
348 }
349
350 fn logical_rewrite_for_stream(
351 &self,
352 ctx: &mut RewriteStreamContext,
353 ) -> Result<(PlanRef, ColIndexMapping)> {
354 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
355 let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change);
356
357 let input_pk = input.expect_stream_key();
359 let i2o = proj.i2o_col_mapping();
360 let col_need_to_add = input_pk
361 .iter()
362 .cloned()
363 .filter(|i| i2o.try_map(*i).is_none());
364 let input_schema = input.schema();
365 let exprs =
366 proj.exprs()
367 .iter()
368 .cloned()
369 .chain(col_need_to_add.map(|idx| {
370 InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
371 }))
372 .collect();
373 let proj = Self::new(input, exprs);
374 let (map, _) = out_col_change.into_parts();
378 let out_col_change = ColIndexMapping::new(map, proj.base.schema().len());
379 Ok((proj.into(), out_col_change))
380 }
381
382 fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
383 if columns.is_empty() {
384 return None;
385 }
386
387 let input_columns = columns
388 .iter()
389 .map(|&col| {
390 if let Some(input_col) = self.o2i_col_mapping().try_map(col) {
392 return Some(input_col);
393 }
394
395 let expr = &self.exprs()[col];
397 if expr.is_pure() {
398 let input_refs = expr.collect_input_refs(self.input().schema().len());
399 if input_refs.count_ones(..) == 1 {
401 return input_refs.ones().next();
402 }
403 }
404
405 None
406 })
407 .collect::<Option<Vec<usize>>>()?;
408
409 let new_input = self.input().try_better_locality(&input_columns)?;
410 Some(self.clone_with_input(new_input).into())
411 }
412}
413
414#[cfg(test)]
415mod tests {
416
417 use risingwave_common::catalog::{Field, Schema};
418 use risingwave_common::types::DataType;
419 use risingwave_pb::expr::expr_node::Type;
420
421 use super::*;
422 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
423 use crate::optimizer::optimizer_context::OptimizerContext;
424 use crate::optimizer::plan_node::LogicalValues;
425
426 #[tokio::test]
427 async fn test_prune_project() {
438 let ty = DataType::Int32;
439 let ctx = OptimizerContext::mock().await;
440 let fields: Vec<Field> = vec![
441 Field::with_name(ty.clone(), "v1"),
442 Field::with_name(ty.clone(), "v2"),
443 Field::with_name(ty.clone(), "v3"),
444 ];
445 let values = LogicalValues::new(
446 vec![],
447 Schema {
448 fields: fields.clone(),
449 },
450 ctx,
451 );
452 let project: PlanRef = LogicalProject::new(
453 values.into(),
454 vec![
455 ExprImpl::Literal(Box::new(Literal::new(None, ty.clone()))),
456 InputRef::new(2, ty.clone()).into(),
457 ExprImpl::FunctionCall(Box::new(
458 FunctionCall::new(
459 Type::LessThan,
460 vec![
461 ExprImpl::InputRef(Box::new(InputRef::new(0, ty.clone()))),
462 ExprImpl::Literal(Box::new(Literal::new(None, ty))),
463 ],
464 )
465 .unwrap(),
466 )),
467 ],
468 )
469 .into();
470
471 let required_cols = vec![1, 2];
473 let plan = project.prune_col(
474 &required_cols,
475 &mut ColumnPruningContext::new(project.clone()),
476 );
477
478 let project = plan.as_logical_project().unwrap();
480 assert_eq!(project.exprs().len(), 2);
481 assert_eq_input_ref!(&project.exprs()[0], 1);
482
483 let expr = project.exprs()[1].clone();
484 let call = expr.as_function_call().unwrap();
485 assert_eq_input_ref!(&call.inputs()[0], 0);
486
487 let values = project.input();
488 let values = values.as_logical_values().unwrap();
489 assert_eq!(values.schema().fields().len(), 2);
490 assert_eq!(values.schema().fields()[0], fields[0]);
491 assert_eq!(values.schema().fields()[1], fields[2]);
492 }
493}