risingwave_frontend/optimizer/plan_node/generic/
project.rs1use std::collections::{BTreeMap, HashMap};
16use std::fmt;
17
18use fixedbitset::FixedBitSet;
19use itertools::Itertools;
20use pretty_xmlish::{Pretty, StrAssocArr};
21use risingwave_common::catalog::{Field, Schema};
22use risingwave_common::util::iter_util::ZipEqFast;
23
24use super::{GenericPlanNode, GenericPlanRef};
25use crate::expr::{
26 Expr, ExprDisplay, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef,
27 assert_input_ref,
28};
29use crate::optimizer::optimizer_context::OptimizerContextRef;
30use crate::optimizer::plan_node::StreamPlanRef;
31use crate::optimizer::property::FunctionalDependencySet;
32use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt};
33
34fn check_expr_type(expr: &ExprImpl) -> std::result::Result<(), &'static str> {
35 if expr.has_subquery() {
36 return Err("subquery");
37 }
38 if expr.has_agg_call() {
39 return Err("aggregate function");
40 }
41 if expr.has_table_function() {
42 return Err("table function");
43 }
44 if expr.has_window_function() {
45 return Err("window function");
46 }
47 Ok(())
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52#[allow(clippy::manual_non_exhaustive)]
53pub struct Project<PlanRef> {
54 pub exprs: Vec<ExprImpl>,
55 pub field_names: BTreeMap<usize, String>,
57 pub input: PlanRef,
58 _private: (),
60}
61
62impl<PlanRef> Project<PlanRef> {
63 pub fn clone_with_input<OtherPlanRef>(&self, input: OtherPlanRef) -> Project<OtherPlanRef> {
64 Project {
65 exprs: self.exprs.clone(),
66 field_names: self.field_names.clone(),
67 input,
68 _private: (),
69 }
70 }
71
72 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
73 self.exprs = self
74 .exprs
75 .iter()
76 .map(|e| r.rewrite_expr(e.clone()))
77 .collect();
78 }
79
80 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
81 self.exprs.iter().for_each(|e| v.visit_expr(e));
82 }
83}
84
85impl<PlanRef: GenericPlanRef> GenericPlanNode for Project<PlanRef> {
86 fn schema(&self) -> Schema {
87 let o2i = self.o2i_col_mapping();
88 let exprs = &self.exprs;
89 let input_schema = self.input.schema();
90 let ctx = self.ctx();
91 let fields = exprs
92 .iter()
93 .enumerate()
94 .map(|(i, expr)| {
95 let name = match o2i.try_map(i) {
97 Some(input_idx) => {
98 if let Some(name) = self.field_names.get(&i) {
99 name.clone()
100 } else {
101 input_schema.fields()[input_idx].name.clone()
102 }
103 }
104 None => match expr {
105 ExprImpl::InputRef(_) | ExprImpl::Literal(_) => {
106 format!("{:?}", ExprDisplay { expr, input_schema })
107 }
108 _ => {
109 if let Some(name) = self.field_names.get(&i) {
110 name.clone()
111 } else {
112 format!("$expr{}", ctx.next_expr_display_id())
113 }
114 }
115 },
116 };
117 Field::with_name(expr.return_type(), name)
118 })
119 .collect();
120 Schema { fields }
121 }
122
123 fn stream_key(&self) -> Option<Vec<usize>> {
124 let i2o = self.i2o_col_mapping();
125 self.input
126 .stream_key()?
127 .iter()
128 .map(|pk_col| i2o.try_map(*pk_col))
129 .collect::<Option<Vec<_>>>()
130 }
131
132 fn ctx(&self) -> OptimizerContextRef {
133 self.input.ctx()
134 }
135
136 fn functional_dependency(&self) -> FunctionalDependencySet {
137 let i2o = self.i2o_col_mapping();
138 i2o.rewrite_functional_dependency_set(self.input.functional_dependency().clone())
139 }
140}
141
142impl<PlanRef: GenericPlanRef> Project<PlanRef> {
143 pub fn new(exprs: Vec<ExprImpl>, input: PlanRef) -> Self {
144 for expr in &exprs {
145 assert_input_ref!(expr, input.schema().fields().len());
146 check_expr_type(expr)
147 .map_err(|expr| format!("{expr} should not in Project operator"))
148 .unwrap();
149 }
150 Project {
151 exprs,
152 field_names: Default::default(),
153 input,
154 _private: (),
155 }
156 }
157
158 pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
165 if mapping.target_size() == 0 {
166 return Self::new(vec![], input);
169 };
170 let mut input_refs = vec![None; mapping.target_size()];
171 for (src, tar) in mapping.mapping_pairs() {
172 assert_eq!(input_refs[tar], None);
173 input_refs[tar] = Some(src);
174 }
175 let input_schema = input.schema();
176 let exprs: Vec<ExprImpl> = input_refs
177 .into_iter()
178 .map(|i| i.unwrap())
179 .map(|i| InputRef::new(i, input_schema.fields()[i].data_type()).into())
180 .collect();
181
182 Self::new(exprs, input)
183 }
184
185 pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
187 Self::with_out_col_idx(input, out_fields.ones())
188 }
189
190 pub fn out_col_idx_exprs<'a>(
191 input: &'a PlanRef,
192 out_fields: impl Iterator<Item = usize> + 'a,
193 ) -> impl Iterator<Item = ExprImpl> + 'a {
194 let input_schema = input.schema();
195 out_fields.map(move |index| InputRef::new(index, input_schema[index].data_type()).into())
196 }
197
198 pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
200 let exprs = Self::out_col_idx_exprs(&input, out_fields).collect();
201 Self::new(exprs, input)
202 }
203
204 pub fn with_vnode_col(input: PlanRef, dist_key: &[usize]) -> Self {
206 let input_fields = input.schema().fields();
207 let mut new_exprs: Vec<_> = input_fields
208 .iter()
209 .enumerate()
210 .map(|(idx, field)| InputRef::new(idx, field.data_type.clone()).into())
211 .collect();
212 new_exprs.push(
213 FunctionCall::new(
214 ExprType::Vnode,
215 dist_key
216 .iter()
217 .map(|idx| InputRef::new(*idx, input_fields[*idx].data_type()).into())
218 .collect(),
219 )
220 .expect("Vnode function call should be valid here")
221 .into(),
222 );
223 let vnode_expr_idx = new_exprs.len() - 1;
224
225 let mut new = Self::new(new_exprs, input);
226 new.field_names.insert(vnode_expr_idx, "_vnode".to_owned());
227 new
228 }
229
230 pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
231 (self.exprs, self.input)
232 }
233
234 pub fn fields_pretty<'a>(&self, schema: &Schema) -> StrAssocArr<'a> {
235 let f = |t| Pretty::debug(&t);
236 let e = Pretty::Array(self.exprs_for_display(schema).iter().map(f).collect());
237 vec![("exprs", e)]
238 }
239
240 fn exprs_for_display<'a>(&'a self, schema: &Schema) -> Vec<AliasedExpr<'a>> {
241 self.exprs
242 .iter()
243 .zip_eq_fast(schema.fields().iter())
244 .map(|(expr, field)| AliasedExpr {
245 expr: ExprDisplay {
246 expr,
247 input_schema: self.input.schema(),
248 },
249 alias: {
250 match expr {
251 ExprImpl::InputRef(_) | ExprImpl::Literal(_) => None,
252 _ => Some(field.name.clone()),
253 }
254 },
255 })
256 .collect()
257 }
258
259 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
260 let exprs = &self.exprs;
261 let input_len = self.input.schema().len();
262 let mut map = vec![None; exprs.len()];
263 for (i, expr) in exprs.iter().enumerate() {
264 if let ExprImpl::InputRef(input) = expr {
265 map[i] = Some(input.index())
266 }
267 }
268 ColIndexMapping::new(map, input_len)
269 }
270
271 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
274 let exprs = &self.exprs;
275 let input_len = self.input.schema().len();
276 let mut map = vec![None; input_len];
277 for (i, expr) in exprs.iter().enumerate() {
278 if let ExprImpl::InputRef(input) = expr {
279 map[input.index()] = Some(i)
280 }
281 }
282 ColIndexMapping::new(map, exprs.len())
283 }
284
285 pub fn is_all_inputref(&self) -> bool {
286 self.exprs
287 .iter()
288 .all(|expr| matches!(expr, ExprImpl::InputRef(_)))
289 }
290
291 pub fn is_identity(&self) -> bool {
292 self.exprs.len() == self.input.schema().len()
293 && self
294 .exprs
295 .iter()
296 .zip_eq_fast(self.input.schema().fields())
297 .enumerate()
298 .all(|(i, (expr, field))| {
299 matches!(expr, ExprImpl::InputRef(input_ref) if **input_ref == InputRef::new(i, field.data_type()))
300 })
301 }
302
303 pub fn try_as_projection(&self) -> Option<Vec<usize>> {
304 self.exprs
305 .iter()
306 .map(|expr| match expr {
307 ExprImpl::InputRef(input_ref) => Some(input_ref.index),
308 _ => None,
309 })
310 .collect::<Option<Vec<_>>>()
311 }
312}
313
314impl Project<StreamPlanRef> {
315 pub(crate) fn likely_produces_noop_updates(&self) -> bool {
318 if self.input.as_stream_now().is_some() {
321 return true;
322 }
323
324 struct HasJsonbAccess {
328 has: bool,
329 }
330
331 impl ExprVisitor for HasJsonbAccess {
332 fn visit_function_call(&mut self, func_call: &FunctionCall) {
333 if matches!(
334 func_call.func_type(),
335 ExprType::JsonbAccess
336 | ExprType::JsonbAccessStr
337 | ExprType::JsonbExtractPath
338 | ExprType::JsonbExtractPathVariadic
339 | ExprType::JsonbExtractPathText
340 | ExprType::JsonbExtractPathTextVariadic
341 | ExprType::JsonbPathExists
342 | ExprType::JsonbPathMatch
343 | ExprType::JsonbPathQueryArray
344 | ExprType::JsonbPathQueryFirst
345 ) {
346 self.has = true;
347 }
348 }
349 }
350
351 self.exprs.iter().any(|expr| {
352 let mut visitor = HasJsonbAccess { has: false };
356 visitor.visit_expr(expr);
357 visitor.has
358 })
359 }
360}
361
362#[derive(Default)]
365pub struct ProjectBuilder {
366 exprs: Vec<ExprImpl>,
367 exprs_index: HashMap<ExprImpl, usize>,
368}
369
370impl ProjectBuilder {
371 pub fn add_expr(&mut self, expr: &ExprImpl) -> std::result::Result<usize, &'static str> {
374 check_expr_type(expr)?;
375 if let Some(idx) = self.exprs_index.get(expr) {
376 Ok(*idx)
377 } else {
378 let index = self.exprs.len();
379 self.exprs.push(expr.clone());
380 self.exprs_index.insert(expr.clone(), index);
381 Ok(index)
382 }
383 }
384
385 pub fn get_expr(&self, index: usize) -> Option<&ExprImpl> {
386 self.exprs.get(index)
387 }
388
389 pub fn expr_index(&self, expr: &ExprImpl) -> Option<usize> {
390 check_expr_type(expr).ok()?;
391 self.exprs_index.get(expr).copied()
392 }
393
394 pub fn build<PlanRef: GenericPlanRef>(self, input: PlanRef) -> Project<PlanRef> {
396 Project::new(self.exprs, input)
397 }
398
399 pub fn exprs_len(&self) -> usize {
400 self.exprs.len()
401 }
402}
403
404pub struct AliasedExpr<'a> {
406 pub expr: ExprDisplay<'a>,
407 pub alias: Option<String>,
408}
409
410impl fmt::Debug for AliasedExpr<'_> {
411 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
412 match &self.alias {
413 Some(alias) => write!(f, "{:?} as {}", self.expr, alias),
414 None => write!(f, "{:?}", self.expr),
415 }
416 }
417}
418
419pub fn ensure_sorted_required_cols(
425 required_cols: &[usize],
426 schema: &Schema,
427) -> (Vec<ExprImpl>, Vec<usize>) {
428 let mut required_cols_with_output_idx = required_cols.iter().copied().enumerate().collect_vec();
429 required_cols_with_output_idx.sort_by_key(|(_, col_idx)| *col_idx);
430 let mut output_indices = vec![0; required_cols.len()];
431 let mut sorted_col_idx = Vec::with_capacity(required_cols.len());
432
433 for (sorted_input_idx, (output_idx, col_idx)) in
434 required_cols_with_output_idx.into_iter().enumerate()
435 {
436 sorted_col_idx.push(col_idx);
437 output_indices[output_idx] = sorted_input_idx;
438 }
439
440 (
441 output_indices
442 .into_iter()
443 .map(|sorted_input_idx| {
444 InputRef::new(
445 sorted_input_idx,
446 schema[sorted_col_idx[sorted_input_idx]].data_type(),
447 )
448 .into()
449 })
450 .collect(),
451 sorted_col_idx,
452 )
453}
454
455#[cfg(test)]
456mod tests {
457
458 use itertools::Itertools;
459 use rand::prelude::SliceRandom;
460 use rand::rng;
461 use risingwave_common::catalog::{Field, Schema};
462 use risingwave_common::types::DataType;
463
464 use super::ensure_sorted_required_cols;
465
466 #[test]
467 fn test_ensure_sorted_required_cols() {
468 let input_len = 10;
469 let schema = Schema::new(
470 (0..input_len)
471 .map(|_| Field::unnamed(DataType::Int32))
472 .collect(),
473 );
474 let mut required_cols = (0..input_len)
475 .filter(|_| rand::random_bool(0.5))
476 .collect_vec();
477 let sorted_required_cols = required_cols.clone();
478 required_cols.shuffle(&mut rng());
479 let required_cols = required_cols;
480
481 let (output_exprs, sorted) = ensure_sorted_required_cols(&required_cols, &schema);
482 assert_eq!(sorted, sorted_required_cols);
483 assert_eq!(
484 output_exprs
485 .iter()
486 .map(|expr| sorted_required_cols[expr.as_input_ref().unwrap().index])
487 .collect_vec(),
488 required_cols
489 );
490 }
491}