risingwave_frontend/optimizer/plan_node/
logical_union.rs1use std::cmp::max;
16use std::collections::BTreeMap;
17
18use itertools::Itertools;
19use risingwave_common::catalog::Schema;
20use risingwave_common::types::{DataType, Scalar};
21
22use super::utils::impl_distill_by_unit;
23use super::{
24 ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PredicatePushdown, ToBatch, ToStream,
25};
26use crate::Explain;
27use crate::error::Result;
28use crate::expr::{ExprImpl, InputRef, Literal};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::GenericPlanRef;
31use crate::optimizer::plan_node::stream_union::StreamUnion;
32use crate::optimizer::plan_node::{
33 BatchHashAgg, BatchUnion, ColumnPruningContext, LogicalProject, PlanTreeNode,
34 PredicatePushdownContext, RewriteStreamContext, ToStreamContext, generic,
35};
36use crate::optimizer::property::RequiredDist;
37use crate::utils::{ColIndexMapping, Condition};
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct LogicalUnion {
43 pub base: PlanBase<Logical>,
44 core: generic::Union<PlanRef>,
45}
46
47impl LogicalUnion {
48 pub fn new(all: bool, inputs: Vec<PlanRef>) -> Self {
49 assert!(Schema::all_type_eq(inputs.iter().map(|x| x.schema())));
50 Self::new_with_source_col(all, inputs, None)
51 }
52
53 pub fn new_with_source_col(all: bool, inputs: Vec<PlanRef>, source_col: Option<usize>) -> Self {
56 let core = generic::Union {
57 all,
58 inputs,
59 source_col,
60 };
61 let base = PlanBase::new_logical_with_core(&core);
62 LogicalUnion { base, core }
63 }
64
65 pub fn create(all: bool, inputs: Vec<PlanRef>) -> PlanRef {
66 LogicalUnion::new(all, inputs).into()
67 }
68
69 pub fn all(&self) -> bool {
70 self.core.all
71 }
72
73 pub fn source_col(&self) -> Option<usize> {
74 self.core.source_col
75 }
76}
77
78impl PlanTreeNode for LogicalUnion {
79 fn inputs(&self) -> smallvec::SmallVec<[crate::optimizer::PlanRef; 2]> {
80 self.core.inputs.clone().into_iter().collect()
81 }
82
83 fn clone_with_inputs(&self, inputs: &[crate::optimizer::PlanRef]) -> PlanRef {
84 Self::new_with_source_col(self.all(), inputs.to_vec(), self.core.source_col).into()
85 }
86}
87
88impl_distill_by_unit!(LogicalUnion, core, "LogicalUnion");
89
90impl ColPrunable for LogicalUnion {
91 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
92 let new_inputs = self
93 .inputs()
94 .iter()
95 .map(|input| input.prune_col(required_cols, ctx))
96 .collect_vec();
97 self.clone_with_inputs(&new_inputs)
98 }
99}
100
101impl ExprRewritable for LogicalUnion {}
102
103impl ExprVisitable for LogicalUnion {}
104
105impl PredicatePushdown for LogicalUnion {
106 fn predicate_pushdown(
107 &self,
108 predicate: Condition,
109 ctx: &mut PredicatePushdownContext,
110 ) -> PlanRef {
111 let new_inputs = self
112 .inputs()
113 .iter()
114 .map(|input| input.predicate_pushdown(predicate.clone(), ctx))
115 .collect_vec();
116 self.clone_with_inputs(&new_inputs)
117 }
118}
119
120impl ToBatch for LogicalUnion {
121 fn to_batch(&self) -> Result<PlanRef> {
122 let new_inputs = self
123 .inputs()
124 .iter()
125 .map(|input| input.to_batch())
126 .try_collect()?;
127 let new_logical = generic::Union {
128 all: true,
129 inputs: new_inputs,
130 source_col: None,
131 };
132 if !self.all() {
136 let batch_union = BatchUnion::new(new_logical).into();
137 Ok(BatchHashAgg::new(
138 generic::Agg::new(vec![], (0..self.base.schema().len()).collect(), batch_union)
139 .with_enable_two_phase(false),
140 )
141 .into())
142 } else {
143 Ok(BatchUnion::new(new_logical).into())
144 }
145 }
146}
147
148impl ToStream for LogicalUnion {
149 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
150 let dist = RequiredDist::hash_shard(self.base.stream_key().unwrap_or_else(|| {
152 panic!(
153 "should always have a stream key in the stream plan but not, sub plan: {}",
154 PlanRef::from(self.clone()).explain_to_string()
155 )
156 }));
157 let new_inputs: Result<Vec<_>> = self
158 .inputs()
159 .iter()
160 .map(|input| input.to_stream_with_dist_required(&dist, ctx))
161 .collect();
162 let new_logical = generic::Union {
163 all: true,
164 inputs: new_inputs?,
165 ..self.core
166 };
167 assert!(
168 self.all(),
169 "After UnionToDistinctRule, union should become union all"
170 );
171 Ok(StreamUnion::new(new_logical).into())
172 }
173
174 fn logical_rewrite_for_stream(
175 &self,
176 ctx: &mut RewriteStreamContext,
177 ) -> Result<(PlanRef, ColIndexMapping)> {
178 let original_schema = self.base.schema().clone();
179 let original_schema_len = original_schema.len();
180 let mut rewrites = vec![];
181 for input in &self.core.inputs {
182 rewrites.push(input.logical_rewrite_for_stream(ctx)?);
183 }
184
185 let original_schema_contain_all_input_stream_keys =
186 rewrites.iter().all(|(new_input, col_index_mapping)| {
187 let original_schema_new_pos = (0..original_schema_len)
188 .map(|x| col_index_mapping.map(x))
189 .collect_vec();
190 new_input
191 .expect_stream_key()
192 .iter()
193 .all(|x| original_schema_new_pos.contains(x))
194 });
195
196 if original_schema_contain_all_input_stream_keys {
197 let new_inputs = rewrites
200 .into_iter()
201 .enumerate()
202 .map(|(i, (new_input, col_index_mapping))| {
203 let mut exprs = (0..original_schema_len)
205 .map(|x| {
206 ExprImpl::InputRef(
207 InputRef::new(
208 col_index_mapping.map(x),
209 original_schema.fields[x].data_type.clone(),
210 )
211 .into(),
212 )
213 })
214 .collect_vec();
215 exprs.push(ExprImpl::Literal(
217 Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
218 ));
219 LogicalProject::create(new_input, exprs)
220 })
221 .collect_vec();
222 let new_union = LogicalUnion::new_with_source_col(
223 self.all(),
224 new_inputs,
225 Some(original_schema_len),
226 );
227 let out_col_change =
230 ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
231 Ok((new_union.into(), out_col_change))
232 } else {
233 let (merged_stream_key_types, types_offset) = {
239 let mut max_types_counter = BTreeMap::default();
240 for (new_input, _) in &rewrites {
241 let mut types_counter = BTreeMap::default();
242 for x in new_input.expect_stream_key() {
243 types_counter
244 .entry(new_input.schema().fields[*x].data_type())
245 .and_modify(|x| *x += 1)
246 .or_insert(1);
247 }
248 for (key, val) in types_counter {
249 max_types_counter
250 .entry(key)
251 .and_modify(|x| *x = max(*x, val))
252 .or_insert(val);
253 }
254 }
255
256 let mut merged_stream_key_types = vec![];
257 let mut types_offset = BTreeMap::default();
258 let mut offset = 0;
259 for (key, val) in max_types_counter {
260 let _ = types_offset.insert(key.clone(), offset);
261 offset += val;
262 merged_stream_key_types.extend(std::iter::repeat_n(key.clone(), val));
263 }
264
265 (merged_stream_key_types, types_offset)
266 };
267
268 let input_stream_key_nulls = merged_stream_key_types
269 .iter()
270 .map(|t| ExprImpl::Literal(Literal::new(None, t.clone()).into()))
271 .collect_vec();
272
273 let new_inputs = rewrites
274 .into_iter()
275 .enumerate()
276 .map(|(i, (new_input, col_index_mapping))| {
277 let mut exprs = (0..original_schema_len)
279 .map(|x| {
280 ExprImpl::InputRef(
281 InputRef::new(
282 col_index_mapping.map(x),
283 original_schema.fields[x].data_type.clone(),
284 )
285 .into(),
286 )
287 })
288 .collect_vec();
289 let mut input_stream_keys = input_stream_key_nulls.clone();
291 let mut types_counter = BTreeMap::default();
292 for stream_key_idx in new_input.expect_stream_key() {
293 let data_type =
294 new_input.schema().fields[*stream_key_idx].data_type.clone();
295 let count = *types_counter
296 .entry(data_type.clone())
297 .and_modify(|x| *x += 1)
298 .or_insert(1);
299 let type_start_offset = *types_offset.get(&data_type).unwrap();
300
301 input_stream_keys[type_start_offset + count - 1] =
302 ExprImpl::InputRef(InputRef::new(*stream_key_idx, data_type).into());
303 }
304 exprs.extend(input_stream_keys);
305 exprs.push(ExprImpl::Literal(
307 Literal::new(Some((i as i32).to_scalar_value()), DataType::Int32).into(),
308 ));
309 LogicalProject::create(new_input, exprs)
310 })
311 .collect_vec();
312
313 let new_union = LogicalUnion::new_with_source_col(
314 self.all(),
315 new_inputs,
316 Some(original_schema_len + merged_stream_key_types.len()),
317 );
318 let out_col_change =
321 ColIndexMapping::identity_or_none(original_schema_len, new_union.schema().len());
322 Ok((new_union.into(), out_col_change))
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329
330 use risingwave_common::catalog::Field;
331
332 use super::*;
333 use crate::optimizer::optimizer_context::OptimizerContext;
334 use crate::optimizer::plan_node::{LogicalValues, PlanTreeNodeUnary};
335
336 #[tokio::test]
337 async fn test_prune_union() {
338 let ty = DataType::Int32;
339 let ctx = OptimizerContext::mock().await;
340 let fields: Vec<Field> = vec![
341 Field::with_name(ty.clone(), "v1"),
342 Field::with_name(ty.clone(), "v2"),
343 Field::with_name(ty.clone(), "v3"),
344 ];
345 let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
346
347 let values2 = values1.clone();
348
349 let union: PlanRef = LogicalUnion::new(false, vec![values1.into(), values2.into()]).into();
350
351 let required_cols = vec![1, 2];
353 let plan = union.prune_col(
354 &required_cols,
355 &mut ColumnPruningContext::new(union.clone()),
356 );
357
358 let union = plan.as_logical_union().unwrap();
360 assert_eq!(union.base.schema().len(), 2);
361 }
362
363 #[tokio::test]
364 async fn test_union_to_batch() {
365 let ty = DataType::Int32;
366 let ctx = OptimizerContext::mock().await;
367 let fields: Vec<Field> = vec![
368 Field::with_name(ty.clone(), "v1"),
369 Field::with_name(ty.clone(), "v2"),
370 Field::with_name(ty.clone(), "v3"),
371 ];
372 let values1 = LogicalValues::new(vec![], Schema { fields }, ctx);
373
374 let values2 = values1.clone();
375
376 let union = LogicalUnion::new(false, vec![values1.into(), values2.into()]);
377
378 let plan = union.to_batch().unwrap();
379 let agg: &BatchHashAgg = plan.as_batch_hash_agg().unwrap();
380 let agg_input = agg.input();
381 let union = agg_input.as_batch_union().unwrap();
382
383 assert_eq!(union.inputs().len(), 2);
384 }
385}