risingwave_frontend/optimizer/plan_node/generic/
project.rs1use std::collections::{BTreeMap, HashMap};
16use std::fmt;
17
18use educe::Educe;
19use fixedbitset::FixedBitSet;
20use itertools::Itertools;
21use pretty_xmlish::{Pretty, StrAssocArr};
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::util::iter_util::ZipEqFast;
24
25use super::{GenericPlanNode, GenericPlanRef};
26use crate::expr::{
27 Expr, ExprDisplay, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef,
28 assert_input_ref,
29};
30use crate::optimizer::optimizer_context::OptimizerContextRef;
31use crate::optimizer::plan_node::StreamPlanRef;
32use crate::optimizer::property::FunctionalDependencySet;
33use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt};
34
35fn check_expr_type(expr: &ExprImpl) -> std::result::Result<(), &'static str> {
36 if expr.has_subquery() {
37 return Err("subquery");
38 }
39 if expr.has_agg_call() {
40 return Err("aggregate function");
41 }
42 if expr.has_table_function() {
43 return Err("table function");
44 }
45 if expr.has_window_function() {
46 return Err("window function");
47 }
48 Ok(())
49}
50
51pub const LOCAL_PHASE_VNODE_COLUMN_NAME: &str = "_vnode";
54
55#[derive(Debug, Clone, Educe)]
57#[educe(PartialEq, Eq, Hash)]
58#[allow(clippy::manual_non_exhaustive)]
59pub struct Project<PlanRef> {
60 pub exprs: Vec<ExprImpl>,
61 #[educe(PartialEq(ignore))]
63 #[educe(Hash(ignore))]
64 pub field_names: BTreeMap<usize, String>,
65 pub input: PlanRef,
66 _private: (),
68}
69
70impl<PlanRef> Project<PlanRef> {
71 pub fn clone_with_input<OtherPlanRef>(&self, input: OtherPlanRef) -> Project<OtherPlanRef> {
72 Project {
73 exprs: self.exprs.clone(),
74 field_names: self.field_names.clone(),
75 input,
76 _private: (),
77 }
78 }
79
80 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
81 self.exprs = self
82 .exprs
83 .iter()
84 .map(|e| r.rewrite_expr(e.clone()))
85 .collect();
86 }
87
88 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
89 self.exprs.iter().for_each(|e| v.visit_expr(e));
90 }
91}
92
93impl<PlanRef: GenericPlanRef> GenericPlanNode for Project<PlanRef> {
94 fn schema(&self) -> Schema {
95 let o2i = self.o2i_col_mapping();
96 let exprs = &self.exprs;
97 let input_schema = self.input.schema();
98 let ctx = self.ctx();
99 let fields = exprs
100 .iter()
101 .enumerate()
102 .map(|(i, expr)| {
103 let name = match o2i.try_map(i) {
105 Some(input_idx) => {
106 if let Some(name) = self.field_names.get(&i) {
107 name.clone()
108 } else {
109 input_schema.fields()[input_idx].name.clone()
110 }
111 }
112 None => match expr {
113 ExprImpl::InputRef(_) | ExprImpl::Literal(_) => {
114 format!("{:?}", ExprDisplay { expr, input_schema })
115 }
116 _ => {
117 if let Some(name) = self.field_names.get(&i) {
118 name.clone()
119 } else {
120 format!("$expr{}", ctx.next_expr_display_id())
121 }
122 }
123 },
124 };
125 Field::with_name(expr.return_type(), name)
126 })
127 .collect();
128 Schema { fields }
129 }
130
131 fn stream_key(&self) -> Option<Vec<usize>> {
132 let i2o = self.i2o_col_mapping();
133 self.input
134 .stream_key()?
135 .iter()
136 .map(|pk_col| i2o.try_map(*pk_col))
137 .collect::<Option<Vec<_>>>()
138 }
139
140 fn ctx(&self) -> OptimizerContextRef {
141 self.input.ctx()
142 }
143
144 fn functional_dependency(&self) -> FunctionalDependencySet {
145 let i2o = self.i2o_col_mapping();
146 i2o.rewrite_functional_dependency_set(self.input.functional_dependency().clone())
147 }
148}
149
150impl<PlanRef: GenericPlanRef> Project<PlanRef> {
151 pub fn new(exprs: Vec<ExprImpl>, input: PlanRef) -> Self {
152 for expr in &exprs {
153 assert_input_ref!(expr, input.schema().fields().len());
154 check_expr_type(expr)
155 .map_err(|expr| format!("{expr} should not in Project operator"))
156 .unwrap();
157 }
158 Project {
159 exprs,
160 field_names: Default::default(),
161 input,
162 _private: (),
163 }
164 }
165
166 pub fn with_mapping(input: PlanRef, mapping: ColIndexMapping) -> Self {
173 if mapping.target_size() == 0 {
174 return Self::new(vec![], input);
177 };
178 let mut input_refs = vec![None; mapping.target_size()];
179 for (src, tar) in mapping.mapping_pairs() {
180 assert_eq!(input_refs[tar], None);
181 input_refs[tar] = Some(src);
182 }
183 let input_schema = input.schema();
184 let exprs: Vec<ExprImpl> = input_refs
185 .into_iter()
186 .map(|i| i.unwrap())
187 .map(|i| InputRef::new(i, input_schema.fields()[i].data_type()).into())
188 .collect();
189
190 Self::new(exprs, input)
191 }
192
193 pub fn with_out_fields(input: PlanRef, out_fields: &FixedBitSet) -> Self {
195 Self::with_out_col_idx(input, out_fields.ones())
196 }
197
198 pub fn out_col_idx_exprs<'a>(
199 input: &'a PlanRef,
200 out_fields: impl Iterator<Item = usize> + 'a,
201 ) -> impl Iterator<Item = ExprImpl> + 'a {
202 let input_schema = input.schema();
203 out_fields.map(move |index| InputRef::new(index, input_schema[index].data_type()).into())
204 }
205
206 pub fn with_out_col_idx(input: PlanRef, out_fields: impl Iterator<Item = usize>) -> Self {
208 let exprs = Self::out_col_idx_exprs(&input, out_fields).collect();
209 Self::new(exprs, input)
210 }
211
212 pub fn with_vnode_col(input: PlanRef, dist_key: &[usize]) -> Self {
214 let input_fields = input.schema().fields();
215 let mut new_exprs: Vec<_> = input_fields
216 .iter()
217 .enumerate()
218 .map(|(idx, field)| InputRef::new(idx, field.data_type.clone()).into())
219 .collect();
220 new_exprs.push(
221 FunctionCall::new(
222 ExprType::Vnode,
223 dist_key
224 .iter()
225 .map(|idx| InputRef::new(*idx, input_fields[*idx].data_type()).into())
226 .collect(),
227 )
228 .expect("Vnode function call should be valid here")
229 .into(),
230 );
231 let vnode_expr_idx = new_exprs.len() - 1;
232
233 let mut new = Self::new(new_exprs, input);
234 new.field_names
235 .insert(vnode_expr_idx, LOCAL_PHASE_VNODE_COLUMN_NAME.to_owned());
236 new
237 }
238
239 pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
240 (self.exprs, self.input)
241 }
242
243 pub fn fields_pretty<'a>(&self, schema: &Schema) -> StrAssocArr<'a> {
244 let f = |t| Pretty::debug(&t);
245 let e = Pretty::Array(self.exprs_for_display(schema).iter().map(f).collect());
246 vec![("exprs", e)]
247 }
248
249 fn exprs_for_display<'a>(&'a self, schema: &Schema) -> Vec<AliasedExpr<'a>> {
250 self.exprs
251 .iter()
252 .zip_eq_fast(schema.fields().iter())
253 .map(|(expr, field)| AliasedExpr {
254 expr: ExprDisplay {
255 expr,
256 input_schema: self.input.schema(),
257 },
258 alias: {
259 match expr {
260 ExprImpl::InputRef(_) | ExprImpl::Literal(_) => None,
261 _ => Some(field.name.clone()),
262 }
263 },
264 })
265 .collect()
266 }
267
268 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
269 let exprs = &self.exprs;
270 let input_len = self.input.schema().len();
271 let mut map = vec![None; exprs.len()];
272 for (i, expr) in exprs.iter().enumerate() {
273 if let ExprImpl::InputRef(input) = expr {
274 map[i] = Some(input.index())
275 }
276 }
277 ColIndexMapping::new(map, input_len)
278 }
279
280 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
283 let exprs = &self.exprs;
284 let input_len = self.input.schema().len();
285 let mut map = vec![None; input_len];
286 for (i, expr) in exprs.iter().enumerate() {
287 if let ExprImpl::InputRef(input) = expr {
288 map[input.index()] = Some(i)
289 }
290 }
291 ColIndexMapping::new(map, exprs.len())
292 }
293
294 pub fn is_all_inputref(&self) -> bool {
295 self.exprs
296 .iter()
297 .all(|expr| matches!(expr, ExprImpl::InputRef(_)))
298 }
299
300 pub fn is_identity(&self) -> bool {
301 self.exprs.len() == self.input.schema().len()
302 && self
303 .exprs
304 .iter()
305 .zip_eq_fast(self.input.schema().fields())
306 .enumerate()
307 .all(|(i, (expr, field))| {
308 matches!(expr, ExprImpl::InputRef(input_ref) if **input_ref == InputRef::new(i, field.data_type()))
309 })
310 }
311
312 pub fn try_as_projection(&self) -> Option<Vec<usize>> {
313 self.exprs
314 .iter()
315 .map(|expr| match expr {
316 ExprImpl::InputRef(input_ref) => Some(input_ref.index),
317 _ => None,
318 })
319 .collect::<Option<Vec<_>>>()
320 }
321}
322
323impl Project<StreamPlanRef> {
324 pub(crate) fn likely_produces_noop_updates(&self) -> bool {
327 if self.input.as_stream_now().is_some() {
330 return true;
331 }
332
333 struct HasJsonbAccess {
337 has: bool,
338 }
339
340 impl ExprVisitor for HasJsonbAccess {
341 fn visit_function_call(&mut self, func_call: &FunctionCall) {
342 if matches!(
343 func_call.func_type(),
344 ExprType::JsonbAccess
345 | ExprType::JsonbAccessStr
346 | ExprType::JsonbExtractPath
347 | ExprType::JsonbExtractPathVariadic
348 | ExprType::JsonbExtractPathText
349 | ExprType::JsonbExtractPathTextVariadic
350 | ExprType::JsonbPathExists
351 | ExprType::JsonbPathMatch
352 | ExprType::JsonbPathQueryArray
353 | ExprType::JsonbPathQueryFirst
354 ) {
355 self.has = true;
356 }
357 }
358 }
359
360 self.exprs.iter().any(|expr| {
361 let mut visitor = HasJsonbAccess { has: false };
365 visitor.visit_expr(expr);
366 visitor.has
367 })
368 }
369}
370
371#[derive(Default)]
374pub struct ProjectBuilder {
375 exprs: Vec<ExprImpl>,
376 exprs_index: HashMap<ExprImpl, usize>,
377}
378
379impl ProjectBuilder {
380 pub fn add_expr(&mut self, expr: &ExprImpl) -> std::result::Result<usize, &'static str> {
383 check_expr_type(expr)?;
384 if let Some(idx) = self.exprs_index.get(expr) {
385 Ok(*idx)
386 } else {
387 let index = self.exprs.len();
388 self.exprs.push(expr.clone());
389 self.exprs_index.insert(expr.clone(), index);
390 Ok(index)
391 }
392 }
393
394 pub fn get_expr(&self, index: usize) -> Option<&ExprImpl> {
395 self.exprs.get(index)
396 }
397
398 pub fn expr_index(&self, expr: &ExprImpl) -> Option<usize> {
399 check_expr_type(expr).ok()?;
400 self.exprs_index.get(expr).copied()
401 }
402
403 pub fn build<PlanRef: GenericPlanRef>(self, input: PlanRef) -> Project<PlanRef> {
405 Project::new(self.exprs, input)
406 }
407
408 pub fn exprs_len(&self) -> usize {
409 self.exprs.len()
410 }
411}
412
413pub struct AliasedExpr<'a> {
415 pub expr: ExprDisplay<'a>,
416 pub alias: Option<String>,
417}
418
419impl fmt::Debug for AliasedExpr<'_> {
420 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
421 match &self.alias {
422 Some(alias) => write!(f, "{:?} as {}", self.expr, alias),
423 None => write!(f, "{:?}", self.expr),
424 }
425 }
426}
427
428pub fn ensure_sorted_required_cols(
434 required_cols: &[usize],
435 schema: &Schema,
436) -> (Vec<ExprImpl>, Vec<usize>) {
437 let mut required_cols_with_output_idx = required_cols.iter().copied().enumerate().collect_vec();
438 required_cols_with_output_idx.sort_by_key(|(_, col_idx)| *col_idx);
439 let mut output_indices = vec![0; required_cols.len()];
440 let mut sorted_col_idx = Vec::with_capacity(required_cols.len());
441
442 for (sorted_input_idx, (output_idx, col_idx)) in
443 required_cols_with_output_idx.into_iter().enumerate()
444 {
445 sorted_col_idx.push(col_idx);
446 output_indices[output_idx] = sorted_input_idx;
447 }
448
449 (
450 output_indices
451 .into_iter()
452 .map(|sorted_input_idx| {
453 InputRef::new(
454 sorted_input_idx,
455 schema[sorted_col_idx[sorted_input_idx]].data_type(),
456 )
457 .into()
458 })
459 .collect(),
460 sorted_col_idx,
461 )
462}
463
464#[cfg(test)]
465mod tests {
466
467 use itertools::Itertools;
468 use rand::prelude::SliceRandom;
469 use rand::rng;
470 use risingwave_common::catalog::{Field, Schema};
471 use risingwave_common::types::DataType;
472
473 use super::ensure_sorted_required_cols;
474
475 #[test]
476 fn test_ensure_sorted_required_cols() {
477 let input_len = 10;
478 let schema = Schema::new(
479 (0..input_len)
480 .map(|_| Field::unnamed(DataType::Int32))
481 .collect(),
482 );
483 let mut required_cols = (0..input_len)
484 .filter(|_| rand::random_bool(0.5))
485 .collect_vec();
486 let sorted_required_cols = required_cols.clone();
487 required_cols.shuffle(&mut rng());
488 let required_cols = required_cols;
489
490 let (output_exprs, sorted) = ensure_sorted_required_cols(&required_cols, &schema);
491 assert_eq!(sorted, sorted_required_cols);
492 assert_eq!(
493 output_exprs
494 .iter()
495 .map(|expr| sorted_required_cols[expr.as_input_ref().unwrap().index])
496 .collect_vec(),
497 required_cols
498 );
499 }
500}