risingwave_frontend/optimizer/plan_node/
logical_union.rs1use std::cmp::max;
16
17use itertools::Itertools;
18use risingwave_common::catalog::Schema;
19use risingwave_common::types::{DataType, Scalar};
20
21use super::utils::impl_distill_by_unit;
22use super::{
23 ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase, PredicatePushdown,
24 ToBatch, ToStream,
25};
26use crate::Explain;
27use crate::error::Result;
28use crate::expr::{ExprImpl, InputRef, Literal};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::GenericPlanRef;
31use crate::optimizer::plan_node::stream_union::StreamUnion;
32use crate::optimizer::plan_node::{
33 BatchHashAgg, BatchUnion, ColumnPruningContext, LogicalProject, PlanTreeNode,
34 PredicatePushdownContext, RewriteStreamContext, ToStreamContext, generic,
35};
36use crate::optimizer::property::RequiredDist;
37use crate::utils::{ColIndexMapping, Condition};
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct LogicalUnion {
43 pub base: PlanBase<Logical>,
44 core: generic::Union<PlanRef>,
45}
46
47impl LogicalUnion {
48 pub fn new(all: bool, inputs: Vec<PlanRef>) -> Self {
49 assert!(Schema::all_type_eq(inputs.iter().map(|x| x.schema())));
50 Self::new_with_source_col(all, inputs, None)
51 }
52
53 pub fn new_with_source_col(all: bool, inputs: Vec<PlanRef>, source_col: Option<usize>) -> Self {
56 let core = generic::Union {
57 all,
58 inputs,
59 source_col,
60 };
61 let base = PlanBase::new_logical_with_core(&core);
62 LogicalUnion { base, core }
63 }
64
65 pub fn create(all: bool, inputs: Vec<PlanRef>) -> PlanRef {
66 LogicalUnion::new(all, inputs).into()
67 }
68
69 pub fn all(&self) -> bool {
70 self.core.all
71 }
72
73 pub fn source_col(&self) -> Option<usize> {
74 self.core.source_col
75 }
76}
77
78impl PlanTreeNode<Logical> for LogicalUnion {
79 fn inputs(&self) -> smallvec::SmallVec<[PlanRef; 2]> {
80 self.core.inputs.clone().into_iter().collect()
81 }
82
83 fn clone_with_inputs(&self, inputs: &[PlanRef]) -> PlanRef {
84 Self::new_with_source_col(self.all(), inputs.to_vec(), self.core.source_col).into()
85 }
86}
87
88impl_distill_by_unit!(LogicalUnion, core, "LogicalUnion");
89
90impl ColPrunable for LogicalUnion {
91 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
92 let new_inputs = self
93 .inputs()
94 .iter()
95 .map(|input| input.prune_col(required_cols, ctx))
96 .collect_vec();
97 self.clone_with_inputs(&new_inputs)
98 }
99}
100
101impl ExprRewritable<Logical> for LogicalUnion {}
102
103impl ExprVisitable for LogicalUnion {}
104
105impl PredicatePushdown for LogicalUnion {
106 fn predicate_pushdown(
107 &self,
108 predicate: Condition,
109 ctx: &mut PredicatePushdownContext,
110 ) -> PlanRef {
111 let new_inputs = self
112 .inputs()
113 .iter()
114 .map(|input| input.predicate_pushdown(predicate.clone(), ctx))
115 .collect_vec();
116 self.clone_with_inputs(&new_inputs)
117 }
118}
119
120impl ToBatch for LogicalUnion {
121 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
122 let new_inputs = self
123 .inputs()
124 .iter()
125 .map(|input| input.to_batch())
126 .try_collect()?;
127 let new_logical = generic::Union {
128 all: true,
129 inputs: new_inputs,
130 source_col: None,
131 };
132 if !self.all() {
136 let batch_union = BatchUnion::new(new_logical).into();
137 Ok(BatchHashAgg::new(
138 generic::Agg::new(vec![], (0..self.base.schema().len()).collect(), batch_union)
139 .with_enable_two_phase(false),
140 )
141 .into())
142 } else {
143 Ok(BatchUnion::new(new_logical).into())
144 }
145 }
146}
147
148impl ToStream for LogicalUnion {
149 fn to_stream(
150 &self,
151 ctx: &mut ToStreamContext,
152 ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
153 let dist = RequiredDist::hash_shard(self.base.stream_key().unwrap_or_else(|| {
155 panic!(
156 "should always have a stream key in the stream plan but not, sub plan: {}",
157 PlanRef::from(self.clone()).explain_to_string()
158 )
159 }));
160 let new_inputs: Result<Vec<_>> = self
161 .inputs()
162 .iter()
163 .map(|input| input.to_stream_with_dist_required(&dist, ctx))
164 .collect();
165 let core = self.core.clone_with_inputs(new_inputs?);
166 assert!(
167 self.all(),
168 "After UnionToDistinctRule, union should become union all"
169 );
170 Ok(StreamUnion::new(core).into())
171 }
172
173 fn logical_rewrite_for_stream(
174 &self,
175 ctx: &mut RewriteStreamContext,
176 ) -> Result<(PlanRef, ColIndexMapping)> {
177 type FixedState = std::hash::BuildHasherDefault<std::hash::DefaultHasher>;
178 type TypeMap<T> = std::collections::HashMap<DataType, T, FixedState>;
179
180 let original_schema = self.base.schema().clone();
181 let original_schema_len = original_schema.len();
182 let mut rewrites = vec![];
183 for input in &self.core.inputs {
184 rewrites.push(input.logical_rewrite_for_stream(ctx)?);
185 }
186
187 let original_schema_contain_all_input_stream_keys =
188 rewrites.iter().all(|(new_input, col_index_mapping)| {
189 let original_schema_new_pos = (0..original_schema_len)
190 .map(|x| col_index_mapping.map(x))
191 .collect_vec();
192 new_input
193 .expect_stream_key()
194 .iter()
195 .all(|x| original_schema_new_pos.contains(x))
196 });
197
198 if original_schema_contain_all_input_stream_keys {
199 let new_inputs = rewrites
202 .into_iter()
203 .enumerate()
204 .map(|(i, (new_input, col_index_mapping))| {
205 let mut exprs = (0..original_schema_len)
207 .map(|x| {
208 ExprImpl::InputRef(
209 InputRef::new(
210 col_index_mapping.map(x),
211 original_schema.fields[x].data_type.clone(),
212 )
213 .into(),
214 )
215 })
216 .collect_vec();
217 exprs.push(ExprImpl::Literal(
219 Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
220 ));
221 LogicalProject::create(new_input, exprs)
222 })
223 .collect_vec();
224 let new_union = LogicalUnion::new_with_source_col(
225 self.all(),
226 new_inputs,
227 Some(original_schema_len),
228 );
229 let out_col_change =
232 ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
233 Ok((new_union.into(), out_col_change))
234 } else {
235 let (merged_stream_key_types, types_offset) = {
241 let mut max_types_counter = TypeMap::default();
242 for (new_input, _) in &rewrites {
243 let mut types_counter = TypeMap::default();
244 for x in new_input.expect_stream_key() {
245 types_counter
246 .entry(new_input.schema().fields[*x].data_type())
247 .and_modify(|x| *x += 1)
248 .or_insert(1);
249 }
250 for (key, val) in types_counter {
251 max_types_counter
252 .entry(key)
253 .and_modify(|x| *x = max(*x, val))
254 .or_insert(val);
255 }
256 }
257
258 let mut merged_stream_key_types = vec![];
259 let mut types_offset = TypeMap::default();
260 let mut offset = 0;
261 for (key, val) in max_types_counter {
262 let _ = types_offset.insert(key.clone(), offset);
263 offset += val;
264 merged_stream_key_types.extend(std::iter::repeat_n(key.clone(), val));
265 }
266
267 (merged_stream_key_types, types_offset)
268 };
269
270 let input_stream_key_nulls = merged_stream_key_types
271 .iter()
272 .map(|t| ExprImpl::Literal(Literal::new(None, t.clone()).into()))
273 .collect_vec();
274
275 let new_inputs = rewrites
276 .into_iter()
277 .enumerate()
278 .map(|(i, (new_input, col_index_mapping))| {
279 let mut exprs = (0..original_schema_len)
281 .map(|x| {
282 ExprImpl::InputRef(
283 InputRef::new(
284 col_index_mapping.map(x),
285 original_schema.fields[x].data_type.clone(),
286 )
287 .into(),
288 )
289 })
290 .collect_vec();
291 let mut input_stream_keys = input_stream_key_nulls.clone();
293 let mut types_counter = TypeMap::default();
294 for stream_key_idx in new_input.expect_stream_key() {
295 let data_type =
296 new_input.schema().fields[*stream_key_idx].data_type.clone();
297 let count = *types_counter
298 .entry(data_type.clone())
299 .and_modify(|x| *x += 1)
300 .or_insert(1);
301 let type_start_offset = *types_offset.get(&data_type).unwrap();
302
303 input_stream_keys[type_start_offset + count - 1] =
304 ExprImpl::InputRef(InputRef::new(*stream_key_idx, data_type).into());
305 }
306 exprs.extend(input_stream_keys);
307 exprs.push(ExprImpl::Literal(
309 Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
310 ));
311 LogicalProject::create(new_input, exprs)
312 })
313 .collect_vec();
314
315 let new_union = LogicalUnion::new_with_source_col(
316 self.all(),
317 new_inputs,
318 Some(original_schema_len + merged_stream_key_types.len()),
319 );
320 let out_col_change =
323 ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
324 Ok((new_union.into(), out_col_change))
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331
332 use risingwave_common::catalog::Field;
333
334 use super::*;
335 use crate::optimizer::optimizer_context::OptimizerContext;
336 use crate::optimizer::plan_node::{LogicalValues, PlanTreeNodeUnary};
337
338 #[tokio::test]
339 async fn test_prune_union() {
340 let ty = DataType::Int32;
341 let ctx = OptimizerContext::mock().await;
342 let fields: Vec<Field> = vec![
343 Field::with_name(ty.clone(), "v1"),
344 Field::with_name(ty.clone(), "v2"),
345 Field::with_name(ty.clone(), "v3"),
346 ];
347 let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
348
349 let values2 = values1.clone();
350
351 let union: PlanRef = LogicalUnion::new(false, vec![values1.into(), values2.into()]).into();
352
353 let required_cols = vec![1, 2];
355 let plan = union.prune_col(
356 &required_cols,
357 &mut ColumnPruningContext::new(union.clone()),
358 );
359
360 let union = plan.as_logical_union().unwrap();
362 assert_eq!(union.base.schema().len(), 2);
363 }
364
365 #[tokio::test]
366 async fn test_union_to_batch() {
367 let ty = DataType::Int32;
368 let ctx = OptimizerContext::mock().await;
369 let fields: Vec<Field> = vec![
370 Field::with_name(ty.clone(), "v1"),
371 Field::with_name(ty.clone(), "v2"),
372 Field::with_name(ty.clone(), "v3"),
373 ];
374 let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
375
376 let values2 = values1.clone();
377
378 let union = LogicalUnion::new(false, vec![values1.into(), values2.into()]);
379
380 let plan = union.to_batch().unwrap();
381 let agg: &BatchHashAgg = plan.as_batch_hash_agg().unwrap();
382 let agg_input = agg.input();
383 let union = agg_input.as_batch_union().unwrap();
384
385 assert_eq!(union.inputs().len(), 2);
386 }
387}