risingwave_frontend/optimizer/plan_node/
logical_union.rs1use std::cmp::max;
16
17use itertools::Itertools;
18use risingwave_common::catalog::Schema;
19use risingwave_common::types::{DataType, Scalar};
20
21use super::utils::impl_distill_by_unit;
22use super::{
23 ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PredicatePushdown, ToBatch, ToStream,
24};
25use crate::Explain;
26use crate::error::Result;
27use crate::expr::{ExprImpl, InputRef, Literal};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::generic::GenericPlanRef;
30use crate::optimizer::plan_node::stream_union::StreamUnion;
31use crate::optimizer::plan_node::{
32 BatchHashAgg, BatchUnion, ColumnPruningContext, LogicalProject, PlanTreeNode,
33 PredicatePushdownContext, RewriteStreamContext, ToStreamContext, generic,
34};
35use crate::optimizer::property::RequiredDist;
36use crate::utils::{ColIndexMapping, Condition};
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct LogicalUnion {
42 pub base: PlanBase<Logical>,
43 core: generic::Union<PlanRef>,
44}
45
46impl LogicalUnion {
47 pub fn new(all: bool, inputs: Vec<PlanRef>) -> Self {
48 assert!(Schema::all_type_eq(inputs.iter().map(|x| x.schema())));
49 Self::new_with_source_col(all, inputs, None)
50 }
51
52 pub fn new_with_source_col(all: bool, inputs: Vec<PlanRef>, source_col: Option<usize>) -> Self {
55 let core = generic::Union {
56 all,
57 inputs,
58 source_col,
59 };
60 let base = PlanBase::new_logical_with_core(&core);
61 LogicalUnion { base, core }
62 }
63
64 pub fn create(all: bool, inputs: Vec<PlanRef>) -> PlanRef {
65 LogicalUnion::new(all, inputs).into()
66 }
67
68 pub fn all(&self) -> bool {
69 self.core.all
70 }
71
72 pub fn source_col(&self) -> Option<usize> {
73 self.core.source_col
74 }
75}
76
77impl PlanTreeNode for LogicalUnion {
78 fn inputs(&self) -> smallvec::SmallVec<[crate::optimizer::PlanRef; 2]> {
79 self.core.inputs.clone().into_iter().collect()
80 }
81
82 fn clone_with_inputs(&self, inputs: &[crate::optimizer::PlanRef]) -> PlanRef {
83 Self::new_with_source_col(self.all(), inputs.to_vec(), self.core.source_col).into()
84 }
85}
86
87impl_distill_by_unit!(LogicalUnion, core, "LogicalUnion");
88
89impl ColPrunable for LogicalUnion {
90 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
91 let new_inputs = self
92 .inputs()
93 .iter()
94 .map(|input| input.prune_col(required_cols, ctx))
95 .collect_vec();
96 self.clone_with_inputs(&new_inputs)
97 }
98}
99
100impl ExprRewritable for LogicalUnion {}
101
102impl ExprVisitable for LogicalUnion {}
103
104impl PredicatePushdown for LogicalUnion {
105 fn predicate_pushdown(
106 &self,
107 predicate: Condition,
108 ctx: &mut PredicatePushdownContext,
109 ) -> PlanRef {
110 let new_inputs = self
111 .inputs()
112 .iter()
113 .map(|input| input.predicate_pushdown(predicate.clone(), ctx))
114 .collect_vec();
115 self.clone_with_inputs(&new_inputs)
116 }
117}
118
119impl ToBatch for LogicalUnion {
120 fn to_batch(&self) -> Result<PlanRef> {
121 let new_inputs = self
122 .inputs()
123 .iter()
124 .map(|input| input.to_batch())
125 .try_collect()?;
126 let new_logical = generic::Union {
127 all: true,
128 inputs: new_inputs,
129 source_col: None,
130 };
131 if !self.all() {
135 let batch_union = BatchUnion::new(new_logical).into();
136 Ok(BatchHashAgg::new(
137 generic::Agg::new(vec![], (0..self.base.schema().len()).collect(), batch_union)
138 .with_enable_two_phase(false),
139 )
140 .into())
141 } else {
142 Ok(BatchUnion::new(new_logical).into())
143 }
144 }
145}
146
147impl ToStream for LogicalUnion {
148 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
149 let dist = RequiredDist::hash_shard(self.base.stream_key().unwrap_or_else(|| {
151 panic!(
152 "should always have a stream key in the stream plan but not, sub plan: {}",
153 PlanRef::from(self.clone()).explain_to_string()
154 )
155 }));
156 let new_inputs: Result<Vec<_>> = self
157 .inputs()
158 .iter()
159 .map(|input| input.to_stream_with_dist_required(&dist, ctx))
160 .collect();
161 let new_logical = generic::Union {
162 all: true,
163 inputs: new_inputs?,
164 ..self.core
165 };
166 assert!(
167 self.all(),
168 "After UnionToDistinctRule, union should become union all"
169 );
170 Ok(StreamUnion::new(new_logical).into())
171 }
172
173 fn logical_rewrite_for_stream(
174 &self,
175 ctx: &mut RewriteStreamContext,
176 ) -> Result<(PlanRef, ColIndexMapping)> {
177 type FixedState = std::hash::BuildHasherDefault<std::hash::DefaultHasher>;
178 type TypeMap<T> = std::collections::HashMap<DataType, T, FixedState>;
179
180 let original_schema = self.base.schema().clone();
181 let original_schema_len = original_schema.len();
182 let mut rewrites = vec![];
183 for input in &self.core.inputs {
184 rewrites.push(input.logical_rewrite_for_stream(ctx)?);
185 }
186
187 let original_schema_contain_all_input_stream_keys =
188 rewrites.iter().all(|(new_input, col_index_mapping)| {
189 let original_schema_new_pos = (0..original_schema_len)
190 .map(|x| col_index_mapping.map(x))
191 .collect_vec();
192 new_input
193 .expect_stream_key()
194 .iter()
195 .all(|x| original_schema_new_pos.contains(x))
196 });
197
198 if original_schema_contain_all_input_stream_keys {
199 let new_inputs = rewrites
202 .into_iter()
203 .enumerate()
204 .map(|(i, (new_input, col_index_mapping))| {
205 let mut exprs = (0..original_schema_len)
207 .map(|x| {
208 ExprImpl::InputRef(
209 InputRef::new(
210 col_index_mapping.map(x),
211 original_schema.fields[x].data_type.clone(),
212 )
213 .into(),
214 )
215 })
216 .collect_vec();
217 exprs.push(ExprImpl::Literal(
219 Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
220 ));
221 LogicalProject::create(new_input, exprs)
222 })
223 .collect_vec();
224 let new_union = LogicalUnion::new_with_source_col(
225 self.all(),
226 new_inputs,
227 Some(original_schema_len),
228 );
229 let out_col_change =
232 ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
233 Ok((new_union.into(), out_col_change))
234 } else {
235 let (merged_stream_key_types, types_offset) = {
241 let mut max_types_counter = TypeMap::default();
242 for (new_input, _) in &rewrites {
243 let mut types_counter = TypeMap::default();
244 for x in new_input.expect_stream_key() {
245 types_counter
246 .entry(new_input.schema().fields[*x].data_type())
247 .and_modify(|x| *x += 1)
248 .or_insert(1);
249 }
250 for (key, val) in types_counter {
251 max_types_counter
252 .entry(key)
253 .and_modify(|x| *x = max(*x, val))
254 .or_insert(val);
255 }
256 }
257
258 let mut merged_stream_key_types = vec![];
259 let mut types_offset = TypeMap::default();
260 let mut offset = 0;
261 for (key, val) in max_types_counter {
262 let _ = types_offset.insert(key.clone(), offset);
263 offset += val;
264 merged_stream_key_types.extend(std::iter::repeat_n(key.clone(), val));
265 }
266
267 (merged_stream_key_types, types_offset)
268 };
269
270 let input_stream_key_nulls = merged_stream_key_types
271 .iter()
272 .map(|t| ExprImpl::Literal(Literal::new(None, t.clone()).into()))
273 .collect_vec();
274
275 let new_inputs = rewrites
276 .into_iter()
277 .enumerate()
278 .map(|(i, (new_input, col_index_mapping))| {
279 let mut exprs = (0..original_schema_len)
281 .map(|x| {
282 ExprImpl::InputRef(
283 InputRef::new(
284 col_index_mapping.map(x),
285 original_schema.fields[x].data_type.clone(),
286 )
287 .into(),
288 )
289 })
290 .collect_vec();
291 let mut input_stream_keys = input_stream_key_nulls.clone();
293 let mut types_counter = TypeMap::default();
294 for stream_key_idx in new_input.expect_stream_key() {
295 let data_type =
296 new_input.schema().fields[*stream_key_idx].data_type.clone();
297 let count = *types_counter
298 .entry(data_type.clone())
299 .and_modify(|x| *x += 1)
300 .or_insert(1);
301 let type_start_offset = *types_offset.get(&data_type).unwrap();
302
303 input_stream_keys[type_start_offset + count - 1] =
304 ExprImpl::InputRef(InputRef::new(*stream_key_idx, data_type).into());
305 }
306 exprs.extend(input_stream_keys);
307 exprs.push(ExprImpl::Literal(
309 Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
310 ));
311 LogicalProject::create(new_input, exprs)
312 })
313 .collect_vec();
314
315 let new_union = LogicalUnion::new_with_source_col(
316 self.all(),
317 new_inputs,
318 Some(original_schema_len + merged_stream_key_types.len()),
319 );
320 let out_col_change =
323 ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
324 Ok((new_union.into(), out_col_change))
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331
332 use risingwave_common::catalog::Field;
333
334 use super::*;
335 use crate::optimizer::optimizer_context::OptimizerContext;
336 use crate::optimizer::plan_node::{LogicalValues, PlanTreeNodeUnary};
337
338 #[tokio::test]
339 async fn test_prune_union() {
340 let ty = DataType::Int32;
341 let ctx = OptimizerContext::mock().await;
342 let fields: Vec<Field> = vec![
343 Field::with_name(ty.clone(), "v1"),
344 Field::with_name(ty.clone(), "v2"),
345 Field::with_name(ty.clone(), "v3"),
346 ];
347 let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
348
349 let values2 = values1.clone();
350
351 let union: PlanRef = LogicalUnion::new(false, vec![values1.into(), values2.into()]).into();
352
353 let required_cols = vec![1, 2];
355 let plan = union.prune_col(
356 &required_cols,
357 &mut ColumnPruningContext::new(union.clone()),
358 );
359
360 let union = plan.as_logical_union().unwrap();
362 assert_eq!(union.base.schema().len(), 2);
363 }
364
365 #[tokio::test]
366 async fn test_union_to_batch() {
367 let ty = DataType::Int32;
368 let ctx = OptimizerContext::mock().await;
369 let fields: Vec<Field> = vec![
370 Field::with_name(ty.clone(), "v1"),
371 Field::with_name(ty.clone(), "v2"),
372 Field::with_name(ty.clone(), "v3"),
373 ];
374 let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
375
376 let values2 = values1.clone();
377
378 let union = LogicalUnion::new(false, vec![values1.into(), values2.into()]);
379
380 let plan = union.to_batch().unwrap();
381 let agg: &BatchHashAgg = plan.as_batch_hash_agg().unwrap();
382 let agg_input = agg.input();
383 let union = agg_input.as_batch_union().unwrap();
384
385 assert_eq!(union.inputs().len(), 2);
386 }
387}