risingwave_frontend/optimizer/plan_node/
logical_project.rs1use std::collections::{BTreeMap, HashSet};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use pretty_xmlish::XmlNode;
20
21use super::generic::GenericPlanNode;
22use super::utils::{Distill, childless_record};
23use super::{
24 BatchProject, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary,
25 PredicatePushdown, StreamMaterializedExprs, StreamProject, ToBatch, ToStream,
26 gen_filter_and_pushdown, generic,
27};
28use crate::error::Result;
29use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef, collect_input_refs};
30use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
31use crate::optimizer::plan_node::generic::GenericPlanRef;
32use crate::optimizer::plan_node::{
33 ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
34};
35use crate::optimizer::property::{Distribution, Order, RequiredDist};
36use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, Substitute};
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct LogicalProject {
41 pub base: PlanBase<Logical>,
42 core: generic::Project<PlanRef>,
43}
44
45impl LogicalProject {
46 pub fn create(input: PlanRef, exprs: Vec<ExprImpl>) -> PlanRef {
47 Self::new(input, exprs).into()
48 }
49
50 pub fn new(input: PlanRef, exprs: Vec<ExprImpl>) -> Self {
51 let core = generic::Project::new(exprs, input);
52 Self::with_core(core)
53 }
54
55 pub fn with_core(core: generic::Project<PlanRef>) -> Self {
56 let base = PlanBase::new_logical_with_core(&core);
57 LogicalProject { base, core }
58 }
59
60 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
61 self.core.o2i_col_mapping()
62 }
63
64 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
65 self.core.i2o_col_mapping()
66 }
67
68 pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
75 Self::with_core(generic::Project::with_mapping(input, mapping))
76 }
77
78 pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
80 Self::with_core(generic::Project::with_out_fields(input, out_fields))
81 }
82
83 pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
85 Self::with_core(generic::Project::with_out_col_idx(input, out_fields))
86 }
87
88 pub fn exprs(&self) -> &Vec<ExprImpl> {
89 &self.core.exprs
90 }
91
92 pub fn is_identity(&self) -> bool {
93 self.core.is_identity()
94 }
95
96 pub fn try_as_projection(&self) -> Option<Vec<usize>> {
97 self.core.try_as_projection()
98 }
99
100 pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
101 self.core.decompose()
102 }
103
104 pub fn is_all_inputref(&self) -> bool {
105 self.core.is_all_inputref()
106 }
107}
108
109impl PlanTreeNodeUnary for LogicalProject {
110 fn input(&self) -> PlanRef {
111 self.core.input.clone()
112 }
113
114 fn clone_with_input(&self, input: PlanRef) -> Self {
115 Self::new(input, self.exprs().clone())
116 }
117
118 fn rewrite_with_input(
119 &self,
120 input: PlanRef,
121 mut input_col_change: ColIndexMapping,
122 ) -> (Self, ColIndexMapping) {
123 let exprs = self
124 .exprs()
125 .clone()
126 .into_iter()
127 .map(|expr| input_col_change.rewrite_expr(expr))
128 .collect();
129 let proj = Self::new(input, exprs);
130 let out_col_change = ColIndexMapping::identity(self.schema().len());
132 (proj, out_col_change)
133 }
134}
135
136impl_plan_tree_node_for_unary! {LogicalProject}
137
138impl Distill for LogicalProject {
139 fn distill<'a>(&self) -> XmlNode<'a> {
140 childless_record(
141 "LogicalProject",
142 self.core.fields_pretty(self.base.schema()),
143 )
144 }
145}
146
147impl ColPrunable for LogicalProject {
148 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
149 let input_col_num: usize = self.input().schema().len();
150 let input_required_cols = collect_input_refs(
151 input_col_num,
152 required_cols.iter().map(|i| &self.exprs()[*i]),
153 )
154 .ones()
155 .collect_vec();
156 let new_input = self.input().prune_col(&input_required_cols, ctx);
157 let mut mapping = ColIndexMapping::with_remaining_columns(
158 &input_required_cols,
159 self.input().schema().len(),
160 );
161 let exprs = required_cols
163 .iter()
164 .map(|&id| mapping.rewrite_expr(self.exprs()[id].clone()))
165 .collect();
166
167 LogicalProject::new(new_input, exprs).into()
169 }
170}
171
172impl ExprRewritable for LogicalProject {
173 fn has_rewritable_expr(&self) -> bool {
174 true
175 }
176
177 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
178 let mut core = self.core.clone();
179 core.rewrite_exprs(r);
180 Self {
181 base: self.base.clone_with_new_plan_id(),
182 core,
183 }
184 .into()
185 }
186}
187
188impl ExprVisitable for LogicalProject {
189 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
190 self.core.visit_exprs(v);
191 }
192}
193
194impl PredicatePushdown for LogicalProject {
195 fn predicate_pushdown(
196 &self,
197 predicate: Condition,
198 ctx: &mut PredicatePushdownContext,
199 ) -> PlanRef {
200 let mut subst = Substitute {
202 mapping: self.exprs().clone(),
203 };
204
205 let impure_mask = {
206 let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
207 for (i, e) in self.exprs().iter().enumerate() {
208 impure_mask.set(i, e.is_impure())
209 }
210 impure_mask
211 };
212 let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
214 let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
215
216 gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
217 }
218}
219
220impl ToBatch for LogicalProject {
221 fn to_batch(&self) -> Result<PlanRef> {
222 self.to_batch_with_order_required(&Order::any())
223 }
224
225 fn to_batch_with_order_required(&self, required_order: &Order) -> Result<PlanRef> {
226 let input_order = self
227 .o2i_col_mapping()
228 .rewrite_provided_order(required_order);
229 let new_input = self.input().to_batch_with_order_required(&input_order)?;
230 let mut new_logical = self.core.clone();
231 new_logical.input = new_input;
232 let batch_project = BatchProject::new(new_logical);
233 required_order.enforce_if_not_satisfies(batch_project.into())
234 }
235}
236
237impl ToStream for LogicalProject {
238 fn to_stream_with_dist_required(
239 &self,
240 required_dist: &RequiredDist,
241 ctx: &mut ToStreamContext,
242 ) -> Result<PlanRef> {
243 let input_required = if required_dist.satisfies(&RequiredDist::AnyShard) {
244 RequiredDist::Any
245 } else {
246 let input_required = self
247 .o2i_col_mapping()
248 .rewrite_required_distribution(required_dist);
249 match input_required {
250 RequiredDist::PhysicalDist(dist) => match dist {
251 Distribution::Single => RequiredDist::Any,
252 _ => RequiredDist::PhysicalDist(dist),
253 },
254 _ => input_required,
255 }
256 };
257 let new_input = self
258 .input()
259 .to_stream_with_dist_required(&input_required, ctx)?;
260
261 let enable_materialized_exprs = self
262 .core
263 .ctx()
264 .session_ctx()
265 .config()
266 .streaming_enable_materialized_expressions();
267
268 let stream_plan = if enable_materialized_exprs {
269 let mut udf_field_names = BTreeMap::new();
271 let mut udf_expr_indices = HashSet::new();
272 let udf_exprs: Vec<_> = self
273 .exprs()
274 .iter()
275 .enumerate()
276 .filter_map(|(idx, expr)| {
277 if expr.has_user_defined_function() {
278 udf_expr_indices.insert(idx);
279 if let Some(name) = self.core.field_names.get(&idx) {
280 udf_field_names.insert(idx, name.clone());
281 }
282 Some(expr.clone())
283 } else {
284 None
285 }
286 })
287 .collect();
288
289 if !udf_exprs.is_empty() {
290 let mat_exprs_plan: PlanRef =
292 StreamMaterializedExprs::new(new_input.clone(), udf_exprs, udf_field_names)
293 .into();
294
295 let input_len = new_input.schema().len();
296 let mut udf_pos = 0;
297
298 let final_exprs = self
300 .exprs()
301 .iter()
302 .enumerate()
303 .map(|(idx, expr)| {
304 if udf_expr_indices.contains(&idx) {
305 let output_idx = input_len + udf_pos;
306 udf_pos += 1;
307 InputRef::new(output_idx, expr.return_type()).into()
308 } else {
309 expr.clone()
310 }
311 })
312 .collect();
313
314 let core = generic::Project::new(final_exprs, mat_exprs_plan);
315 StreamProject::new(core).into()
316 } else {
317 let core = generic::Project::new(self.exprs().clone(), new_input);
319 StreamProject::new(core).into()
320 }
321 } else {
322 let core = generic::Project::new(self.exprs().clone(), new_input);
324 StreamProject::new(core).into()
325 };
326
327 required_dist.enforce_if_not_satisfies(stream_plan, &Order::any())
328 }
329
330 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
331 self.to_stream_with_dist_required(&RequiredDist::Any, ctx)
332 }
333
334 fn logical_rewrite_for_stream(
335 &self,
336 ctx: &mut RewriteStreamContext,
337 ) -> Result<(PlanRef, ColIndexMapping)> {
338 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
339 let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change);
340
341 let input_pk = input.expect_stream_key();
343 let i2o = proj.i2o_col_mapping();
344 let col_need_to_add = input_pk
345 .iter()
346 .cloned()
347 .filter(|i| i2o.try_map(*i).is_none());
348 let input_schema = input.schema();
349 let exprs =
350 proj.exprs()
351 .iter()
352 .cloned()
353 .chain(col_need_to_add.map(|idx| {
354 InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
355 }))
356 .collect();
357 let proj = Self::new(input, exprs);
358 let (map, _) = out_col_change.into_parts();
362 let out_col_change = ColIndexMapping::new(map, proj.base.schema().len());
363 Ok((proj.into(), out_col_change))
364 }
365}
366
367#[cfg(test)]
368mod tests {
369
370 use risingwave_common::catalog::{Field, Schema};
371 use risingwave_common::types::DataType;
372 use risingwave_pb::expr::expr_node::Type;
373
374 use super::*;
375 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
376 use crate::optimizer::optimizer_context::OptimizerContext;
377 use crate::optimizer::plan_node::LogicalValues;
378
379 #[tokio::test]
380 async fn test_prune_project() {
391 let ty = DataType::Int32;
392 let ctx = OptimizerContext::mock().await;
393 let fields: Vec<Field> = vec![
394 Field::with_name(ty.clone(), "v1"),
395 Field::with_name(ty.clone(), "v2"),
396 Field::with_name(ty.clone(), "v3"),
397 ];
398 let values = LogicalValues::new(
399 vec![],
400 Schema {
401 fields: fields.clone(),
402 },
403 ctx,
404 );
405 let project: PlanRef = LogicalProject::new(
406 values.into(),
407 vec![
408 ExprImpl::Literal(Box::new(Literal::new(None, ty.clone()))),
409 InputRef::new(2, ty.clone()).into(),
410 ExprImpl::FunctionCall(Box::new(
411 FunctionCall::new(
412 Type::LessThan,
413 vec![
414 ExprImpl::InputRef(Box::new(InputRef::new(0, ty.clone()))),
415 ExprImpl::Literal(Box::new(Literal::new(None, ty))),
416 ],
417 )
418 .unwrap(),
419 )),
420 ],
421 )
422 .into();
423
424 let required_cols = vec![1, 2];
426 let plan = project.prune_col(
427 &required_cols,
428 &mut ColumnPruningContext::new(project.clone()),
429 );
430
431 let project = plan.as_logical_project().unwrap();
433 assert_eq!(project.exprs().len(), 2);
434 assert_eq_input_ref!(&project.exprs()[0], 1);
435
436 let expr = project.exprs()[1].clone();
437 let call = expr.as_function_call().unwrap();
438 assert_eq_input_ref!(&call.inputs()[0], 0);
439
440 let values = project.input();
441 let values = values.as_logical_values().unwrap();
442 assert_eq!(values.schema().fields().len(), 2);
443 assert_eq!(values.schema().fields()[0], fields[0]);
444 assert_eq!(values.schema().fields()[1], fields[2]);
445 }
446}