risingwave_frontend/optimizer/plan_node/
logical_share.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use pretty_xmlish::{Pretty, XmlNode};
16use risingwave_common::bail_not_implemented;
17
18use super::utils::{Distill, childless_record};
19use super::{
20    ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase, PlanTreeNodeUnary,
21    PredicatePushdown, ShareNode, StreamPlanRef, ToBatch, ToStream, generic,
22};
23use crate::error::Result;
24use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
25use crate::optimizer::plan_node::generic::{GenericPlanRef, Share};
26use crate::optimizer::plan_node::{
27    ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, StreamShare,
28    ToStreamContext,
29};
30use crate::utils::{ColIndexMapping, Condition};
31
32/// `LogicalShare` operator is used to represent reusing of existing operators.
33/// It is the key operator for DAG plan.
34/// It could have multiple parents which makes it different from other operators.
35/// Currently, it has been used to the following scenarios:
36/// 1. Share source.
37/// 2. Subquery unnesting domain calculation.
38///
39/// A DAG plan example: A self join shares the same source.
40/// ```text
41///     LogicalJoin
42///    /           \
43///   |            |
44///   \           /
45///   LogicalShare
46///        |
47///   LogicalSource
48/// ```
49#[derive(Debug, Clone, PartialEq, Eq, Hash)]
50pub struct LogicalShare {
51    pub base: PlanBase<Logical>,
52    core: generic::Share<PlanRef>,
53}
54
55impl LogicalShare {
56    pub fn new(input: PlanRef) -> Self {
57        let _ctx = input.ctx();
58        let _functional_dependency = input.functional_dependency().clone();
59        let core = generic::Share::new(input);
60        let base = PlanBase::new_logical_with_core(&core);
61        LogicalShare { base, core }
62    }
63
64    pub fn create(input: PlanRef) -> PlanRef {
65        LogicalShare::new(input).into()
66    }
67
68    pub(super) fn pretty_fields(base: impl GenericPlanRef, name: &str) -> XmlNode<'_> {
69        childless_record(name, vec![("id", Pretty::debug(&base.id().0))])
70    }
71}
72
73impl PlanTreeNodeUnary<Logical> for LogicalShare {
74    fn input(&self) -> PlanRef {
75        self.core.input.borrow().clone()
76    }
77
78    fn clone_with_input(&self, _input: PlanRef) -> Self {
79        unreachable!("shared node should be handled specially in PlanRef::clone_with_input")
80    }
81
82    fn rewrite_with_input(
83        &self,
84        input: PlanRef,
85        input_col_change: ColIndexMapping,
86    ) -> (Self, ColIndexMapping) {
87        (Self::new(input), input_col_change)
88    }
89}
90
91impl_plan_tree_node_for_unary! { Logical, LogicalShare}
92
93impl ShareNode<Logical> for LogicalShare {
94    fn new_share(core: Share<PlanRef>) -> PlanRef {
95        let base = PlanBase::new_logical_with_core(&core);
96        LogicalShare { base, core }.into()
97    }
98
99    fn replace_input(&self, plan: PlanRef) {
100        *self.core.input.borrow_mut() = plan;
101    }
102}
103
104impl Distill for LogicalShare {
105    fn distill<'a>(&self) -> XmlNode<'a> {
106        Self::pretty_fields(&self.base, "LogicalShare")
107    }
108}
109
110impl ColPrunable for LogicalShare {
111    fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
112        unimplemented!("call prune_col of the PlanRef instead of calling directly on LogicalShare")
113    }
114}
115
116impl ExprRewritable<Logical> for LogicalShare {}
117
118impl ExprVisitable for LogicalShare {}
119
120impl PredicatePushdown for LogicalShare {
121    fn predicate_pushdown(
122        &self,
123        _predicate: Condition,
124        _ctx: &mut PredicatePushdownContext,
125    ) -> PlanRef {
126        unimplemented!(
127            "call predicate_pushdown of the PlanRef instead of calling directly on LogicalShare"
128        )
129    }
130}
131
132impl ToBatch for LogicalShare {
133    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
134        bail_not_implemented!("batch query doesn't support share operator for now");
135    }
136}
137
138impl ToStream for LogicalShare {
139    fn to_stream(
140        &self,
141        ctx: &mut ToStreamContext,
142    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
143        match ctx.get_to_stream_result(self.id()) {
144            None => {
145                let new_input = self.input().to_stream(ctx)?;
146                let core = generic::Share::new(new_input);
147                let stream_share_ref: StreamPlanRef = StreamShare::new(core).into();
148                ctx.add_to_stream_result(self.id(), stream_share_ref.clone());
149                Ok(stream_share_ref)
150            }
151            Some(cache) => Ok(cache.clone()),
152        }
153    }
154
155    fn logical_rewrite_for_stream(
156        &self,
157        ctx: &mut RewriteStreamContext,
158    ) -> Result<(PlanRef, ColIndexMapping)> {
159        match ctx.get_rewrite_result(self.id()) {
160            None => {
161                let (new_input, col_change) = self.input().logical_rewrite_for_stream(ctx)?;
162                let new_share: PlanRef = Self::new(new_input).into();
163                ctx.add_rewrite_result(self.id(), new_share.clone(), col_change.clone());
164                Ok((new_share, col_change))
165            }
166            Some(cache) => Ok(cache.clone()),
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173
174    use risingwave_common::catalog::{Field, Schema};
175    use risingwave_common::types::{DataType, ScalarImpl};
176    use risingwave_pb::expr::expr_node::Type;
177    use risingwave_pb::plan_common::JoinType;
178
179    use super::*;
180    use crate::expr::{ExprImpl, FunctionCall, InputRef, Literal};
181    use crate::optimizer::optimizer_context::OptimizerContext;
182    use crate::optimizer::plan_node::{
183        LogicalFilter, LogicalJoin, LogicalValues, PlanTreeNodeBinary,
184    };
185
186    #[tokio::test]
187    async fn test_share_predicate_pushdown() {
188        let ty = DataType::Int32;
189        let ctx = OptimizerContext::mock().await;
190        let fields: Vec<Field> = vec![
191            Field::with_name(ty.clone(), "v1"),
192            Field::with_name(ty.clone(), "v2"),
193            Field::with_name(ty.clone(), "v3"),
194        ];
195        let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
196
197        let share: PlanRef = LogicalShare::create(values1.into());
198
199        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
200            FunctionCall::new(
201                Type::Equal,
202                vec![
203                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
204                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty.clone()))),
205                ],
206            )
207            .unwrap(),
208        ));
209
210        let predicate1: ExprImpl = ExprImpl::FunctionCall(Box::new(
211            FunctionCall::new(
212                Type::Equal,
213                vec![
214                    ExprImpl::InputRef(Box::new(InputRef::new(0, DataType::Int32))),
215                    ExprImpl::Literal(Box::new(Literal::new(
216                        Some(ScalarImpl::from(100)),
217                        DataType::Int32,
218                    ))),
219                ],
220            )
221            .unwrap(),
222        ));
223
224        let predicate2: ExprImpl = ExprImpl::FunctionCall(Box::new(
225            FunctionCall::new(
226                Type::Equal,
227                vec![
228                    ExprImpl::InputRef(Box::new(InputRef::new(4, DataType::Int32))),
229                    ExprImpl::Literal(Box::new(Literal::new(
230                        Some(ScalarImpl::from(200)),
231                        DataType::Int32,
232                    ))),
233                ],
234            )
235            .unwrap(),
236        ));
237
238        let join: PlanRef = LogicalJoin::create(share.clone(), share, JoinType::Inner, on);
239
240        let filter1: PlanRef = LogicalFilter::create_with_expr(join, predicate1);
241
242        let filter2: PlanRef = LogicalFilter::create_with_expr(filter1, predicate2);
243
244        let result = filter2.predicate_pushdown(
245            Condition::true_cond(),
246            &mut PredicatePushdownContext::new(filter2.clone()),
247        );
248
249        // LogicalJoin { type: Inner, on: (v2 = v1) }
250        // ├─LogicalFilter { predicate: (v1 = 100:Int32) }
251        // | └─LogicalShare { id = 2 }
252        // |   └─LogicalFilter { predicate: ((v1 = 100:Int32) OR (v2 = 200:Int32)) }
253        // |     └─LogicalValues { schema: Schema { fields: [v1:Int32, v2:Int32, v3:Int32] } }
254        // └─LogicalFilter { predicate: (v2 = 200:Int32) }
255        //   └─LogicalShare { id = 2 }
256        //     └─LogicalFilter { predicate: ((v1 = 100:Int32) OR (v2 = 200:Int32)) }
257        //       └─LogicalValues { schema: Schema { fields: [v1:Int32, v2:Int32, v3:Int32] } }
258
259        let logical_join: &LogicalJoin = result.as_logical_join().unwrap();
260        let left = logical_join.left();
261        let left_filter: &LogicalFilter = left.as_logical_filter().unwrap();
262        let left_filter_input = left_filter.input();
263        let logical_share: &LogicalShare = left_filter_input.as_logical_share().unwrap();
264        let share_input = logical_share.input();
265        let share_input_filter: &LogicalFilter = share_input.as_logical_filter().unwrap();
266        let disjunctions = share_input_filter.predicate().conjunctions[0]
267            .as_or_disjunctions()
268            .unwrap();
269        assert_eq!(disjunctions.len(), 2);
270        let (input_ref1, _const1) = disjunctions[0].as_eq_const().unwrap();
271        let (input_ref2, _const2) = disjunctions[1].as_eq_const().unwrap();
272        if input_ref1.index() == 0 {
273            assert_eq!(input_ref2.index(), 1);
274        } else {
275            assert_eq!(input_ref1.index(), 1);
276            assert_eq!(input_ref2.index(), 0);
277        }
278    }
279}