risingwave_frontend/optimizer/rule/
apply_agg_transpose_rule.rs1use risingwave_common::types::DataType;
16use risingwave_expr::aggregate::{AggType, PbAggKind};
17use risingwave_pb::plan_common::JoinType;
18
19use super::{ApplyOffsetRewriter, BoxedRule, Rule};
20use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
21use crate::optimizer::PlanRef;
22use crate::optimizer::plan_node::generic::Agg;
23use crate::optimizer::plan_node::{LogicalAgg, LogicalApply, LogicalFilter, LogicalProject};
24use crate::utils::{Condition, IndexSet};
25
26pub struct ApplyAggTransposeRule {}
48impl Rule for ApplyAggTransposeRule {
49 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
50 let apply: &LogicalApply = plan.as_logical_apply()?;
51 let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
52 apply.clone().decompose();
53 assert_eq!(join_type, JoinType::Inner);
54 let agg: &LogicalAgg = right.as_logical_agg()?;
55 let (mut agg_calls, agg_group_key, grouping_sets, agg_input, enable_two_phase) =
56 agg.clone().decompose();
57 assert!(grouping_sets.is_empty());
58 let is_scalar_agg = agg_group_key.is_empty();
59 let apply_left_len = left.schema().len();
60
61 if !is_scalar_agg && max_one_row {
62 return None;
64 }
65
66 let input = if is_scalar_agg {
67 let mut exprs: Vec<ExprImpl> = agg_input
69 .schema()
70 .data_types()
71 .into_iter()
72 .enumerate()
73 .map(|(i, data_type)| InputRef::new(i, data_type).into())
74 .collect();
75 exprs.push(ExprImpl::literal_int(1));
76 LogicalProject::create(agg_input, exprs)
77 } else {
78 agg_input
79 };
80
81 let node = if is_scalar_agg {
82 let left_len = left.schema().len();
84 let eq_predicates = left
85 .schema()
86 .data_types()
87 .into_iter()
88 .enumerate()
89 .map(|(i, data_type)| {
90 let left = InputRef::new(i, data_type.clone());
91 let right = InputRef::new(i + left_len, data_type);
92 FunctionCall::new_unchecked(
94 ExprType::IsNotDistinctFrom,
95 vec![left.into(), right.into()],
96 DataType::Boolean,
97 )
98 .into()
99 })
100 .collect();
101 LogicalApply::new(
102 left.clone(),
103 input,
104 JoinType::LeftOuter,
105 Condition::true_cond(),
106 correlated_id,
107 correlated_indices.clone(),
108 false,
109 false,
110 )
111 .translate_apply(left, eq_predicates)
112 } else {
113 LogicalApply::create(
114 left,
115 input,
116 JoinType::Inner,
117 Condition::true_cond(),
118 correlated_id,
119 correlated_indices.clone(),
120 false,
121 )
122 };
123
124 let group_agg = {
125 let offset = apply_left_len as isize;
127 let mut rewriter =
128 ApplyOffsetRewriter::new(apply_left_len, &correlated_indices, correlated_id);
129 agg_calls.iter_mut().for_each(|agg_call| {
130 agg_call.inputs.iter_mut().for_each(|input_ref| {
131 input_ref.shift_with_offset(offset);
132 });
133 agg_call
134 .order_by
135 .iter_mut()
136 .for_each(|o| o.shift_with_offset(offset));
137 agg_call.filter = agg_call.filter.clone().rewrite_expr(&mut rewriter);
138 });
139 if is_scalar_agg {
140 let pos_of_constant_column = node.schema().len() - 1;
142 agg_calls.iter_mut().for_each(|agg_call| {
143 match agg_call.agg_type {
144 AggType::Builtin(PbAggKind::Count) if agg_call.inputs.is_empty() => {
145 let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
146 agg_call.inputs.push(input_ref);
147 }
148 AggType::Builtin(PbAggKind::ArrayAgg
149 | PbAggKind::JsonbAgg
150 | PbAggKind::JsonbObjectAgg)
151 | AggType::UserDefined(_)
152 | AggType::WrapScalar(_) => {
153 let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
154 let cond = FunctionCall::new(ExprType::IsNotNull, vec![input_ref.into()]).unwrap();
155 agg_call.filter.conjunctions.push(cond.into());
156 }
157 AggType::Builtin(PbAggKind::Count
158 | PbAggKind::Sum
159 | PbAggKind::Sum0
160 | PbAggKind::Avg
161 | PbAggKind::Min
162 | PbAggKind::Max
163 | PbAggKind::BitAnd
164 | PbAggKind::BitOr
165 | PbAggKind::BitXor
166 | PbAggKind::BoolAnd
167 | PbAggKind::BoolOr
168 | PbAggKind::StringAgg
169 | PbAggKind::ApproxCountDistinct
171 | PbAggKind::FirstValue
172 | PbAggKind::LastValue
173 | PbAggKind::InternalLastSeenValue
174 | PbAggKind::ApproxPercentile
176 | PbAggKind::VarPop
177 | PbAggKind::VarSamp
178 | PbAggKind::StddevPop
179 | PbAggKind::StddevSamp
180 | PbAggKind::PercentileCont
182 | PbAggKind::PercentileDisc
183 | PbAggKind::Mode
184 | PbAggKind::Grouping)
186 => {
187 }
189 AggType::Builtin(PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar) => {
190 panic!("Unexpected aggregate function: {:?}", agg_call.agg_type)
191 }
192 }
193 });
194 }
195 let mut group_keys: IndexSet = (0..apply_left_len).collect();
196 group_keys.extend(agg_group_key.indices().map(|key| key + apply_left_len));
197 Agg::new(agg_calls, group_keys, node)
198 .with_enable_two_phase(enable_two_phase)
199 .into()
200 };
201
202 let filter = LogicalFilter::create(group_agg, on);
203 Some(filter)
204 }
205}
206
207impl ApplyAggTransposeRule {
208 pub fn create() -> BoxedRule {
209 Box::new(ApplyAggTransposeRule {})
210 }
211}