risingwave_frontend/optimizer/plan_node/
logical_share.rs1use std::cell::RefCell;
16
17use pretty_xmlish::{Pretty, XmlNode};
18use risingwave_common::bail_not_implemented;
19
20use super::utils::{Distill, childless_record};
21use super::{
22 ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary, PredicatePushdown,
23 ToBatch, ToStream, generic,
24};
25use crate::error::Result;
26use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
27use crate::optimizer::plan_node::generic::GenericPlanRef;
28use crate::optimizer::plan_node::{
29 ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, StreamShare,
30 ToStreamContext,
31};
32use crate::utils::{ColIndexMapping, Condition};
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct LogicalShare {
53 pub base: PlanBase<Logical>,
54 core: generic::Share<PlanRef>,
55}
56
57impl LogicalShare {
58 pub fn new(input: PlanRef) -> Self {
59 let _ctx = input.ctx();
60 let _functional_dependency = input.functional_dependency().clone();
61 let core = generic::Share {
62 input: RefCell::new(input),
63 };
64 let base = PlanBase::new_logical_with_core(&core);
65 LogicalShare { base, core }
66 }
67
68 pub fn create(input: PlanRef) -> PlanRef {
69 LogicalShare::new(input).into()
70 }
71
72 pub(super) fn pretty_fields(base: impl GenericPlanRef, name: &str) -> XmlNode<'_> {
73 childless_record(name, vec![("id", Pretty::debug(&base.id().0))])
74 }
75}
76
77impl PlanTreeNodeUnary for LogicalShare {
78 fn input(&self) -> PlanRef {
79 self.core.input.borrow().clone()
80 }
81
82 fn clone_with_input(&self, input: PlanRef) -> Self {
83 Self::new(input)
84 }
85
86 fn rewrite_with_input(
87 &self,
88 input: PlanRef,
89 input_col_change: ColIndexMapping,
90 ) -> (Self, ColIndexMapping) {
91 (Self::new(input), input_col_change)
92 }
93}
94
95impl_plan_tree_node_for_unary! {LogicalShare}
96
97impl LogicalShare {
98 pub fn replace_input(&self, plan: PlanRef) {
99 *self.core.input.borrow_mut() = plan;
100 }
101}
102
103impl Distill for LogicalShare {
104 fn distill<'a>(&self) -> XmlNode<'a> {
105 Self::pretty_fields(&self.base, "LogicalShare")
106 }
107}
108
109impl ColPrunable for LogicalShare {
110 fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
111 unimplemented!("call prune_col of the PlanRef instead of calling directly on LogicalShare")
112 }
113}
114
115impl ExprRewritable for LogicalShare {}
116
117impl ExprVisitable for LogicalShare {}
118
119impl PredicatePushdown for LogicalShare {
120 fn predicate_pushdown(
121 &self,
122 _predicate: Condition,
123 _ctx: &mut PredicatePushdownContext,
124 ) -> PlanRef {
125 unimplemented!(
126 "call predicate_pushdown of the PlanRef instead of calling directly on LogicalShare"
127 )
128 }
129}
130
131impl ToBatch for LogicalShare {
132 fn to_batch(&self) -> Result<PlanRef> {
133 bail_not_implemented!("batch query doesn't support share operator for now");
134 }
135}
136
137impl ToStream for LogicalShare {
138 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
139 match ctx.get_to_stream_result(self.id()) {
140 None => {
141 let new_input = self.input().to_stream(ctx)?;
142 let new_logical = self.core.clone();
143 new_logical.replace_input(new_input);
144 let stream_share_ref: PlanRef = StreamShare::new(new_logical).into();
145 ctx.add_to_stream_result(self.id(), stream_share_ref.clone());
146 Ok(stream_share_ref)
147 }
148 Some(cache) => Ok(cache.clone()),
149 }
150 }
151
152 fn logical_rewrite_for_stream(
153 &self,
154 ctx: &mut RewriteStreamContext,
155 ) -> Result<(PlanRef, ColIndexMapping)> {
156 match ctx.get_rewrite_result(self.id()) {
157 None => {
158 let (new_input, col_change) = self.input().logical_rewrite_for_stream(ctx)?;
159 let new_share: PlanRef = self.clone_with_input(new_input).into();
160 ctx.add_rewrite_result(self.id(), new_share.clone(), col_change.clone());
161 Ok((new_share, col_change))
162 }
163 Some(cache) => Ok(cache.clone()),
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170
171 use risingwave_common::catalog::{Field, Schema};
172 use risingwave_common::types::{DataType, ScalarImpl};
173 use risingwave_pb::expr::expr_node::Type;
174 use risingwave_pb::plan_common::JoinType;
175
176 use super::*;
177 use crate::expr::{ExprImpl, FunctionCall, InputRef, Literal};
178 use crate::optimizer::optimizer_context::OptimizerContext;
179 use crate::optimizer::plan_node::{
180 LogicalFilter, LogicalJoin, LogicalValues, PlanTreeNodeBinary,
181 };
182
183 #[tokio::test]
184 async fn test_share_predicate_pushdown() {
185 let ty = DataType::Int32;
186 let ctx = OptimizerContext::mock().await;
187 let fields: Vec<Field> = vec![
188 Field::with_name(ty.clone(), "v1"),
189 Field::with_name(ty.clone(), "v2"),
190 Field::with_name(ty.clone(), "v3"),
191 ];
192 let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
193
194 let share: PlanRef = LogicalShare::create(values1.into());
195
196 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
197 FunctionCall::new(
198 Type::Equal,
199 vec![
200 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
201 ExprImpl::InputRef(Box::new(InputRef::new(3, ty.clone()))),
202 ],
203 )
204 .unwrap(),
205 ));
206
207 let predicate1: ExprImpl = ExprImpl::FunctionCall(Box::new(
208 FunctionCall::new(
209 Type::Equal,
210 vec![
211 ExprImpl::InputRef(Box::new(InputRef::new(0, DataType::Int32))),
212 ExprImpl::Literal(Box::new(Literal::new(
213 Some(ScalarImpl::from(100)),
214 DataType::Int32,
215 ))),
216 ],
217 )
218 .unwrap(),
219 ));
220
221 let predicate2: ExprImpl = ExprImpl::FunctionCall(Box::new(
222 FunctionCall::new(
223 Type::Equal,
224 vec![
225 ExprImpl::InputRef(Box::new(InputRef::new(4, DataType::Int32))),
226 ExprImpl::Literal(Box::new(Literal::new(
227 Some(ScalarImpl::from(200)),
228 DataType::Int32,
229 ))),
230 ],
231 )
232 .unwrap(),
233 ));
234
235 let join: PlanRef = LogicalJoin::create(share.clone(), share, JoinType::Inner, on);
236
237 let filter1: PlanRef = LogicalFilter::create_with_expr(join, predicate1);
238
239 let filter2: PlanRef = LogicalFilter::create_with_expr(filter1, predicate2);
240
241 let result = filter2.predicate_pushdown(
242 Condition::true_cond(),
243 &mut PredicatePushdownContext::new(filter2.clone()),
244 );
245
246 let logical_join: &LogicalJoin = result.as_logical_join().unwrap();
257 let left = logical_join.left();
258 let left_filter: &LogicalFilter = left.as_logical_filter().unwrap();
259 let left_filter_input = left_filter.input();
260 let logical_share: &LogicalShare = left_filter_input.as_logical_share().unwrap();
261 let share_input = logical_share.input();
262 let share_input_filter: &LogicalFilter = share_input.as_logical_filter().unwrap();
263 let disjunctions = share_input_filter.predicate().conjunctions[0]
264 .as_or_disjunctions()
265 .unwrap();
266 assert_eq!(disjunctions.len(), 2);
267 let (input_ref1, _const1) = disjunctions[0].as_eq_const().unwrap();
268 let (input_ref2, _const2) = disjunctions[1].as_eq_const().unwrap();
269 if input_ref1.index() == 0 {
270 assert_eq!(input_ref2.index(), 1);
271 } else {
272 assert_eq!(input_ref1.index(), 1);
273 assert_eq!(input_ref2.index(), 0);
274 }
275 }
276}