risingwave_frontend/optimizer/plan_node/
logical_apply.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::catalog::Schema;
18use risingwave_pb::plan_common::JoinType;
19
20use super::generic::{
21    self, GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
22};
23use super::utils::{Distill, childless_record};
24use super::{
25    ColPrunable, Logical, LogicalJoin, LogicalProject, PlanBase, PlanRef, PlanTreeNodeBinary,
26    PredicatePushdown, ToBatch, ToStream,
27};
28use crate::error::{ErrorCode, Result, RwError};
29use crate::expr::{CorrelatedId, Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef};
30use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
31use crate::optimizer::plan_node::{
32    ColumnPruningContext, ExprRewritable, LogicalFilter, PredicatePushdownContext,
33    RewriteStreamContext, ToStreamContext,
34};
35use crate::optimizer::property::FunctionalDependencySet;
36use crate::utils::{ColIndexMapping, Condition, ConditionDisplay};
37
38/// `LogicalApply` represents a correlated join, where the right side may refer to columns from the
39/// left side.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct LogicalApply {
42    pub base: PlanBase<Logical>,
43    left: PlanRef,
44    right: PlanRef,
45    on: Condition,
46    join_type: JoinType,
47
48    /// Id of the Apply operator.
49    /// So `correlated_input_ref` can refer the Apply operator exactly by `correlated_id`.
50    correlated_id: CorrelatedId,
51    /// The indices of `CorrelatedInputRef`s in `right`.
52    correlated_indices: Vec<usize>,
53    /// Whether we require the subquery to produce at most one row. If `true`, we have to report an
54    /// error if the subquery produces more than one row.
55    max_one_row: bool,
56
57    /// An apply has been translated by `translate_apply()`, so we should not translate it in `translate_apply_rule` again.
58    /// This flag is used to avoid infinite loop in General Unnesting(Translate Apply), since we use a top-down apply order instead of bottom-up to improve the multi-scalar subqueries optimization time.
59    translated: bool,
60}
61
62impl Distill for LogicalApply {
63    fn distill<'a>(&self) -> XmlNode<'a> {
64        let mut vec = Vec::with_capacity(if self.max_one_row { 4 } else { 3 });
65        vec.push(("type", Pretty::debug(&self.join_type)));
66
67        let concat_schema = self.concat_schema();
68        let cond = Pretty::debug(&ConditionDisplay {
69            condition: &self.on,
70            input_schema: &concat_schema,
71        });
72        vec.push(("on", cond));
73
74        vec.push(("correlated_id", Pretty::debug(&self.correlated_id)));
75        if self.max_one_row {
76            vec.push(("max_one_row", Pretty::debug(&true)));
77        }
78
79        childless_record("LogicalApply", vec)
80    }
81}
82
83impl LogicalApply {
84    pub(crate) fn new(
85        left: PlanRef,
86        right: PlanRef,
87        join_type: JoinType,
88        on: Condition,
89        correlated_id: CorrelatedId,
90        correlated_indices: Vec<usize>,
91        max_one_row: bool,
92        translated: bool,
93    ) -> Self {
94        let ctx = left.ctx();
95        let join_core = generic::Join::with_full_output(left, right, join_type, on);
96        let schema = join_core.schema();
97        let stream_key = join_core.stream_key();
98        let functional_dependency = match &stream_key {
99            Some(stream_key) => FunctionalDependencySet::with_key(schema.len(), stream_key),
100            None => FunctionalDependencySet::new(schema.len()),
101        };
102        let (left, right, on, join_type, _output_indices) = join_core.decompose();
103        let base = PlanBase::new_logical(ctx, schema, stream_key, functional_dependency);
104        LogicalApply {
105            base,
106            left,
107            right,
108            on,
109            join_type,
110            correlated_id,
111            correlated_indices,
112            max_one_row,
113            translated,
114        }
115    }
116
117    pub fn create(
118        left: PlanRef,
119        right: PlanRef,
120        join_type: JoinType,
121        on: Condition,
122        correlated_id: CorrelatedId,
123        correlated_indices: Vec<usize>,
124        max_one_row: bool,
125    ) -> PlanRef {
126        Self::new(
127            left,
128            right,
129            join_type,
130            on,
131            correlated_id,
132            correlated_indices,
133            max_one_row,
134            false,
135        )
136        .into()
137    }
138
139    /// Get the join type of the logical apply.
140    pub fn join_type(&self) -> JoinType {
141        self.join_type
142    }
143
144    pub fn decompose(
145        self,
146    ) -> (
147        PlanRef,
148        PlanRef,
149        Condition,
150        JoinType,
151        CorrelatedId,
152        Vec<usize>,
153        bool,
154    ) {
155        (
156            self.left,
157            self.right,
158            self.on,
159            self.join_type,
160            self.correlated_id,
161            self.correlated_indices,
162            self.max_one_row,
163        )
164    }
165
166    pub fn correlated_id(&self) -> CorrelatedId {
167        self.correlated_id
168    }
169
170    pub fn correlated_indices(&self) -> Vec<usize> {
171        self.correlated_indices.to_owned()
172    }
173
174    pub fn translated(&self) -> bool {
175        self.translated
176    }
177
178    pub fn max_one_row(&self) -> bool {
179        self.max_one_row
180    }
181
182    /// Translate Apply.
183    ///
184    /// Used to convert other kinds of Apply to cross Apply.
185    ///
186    /// Before:
187    ///
188    /// ```text
189    ///     LogicalApply
190    ///    /            \
191    ///  LHS           RHS
192    /// ```
193    ///
194    /// After:
195    ///
196    /// ```text
197    ///      LogicalJoin
198    ///    /            \
199    ///  LHS        LogicalApply
200    ///             /           \
201    ///          Domain         RHS
202    /// ```
203    pub fn translate_apply(self, domain: PlanRef, eq_predicates: Vec<ExprImpl>) -> PlanRef {
204        let (
205            apply_left,
206            apply_right,
207            on,
208            apply_type,
209            correlated_id,
210            correlated_indices,
211            max_one_row,
212        ) = self.decompose();
213        let apply_left_len = apply_left.schema().len();
214        let correlated_indices_len = correlated_indices.len();
215
216        let new_apply = LogicalApply::new(
217            domain,
218            apply_right,
219            JoinType::Inner,
220            Condition::true_cond(),
221            correlated_id,
222            correlated_indices,
223            max_one_row,
224            true,
225        )
226        .into();
227
228        let on = Self::rewrite_on(on, correlated_indices_len, apply_left_len).and(Condition {
229            conjunctions: eq_predicates,
230        });
231        let new_join = LogicalJoin::new(apply_left, new_apply, apply_type, on);
232
233        if new_join.join_type() == JoinType::LeftSemi {
234            // Schema doesn't change, still LHS.
235            new_join.into()
236        } else {
237            // `new_join`'s schema is different from original apply's schema, so `LogicalProject` is
238            // used to ensure they are the same.
239            let mut exprs: Vec<ExprImpl> = vec![];
240            new_join
241                .schema()
242                .data_types()
243                .into_iter()
244                .enumerate()
245                .for_each(|(index, data_type)| {
246                    if index < apply_left_len || index >= apply_left_len + correlated_indices_len {
247                        exprs.push(InputRef::new(index, data_type).into());
248                    }
249                });
250            LogicalProject::create(new_join.into(), exprs)
251        }
252    }
253
254    fn rewrite_on(on: Condition, offset: usize, apply_left_len: usize) -> Condition {
255        struct Rewriter {
256            offset: usize,
257            apply_left_len: usize,
258        }
259        impl ExprRewriter for Rewriter {
260            fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
261                let index = input_ref.index();
262                if index >= self.apply_left_len {
263                    InputRef::new(index + self.offset, input_ref.return_type()).into()
264                } else {
265                    input_ref.into()
266                }
267            }
268        }
269        let mut rewriter = Rewriter {
270            offset,
271            apply_left_len,
272        };
273        on.rewrite_expr(&mut rewriter)
274    }
275
276    fn concat_schema(&self) -> Schema {
277        let mut concat_schema = self.left().schema().fields.clone();
278        concat_schema.extend(self.right().schema().fields.clone());
279        Schema::new(concat_schema)
280    }
281}
282
283impl PlanTreeNodeBinary for LogicalApply {
284    fn left(&self) -> PlanRef {
285        self.left.clone()
286    }
287
288    fn right(&self) -> PlanRef {
289        self.right.clone()
290    }
291
292    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
293        Self::new(
294            left,
295            right,
296            self.join_type,
297            self.on.clone(),
298            self.correlated_id,
299            self.correlated_indices.clone(),
300            self.max_one_row,
301            self.translated,
302        )
303    }
304}
305
306impl_plan_tree_node_for_binary! { LogicalApply }
307
308impl ColPrunable for LogicalApply {
309    fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
310        panic!("LogicalApply should be unnested")
311    }
312}
313
314impl ExprRewritable for LogicalApply {
315    fn has_rewritable_expr(&self) -> bool {
316        true
317    }
318
319    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
320        let mut new = self.clone();
321        new.on = new.on.rewrite_expr(r);
322        new.base = new.base.clone_with_new_plan_id();
323        new.into()
324    }
325}
326
327impl ExprVisitable for LogicalApply {
328    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
329        self.on.visit_expr(v)
330    }
331}
332
333impl PredicatePushdown for LogicalApply {
334    fn predicate_pushdown(
335        &self,
336        mut predicate: Condition,
337        ctx: &mut PredicatePushdownContext,
338    ) -> PlanRef {
339        let left_col_num = self.left().schema().len();
340        let right_col_num = self.right().schema().len();
341        let join_type = self.join_type();
342
343        let (left_from_filter, right_from_filter, on) =
344            push_down_into_join(&mut predicate, left_col_num, right_col_num, join_type, true);
345
346        let mut new_on = self.on.clone().and(on);
347        let (left_from_on, right_from_on) =
348            push_down_join_condition(&mut new_on, left_col_num, right_col_num, join_type, true);
349
350        let left_predicate = left_from_filter.and(left_from_on);
351        let right_predicate = right_from_filter.and(right_from_on);
352
353        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
354        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
355
356        let new_apply = LogicalApply::create(
357            new_left,
358            new_right,
359            join_type,
360            new_on,
361            self.correlated_id,
362            self.correlated_indices.clone(),
363            self.max_one_row,
364        );
365        LogicalFilter::create(new_apply, predicate)
366    }
367}
368
369impl ToBatch for LogicalApply {
370    fn to_batch(&self) -> Result<PlanRef> {
371        Err(RwError::from(ErrorCode::InternalError(
372            "LogicalApply should be unnested".to_owned(),
373        )))
374    }
375}
376
377impl ToStream for LogicalApply {
378    fn to_stream(&self, _ctx: &mut ToStreamContext) -> Result<PlanRef> {
379        Err(RwError::from(ErrorCode::InternalError(
380            "LogicalApply should be unnested".to_owned(),
381        )))
382    }
383
384    fn logical_rewrite_for_stream(
385        &self,
386        _ctx: &mut RewriteStreamContext,
387    ) -> Result<(PlanRef, ColIndexMapping)> {
388        Err(RwError::from(ErrorCode::InternalError(
389            "LogicalApply should be unnested".to_owned(),
390        )))
391    }
392}