risingwave_frontend/optimizer/plan_node/
logical_project.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use pretty_xmlish::XmlNode;
18
19use super::utils::{Distill, childless_record};
20use super::{
21    BatchProject, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary,
22    PredicatePushdown, StreamProject, ToBatch, ToStream, gen_filter_and_pushdown, generic,
23};
24use crate::error::Result;
25use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor, InputRef, collect_input_refs};
26use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
27use crate::optimizer::plan_node::generic::GenericPlanRef;
28use crate::optimizer::plan_node::{
29    ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
30};
31use crate::optimizer::property::{Distribution, Order, RequiredDist};
32use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, Substitute};
33
34/// `LogicalProject` computes a set of expressions from its input relation.
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub struct LogicalProject {
37    pub base: PlanBase<Logical>,
38    core: generic::Project<PlanRef>,
39}
40
41impl LogicalProject {
42    pub fn create(input: PlanRef, exprs: Vec<ExprImpl>) -> PlanRef {
43        Self::new(input, exprs).into()
44    }
45
46    pub fn new(input: PlanRef, exprs: Vec<ExprImpl>) -> Self {
47        let core = generic::Project::new(exprs, input);
48        Self::with_core(core)
49    }
50
51    pub fn with_core(core: generic::Project<PlanRef>) -> Self {
52        let base = PlanBase::new_logical_with_core(&core);
53        LogicalProject { base, core }
54    }
55
56    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
57        self.core.o2i_col_mapping()
58    }
59
60    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
61        self.core.i2o_col_mapping()
62    }
63
64    /// Creates a `LogicalProject` which select some columns from the input.
65    ///
66    /// `mapping` should maps from `(0..input_fields.len())` to a consecutive range starting from 0.
67    ///
68    /// This is useful in column pruning when we want to add a project to ensure the output schema
69    /// is correct.
70    pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
71        Self::with_core(generic::Project::with_mapping(input, mapping))
72    }
73
74    /// Creates a `LogicalProject` which select some columns from the input.
75    pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
76        Self::with_core(generic::Project::with_out_fields(input, out_fields))
77    }
78
79    /// Creates a `LogicalProject` which select some columns from the input.
80    pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
81        Self::with_core(generic::Project::with_out_col_idx(input, out_fields))
82    }
83
84    pub fn exprs(&self) -> &Vec<ExprImpl> {
85        &self.core.exprs
86    }
87
88    pub fn is_identity(&self) -> bool {
89        self.core.is_identity()
90    }
91
92    pub fn try_as_projection(&self) -> Option<Vec<usize>> {
93        self.core.try_as_projection()
94    }
95
96    pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
97        self.core.decompose()
98    }
99
100    pub fn is_all_inputref(&self) -> bool {
101        self.core.is_all_inputref()
102    }
103}
104
105impl PlanTreeNodeUnary for LogicalProject {
106    fn input(&self) -> PlanRef {
107        self.core.input.clone()
108    }
109
110    fn clone_with_input(&self, input: PlanRef) -> Self {
111        Self::new(input, self.exprs().clone())
112    }
113
114    fn rewrite_with_input(
115        &self,
116        input: PlanRef,
117        mut input_col_change: ColIndexMapping,
118    ) -> (Self, ColIndexMapping) {
119        let exprs = self
120            .exprs()
121            .clone()
122            .into_iter()
123            .map(|expr| input_col_change.rewrite_expr(expr))
124            .collect();
125        let proj = Self::new(input, exprs);
126        // change the input columns index will not change the output column index
127        let out_col_change = ColIndexMapping::identity(self.schema().len());
128        (proj, out_col_change)
129    }
130}
131
132impl_plan_tree_node_for_unary! {LogicalProject}
133
134impl Distill for LogicalProject {
135    fn distill<'a>(&self) -> XmlNode<'a> {
136        childless_record(
137            "LogicalProject",
138            self.core.fields_pretty(self.base.schema()),
139        )
140    }
141}
142
143impl ColPrunable for LogicalProject {
144    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
145        let input_col_num: usize = self.input().schema().len();
146        let input_required_cols = collect_input_refs(
147            input_col_num,
148            required_cols.iter().map(|i| &self.exprs()[*i]),
149        )
150        .ones()
151        .collect_vec();
152        let new_input = self.input().prune_col(&input_required_cols, ctx);
153        let mut mapping = ColIndexMapping::with_remaining_columns(
154            &input_required_cols,
155            self.input().schema().len(),
156        );
157        // Rewrite each InputRef with new index.
158        let exprs = required_cols
159            .iter()
160            .map(|&id| mapping.rewrite_expr(self.exprs()[id].clone()))
161            .collect();
162
163        // Reconstruct the LogicalProject.
164        LogicalProject::new(new_input, exprs).into()
165    }
166}
167
168impl ExprRewritable for LogicalProject {
169    fn has_rewritable_expr(&self) -> bool {
170        true
171    }
172
173    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
174        let mut core = self.core.clone();
175        core.rewrite_exprs(r);
176        Self {
177            base: self.base.clone_with_new_plan_id(),
178            core,
179        }
180        .into()
181    }
182}
183
184impl ExprVisitable for LogicalProject {
185    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
186        self.core.visit_exprs(v);
187    }
188}
189
190impl PredicatePushdown for LogicalProject {
191    fn predicate_pushdown(
192        &self,
193        predicate: Condition,
194        ctx: &mut PredicatePushdownContext,
195    ) -> PlanRef {
196        // convert the predicate to one that references the child of the project
197        let mut subst = Substitute {
198            mapping: self.exprs().clone(),
199        };
200
201        let impure_mask = {
202            let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
203            for (i, e) in self.exprs().iter().enumerate() {
204                impure_mask.set(i, e.is_impure())
205            }
206            impure_mask
207        };
208        // (with impure input, with pure input)
209        let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
210        let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
211
212        gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
213    }
214}
215
216impl ToBatch for LogicalProject {
217    fn to_batch(&self) -> Result<PlanRef> {
218        self.to_batch_with_order_required(&Order::any())
219    }
220
221    fn to_batch_with_order_required(&self, required_order: &Order) -> Result<PlanRef> {
222        let input_order = self
223            .o2i_col_mapping()
224            .rewrite_provided_order(required_order);
225        let new_input = self.input().to_batch_with_order_required(&input_order)?;
226        let mut new_logical = self.core.clone();
227        new_logical.input = new_input;
228        let batch_project = BatchProject::new(new_logical);
229        required_order.enforce_if_not_satisfies(batch_project.into())
230    }
231}
232
233impl ToStream for LogicalProject {
234    fn to_stream_with_dist_required(
235        &self,
236        required_dist: &RequiredDist,
237        ctx: &mut ToStreamContext,
238    ) -> Result<PlanRef> {
239        let input_required = if required_dist.satisfies(&RequiredDist::AnyShard) {
240            RequiredDist::Any
241        } else {
242            let input_required = self
243                .o2i_col_mapping()
244                .rewrite_required_distribution(required_dist);
245            match input_required {
246                RequiredDist::PhysicalDist(dist) => match dist {
247                    Distribution::Single => RequiredDist::Any,
248                    _ => RequiredDist::PhysicalDist(dist),
249                },
250                _ => input_required,
251            }
252        };
253        let new_input = self
254            .input()
255            .to_stream_with_dist_required(&input_required, ctx)?;
256        let mut new_logical = self.core.clone();
257        new_logical.input = new_input;
258        let stream_plan = StreamProject::new(new_logical);
259        required_dist.enforce_if_not_satisfies(stream_plan.into(), &Order::any())
260    }
261
262    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
263        self.to_stream_with_dist_required(&RequiredDist::Any, ctx)
264    }
265
266    fn logical_rewrite_for_stream(
267        &self,
268        ctx: &mut RewriteStreamContext,
269    ) -> Result<(PlanRef, ColIndexMapping)> {
270        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
271        let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change);
272
273        // Add missing columns of input_pk into the select list.
274        let input_pk = input.expect_stream_key();
275        let i2o = proj.i2o_col_mapping();
276        let col_need_to_add = input_pk
277            .iter()
278            .cloned()
279            .filter(|i| i2o.try_map(*i).is_none());
280        let input_schema = input.schema();
281        let exprs =
282            proj.exprs()
283                .iter()
284                .cloned()
285                .chain(col_need_to_add.map(|idx| {
286                    InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
287                }))
288                .collect();
289        let proj = Self::new(input, exprs);
290        // The added columns is at the end, so it will not change existing column indices.
291        // But the target size of `out_col_change` should be the same as the length of the new
292        // schema.
293        let (map, _) = out_col_change.into_parts();
294        let out_col_change = ColIndexMapping::new(map, proj.base.schema().len());
295        Ok((proj.into(), out_col_change))
296    }
297}
298#[cfg(test)]
299mod tests {
300
301    use risingwave_common::catalog::{Field, Schema};
302    use risingwave_common::types::DataType;
303    use risingwave_pb::expr::expr_node::Type;
304
305    use super::*;
306    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
307    use crate::optimizer::optimizer_context::OptimizerContext;
308    use crate::optimizer::plan_node::LogicalValues;
309
310    #[tokio::test]
311    /// Pruning
312    /// ```text
313    /// Project(1, input_ref(2), input_ref(0)<5)
314    ///   TableScan(v1, v2, v3)
315    /// ```
316    /// with required columns `[1, 2]` will result in
317    /// ```text
318    /// Project(input_ref(1), input_ref(0)<5)
319    ///   TableScan(v1, v3)
320    /// ```
321    async fn test_prune_project() {
322        let ty = DataType::Int32;
323        let ctx = OptimizerContext::mock().await;
324        let fields: Vec<Field> = vec![
325            Field::with_name(ty.clone(), "v1"),
326            Field::with_name(ty.clone(), "v2"),
327            Field::with_name(ty.clone(), "v3"),
328        ];
329        let values = LogicalValues::new(
330            vec![],
331            Schema {
332                fields: fields.clone(),
333            },
334            ctx,
335        );
336        let project: PlanRef = LogicalProject::new(
337            values.into(),
338            vec![
339                ExprImpl::Literal(Box::new(Literal::new(None, ty.clone()))),
340                InputRef::new(2, ty.clone()).into(),
341                ExprImpl::FunctionCall(Box::new(
342                    FunctionCall::new(
343                        Type::LessThan,
344                        vec![
345                            ExprImpl::InputRef(Box::new(InputRef::new(0, ty.clone()))),
346                            ExprImpl::Literal(Box::new(Literal::new(None, ty))),
347                        ],
348                    )
349                    .unwrap(),
350                )),
351            ],
352        )
353        .into();
354
355        // Perform the prune
356        let required_cols = vec![1, 2];
357        let plan = project.prune_col(
358            &required_cols,
359            &mut ColumnPruningContext::new(project.clone()),
360        );
361
362        // Check the result
363        let project = plan.as_logical_project().unwrap();
364        assert_eq!(project.exprs().len(), 2);
365        assert_eq_input_ref!(&project.exprs()[0], 1);
366
367        let expr = project.exprs()[1].clone();
368        let call = expr.as_function_call().unwrap();
369        assert_eq_input_ref!(&call.inputs()[0], 0);
370
371        let values = project.input();
372        let values = values.as_logical_values().unwrap();
373        assert_eq!(values.schema().fields().len(), 2);
374        assert_eq!(values.schema().fields()[0], fields[0]);
375        assert_eq!(values.schema().fields()[1], fields[2]);
376    }
377}