risingwave_frontend/optimizer/plan_node/
logical_project.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use pretty_xmlish::XmlNode;
20
21use super::utils::{Distill, childless_record};
22use super::{
23    BatchPlanRef, BatchProject, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
24    LogicalPlanRef, PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamMaterializedExprs,
25    StreamPlanRef, StreamProject, ToBatch, ToStream, gen_filter_and_pushdown, generic,
26};
27use crate::error::Result;
28use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef, collect_input_refs};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::GenericPlanRef;
31use crate::optimizer::plan_node::stream::StreamPlanNodeMetadata;
32use crate::optimizer::plan_node::{
33    ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
34};
35use crate::optimizer::property::{Distribution, Order, RequiredDist, StreamKind};
36use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, Substitute};
37
38/// `LogicalProject` computes a set of expressions from its input relation.
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct LogicalProject {
41    pub base: PlanBase<Logical>,
42    core: generic::Project<PlanRef>,
43}
44
45impl LogicalProject {
46    pub fn create(input: PlanRef, exprs: Vec<ExprImpl>) -> PlanRef {
47        Self::new(input, exprs).into()
48    }
49
50    // TODO(kwannoel): We only need create/new don't keep both.
51    pub fn new(input: PlanRef, exprs: Vec<ExprImpl>) -> Self {
52        let core = generic::Project::new(exprs, input);
53        Self::with_core(core)
54    }
55
56    pub fn with_core(core: generic::Project<PlanRef>) -> Self {
57        let base = PlanBase::new_logical_with_core(&core);
58        LogicalProject { base, core }
59    }
60
61    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
62        self.core.o2i_col_mapping()
63    }
64
65    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
66        self.core.i2o_col_mapping()
67    }
68
69    /// Creates a `LogicalProject` which select some columns from the input.
70    ///
71    /// `mapping` should maps from `(0..input_fields.len())` to a consecutive range starting from 0.
72    ///
73    /// This is useful in column pruning when we want to add a project to ensure the output schema
74    /// is correct.
75    pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
76        Self::with_core(generic::Project::with_mapping(input, mapping))
77    }
78
79    /// Creates a `LogicalProject` which select some columns from the input.
80    pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
81        Self::with_core(generic::Project::with_out_fields(input, out_fields))
82    }
83
84    /// Creates a `LogicalProject` which select some columns from the input.
85    pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
86        Self::with_core(generic::Project::with_out_col_idx(input, out_fields))
87    }
88
89    pub fn exprs(&self) -> &Vec<ExprImpl> {
90        &self.core.exprs
91    }
92
93    pub fn is_identity(&self) -> bool {
94        self.core.is_identity()
95    }
96
97    pub fn try_as_projection(&self) -> Option<Vec<usize>> {
98        self.core.try_as_projection()
99    }
100
101    pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
102        self.core.decompose()
103    }
104
105    pub fn is_all_inputref(&self) -> bool {
106        self.core.is_all_inputref()
107    }
108}
109
110impl PlanTreeNodeUnary<Logical> for LogicalProject {
111    fn input(&self) -> LogicalPlanRef {
112        self.core.input.clone()
113    }
114
115    fn clone_with_input(&self, input: LogicalPlanRef) -> Self {
116        Self::new(input, self.exprs().clone())
117    }
118
119    fn rewrite_with_input(
120        &self,
121        input: PlanRef,
122        mut input_col_change: ColIndexMapping,
123    ) -> (Self, ColIndexMapping) {
124        let exprs = self
125            .exprs()
126            .clone()
127            .into_iter()
128            .map(|expr| input_col_change.rewrite_expr(expr))
129            .collect();
130        let proj = Self::new(input, exprs);
131        // change the input columns index will not change the output column index
132        let out_col_change = ColIndexMapping::identity(self.schema().len());
133        (proj, out_col_change)
134    }
135}
136
137impl_plan_tree_node_for_unary! { Logical, LogicalProject}
138
139impl Distill for LogicalProject {
140    fn distill<'a>(&self) -> XmlNode<'a> {
141        childless_record(
142            "LogicalProject",
143            self.core.fields_pretty(self.base.schema()),
144        )
145    }
146}
147
148impl ColPrunable for LogicalProject {
149    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
150        let input_col_num: usize = self.input().schema().len();
151        let input_required_cols = collect_input_refs(
152            input_col_num,
153            required_cols.iter().map(|i| &self.exprs()[*i]),
154        )
155        .ones()
156        .collect_vec();
157        let new_input = self.input().prune_col(&input_required_cols, ctx);
158        let mut mapping = ColIndexMapping::with_remaining_columns(
159            &input_required_cols,
160            self.input().schema().len(),
161        );
162        // Rewrite each InputRef with new index.
163        let exprs = required_cols
164            .iter()
165            .map(|&id| mapping.rewrite_expr(self.exprs()[id].clone()))
166            .collect();
167
168        // Reconstruct the LogicalProject.
169        LogicalProject::new(new_input, exprs).into()
170    }
171}
172
173impl ExprRewritable<Logical> for LogicalProject {
174    fn has_rewritable_expr(&self) -> bool {
175        true
176    }
177
178    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
179        let mut core = self.core.clone();
180        core.rewrite_exprs(r);
181        Self {
182            base: self.base.clone_with_new_plan_id(),
183            core,
184        }
185        .into()
186    }
187}
188
189impl ExprVisitable for LogicalProject {
190    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
191        self.core.visit_exprs(v);
192    }
193}
194
195impl PredicatePushdown for LogicalProject {
196    fn predicate_pushdown(
197        &self,
198        predicate: Condition,
199        ctx: &mut PredicatePushdownContext,
200    ) -> PlanRef {
201        // convert the predicate to one that references the child of the project
202        let mut subst = Substitute {
203            mapping: self.exprs().clone(),
204        };
205
206        let impure_mask = {
207            let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
208            for (i, e) in self.exprs().iter().enumerate() {
209                impure_mask.set(i, e.is_impure())
210            }
211            impure_mask
212        };
213        // (with impure input, with pure input)
214        let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
215        let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
216
217        gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
218    }
219}
220
221impl ToBatch for LogicalProject {
222    fn to_batch(&self) -> Result<BatchPlanRef> {
223        self.to_batch_with_order_required(&Order::any())
224    }
225
226    fn to_batch_with_order_required(&self, required_order: &Order) -> Result<BatchPlanRef> {
227        let input_order = self
228            .o2i_col_mapping()
229            .rewrite_provided_order(required_order);
230        let new_input = self.input().to_batch_with_order_required(&input_order)?;
231        let project = self.core.clone_with_input(new_input);
232        let batch_project = BatchProject::new(project);
233        required_order.enforce_if_not_satisfies(batch_project.into())
234    }
235}
236
237impl ToStream for LogicalProject {
238    fn to_stream_with_dist_required(
239        &self,
240        required_dist: &RequiredDist,
241        ctx: &mut ToStreamContext,
242    ) -> Result<StreamPlanRef> {
243        let input_required = if required_dist.satisfies(&RequiredDist::AnyShard) {
244            RequiredDist::Any
245        } else {
246            let input_required = self
247                .o2i_col_mapping()
248                .rewrite_required_distribution(required_dist);
249            match input_required {
250                RequiredDist::PhysicalDist(dist) => match dist {
251                    Distribution::Single => RequiredDist::Any,
252                    _ => RequiredDist::PhysicalDist(dist),
253                },
254                _ => input_required,
255            }
256        };
257        let new_input = self
258            .input()
259            .to_stream_with_dist_required(&input_required, ctx)?;
260
261        let should_materialize_expr = match new_input.stream_kind() {
262            StreamKind::AppendOnly => None,
263            kind @ (StreamKind::Retract | StreamKind::Upsert) => {
264                // Extract impure functions to `MaterializedExprs` operator
265                let mut impure_field_names = BTreeMap::new();
266                let mut impure_expr_indices = HashSet::new();
267                let impure_exprs: Vec<_> = self
268                    .exprs()
269                    .iter()
270                    .enumerate()
271                    .filter_map(|(idx, expr)| {
272                        // Extract impure expressions
273                        if expr.is_impure() {
274                            impure_expr_indices.insert(idx);
275                            if let Some(name) = self.core.field_names.get(&idx) {
276                                impure_field_names.insert(idx, name.clone());
277                            }
278                            Some(expr.clone())
279                        } else {
280                            None
281                        }
282                    })
283                    .collect();
284                if impure_exprs.is_empty() {
285                    None
286                } else if kind == StreamKind::Upsert
287                    && new_input
288                        .stream_key()
289                        .into_iter()
290                        .flatten()
291                        .all(|stream_key_idx| !impure_expr_indices.contains(stream_key_idx))
292                {
293                    // We're operating on non-stream-key columns of upsert stream, no need to materialize.
294                    None
295                } else {
296                    Some((impure_field_names, impure_expr_indices, impure_exprs))
297                }
298            }
299        };
300
301        let stream_plan = if let Some((impure_field_names, impure_expr_indices, impure_exprs)) =
302            should_materialize_expr
303        {
304            {
305                let new_input = new_input.enforce_concrete_distribution();
306
307                // Create `MaterializedExprs` for impure expressions
308                let mat_exprs_plan: StreamPlanRef = StreamMaterializedExprs::new(
309                    new_input.clone(),
310                    impure_exprs,
311                    impure_field_names,
312                )?
313                .into();
314
315                let input_len = new_input.schema().len();
316                let mut materialized_pos = 0;
317
318                // Create final expressions list with impure expressions replaced by `InputRef`s
319                let final_exprs = self
320                    .exprs()
321                    .iter()
322                    .enumerate()
323                    .map(|(idx, expr)| {
324                        if impure_expr_indices.contains(&idx) {
325                            let output_idx = input_len + materialized_pos;
326                            materialized_pos += 1;
327                            InputRef::new(output_idx, expr.return_type()).into()
328                        } else {
329                            expr.clone()
330                        }
331                    })
332                    .collect();
333
334                let core = generic::Project::new(final_exprs, mat_exprs_plan);
335                StreamProject::new(core).into()
336            }
337        } else {
338            // No expressions to materialize or the feature is not enabled, create a regular `StreamProject`
339            let core = generic::Project::new(self.exprs().clone(), new_input);
340            StreamProject::new(core).into()
341        };
342
343        required_dist.streaming_enforce_if_not_satisfies(stream_plan)
344    }
345
346    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
347        self.to_stream_with_dist_required(&RequiredDist::Any, ctx)
348    }
349
350    fn logical_rewrite_for_stream(
351        &self,
352        ctx: &mut RewriteStreamContext,
353    ) -> Result<(PlanRef, ColIndexMapping)> {
354        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
355        let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change);
356
357        // Add missing columns of `input_pk` into the select list.
358        let input_pk = input.expect_stream_key();
359        let i2o = proj.i2o_col_mapping();
360        let col_need_to_add = input_pk
361            .iter()
362            .cloned()
363            .filter(|i| i2o.try_map(*i).is_none());
364        let input_schema = input.schema();
365        let exprs =
366            proj.exprs()
367                .iter()
368                .cloned()
369                .chain(col_need_to_add.map(|idx| {
370                    InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
371                }))
372                .collect();
373        let proj = Self::new(input, exprs);
374        // The added columns is at the end, so it will not change existing column indices.
375        // But the target size of `out_col_change` should be the same as the length of the new
376        // schema.
377        let (map, _) = out_col_change.into_parts();
378        let out_col_change = ColIndexMapping::new(map, proj.base.schema().len());
379        Ok((proj.into(), out_col_change))
380    }
381
382    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
383        if columns.is_empty() {
384            return None;
385        }
386
387        let input_columns = columns
388            .iter()
389            .map(|&col| {
390                // First try the original o2i mapping for direct InputRef
391                if let Some(input_col) = self.o2i_col_mapping().try_map(col) {
392                    return Some(input_col);
393                }
394
395                // If not a direct InputRef, check if it's a pure function with single InputRef
396                let expr = &self.exprs()[col];
397                if expr.is_pure() {
398                    let input_refs = expr.collect_input_refs(self.input().schema().len());
399                    // Check if expression references exactly one input column
400                    if input_refs.count_ones(..) == 1 {
401                        return input_refs.ones().next();
402                    }
403                }
404
405                None
406            })
407            .collect::<Option<Vec<usize>>>()?;
408
409        let new_input = self.input().try_better_locality(&input_columns)?;
410        Some(self.clone_with_input(new_input).into())
411    }
412}
413
414#[cfg(test)]
415mod tests {
416
417    use risingwave_common::catalog::{Field, Schema};
418    use risingwave_common::types::DataType;
419    use risingwave_pb::expr::expr_node::Type;
420
421    use super::*;
422    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
423    use crate::optimizer::optimizer_context::OptimizerContext;
424    use crate::optimizer::plan_node::LogicalValues;
425
426    #[tokio::test]
427    /// Pruning
428    /// ```text
429    /// Project(1, input_ref(2), input_ref(0)<5)
430    ///   TableScan(v1, v2, v3)
431    /// ```
432    /// with required columns `[1, 2]` will result in
433    /// ```text
434    /// Project(input_ref(1), input_ref(0)<5)
435    ///   TableScan(v1, v3)
436    /// ```
437    async fn test_prune_project() {
438        let ty = DataType::Int32;
439        let ctx = OptimizerContext::mock().await;
440        let fields: Vec<Field> = vec![
441            Field::with_name(ty.clone(), "v1"),
442            Field::with_name(ty.clone(), "v2"),
443            Field::with_name(ty.clone(), "v3"),
444        ];
445        let values = LogicalValues::new(
446            vec![],
447            Schema {
448                fields: fields.clone(),
449            },
450            ctx,
451        );
452        let project: PlanRef = LogicalProject::new(
453            values.into(),
454            vec![
455                ExprImpl::Literal(Box::new(Literal::new(None, ty.clone()))),
456                InputRef::new(2, ty.clone()).into(),
457                ExprImpl::FunctionCall(Box::new(
458                    FunctionCall::new(
459                        Type::LessThan,
460                        vec![
461                            ExprImpl::InputRef(Box::new(InputRef::new(0, ty.clone()))),
462                            ExprImpl::Literal(Box::new(Literal::new(None, ty))),
463                        ],
464                    )
465                    .unwrap(),
466                )),
467            ],
468        )
469        .into();
470
471        // Perform the prune
472        let required_cols = vec![1, 2];
473        let plan = project.prune_col(
474            &required_cols,
475            &mut ColumnPruningContext::new(project.clone()),
476        );
477
478        // Check the result
479        let project = plan.as_logical_project().unwrap();
480        assert_eq!(project.exprs().len(), 2);
481        assert_eq_input_ref!(&project.exprs()[0], 1);
482
483        let expr = project.exprs()[1].clone();
484        let call = expr.as_function_call().unwrap();
485        assert_eq_input_ref!(&call.inputs()[0], 0);
486
487        let values = project.input();
488        let values = values.as_logical_values().unwrap();
489        assert_eq!(values.schema().fields().len(), 2);
490        assert_eq!(values.schema().fields()[0], fields[0]);
491        assert_eq!(values.schema().fields()[1], fields[2]);
492    }
493}