risingwave_frontend/optimizer/plan_node/
logical_project.rs1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use pretty_xmlish::XmlNode;
18
19use super::utils::{Distill, childless_record};
20use super::{
21 BatchProject, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary,
22 PredicatePushdown, StreamProject, ToBatch, ToStream, gen_filter_and_pushdown, generic,
23};
24use crate::error::Result;
25use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor, InputRef, collect_input_refs};
26use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
27use crate::optimizer::plan_node::generic::GenericPlanRef;
28use crate::optimizer::plan_node::{
29 ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
30};
31use crate::optimizer::property::{Distribution, Order, RequiredDist};
32use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, Substitute};
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub struct LogicalProject {
37 pub base: PlanBase<Logical>,
38 core: generic::Project<PlanRef>,
39}
40
41impl LogicalProject {
42 pub fn create(input: PlanRef, exprs: Vec<ExprImpl>) -> PlanRef {
43 Self::new(input, exprs).into()
44 }
45
46 pub fn new(input: PlanRef, exprs: Vec<ExprImpl>) -> Self {
47 let core = generic::Project::new(exprs, input);
48 Self::with_core(core)
49 }
50
51 pub fn with_core(core: generic::Project<PlanRef>) -> Self {
52 let base = PlanBase::new_logical_with_core(&core);
53 LogicalProject { base, core }
54 }
55
56 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
57 self.core.o2i_col_mapping()
58 }
59
60 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
61 self.core.i2o_col_mapping()
62 }
63
64 pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
71 Self::with_core(generic::Project::with_mapping(input, mapping))
72 }
73
74 pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
76 Self::with_core(generic::Project::with_out_fields(input, out_fields))
77 }
78
79 pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
81 Self::with_core(generic::Project::with_out_col_idx(input, out_fields))
82 }
83
84 pub fn exprs(&self) -> &Vec<ExprImpl> {
85 &self.core.exprs
86 }
87
88 pub fn is_identity(&self) -> bool {
89 self.core.is_identity()
90 }
91
92 pub fn try_as_projection(&self) -> Option<Vec<usize>> {
93 self.core.try_as_projection()
94 }
95
96 pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
97 self.core.decompose()
98 }
99
100 pub fn is_all_inputref(&self) -> bool {
101 self.core.is_all_inputref()
102 }
103}
104
105impl PlanTreeNodeUnary for LogicalProject {
106 fn input(&self) -> PlanRef {
107 self.core.input.clone()
108 }
109
110 fn clone_with_input(&self, input: PlanRef) -> Self {
111 Self::new(input, self.exprs().clone())
112 }
113
114 fn rewrite_with_input(
115 &self,
116 input: PlanRef,
117 mut input_col_change: ColIndexMapping,
118 ) -> (Self, ColIndexMapping) {
119 let exprs = self
120 .exprs()
121 .clone()
122 .into_iter()
123 .map(|expr| input_col_change.rewrite_expr(expr))
124 .collect();
125 let proj = Self::new(input, exprs);
126 let out_col_change = ColIndexMapping::identity(self.schema().len());
128 (proj, out_col_change)
129 }
130}
131
132impl_plan_tree_node_for_unary! {LogicalProject}
133
134impl Distill for LogicalProject {
135 fn distill<'a>(&self) -> XmlNode<'a> {
136 childless_record(
137 "LogicalProject",
138 self.core.fields_pretty(self.base.schema()),
139 )
140 }
141}
142
143impl ColPrunable for LogicalProject {
144 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
145 let input_col_num: usize = self.input().schema().len();
146 let input_required_cols = collect_input_refs(
147 input_col_num,
148 required_cols.iter().map(|i| &self.exprs()[*i]),
149 )
150 .ones()
151 .collect_vec();
152 let new_input = self.input().prune_col(&input_required_cols, ctx);
153 let mut mapping = ColIndexMapping::with_remaining_columns(
154 &input_required_cols,
155 self.input().schema().len(),
156 );
157 let exprs = required_cols
159 .iter()
160 .map(|&id| mapping.rewrite_expr(self.exprs()[id].clone()))
161 .collect();
162
163 LogicalProject::new(new_input, exprs).into()
165 }
166}
167
168impl ExprRewritable for LogicalProject {
169 fn has_rewritable_expr(&self) -> bool {
170 true
171 }
172
173 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
174 let mut core = self.core.clone();
175 core.rewrite_exprs(r);
176 Self {
177 base: self.base.clone_with_new_plan_id(),
178 core,
179 }
180 .into()
181 }
182}
183
184impl ExprVisitable for LogicalProject {
185 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
186 self.core.visit_exprs(v);
187 }
188}
189
190impl PredicatePushdown for LogicalProject {
191 fn predicate_pushdown(
192 &self,
193 predicate: Condition,
194 ctx: &mut PredicatePushdownContext,
195 ) -> PlanRef {
196 let mut subst = Substitute {
198 mapping: self.exprs().clone(),
199 };
200
201 let impure_mask = {
202 let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
203 for (i, e) in self.exprs().iter().enumerate() {
204 impure_mask.set(i, e.is_impure())
205 }
206 impure_mask
207 };
208 let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
210 let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
211
212 gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
213 }
214}
215
216impl ToBatch for LogicalProject {
217 fn to_batch(&self) -> Result<PlanRef> {
218 self.to_batch_with_order_required(&Order::any())
219 }
220
221 fn to_batch_with_order_required(&self, required_order: &Order) -> Result<PlanRef> {
222 let input_order = self
223 .o2i_col_mapping()
224 .rewrite_provided_order(required_order);
225 let new_input = self.input().to_batch_with_order_required(&input_order)?;
226 let mut new_logical = self.core.clone();
227 new_logical.input = new_input;
228 let batch_project = BatchProject::new(new_logical);
229 required_order.enforce_if_not_satisfies(batch_project.into())
230 }
231}
232
233impl ToStream for LogicalProject {
234 fn to_stream_with_dist_required(
235 &self,
236 required_dist: &RequiredDist,
237 ctx: &mut ToStreamContext,
238 ) -> Result<PlanRef> {
239 let input_required = if required_dist.satisfies(&RequiredDist::AnyShard) {
240 RequiredDist::Any
241 } else {
242 let input_required = self
243 .o2i_col_mapping()
244 .rewrite_required_distribution(required_dist);
245 match input_required {
246 RequiredDist::PhysicalDist(dist) => match dist {
247 Distribution::Single => RequiredDist::Any,
248 _ => RequiredDist::PhysicalDist(dist),
249 },
250 _ => input_required,
251 }
252 };
253 let new_input = self
254 .input()
255 .to_stream_with_dist_required(&input_required, ctx)?;
256 let mut new_logical = self.core.clone();
257 new_logical.input = new_input;
258 let stream_plan = StreamProject::new(new_logical);
259 required_dist.enforce_if_not_satisfies(stream_plan.into(), &Order::any())
260 }
261
262 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
263 self.to_stream_with_dist_required(&RequiredDist::Any, ctx)
264 }
265
266 fn logical_rewrite_for_stream(
267 &self,
268 ctx: &mut RewriteStreamContext,
269 ) -> Result<(PlanRef, ColIndexMapping)> {
270 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
271 let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change);
272
273 let input_pk = input.expect_stream_key();
275 let i2o = proj.i2o_col_mapping();
276 let col_need_to_add = input_pk
277 .iter()
278 .cloned()
279 .filter(|i| i2o.try_map(*i).is_none());
280 let input_schema = input.schema();
281 let exprs =
282 proj.exprs()
283 .iter()
284 .cloned()
285 .chain(col_need_to_add.map(|idx| {
286 InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
287 }))
288 .collect();
289 let proj = Self::new(input, exprs);
290 let (map, _) = out_col_change.into_parts();
294 let out_col_change = ColIndexMapping::new(map, proj.base.schema().len());
295 Ok((proj.into(), out_col_change))
296 }
297}
298#[cfg(test)]
299mod tests {
300
301 use risingwave_common::catalog::{Field, Schema};
302 use risingwave_common::types::DataType;
303 use risingwave_pb::expr::expr_node::Type;
304
305 use super::*;
306 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
307 use crate::optimizer::optimizer_context::OptimizerContext;
308 use crate::optimizer::plan_node::LogicalValues;
309
310 #[tokio::test]
311 async fn test_prune_project() {
322 let ty = DataType::Int32;
323 let ctx = OptimizerContext::mock().await;
324 let fields: Vec<Field> = vec![
325 Field::with_name(ty.clone(), "v1"),
326 Field::with_name(ty.clone(), "v2"),
327 Field::with_name(ty.clone(), "v3"),
328 ];
329 let values = LogicalValues::new(
330 vec![],
331 Schema {
332 fields: fields.clone(),
333 },
334 ctx,
335 );
336 let project: PlanRef = LogicalProject::new(
337 values.into(),
338 vec![
339 ExprImpl::Literal(Box::new(Literal::new(None, ty.clone()))),
340 InputRef::new(2, ty.clone()).into(),
341 ExprImpl::FunctionCall(Box::new(
342 FunctionCall::new(
343 Type::LessThan,
344 vec![
345 ExprImpl::InputRef(Box::new(InputRef::new(0, ty.clone()))),
346 ExprImpl::Literal(Box::new(Literal::new(None, ty))),
347 ],
348 )
349 .unwrap(),
350 )),
351 ],
352 )
353 .into();
354
355 let required_cols = vec![1, 2];
357 let plan = project.prune_col(
358 &required_cols,
359 &mut ColumnPruningContext::new(project.clone()),
360 );
361
362 let project = plan.as_logical_project().unwrap();
364 assert_eq!(project.exprs().len(), 2);
365 assert_eq_input_ref!(&project.exprs()[0], 1);
366
367 let expr = project.exprs()[1].clone();
368 let call = expr.as_function_call().unwrap();
369 assert_eq_input_ref!(&call.inputs()[0], 0);
370
371 let values = project.input();
372 let values = values.as_logical_values().unwrap();
373 assert_eq!(values.schema().fields().len(), 2);
374 assert_eq!(values.schema().fields()[0], fields[0]);
375 assert_eq!(values.schema().fields()[1], fields[2]);
376 }
377}