Skip to main content

risingwave_frontend/optimizer/plan_node/
logical_project.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use pretty_xmlish::XmlNode;
20
21use super::utils::{Distill, childless_record};
22use super::{
23    BatchPlanRef, BatchProject, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
24    LogicalPlanRef, PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamMaterializedExprs,
25    StreamPlanRef, StreamProject, ToBatch, ToStream, gen_filter_and_pushdown, generic,
26};
27use crate::error::Result;
28use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef, collect_input_refs};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::GenericPlanRef;
31use crate::optimizer::plan_node::stream::StreamPlanNodeMetadata;
32use crate::optimizer::plan_node::{
33    ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
34};
35use crate::optimizer::property::{Distribution, Order, RequiredDist, StreamKind};
36use crate::session::current;
37use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, Substitute};
38
39/// `LogicalProject` computes a set of expressions from its input relation.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct LogicalProject {
42    pub base: PlanBase<Logical>,
43    core: generic::Project<PlanRef>,
44}
45
46impl LogicalProject {
47    pub fn create(input: PlanRef, exprs: Vec<ExprImpl>) -> PlanRef {
48        Self::new(input, exprs).into()
49    }
50
51    // TODO(kwannoel): We only need create/new don't keep both.
52    pub fn new(input: PlanRef, exprs: Vec<ExprImpl>) -> Self {
53        let core = generic::Project::new(exprs, input);
54        Self::with_core(core)
55    }
56
57    pub fn with_core(core: generic::Project<PlanRef>) -> Self {
58        let base = PlanBase::new_logical_with_core(&core);
59        LogicalProject { base, core }
60    }
61
62    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
63        self.core.o2i_col_mapping()
64    }
65
66    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
67        self.core.i2o_col_mapping()
68    }
69
70    /// Creates a `LogicalProject` which select some columns from the input.
71    ///
72    /// `mapping` should maps from `(0..input_fields.len())` to a consecutive range starting from 0.
73    ///
74    /// This is useful in column pruning when we want to add a project to ensure the output schema
75    /// is correct.
76    pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
77        Self::with_core(generic::Project::with_mapping(input, mapping))
78    }
79
80    /// Creates a `LogicalProject` which select some columns from the input.
81    pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
82        Self::with_core(generic::Project::with_out_fields(input, out_fields))
83    }
84
85    /// Creates a `LogicalProject` which select some columns from the input.
86    pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
87        Self::with_core(generic::Project::with_out_col_idx(input, out_fields))
88    }
89
90    pub fn exprs(&self) -> &Vec<ExprImpl> {
91        &self.core.exprs
92    }
93
94    pub fn is_identity(&self) -> bool {
95        self.core.is_identity()
96    }
97
98    pub fn try_as_projection(&self) -> Option<Vec<usize>> {
99        self.core.try_as_projection()
100    }
101
102    pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
103        self.core.decompose()
104    }
105
106    pub fn is_all_inputref(&self) -> bool {
107        self.core.is_all_inputref()
108    }
109}
110
111impl PlanTreeNodeUnary<Logical> for LogicalProject {
112    fn input(&self) -> LogicalPlanRef {
113        self.core.input.clone()
114    }
115
116    fn clone_with_input(&self, input: LogicalPlanRef) -> Self {
117        Self::new(input, self.exprs().clone())
118    }
119
120    fn rewrite_with_input(
121        &self,
122        input: PlanRef,
123        mut input_col_change: ColIndexMapping,
124    ) -> (Self, ColIndexMapping) {
125        let exprs = self
126            .exprs()
127            .clone()
128            .into_iter()
129            .map(|expr| input_col_change.rewrite_expr(expr))
130            .collect();
131        let proj = Self::new(input, exprs);
132        // change the input columns index will not change the output column index
133        let out_col_change = ColIndexMapping::identity(self.schema().len());
134        (proj, out_col_change)
135    }
136}
137
138impl_plan_tree_node_for_unary! { Logical, LogicalProject}
139
140impl Distill for LogicalProject {
141    fn distill<'a>(&self) -> XmlNode<'a> {
142        childless_record(
143            "LogicalProject",
144            self.core.fields_pretty(self.base.schema()),
145        )
146    }
147}
148
149impl ColPrunable for LogicalProject {
150    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
151        let input_col_num: usize = self.input().schema().len();
152        let input_required_cols = collect_input_refs(
153            input_col_num,
154            required_cols.iter().map(|i| &self.exprs()[*i]),
155        )
156        .ones()
157        .collect_vec();
158        let new_input = self.input().prune_col(&input_required_cols, ctx);
159        let mut mapping = ColIndexMapping::with_remaining_columns(
160            &input_required_cols,
161            self.input().schema().len(),
162        );
163        // Rewrite each InputRef with new index.
164        let exprs = required_cols
165            .iter()
166            .map(|&id| mapping.rewrite_expr(self.exprs()[id].clone()))
167            .collect();
168
169        // Reconstruct the LogicalProject.
170        LogicalProject::new(new_input, exprs).into()
171    }
172}
173
174impl ExprRewritable<Logical> for LogicalProject {
175    fn has_rewritable_expr(&self) -> bool {
176        true
177    }
178
179    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
180        let mut core = self.core.clone();
181        core.rewrite_exprs(r);
182        Self {
183            base: self.base.clone_with_new_plan_id(),
184            core,
185        }
186        .into()
187    }
188}
189
190impl ExprVisitable for LogicalProject {
191    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
192        self.core.visit_exprs(v);
193    }
194}
195
196impl PredicatePushdown for LogicalProject {
197    fn predicate_pushdown(
198        &self,
199        predicate: Condition,
200        ctx: &mut PredicatePushdownContext,
201    ) -> PlanRef {
202        // convert the predicate to one that references the child of the project
203        let mut subst = Substitute {
204            mapping: self.exprs().clone(),
205        };
206
207        let impure_mask = {
208            let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
209            for (i, e) in self.exprs().iter().enumerate() {
210                impure_mask.set(i, e.is_impure())
211            }
212            impure_mask
213        };
214        // (with impure input, with pure input)
215        let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
216        let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
217
218        gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
219    }
220}
221
222impl ToBatch for LogicalProject {
223    fn to_batch(&self) -> Result<BatchPlanRef> {
224        self.to_batch_with_order_required(&Order::any())
225    }
226
227    fn to_batch_with_order_required(&self, required_order: &Order) -> Result<BatchPlanRef> {
228        let input_order = self
229            .o2i_col_mapping()
230            .rewrite_provided_order(required_order);
231        let new_input = self.input().to_batch_with_order_required(&input_order)?;
232        let project = self.core.clone_with_input(new_input);
233        let batch_project = BatchProject::new(project);
234        required_order.enforce_if_not_satisfies(batch_project.into())
235    }
236}
237
238impl ToStream for LogicalProject {
239    fn to_stream_with_dist_required(
240        &self,
241        required_dist: &RequiredDist,
242        ctx: &mut ToStreamContext,
243    ) -> Result<StreamPlanRef> {
244        let input_required = if required_dist.satisfies(&RequiredDist::AnyShard) {
245            RequiredDist::Any
246        } else {
247            let input_required = self
248                .o2i_col_mapping()
249                .rewrite_required_distribution(required_dist);
250            match input_required {
251                RequiredDist::PhysicalDist(dist) => match dist {
252                    Distribution::Single => RequiredDist::Any,
253                    _ => RequiredDist::PhysicalDist(dist),
254                },
255                _ => input_required,
256            }
257        };
258        let new_input = self
259            .input()
260            .to_stream_with_dist_required(&input_required, ctx)?;
261
262        let unsafe_allow_unmaterialized_impure_expr = current::config()
263            .map(|c| c.read().streaming_unsafe_allow_unmaterialized_impure_expr())
264            .unwrap_or_else(|| {
265                self.base
266                    .ctx()
267                    .session_ctx()
268                    .config()
269                    .streaming_unsafe_allow_unmaterialized_impure_expr()
270            });
271        let should_materialize_expr = match new_input.stream_kind() {
272            StreamKind::AppendOnly => None,
273            StreamKind::Retract | StreamKind::Upsert if unsafe_allow_unmaterialized_impure_expr => {
274                // This deliberately leaves impure expressions in `StreamProject` on retract/upsert
275                // streams. The behavior is unsafe and only enabled by the explicit session option.
276                None
277            }
278            kind @ (StreamKind::Retract | StreamKind::Upsert) => {
279                // Extract impure functions to `MaterializedExprs` operator
280                let mut impure_field_names = BTreeMap::new();
281                let mut impure_expr_indices = HashSet::new();
282                let impure_exprs: Vec<_> = self
283                    .exprs()
284                    .iter()
285                    .enumerate()
286                    .filter_map(|(idx, expr)| {
287                        // Extract impure expressions
288                        if expr.is_impure() {
289                            impure_expr_indices.insert(idx);
290                            if let Some(name) = self.core.field_names.get(&idx) {
291                                impure_field_names.insert(idx, name.clone());
292                            }
293                            Some(expr.clone())
294                        } else {
295                            None
296                        }
297                    })
298                    .collect();
299                if impure_exprs.is_empty() {
300                    None
301                } else if kind == StreamKind::Upsert
302                    && new_input
303                        .stream_key()
304                        .into_iter()
305                        .flatten()
306                        .all(|stream_key_idx| !impure_expr_indices.contains(stream_key_idx))
307                {
308                    // We're operating on non-stream-key columns of upsert stream, no need to materialize.
309                    None
310                } else {
311                    Some((impure_field_names, impure_expr_indices, impure_exprs))
312                }
313            }
314        };
315
316        let stream_plan = if let Some((impure_field_names, impure_expr_indices, impure_exprs)) =
317            should_materialize_expr
318        {
319            {
320                let new_input = new_input.enforce_concrete_distribution();
321
322                // Create `MaterializedExprs` for impure expressions
323                let mat_exprs_plan: StreamPlanRef = StreamMaterializedExprs::new(
324                    new_input.clone(),
325                    impure_exprs,
326                    impure_field_names,
327                )?
328                .into();
329
330                let input_len = new_input.schema().len();
331                let mut materialized_pos = 0;
332
333                // Create final expressions list with impure expressions replaced by `InputRef`s
334                let final_exprs = self
335                    .exprs()
336                    .iter()
337                    .enumerate()
338                    .map(|(idx, expr)| {
339                        if impure_expr_indices.contains(&idx) {
340                            let output_idx = input_len + materialized_pos;
341                            materialized_pos += 1;
342                            InputRef::new(output_idx, expr.return_type()).into()
343                        } else {
344                            expr.clone()
345                        }
346                    })
347                    .collect();
348
349                let core = generic::Project::new(final_exprs, mat_exprs_plan);
350                StreamProject::new(core).into()
351            }
352        } else {
353            // No expressions to materialize or the feature is not enabled, create a regular `StreamProject`
354            let core = generic::Project::new(self.exprs().clone(), new_input);
355            StreamProject::new(core).into()
356        };
357
358        required_dist.streaming_enforce_if_not_satisfies(stream_plan)
359    }
360
361    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
362        self.to_stream_with_dist_required(&RequiredDist::Any, ctx)
363    }
364
365    fn logical_rewrite_for_stream(
366        &self,
367        ctx: &mut RewriteStreamContext,
368    ) -> Result<(PlanRef, ColIndexMapping)> {
369        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
370        let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change);
371
372        // Add missing columns of `input_pk` into the select list.
373        let input_pk = input.expect_stream_key();
374        let i2o = proj.i2o_col_mapping();
375        let col_need_to_add = input_pk
376            .iter()
377            .cloned()
378            .filter(|i| i2o.try_map(*i).is_none());
379        let input_schema = input.schema();
380        let exprs =
381            proj.exprs()
382                .iter()
383                .cloned()
384                .chain(col_need_to_add.map(|idx| {
385                    InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
386                }))
387                .collect();
388        let proj = Self::new(input, exprs);
389        // The added columns is at the end, so it will not change existing column indices.
390        // But the target size of `out_col_change` should be the same as the length of the new
391        // schema.
392        let (map, _) = out_col_change.into_parts();
393        let out_col_change = ColIndexMapping::new(map, proj.base.schema().len());
394        Ok((proj.into(), out_col_change))
395    }
396
397    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
398        if columns.is_empty() {
399            return None;
400        }
401
402        let input_columns = columns
403            .iter()
404            .map(|&col| {
405                // First try the original o2i mapping for direct InputRef
406                if let Some(input_col) = self.o2i_col_mapping().try_map(col) {
407                    return Some(input_col);
408                }
409
410                // If not a direct InputRef, check if it's a pure function with single InputRef
411                let expr = &self.exprs()[col];
412                if expr.is_pure() {
413                    let input_refs = expr.collect_input_refs(self.input().schema().len());
414                    // Check if expression references exactly one input column
415                    if input_refs.count_ones(..) == 1 {
416                        return input_refs.ones().next();
417                    }
418                }
419
420                None
421            })
422            .collect::<Option<Vec<usize>>>()?;
423
424        let new_input = self.input().try_better_locality(&input_columns)?;
425        Some(self.clone_with_input(new_input).into())
426    }
427}
428
429#[cfg(test)]
430mod tests {
431
432    use risingwave_common::catalog::{Field, Schema};
433    use risingwave_common::types::DataType;
434    use risingwave_pb::expr::expr_node::Type;
435
436    use super::*;
437    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
438    use crate::optimizer::optimizer_context::OptimizerContext;
439    use crate::optimizer::plan_node::LogicalValues;
440
441    #[tokio::test]
442    /// Pruning
443    /// ```text
444    /// Project(1, input_ref(2), input_ref(0)<5)
445    ///   TableScan(v1, v2, v3)
446    /// ```
447    /// with required columns `[1, 2]` will result in
448    /// ```text
449    /// Project(input_ref(1), input_ref(0)<5)
450    ///   TableScan(v1, v3)
451    /// ```
452    async fn test_prune_project() {
453        let ty = DataType::Int32;
454        let ctx = OptimizerContext::mock();
455        let fields: Vec<Field> = vec![
456            Field::with_name(ty.clone(), "v1"),
457            Field::with_name(ty.clone(), "v2"),
458            Field::with_name(ty.clone(), "v3"),
459        ];
460        let values = LogicalValues::new(
461            vec![],
462            Schema {
463                fields: fields.clone(),
464            },
465            ctx,
466        );
467        let project: PlanRef = LogicalProject::new(
468            values.into(),
469            vec![
470                ExprImpl::Literal(Box::new(Literal::new(None, ty.clone()))),
471                InputRef::new(2, ty.clone()).into(),
472                ExprImpl::FunctionCall(Box::new(
473                    FunctionCall::new(
474                        Type::LessThan,
475                        vec![
476                            ExprImpl::InputRef(Box::new(InputRef::new(0, ty.clone()))),
477                            ExprImpl::Literal(Box::new(Literal::new(None, ty))),
478                        ],
479                    )
480                    .unwrap(),
481                )),
482            ],
483        )
484        .into();
485
486        // Perform the prune
487        let required_cols = vec![1, 2];
488        let plan = project.prune_col(
489            &required_cols,
490            &mut ColumnPruningContext::new(project.clone()),
491        );
492
493        // Check the result
494        let project = plan.as_logical_project().unwrap();
495        assert_eq!(project.exprs().len(), 2);
496        assert_eq_input_ref!(&project.exprs()[0], 1);
497
498        let expr = project.exprs()[1].clone();
499        let call = expr.as_function_call().unwrap();
500        assert_eq_input_ref!(&call.inputs()[0], 0);
501
502        let values = project.input();
503        let values = values.as_logical_values().unwrap();
504        assert_eq!(values.schema().fields().len(), 2);
505        assert_eq!(values.schema().fields()[0], fields[0]);
506        assert_eq!(values.schema().fields()[1], fields[2]);
507    }
508}