risingwave_frontend/optimizer/plan_node/
logical_share.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16
17use pretty_xmlish::{Pretty, XmlNode};
18use risingwave_common::bail_not_implemented;
19
20use super::utils::{Distill, childless_record};
21use super::{
22    ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary, PredicatePushdown,
23    ToBatch, ToStream, generic,
24};
25use crate::error::Result;
26use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
27use crate::optimizer::plan_node::generic::GenericPlanRef;
28use crate::optimizer::plan_node::{
29    ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, StreamShare,
30    ToStreamContext,
31};
32use crate::utils::{ColIndexMapping, Condition};
33
34/// `LogicalShare` operator is used to represent reusing of existing operators.
35/// It is the key operator for DAG plan.
36/// It could have multiple parents which makes it different from other operators.
37/// Currently, it has been used to the following scenarios:
38/// 1. Share source.
39/// 2. Subquery unnesting domain calculation.
40///
41/// A DAG plan example: A self join shares the same source.
42/// ```text
43///     LogicalJoin
44///    /           \
45///   |            |
46///   \           /
47///   LogicalShare
48///        |
49///   LogicalSource
50/// ```
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct LogicalShare {
53    pub base: PlanBase<Logical>,
54    core: generic::Share<PlanRef>,
55}
56
57impl LogicalShare {
58    pub fn new(input: PlanRef) -> Self {
59        let _ctx = input.ctx();
60        let _functional_dependency = input.functional_dependency().clone();
61        let core = generic::Share {
62            input: RefCell::new(input),
63        };
64        let base = PlanBase::new_logical_with_core(&core);
65        LogicalShare { base, core }
66    }
67
68    pub fn create(input: PlanRef) -> PlanRef {
69        LogicalShare::new(input).into()
70    }
71
72    pub(super) fn pretty_fields(base: impl GenericPlanRef, name: &str) -> XmlNode<'_> {
73        childless_record(name, vec![("id", Pretty::debug(&base.id().0))])
74    }
75}
76
77impl PlanTreeNodeUnary for LogicalShare {
78    fn input(&self) -> PlanRef {
79        self.core.input.borrow().clone()
80    }
81
82    fn clone_with_input(&self, input: PlanRef) -> Self {
83        Self::new(input)
84    }
85
86    fn rewrite_with_input(
87        &self,
88        input: PlanRef,
89        input_col_change: ColIndexMapping,
90    ) -> (Self, ColIndexMapping) {
91        (Self::new(input), input_col_change)
92    }
93}
94
95impl_plan_tree_node_for_unary! {LogicalShare}
96
97impl LogicalShare {
98    pub fn replace_input(&self, plan: PlanRef) {
99        *self.core.input.borrow_mut() = plan;
100    }
101}
102
103impl Distill for LogicalShare {
104    fn distill<'a>(&self) -> XmlNode<'a> {
105        Self::pretty_fields(&self.base, "LogicalShare")
106    }
107}
108
109impl ColPrunable for LogicalShare {
110    fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
111        unimplemented!("call prune_col of the PlanRef instead of calling directly on LogicalShare")
112    }
113}
114
115impl ExprRewritable for LogicalShare {}
116
117impl ExprVisitable for LogicalShare {}
118
119impl PredicatePushdown for LogicalShare {
120    fn predicate_pushdown(
121        &self,
122        _predicate: Condition,
123        _ctx: &mut PredicatePushdownContext,
124    ) -> PlanRef {
125        unimplemented!(
126            "call predicate_pushdown of the PlanRef instead of calling directly on LogicalShare"
127        )
128    }
129}
130
131impl ToBatch for LogicalShare {
132    fn to_batch(&self) -> Result<PlanRef> {
133        bail_not_implemented!("batch query doesn't support share operator for now");
134    }
135}
136
137impl ToStream for LogicalShare {
138    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
139        match ctx.get_to_stream_result(self.id()) {
140            None => {
141                let new_input = self.input().to_stream(ctx)?;
142                let new_logical = self.core.clone();
143                new_logical.replace_input(new_input);
144                let stream_share_ref: PlanRef = StreamShare::new(new_logical).into();
145                ctx.add_to_stream_result(self.id(), stream_share_ref.clone());
146                Ok(stream_share_ref)
147            }
148            Some(cache) => Ok(cache.clone()),
149        }
150    }
151
152    fn logical_rewrite_for_stream(
153        &self,
154        ctx: &mut RewriteStreamContext,
155    ) -> Result<(PlanRef, ColIndexMapping)> {
156        match ctx.get_rewrite_result(self.id()) {
157            None => {
158                let (new_input, col_change) = self.input().logical_rewrite_for_stream(ctx)?;
159                let new_share: PlanRef = self.clone_with_input(new_input).into();
160                ctx.add_rewrite_result(self.id(), new_share.clone(), col_change.clone());
161                Ok((new_share, col_change))
162            }
163            Some(cache) => Ok(cache.clone()),
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170
171    use risingwave_common::catalog::{Field, Schema};
172    use risingwave_common::types::{DataType, ScalarImpl};
173    use risingwave_pb::expr::expr_node::Type;
174    use risingwave_pb::plan_common::JoinType;
175
176    use super::*;
177    use crate::expr::{ExprImpl, FunctionCall, InputRef, Literal};
178    use crate::optimizer::optimizer_context::OptimizerContext;
179    use crate::optimizer::plan_node::{
180        LogicalFilter, LogicalJoin, LogicalValues, PlanTreeNodeBinary,
181    };
182
183    #[tokio::test]
184    async fn test_share_predicate_pushdown() {
185        let ty = DataType::Int32;
186        let ctx = OptimizerContext::mock().await;
187        let fields: Vec<Field> = vec![
188            Field::with_name(ty.clone(), "v1"),
189            Field::with_name(ty.clone(), "v2"),
190            Field::with_name(ty.clone(), "v3"),
191        ];
192        let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
193
194        let share: PlanRef = LogicalShare::create(values1.into());
195
196        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
197            FunctionCall::new(
198                Type::Equal,
199                vec![
200                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
201                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty.clone()))),
202                ],
203            )
204            .unwrap(),
205        ));
206
207        let predicate1: ExprImpl = ExprImpl::FunctionCall(Box::new(
208            FunctionCall::new(
209                Type::Equal,
210                vec![
211                    ExprImpl::InputRef(Box::new(InputRef::new(0, DataType::Int32))),
212                    ExprImpl::Literal(Box::new(Literal::new(
213                        Some(ScalarImpl::from(100)),
214                        DataType::Int32,
215                    ))),
216                ],
217            )
218            .unwrap(),
219        ));
220
221        let predicate2: ExprImpl = ExprImpl::FunctionCall(Box::new(
222            FunctionCall::new(
223                Type::Equal,
224                vec![
225                    ExprImpl::InputRef(Box::new(InputRef::new(4, DataType::Int32))),
226                    ExprImpl::Literal(Box::new(Literal::new(
227                        Some(ScalarImpl::from(200)),
228                        DataType::Int32,
229                    ))),
230                ],
231            )
232            .unwrap(),
233        ));
234
235        let join: PlanRef = LogicalJoin::create(share.clone(), share, JoinType::Inner, on);
236
237        let filter1: PlanRef = LogicalFilter::create_with_expr(join, predicate1);
238
239        let filter2: PlanRef = LogicalFilter::create_with_expr(filter1, predicate2);
240
241        let result = filter2.predicate_pushdown(
242            Condition::true_cond(),
243            &mut PredicatePushdownContext::new(filter2.clone()),
244        );
245
246        // LogicalJoin { type: Inner, on: (v2 = v1) }
247        // ├─LogicalFilter { predicate: (v1 = 100:Int32) }
248        // | └─LogicalShare { id = 2 }
249        // |   └─LogicalFilter { predicate: ((v1 = 100:Int32) OR (v2 = 200:Int32)) }
250        // |     └─LogicalValues { schema: Schema { fields: [v1:Int32, v2:Int32, v3:Int32] } }
251        // └─LogicalFilter { predicate: (v2 = 200:Int32) }
252        //   └─LogicalShare { id = 2 }
253        //     └─LogicalFilter { predicate: ((v1 = 100:Int32) OR (v2 = 200:Int32)) }
254        //       └─LogicalValues { schema: Schema { fields: [v1:Int32, v2:Int32, v3:Int32] } }
255
256        let logical_join: &LogicalJoin = result.as_logical_join().unwrap();
257        let left = logical_join.left();
258        let left_filter: &LogicalFilter = left.as_logical_filter().unwrap();
259        let left_filter_input = left_filter.input();
260        let logical_share: &LogicalShare = left_filter_input.as_logical_share().unwrap();
261        let share_input = logical_share.input();
262        let share_input_filter: &LogicalFilter = share_input.as_logical_filter().unwrap();
263        let disjunctions = share_input_filter.predicate().conjunctions[0]
264            .as_or_disjunctions()
265            .unwrap();
266        assert_eq!(disjunctions.len(), 2);
267        let (input_ref1, _const1) = disjunctions[0].as_eq_const().unwrap();
268        let (input_ref2, _const2) = disjunctions[1].as_eq_const().unwrap();
269        if input_ref1.index() == 0 {
270            assert_eq!(input_ref2.index(), 1);
271        } else {
272            assert_eq!(input_ref1.index(), 1);
273            assert_eq!(input_ref2.index(), 0);
274        }
275    }
276}