risingwave_frontend/optimizer/plan_node/
logical_project.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use pretty_xmlish::XmlNode;
20
21use super::generic::GenericPlanNode;
22use super::utils::{Distill, childless_record};
23use super::{
24    BatchPlanRef, BatchProject, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
25    LogicalPlanRef, PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamMaterializedExprs,
26    StreamPlanRef, StreamProject, ToBatch, ToStream, gen_filter_and_pushdown, generic,
27};
28use crate::error::Result;
29use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef, collect_input_refs};
30use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
31use crate::optimizer::plan_node::generic::GenericPlanRef;
32use crate::optimizer::plan_node::stream::StreamPlanNodeMetadata;
33use crate::optimizer::plan_node::{
34    ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
35};
36use crate::optimizer::property::{Distribution, Order, RequiredDist, StreamKind};
37use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, Substitute};
38
39/// `LogicalProject` computes a set of expressions from its input relation.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct LogicalProject {
42    pub base: PlanBase<Logical>,
43    core: generic::Project<PlanRef>,
44}
45
46impl LogicalProject {
47    pub fn create(input: PlanRef, exprs: Vec<ExprImpl>) -> PlanRef {
48        Self::new(input, exprs).into()
49    }
50
51    // TODO(kwannoel): We only need create/new don't keep both.
52    pub fn new(input: PlanRef, exprs: Vec<ExprImpl>) -> Self {
53        let core = generic::Project::new(exprs, input);
54        Self::with_core(core)
55    }
56
57    pub fn with_core(core: generic::Project<PlanRef>) -> Self {
58        let base = PlanBase::new_logical_with_core(&core);
59        LogicalProject { base, core }
60    }
61
62    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
63        self.core.o2i_col_mapping()
64    }
65
66    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
67        self.core.i2o_col_mapping()
68    }
69
70    /// Creates a `LogicalProject` which select some columns from the input.
71    ///
72    /// `mapping` should maps from `(0..input_fields.len())` to a consecutive range starting from 0.
73    ///
74    /// This is useful in column pruning when we want to add a project to ensure the output schema
75    /// is correct.
76    pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
77        Self::with_core(generic::Project::with_mapping(input, mapping))
78    }
79
80    /// Creates a `LogicalProject` which select some columns from the input.
81    pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
82        Self::with_core(generic::Project::with_out_fields(input, out_fields))
83    }
84
85    /// Creates a `LogicalProject` which select some columns from the input.
86    pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
87        Self::with_core(generic::Project::with_out_col_idx(input, out_fields))
88    }
89
90    pub fn exprs(&self) -> &Vec<ExprImpl> {
91        &self.core.exprs
92    }
93
94    pub fn is_identity(&self) -> bool {
95        self.core.is_identity()
96    }
97
98    pub fn try_as_projection(&self) -> Option<Vec<usize>> {
99        self.core.try_as_projection()
100    }
101
102    pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
103        self.core.decompose()
104    }
105
106    pub fn is_all_inputref(&self) -> bool {
107        self.core.is_all_inputref()
108    }
109}
110
111impl PlanTreeNodeUnary<Logical> for LogicalProject {
112    fn input(&self) -> LogicalPlanRef {
113        self.core.input.clone()
114    }
115
116    fn clone_with_input(&self, input: LogicalPlanRef) -> Self {
117        Self::new(input, self.exprs().clone())
118    }
119
120    fn rewrite_with_input(
121        &self,
122        input: PlanRef,
123        mut input_col_change: ColIndexMapping,
124    ) -> (Self, ColIndexMapping) {
125        let exprs = self
126            .exprs()
127            .clone()
128            .into_iter()
129            .map(|expr| input_col_change.rewrite_expr(expr))
130            .collect();
131        let proj = Self::new(input, exprs);
132        // change the input columns index will not change the output column index
133        let out_col_change = ColIndexMapping::identity(self.schema().len());
134        (proj, out_col_change)
135    }
136}
137
138impl_plan_tree_node_for_unary! { Logical, LogicalProject}
139
140impl Distill for LogicalProject {
141    fn distill<'a>(&self) -> XmlNode<'a> {
142        childless_record(
143            "LogicalProject",
144            self.core.fields_pretty(self.base.schema()),
145        )
146    }
147}
148
149impl ColPrunable for LogicalProject {
150    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
151        let input_col_num: usize = self.input().schema().len();
152        let input_required_cols = collect_input_refs(
153            input_col_num,
154            required_cols.iter().map(|i| &self.exprs()[*i]),
155        )
156        .ones()
157        .collect_vec();
158        let new_input = self.input().prune_col(&input_required_cols, ctx);
159        let mut mapping = ColIndexMapping::with_remaining_columns(
160            &input_required_cols,
161            self.input().schema().len(),
162        );
163        // Rewrite each InputRef with new index.
164        let exprs = required_cols
165            .iter()
166            .map(|&id| mapping.rewrite_expr(self.exprs()[id].clone()))
167            .collect();
168
169        // Reconstruct the LogicalProject.
170        LogicalProject::new(new_input, exprs).into()
171    }
172}
173
174impl ExprRewritable<Logical> for LogicalProject {
175    fn has_rewritable_expr(&self) -> bool {
176        true
177    }
178
179    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
180        let mut core = self.core.clone();
181        core.rewrite_exprs(r);
182        Self {
183            base: self.base.clone_with_new_plan_id(),
184            core,
185        }
186        .into()
187    }
188}
189
190impl ExprVisitable for LogicalProject {
191    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
192        self.core.visit_exprs(v);
193    }
194}
195
196impl PredicatePushdown for LogicalProject {
197    fn predicate_pushdown(
198        &self,
199        predicate: Condition,
200        ctx: &mut PredicatePushdownContext,
201    ) -> PlanRef {
202        // convert the predicate to one that references the child of the project
203        let mut subst = Substitute {
204            mapping: self.exprs().clone(),
205        };
206
207        let impure_mask = {
208            let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
209            for (i, e) in self.exprs().iter().enumerate() {
210                impure_mask.set(i, e.is_impure())
211            }
212            impure_mask
213        };
214        // (with impure input, with pure input)
215        let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
216        let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
217
218        gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
219    }
220}
221
222impl ToBatch for LogicalProject {
223    fn to_batch(&self) -> Result<BatchPlanRef> {
224        self.to_batch_with_order_required(&Order::any())
225    }
226
227    fn to_batch_with_order_required(&self, required_order: &Order) -> Result<BatchPlanRef> {
228        let input_order = self
229            .o2i_col_mapping()
230            .rewrite_provided_order(required_order);
231        let new_input = self.input().to_batch_with_order_required(&input_order)?;
232        let project = self.core.clone_with_input(new_input);
233        let batch_project = BatchProject::new(project);
234        required_order.enforce_if_not_satisfies(batch_project.into())
235    }
236}
237
238impl ToStream for LogicalProject {
239    fn to_stream_with_dist_required(
240        &self,
241        required_dist: &RequiredDist,
242        ctx: &mut ToStreamContext,
243    ) -> Result<StreamPlanRef> {
244        let input_required = if required_dist.satisfies(&RequiredDist::AnyShard) {
245            RequiredDist::Any
246        } else {
247            let input_required = self
248                .o2i_col_mapping()
249                .rewrite_required_distribution(required_dist);
250            match input_required {
251                RequiredDist::PhysicalDist(dist) => match dist {
252                    Distribution::Single => RequiredDist::Any,
253                    _ => RequiredDist::PhysicalDist(dist),
254                },
255                _ => input_required,
256            }
257        };
258        let new_input = self
259            .input()
260            .to_stream_with_dist_required(&input_required, ctx)?;
261
262        let enable_materialized_exprs = self
263            .core
264            .ctx()
265            .session_ctx()
266            .config()
267            .streaming_enable_materialized_expressions();
268
269        let should_materialize_expr = match new_input.stream_kind() {
270            StreamKind::AppendOnly => None,
271            kind @ (StreamKind::Retract | StreamKind::Upsert) => {
272                if enable_materialized_exprs {
273                    // Extract impure functions to `MaterializedExprs` operator
274                    let mut impure_field_names = BTreeMap::new();
275                    let mut impure_expr_indices = HashSet::new();
276                    let impure_exprs: Vec<_> = self
277                        .exprs()
278                        .iter()
279                        .enumerate()
280                        .filter_map(|(idx, expr)| {
281                            // Extract impure expressions
282                            if expr.is_impure() {
283                                impure_expr_indices.insert(idx);
284                                if let Some(name) = self.core.field_names.get(&idx) {
285                                    impure_field_names.insert(idx, name.clone());
286                                }
287                                Some(expr.clone())
288                            } else {
289                                None
290                            }
291                        })
292                        .collect();
293                    if impure_exprs.is_empty() {
294                        None
295                    } else if kind == StreamKind::Upsert
296                        && new_input
297                            .stream_key()
298                            .into_iter()
299                            .flatten()
300                            .all(|stream_key_idx| !impure_expr_indices.contains(stream_key_idx))
301                    {
302                        // We're operating on non-stream-key columns of upsert stream, no need to materialize.
303                        None
304                    } else {
305                        Some((impure_field_names, impure_expr_indices, impure_exprs))
306                    }
307                } else {
308                    None
309                }
310            }
311        };
312
313        let stream_plan = if let Some((impure_field_names, impure_expr_indices, impure_exprs)) =
314            should_materialize_expr
315        {
316            {
317                let new_input = new_input.enforce_concrete_distribution();
318
319                // Create `MaterializedExprs` for impure expressions
320                let mat_exprs_plan: StreamPlanRef = StreamMaterializedExprs::new(
321                    new_input.clone(),
322                    impure_exprs,
323                    impure_field_names,
324                )?
325                .into();
326
327                let input_len = new_input.schema().len();
328                let mut materialized_pos = 0;
329
330                // Create final expressions list with impure expressions replaced by `InputRef`s
331                let final_exprs = self
332                    .exprs()
333                    .iter()
334                    .enumerate()
335                    .map(|(idx, expr)| {
336                        if impure_expr_indices.contains(&idx) {
337                            let output_idx = input_len + materialized_pos;
338                            materialized_pos += 1;
339                            InputRef::new(output_idx, expr.return_type()).into()
340                        } else {
341                            expr.clone()
342                        }
343                    })
344                    .collect();
345
346                let core = generic::Project::new(final_exprs, mat_exprs_plan);
347                StreamProject::new(core).into()
348            }
349        } else {
350            // No expressions to materialize or the feature is not enabled, create a regular `StreamProject`
351            let core = generic::Project::new(self.exprs().clone(), new_input);
352            StreamProject::new(core).into()
353        };
354
355        required_dist.streaming_enforce_if_not_satisfies(stream_plan)
356    }
357
358    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
359        self.to_stream_with_dist_required(&RequiredDist::Any, ctx)
360    }
361
362    fn logical_rewrite_for_stream(
363        &self,
364        ctx: &mut RewriteStreamContext,
365    ) -> Result<(PlanRef, ColIndexMapping)> {
366        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
367        let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change);
368
369        // Add missing columns of `input_pk` into the select list.
370        let input_pk = input.expect_stream_key();
371        let i2o = proj.i2o_col_mapping();
372        let col_need_to_add = input_pk
373            .iter()
374            .cloned()
375            .filter(|i| i2o.try_map(*i).is_none());
376        let input_schema = input.schema();
377        let exprs =
378            proj.exprs()
379                .iter()
380                .cloned()
381                .chain(col_need_to_add.map(|idx| {
382                    InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
383                }))
384                .collect();
385        let proj = Self::new(input, exprs);
386        // The added columns is at the end, so it will not change existing column indices.
387        // But the target size of `out_col_change` should be the same as the length of the new
388        // schema.
389        let (map, _) = out_col_change.into_parts();
390        let out_col_change = ColIndexMapping::new(map, proj.base.schema().len());
391        Ok((proj.into(), out_col_change))
392    }
393
394    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
395        if columns.is_empty() {
396            return None;
397        }
398
399        let input_columns = columns
400            .iter()
401            .map(|&col| {
402                // First try the original o2i mapping for direct InputRef
403                if let Some(input_col) = self.o2i_col_mapping().try_map(col) {
404                    return Some(input_col);
405                }
406
407                // If not a direct InputRef, check if it's a pure function with single InputRef
408                let expr = &self.exprs()[col];
409                if expr.is_pure() {
410                    let input_refs = expr.collect_input_refs(self.input().schema().len());
411                    // Check if expression references exactly one input column
412                    if input_refs.count_ones(..) == 1 {
413                        return input_refs.ones().next();
414                    }
415                }
416
417                None
418            })
419            .collect::<Option<Vec<usize>>>()?;
420
421        let new_input = self.input().try_better_locality(&input_columns)?;
422        Some(self.clone_with_input(new_input).into())
423    }
424}
425
426#[cfg(test)]
427mod tests {
428
429    use risingwave_common::catalog::{Field, Schema};
430    use risingwave_common::types::DataType;
431    use risingwave_pb::expr::expr_node::Type;
432
433    use super::*;
434    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
435    use crate::optimizer::optimizer_context::OptimizerContext;
436    use crate::optimizer::plan_node::LogicalValues;
437
438    #[tokio::test]
439    /// Pruning
440    /// ```text
441    /// Project(1, input_ref(2), input_ref(0)<5)
442    ///   TableScan(v1, v2, v3)
443    /// ```
444    /// with required columns `[1, 2]` will result in
445    /// ```text
446    /// Project(input_ref(1), input_ref(0)<5)
447    ///   TableScan(v1, v3)
448    /// ```
449    async fn test_prune_project() {
450        let ty = DataType::Int32;
451        let ctx = OptimizerContext::mock().await;
452        let fields: Vec<Field> = vec![
453            Field::with_name(ty.clone(), "v1"),
454            Field::with_name(ty.clone(), "v2"),
455            Field::with_name(ty.clone(), "v3"),
456        ];
457        let values = LogicalValues::new(
458            vec![],
459            Schema {
460                fields: fields.clone(),
461            },
462            ctx,
463        );
464        let project: PlanRef = LogicalProject::new(
465            values.into(),
466            vec![
467                ExprImpl::Literal(Box::new(Literal::new(None, ty.clone()))),
468                InputRef::new(2, ty.clone()).into(),
469                ExprImpl::FunctionCall(Box::new(
470                    FunctionCall::new(
471                        Type::LessThan,
472                        vec![
473                            ExprImpl::InputRef(Box::new(InputRef::new(0, ty.clone()))),
474                            ExprImpl::Literal(Box::new(Literal::new(None, ty))),
475                        ],
476                    )
477                    .unwrap(),
478                )),
479            ],
480        )
481        .into();
482
483        // Perform the prune
484        let required_cols = vec![1, 2];
485        let plan = project.prune_col(
486            &required_cols,
487            &mut ColumnPruningContext::new(project.clone()),
488        );
489
490        // Check the result
491        let project = plan.as_logical_project().unwrap();
492        assert_eq!(project.exprs().len(), 2);
493        assert_eq_input_ref!(&project.exprs()[0], 1);
494
495        let expr = project.exprs()[1].clone();
496        let call = expr.as_function_call().unwrap();
497        assert_eq_input_ref!(&call.inputs()[0], 0);
498
499        let values = project.input();
500        let values = values.as_logical_values().unwrap();
501        assert_eq!(values.schema().fields().len(), 2);
502        assert_eq!(values.schema().fields()[0], fields[0]);
503        assert_eq!(values.schema().fields()[1], fields[2]);
504    }
505}