risingwave_frontend/optimizer/plan_node/
logical_project.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use pretty_xmlish::XmlNode;
20
21use super::generic::GenericPlanNode;
22use super::utils::{Distill, childless_record};
23use super::{
24    BatchProject, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary,
25    PredicatePushdown, StreamMaterializedExprs, StreamProject, ToBatch, ToStream,
26    gen_filter_and_pushdown, generic,
27};
28use crate::error::Result;
29use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef, collect_input_refs};
30use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
31use crate::optimizer::plan_node::generic::GenericPlanRef;
32use crate::optimizer::plan_node::{
33    ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
34};
35use crate::optimizer::property::{Distribution, Order, RequiredDist};
36use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, Substitute};
37
38/// `LogicalProject` computes a set of expressions from its input relation.
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct LogicalProject {
41    pub base: PlanBase<Logical>,
42    core: generic::Project<PlanRef>,
43}
44
45impl LogicalProject {
46    pub fn create(input: PlanRef, exprs: Vec<ExprImpl>) -> PlanRef {
47        Self::new(input, exprs).into()
48    }
49
50    pub fn new(input: PlanRef, exprs: Vec<ExprImpl>) -> Self {
51        let core = generic::Project::new(exprs, input);
52        Self::with_core(core)
53    }
54
55    pub fn with_core(core: generic::Project<PlanRef>) -> Self {
56        let base = PlanBase::new_logical_with_core(&core);
57        LogicalProject { base, core }
58    }
59
60    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
61        self.core.o2i_col_mapping()
62    }
63
64    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
65        self.core.i2o_col_mapping()
66    }
67
68    /// Creates a `LogicalProject` which select some columns from the input.
69    ///
70    /// `mapping` should maps from `(0..input_fields.len())` to a consecutive range starting from 0.
71    ///
72    /// This is useful in column pruning when we want to add a project to ensure the output schema
73    /// is correct.
74    pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
75        Self::with_core(generic::Project::with_mapping(input, mapping))
76    }
77
78    /// Creates a `LogicalProject` which select some columns from the input.
79    pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
80        Self::with_core(generic::Project::with_out_fields(input, out_fields))
81    }
82
83    /// Creates a `LogicalProject` which select some columns from the input.
84    pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
85        Self::with_core(generic::Project::with_out_col_idx(input, out_fields))
86    }
87
88    pub fn exprs(&self) -> &Vec<ExprImpl> {
89        &self.core.exprs
90    }
91
92    pub fn is_identity(&self) -> bool {
93        self.core.is_identity()
94    }
95
96    pub fn try_as_projection(&self) -> Option<Vec<usize>> {
97        self.core.try_as_projection()
98    }
99
100    pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
101        self.core.decompose()
102    }
103
104    pub fn is_all_inputref(&self) -> bool {
105        self.core.is_all_inputref()
106    }
107}
108
109impl PlanTreeNodeUnary for LogicalProject {
110    fn input(&self) -> PlanRef {
111        self.core.input.clone()
112    }
113
114    fn clone_with_input(&self, input: PlanRef) -> Self {
115        Self::new(input, self.exprs().clone())
116    }
117
118    fn rewrite_with_input(
119        &self,
120        input: PlanRef,
121        mut input_col_change: ColIndexMapping,
122    ) -> (Self, ColIndexMapping) {
123        let exprs = self
124            .exprs()
125            .clone()
126            .into_iter()
127            .map(|expr| input_col_change.rewrite_expr(expr))
128            .collect();
129        let proj = Self::new(input, exprs);
130        // change the input columns index will not change the output column index
131        let out_col_change = ColIndexMapping::identity(self.schema().len());
132        (proj, out_col_change)
133    }
134}
135
136impl_plan_tree_node_for_unary! {LogicalProject}
137
138impl Distill for LogicalProject {
139    fn distill<'a>(&self) -> XmlNode<'a> {
140        childless_record(
141            "LogicalProject",
142            self.core.fields_pretty(self.base.schema()),
143        )
144    }
145}
146
147impl ColPrunable for LogicalProject {
148    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
149        let input_col_num: usize = self.input().schema().len();
150        let input_required_cols = collect_input_refs(
151            input_col_num,
152            required_cols.iter().map(|i| &self.exprs()[*i]),
153        )
154        .ones()
155        .collect_vec();
156        let new_input = self.input().prune_col(&input_required_cols, ctx);
157        let mut mapping = ColIndexMapping::with_remaining_columns(
158            &input_required_cols,
159            self.input().schema().len(),
160        );
161        // Rewrite each InputRef with new index.
162        let exprs = required_cols
163            .iter()
164            .map(|&id| mapping.rewrite_expr(self.exprs()[id].clone()))
165            .collect();
166
167        // Reconstruct the LogicalProject.
168        LogicalProject::new(new_input, exprs).into()
169    }
170}
171
172impl ExprRewritable for LogicalProject {
173    fn has_rewritable_expr(&self) -> bool {
174        true
175    }
176
177    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
178        let mut core = self.core.clone();
179        core.rewrite_exprs(r);
180        Self {
181            base: self.base.clone_with_new_plan_id(),
182            core,
183        }
184        .into()
185    }
186}
187
188impl ExprVisitable for LogicalProject {
189    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
190        self.core.visit_exprs(v);
191    }
192}
193
194impl PredicatePushdown for LogicalProject {
195    fn predicate_pushdown(
196        &self,
197        predicate: Condition,
198        ctx: &mut PredicatePushdownContext,
199    ) -> PlanRef {
200        // convert the predicate to one that references the child of the project
201        let mut subst = Substitute {
202            mapping: self.exprs().clone(),
203        };
204
205        let impure_mask = {
206            let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
207            for (i, e) in self.exprs().iter().enumerate() {
208                impure_mask.set(i, e.is_impure())
209            }
210            impure_mask
211        };
212        // (with impure input, with pure input)
213        let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
214        let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
215
216        gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
217    }
218}
219
220impl ToBatch for LogicalProject {
221    fn to_batch(&self) -> Result<PlanRef> {
222        self.to_batch_with_order_required(&Order::any())
223    }
224
225    fn to_batch_with_order_required(&self, required_order: &Order) -> Result<PlanRef> {
226        let input_order = self
227            .o2i_col_mapping()
228            .rewrite_provided_order(required_order);
229        let new_input = self.input().to_batch_with_order_required(&input_order)?;
230        let mut new_logical = self.core.clone();
231        new_logical.input = new_input;
232        let batch_project = BatchProject::new(new_logical);
233        required_order.enforce_if_not_satisfies(batch_project.into())
234    }
235}
236
237impl ToStream for LogicalProject {
238    fn to_stream_with_dist_required(
239        &self,
240        required_dist: &RequiredDist,
241        ctx: &mut ToStreamContext,
242    ) -> Result<PlanRef> {
243        let input_required = if required_dist.satisfies(&RequiredDist::AnyShard) {
244            RequiredDist::Any
245        } else {
246            let input_required = self
247                .o2i_col_mapping()
248                .rewrite_required_distribution(required_dist);
249            match input_required {
250                RequiredDist::PhysicalDist(dist) => match dist {
251                    Distribution::Single => RequiredDist::Any,
252                    _ => RequiredDist::PhysicalDist(dist),
253                },
254                _ => input_required,
255            }
256        };
257        let new_input = self
258            .input()
259            .to_stream_with_dist_required(&input_required, ctx)?;
260
261        let enable_materialized_exprs = self
262            .core
263            .ctx()
264            .session_ctx()
265            .config()
266            .streaming_enable_materialized_expressions();
267
268        let stream_plan = if enable_materialized_exprs {
269            // Extract UDFs to `MaterializedExprs` operator
270            let mut udf_field_names = BTreeMap::new();
271            let mut udf_expr_indices = HashSet::new();
272            let udf_exprs: Vec<_> = self
273                .exprs()
274                .iter()
275                .enumerate()
276                .filter_map(|(idx, expr)| {
277                    if expr.has_user_defined_function() {
278                        udf_expr_indices.insert(idx);
279                        if let Some(name) = self.core.field_names.get(&idx) {
280                            udf_field_names.insert(idx, name.clone());
281                        }
282                        Some(expr.clone())
283                    } else {
284                        None
285                    }
286                })
287                .collect();
288
289            if !udf_exprs.is_empty() {
290                // Create `MaterializedExprs` for UDFs
291                let mat_exprs_plan: PlanRef =
292                    StreamMaterializedExprs::new(new_input.clone(), udf_exprs, udf_field_names)
293                        .into();
294
295                let input_len = new_input.schema().len();
296                let mut udf_pos = 0;
297
298                // Create final expressions list with UDFs replaced by `InputRef`s
299                let final_exprs = self
300                    .exprs()
301                    .iter()
302                    .enumerate()
303                    .map(|(idx, expr)| {
304                        if udf_expr_indices.contains(&idx) {
305                            let output_idx = input_len + udf_pos;
306                            udf_pos += 1;
307                            InputRef::new(output_idx, expr.return_type()).into()
308                        } else {
309                            expr.clone()
310                        }
311                    })
312                    .collect();
313
314                let core = generic::Project::new(final_exprs, mat_exprs_plan);
315                StreamProject::new(core).into()
316            } else {
317                // No UDFs, create a regular `StreamProject`
318                let core = generic::Project::new(self.exprs().clone(), new_input);
319                StreamProject::new(core).into()
320            }
321        } else {
322            // Materialized expressions feature is not enabled, create a regular `StreamProject`
323            let core = generic::Project::new(self.exprs().clone(), new_input);
324            StreamProject::new(core).into()
325        };
326
327        required_dist.enforce_if_not_satisfies(stream_plan, &Order::any())
328    }
329
330    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
331        self.to_stream_with_dist_required(&RequiredDist::Any, ctx)
332    }
333
334    fn logical_rewrite_for_stream(
335        &self,
336        ctx: &mut RewriteStreamContext,
337    ) -> Result<(PlanRef, ColIndexMapping)> {
338        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
339        let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change);
340
341        // Add missing columns of `input_pk` into the select list.
342        let input_pk = input.expect_stream_key();
343        let i2o = proj.i2o_col_mapping();
344        let col_need_to_add = input_pk
345            .iter()
346            .cloned()
347            .filter(|i| i2o.try_map(*i).is_none());
348        let input_schema = input.schema();
349        let exprs =
350            proj.exprs()
351                .iter()
352                .cloned()
353                .chain(col_need_to_add.map(|idx| {
354                    InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
355                }))
356                .collect();
357        let proj = Self::new(input, exprs);
358        // The added columns is at the end, so it will not change existing column indices.
359        // But the target size of `out_col_change` should be the same as the length of the new
360        // schema.
361        let (map, _) = out_col_change.into_parts();
362        let out_col_change = ColIndexMapping::new(map, proj.base.schema().len());
363        Ok((proj.into(), out_col_change))
364    }
365}
366
367#[cfg(test)]
368mod tests {
369
370    use risingwave_common::catalog::{Field, Schema};
371    use risingwave_common::types::DataType;
372    use risingwave_pb::expr::expr_node::Type;
373
374    use super::*;
375    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
376    use crate::optimizer::optimizer_context::OptimizerContext;
377    use crate::optimizer::plan_node::LogicalValues;
378
379    #[tokio::test]
380    /// Pruning
381    /// ```text
382    /// Project(1, input_ref(2), input_ref(0)<5)
383    ///   TableScan(v1, v2, v3)
384    /// ```
385    /// with required columns `[1, 2]` will result in
386    /// ```text
387    /// Project(input_ref(1), input_ref(0)<5)
388    ///   TableScan(v1, v3)
389    /// ```
390    async fn test_prune_project() {
391        let ty = DataType::Int32;
392        let ctx = OptimizerContext::mock().await;
393        let fields: Vec<Field> = vec![
394            Field::with_name(ty.clone(), "v1"),
395            Field::with_name(ty.clone(), "v2"),
396            Field::with_name(ty.clone(), "v3"),
397        ];
398        let values = LogicalValues::new(
399            vec![],
400            Schema {
401                fields: fields.clone(),
402            },
403            ctx,
404        );
405        let project: PlanRef = LogicalProject::new(
406            values.into(),
407            vec![
408                ExprImpl::Literal(Box::new(Literal::new(None, ty.clone()))),
409                InputRef::new(2, ty.clone()).into(),
410                ExprImpl::FunctionCall(Box::new(
411                    FunctionCall::new(
412                        Type::LessThan,
413                        vec![
414                            ExprImpl::InputRef(Box::new(InputRef::new(0, ty.clone()))),
415                            ExprImpl::Literal(Box::new(Literal::new(None, ty))),
416                        ],
417                    )
418                    .unwrap(),
419                )),
420            ],
421        )
422        .into();
423
424        // Perform the prune
425        let required_cols = vec![1, 2];
426        let plan = project.prune_col(
427            &required_cols,
428            &mut ColumnPruningContext::new(project.clone()),
429        );
430
431        // Check the result
432        let project = plan.as_logical_project().unwrap();
433        assert_eq!(project.exprs().len(), 2);
434        assert_eq_input_ref!(&project.exprs()[0], 1);
435
436        let expr = project.exprs()[1].clone();
437        let call = expr.as_function_call().unwrap();
438        assert_eq_input_ref!(&call.inputs()[0], 0);
439
440        let values = project.input();
441        let values = values.as_logical_values().unwrap();
442        assert_eq!(values.schema().fields().len(), 2);
443        assert_eq!(values.schema().fields()[0], fields[0]);
444        assert_eq!(values.schema().fields()[1], fields[2]);
445    }
446}