risingwave_frontend/optimizer/plan_node/generic/
project.rs1use std::collections::{BTreeMap, HashMap};
16use std::fmt;
17
18use fixedbitset::FixedBitSet;
19use itertools::Itertools;
20use pretty_xmlish::{Pretty, StrAssocArr};
21use risingwave_common::catalog::{Field, Schema};
22use risingwave_common::util::iter_util::ZipEqFast;
23
24use super::{GenericPlanNode, GenericPlanRef};
25use crate::expr::{
26 Expr, ExprDisplay, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef,
27 assert_input_ref,
28};
29use crate::optimizer::optimizer_context::OptimizerContextRef;
30use crate::optimizer::plan_node::StreamPlanRef;
31use crate::optimizer::property::FunctionalDependencySet;
32use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt};
33
34fn check_expr_type(expr: &ExprImpl) -> std::result::Result<(), &'static str> {
35 if expr.has_subquery() {
36 return Err("subquery");
37 }
38 if expr.has_agg_call() {
39 return Err("aggregate function");
40 }
41 if expr.has_table_function() {
42 return Err("table function");
43 }
44 if expr.has_window_function() {
45 return Err("window function");
46 }
47 Ok(())
48}
49
50pub const LOCAL_PHASE_VNODE_COLUMN_NAME: &str = "_vnode";
53
54#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56#[allow(clippy::manual_non_exhaustive)]
57pub struct Project<PlanRef> {
58 pub exprs: Vec<ExprImpl>,
59 pub field_names: BTreeMap<usize, String>,
61 pub input: PlanRef,
62 _private: (),
64}
65
66impl<PlanRef> Project<PlanRef> {
67 pub fn clone_with_input<OtherPlanRef>(&self, input: OtherPlanRef) -> Project<OtherPlanRef> {
68 Project {
69 exprs: self.exprs.clone(),
70 field_names: self.field_names.clone(),
71 input,
72 _private: (),
73 }
74 }
75
76 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
77 self.exprs = self
78 .exprs
79 .iter()
80 .map(|e| r.rewrite_expr(e.clone()))
81 .collect();
82 }
83
84 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
85 self.exprs.iter().for_each(|e| v.visit_expr(e));
86 }
87}
88
89impl<PlanRef: GenericPlanRef> GenericPlanNode for Project<PlanRef> {
90 fn schema(&self) -> Schema {
91 let o2i = self.o2i_col_mapping();
92 let exprs = &self.exprs;
93 let input_schema = self.input.schema();
94 let ctx = self.ctx();
95 let fields = exprs
96 .iter()
97 .enumerate()
98 .map(|(i, expr)| {
99 let name = match o2i.try_map(i) {
101 Some(input_idx) => {
102 if let Some(name) = self.field_names.get(&i) {
103 name.clone()
104 } else {
105 input_schema.fields()[input_idx].name.clone()
106 }
107 }
108 None => match expr {
109 ExprImpl::InputRef(_) | ExprImpl::Literal(_) => {
110 format!("{:?}", ExprDisplay { expr, input_schema })
111 }
112 _ => {
113 if let Some(name) = self.field_names.get(&i) {
114 name.clone()
115 } else {
116 format!("$expr{}", ctx.next_expr_display_id())
117 }
118 }
119 },
120 };
121 Field::with_name(expr.return_type(), name)
122 })
123 .collect();
124 Schema { fields }
125 }
126
127 fn stream_key(&self) -> Option<Vec<usize>> {
128 let i2o = self.i2o_col_mapping();
129 self.input
130 .stream_key()?
131 .iter()
132 .map(|pk_col| i2o.try_map(*pk_col))
133 .collect::<Option<Vec<_>>>()
134 }
135
136 fn ctx(&self) -> OptimizerContextRef {
137 self.input.ctx()
138 }
139
140 fn functional_dependency(&self) -> FunctionalDependencySet {
141 let i2o = self.i2o_col_mapping();
142 i2o.rewrite_functional_dependency_set(self.input.functional_dependency().clone())
143 }
144}
145
146impl<PlanRef: GenericPlanRef> Project<PlanRef> {
147 pub fn new(exprs: Vec<ExprImpl>, input: PlanRef) -> Self {
148 for expr in &exprs {
149 assert_input_ref!(expr, input.schema().fields().len());
150 check_expr_type(expr)
151 .map_err(|expr| format!("{expr} should not in Project operator"))
152 .unwrap();
153 }
154 Project {
155 exprs,
156 field_names: Default::default(),
157 input,
158 _private: (),
159 }
160 }
161
162 pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
169 if mapping.target_size() == 0 {
170 return Self::new(vec![], input);
173 };
174 let mut input_refs = vec![None; mapping.target_size()];
175 for (src, tar) in mapping.mapping_pairs() {
176 assert_eq!(input_refs[tar], None);
177 input_refs[tar] = Some(src);
178 }
179 let input_schema = input.schema();
180 let exprs: Vec<ExprImpl> = input_refs
181 .into_iter()
182 .map(|i| i.unwrap())
183 .map(|i| InputRef::new(i, input_schema.fields()[i].data_type()).into())
184 .collect();
185
186 Self::new(exprs, input)
187 }
188
189 pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
191 Self::with_out_col_idx(input, out_fields.ones())
192 }
193
194 pub fn out_col_idx_exprs<'a>(
195 input: &'a PlanRef,
196 out_fields: impl Iterator<Item = usize> + 'a,
197 ) -> impl Iterator<Item = ExprImpl> + 'a {
198 let input_schema = input.schema();
199 out_fields.map(move |index| InputRef::new(index, input_schema[index].data_type()).into())
200 }
201
202 pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
204 let exprs = Self::out_col_idx_exprs(&input, out_fields).collect();
205 Self::new(exprs, input)
206 }
207
208 pub fn with_vnode_col(input: PlanRef, dist_key: &[usize]) -> Self {
210 let input_fields = input.schema().fields();
211 let mut new_exprs: Vec<_> = input_fields
212 .iter()
213 .enumerate()
214 .map(|(idx, field)| InputRef::new(idx, field.data_type.clone()).into())
215 .collect();
216 new_exprs.push(
217 FunctionCall::new(
218 ExprType::Vnode,
219 dist_key
220 .iter()
221 .map(|idx| InputRef::new(*idx, input_fields[*idx].data_type()).into())
222 .collect(),
223 )
224 .expect("Vnode function call should be valid here")
225 .into(),
226 );
227 let vnode_expr_idx = new_exprs.len() - 1;
228
229 let mut new = Self::new(new_exprs, input);
230 new.field_names
231 .insert(vnode_expr_idx, LOCAL_PHASE_VNODE_COLUMN_NAME.to_owned());
232 new
233 }
234
235 pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
236 (self.exprs, self.input)
237 }
238
239 pub fn fields_pretty<'a>(&self, schema: &Schema) -> StrAssocArr<'a> {
240 let f = |t| Pretty::debug(&t);
241 let e = Pretty::Array(self.exprs_for_display(schema).iter().map(f).collect());
242 vec![("exprs", e)]
243 }
244
245 fn exprs_for_display<'a>(&'a self, schema: &Schema) -> Vec<AliasedExpr<'a>> {
246 self.exprs
247 .iter()
248 .zip_eq_fast(schema.fields().iter())
249 .map(|(expr, field)| AliasedExpr {
250 expr: ExprDisplay {
251 expr,
252 input_schema: self.input.schema(),
253 },
254 alias: {
255 match expr {
256 ExprImpl::InputRef(_) | ExprImpl::Literal(_) => None,
257 _ => Some(field.name.clone()),
258 }
259 },
260 })
261 .collect()
262 }
263
264 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
265 let exprs = &self.exprs;
266 let input_len = self.input.schema().len();
267 let mut map = vec![None; exprs.len()];
268 for (i, expr) in exprs.iter().enumerate() {
269 if let ExprImpl::InputRef(input) = expr {
270 map[i] = Some(input.index())
271 }
272 }
273 ColIndexMapping::new(map, input_len)
274 }
275
276 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
279 let exprs = &self.exprs;
280 let input_len = self.input.schema().len();
281 let mut map = vec![None; input_len];
282 for (i, expr) in exprs.iter().enumerate() {
283 if let ExprImpl::InputRef(input) = expr {
284 map[input.index()] = Some(i)
285 }
286 }
287 ColIndexMapping::new(map, exprs.len())
288 }
289
290 pub fn is_all_inputref(&self) -> bool {
291 self.exprs
292 .iter()
293 .all(|expr| matches!(expr, ExprImpl::InputRef(_)))
294 }
295
296 pub fn is_identity(&self) -> bool {
297 self.exprs.len() == self.input.schema().len()
298 && self
299 .exprs
300 .iter()
301 .zip_eq_fast(self.input.schema().fields())
302 .enumerate()
303 .all(|(i, (expr, field))| {
304 matches!(expr, ExprImpl::InputRef(input_ref) if **input_ref == InputRef::new(i, field.data_type()))
305 })
306 }
307
308 pub fn try_as_projection(&self) -> Option<Vec<usize>> {
309 self.exprs
310 .iter()
311 .map(|expr| match expr {
312 ExprImpl::InputRef(input_ref) => Some(input_ref.index),
313 _ => None,
314 })
315 .collect::<Option<Vec<_>>>()
316 }
317}
318
319impl Project<StreamPlanRef> {
320 pub(crate) fn likely_produces_noop_updates(&self) -> bool {
323 if self.input.as_stream_now().is_some() {
326 return true;
327 }
328
329 struct HasJsonbAccess {
333 has: bool,
334 }
335
336 impl ExprVisitor for HasJsonbAccess {
337 fn visit_function_call(&mut self, func_call: &FunctionCall) {
338 if matches!(
339 func_call.func_type(),
340 ExprType::JsonbAccess
341 | ExprType::JsonbAccessStr
342 | ExprType::JsonbExtractPath
343 | ExprType::JsonbExtractPathVariadic
344 | ExprType::JsonbExtractPathText
345 | ExprType::JsonbExtractPathTextVariadic
346 | ExprType::JsonbPathExists
347 | ExprType::JsonbPathMatch
348 | ExprType::JsonbPathQueryArray
349 | ExprType::JsonbPathQueryFirst
350 ) {
351 self.has = true;
352 }
353 }
354 }
355
356 self.exprs.iter().any(|expr| {
357 let mut visitor = HasJsonbAccess { has: false };
361 visitor.visit_expr(expr);
362 visitor.has
363 })
364 }
365}
366
367#[derive(Default)]
370pub struct ProjectBuilder {
371 exprs: Vec<ExprImpl>,
372 exprs_index: HashMap<ExprImpl, usize>,
373}
374
375impl ProjectBuilder {
376 pub fn add_expr(&mut self, expr: &ExprImpl) -> std::result::Result<usize, &'static str> {
379 check_expr_type(expr)?;
380 if let Some(idx) = self.exprs_index.get(expr) {
381 Ok(*idx)
382 } else {
383 let index = self.exprs.len();
384 self.exprs.push(expr.clone());
385 self.exprs_index.insert(expr.clone(), index);
386 Ok(index)
387 }
388 }
389
390 pub fn get_expr(&self, index: usize) -> Option<&ExprImpl> {
391 self.exprs.get(index)
392 }
393
394 pub fn expr_index(&self, expr: &ExprImpl) -> Option<usize> {
395 check_expr_type(expr).ok()?;
396 self.exprs_index.get(expr).copied()
397 }
398
399 pub fn build<PlanRef: GenericPlanRef>(self, input: PlanRef) -> Project<PlanRef> {
401 Project::new(self.exprs, input)
402 }
403
404 pub fn exprs_len(&self) -> usize {
405 self.exprs.len()
406 }
407}
408
409pub struct AliasedExpr<'a> {
411 pub expr: ExprDisplay<'a>,
412 pub alias: Option<String>,
413}
414
415impl fmt::Debug for AliasedExpr<'_> {
416 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417 match &self.alias {
418 Some(alias) => write!(f, "{:?} as {}", self.expr, alias),
419 None => write!(f, "{:?}", self.expr),
420 }
421 }
422}
423
424pub fn ensure_sorted_required_cols(
430 required_cols: &[usize],
431 schema: &Schema,
432) -> (Vec<ExprImpl>, Vec<usize>) {
433 let mut required_cols_with_output_idx = required_cols.iter().copied().enumerate().collect_vec();
434 required_cols_with_output_idx.sort_by_key(|(_, col_idx)| *col_idx);
435 let mut output_indices = vec![0; required_cols.len()];
436 let mut sorted_col_idx = Vec::with_capacity(required_cols.len());
437
438 for (sorted_input_idx, (output_idx, col_idx)) in
439 required_cols_with_output_idx.into_iter().enumerate()
440 {
441 sorted_col_idx.push(col_idx);
442 output_indices[output_idx] = sorted_input_idx;
443 }
444
445 (
446 output_indices
447 .into_iter()
448 .map(|sorted_input_idx| {
449 InputRef::new(
450 sorted_input_idx,
451 schema[sorted_col_idx[sorted_input_idx]].data_type(),
452 )
453 .into()
454 })
455 .collect(),
456 sorted_col_idx,
457 )
458}
459
460#[cfg(test)]
461mod tests {
462
463 use itertools::Itertools;
464 use rand::prelude::SliceRandom;
465 use rand::rng;
466 use risingwave_common::catalog::{Field, Schema};
467 use risingwave_common::types::DataType;
468
469 use super::ensure_sorted_required_cols;
470
471 #[test]
472 fn test_ensure_sorted_required_cols() {
473 let input_len = 10;
474 let schema = Schema::new(
475 (0..input_len)
476 .map(|_| Field::unnamed(DataType::Int32))
477 .collect(),
478 );
479 let mut required_cols = (0..input_len)
480 .filter(|_| rand::random_bool(0.5))
481 .collect_vec();
482 let sorted_required_cols = required_cols.clone();
483 required_cols.shuffle(&mut rng());
484 let required_cols = required_cols;
485
486 let (output_exprs, sorted) = ensure_sorted_required_cols(&required_cols, &schema);
487 assert_eq!(sorted, sorted_required_cols);
488 assert_eq!(
489 output_exprs
490 .iter()
491 .map(|expr| sorted_required_cols[expr.as_input_ref().unwrap().index])
492 .collect_vec(),
493 required_cols
494 );
495 }
496}