risingwave_stream/executor/test_utils/
agg_executor.rs1use std::sync::Arc;
16use std::sync::atomic::AtomicU64;
17
18use futures::future;
19use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
20use risingwave_common::hash::SerializedKey;
21use risingwave_common::types::DataType;
22use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
23use risingwave_expr::aggregate::{AggCall, AggType, PbAggKind};
24use risingwave_pb::stream_plan::PbAggNodeVersion;
25use risingwave_storage::StateStore;
26
27use crate::common::StateTableColumnMapping;
28use crate::common::table::state_table::StateTable;
29use crate::common::table::test_utils::gen_pbtable;
30use crate::executor::aggregate::{
31 AggExecutorArgs, AggStateStorage, HashAggExecutor, HashAggExecutorExtraArgs, SimpleAggExecutor,
32 SimpleAggExecutorExtraArgs,
33};
34use crate::executor::{ActorContext, ActorContextRef, Executor, ExecutorInfo, PkIndices};
35
36pub fn generate_agg_schema(
39 input_ref: &Executor,
40 agg_calls: &[AggCall],
41 group_key_indices: Option<&[usize]>,
42) -> Schema {
43 let aggs = agg_calls
44 .iter()
45 .map(|agg| Field::unnamed(agg.return_type.clone()));
46
47 let fields = if let Some(key_indices) = group_key_indices {
48 let keys = key_indices
49 .iter()
50 .map(|idx| input_ref.schema().fields[*idx].clone());
51
52 keys.chain(aggs).collect()
53 } else {
54 aggs.collect()
55 };
56
57 Schema { fields }
58}
59
60pub async fn create_agg_state_storage<S: StateStore>(
63 store: S,
64 table_id: TableId,
65 agg_call: &AggCall,
66 group_key_indices: &[usize],
67 pk_indices: &[usize],
68 input_fields: Vec<Field>,
69 is_append_only: bool,
70) -> AggStateStorage<S> {
71 match agg_call.agg_type {
72 AggType::Builtin(PbAggKind::Min | PbAggKind::Max) if !is_append_only => {
73 let mut column_descs = Vec::new();
74 let mut order_types = Vec::new();
75 let mut upstream_columns = Vec::new();
76 let mut order_columns = Vec::new();
77
78 let mut next_column_id = 0;
79 let mut add_column = |upstream_idx: usize, data_type: DataType, order_type: Option<OrderType>| {
80 upstream_columns.push(upstream_idx);
81 column_descs.push(ColumnDesc::unnamed(
82 ColumnId::new(next_column_id),
83 data_type,
84 ));
85 if let Some(order_type) = order_type {
86 order_columns.push(ColumnOrder::new(upstream_idx as _, order_type));
87 order_types.push(order_type);
88 }
89 next_column_id += 1;
90 };
91
92 for idx in group_key_indices {
93 add_column(*idx, input_fields[*idx].data_type(), None);
94 }
95
96 add_column(agg_call.args.val_indices()[0], agg_call.args.arg_types()[0].clone(), if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Max)) {
97 Some(OrderType::descending())
98 } else {
99 Some(OrderType::ascending())
100 });
101
102 for idx in pk_indices {
103 add_column(*idx, input_fields[*idx].data_type(), Some(OrderType::ascending()));
104 }
105
106 let state_table = StateTable::from_table_catalog(
107 &gen_pbtable(
108 table_id,
109 column_descs,
110 order_types.clone(),
111 (0..order_types.len()).collect(),
112 0,
113 ),
114 store,
115 None,
116 ).await;
117
118 AggStateStorage::MaterializedInput { table: state_table, mapping: StateTableColumnMapping::new(upstream_columns, None), order_columns }
119 }
120 AggType::Builtin(
121 PbAggKind::Min | PbAggKind::Max | PbAggKind::Sum
124 | PbAggKind::Sum0
125 | PbAggKind::Count
126 | PbAggKind::Avg
127 | PbAggKind::ApproxCountDistinct
128 ) => {
129 AggStateStorage::Value
130 }
131 _ => {
132 panic!("no need to mock other agg kinds here");
133 }
134 }
135}
136
137pub async fn create_intermediate_state_table<S: StateStore>(
139 store: S,
140 table_id: TableId,
141 agg_calls: &[AggCall],
142 group_key_indices: &[usize],
143 input_fields: Vec<Field>,
144) -> StateTable<S> {
145 let mut column_descs = Vec::new();
146 let mut order_types = Vec::new();
147
148 let mut next_column_id = 0;
149 let mut add_column_desc = |data_type: DataType| {
150 column_descs.push(ColumnDesc::unnamed(
151 ColumnId::new(next_column_id),
152 data_type,
153 ));
154 next_column_id += 1;
155 };
156
157 group_key_indices.iter().for_each(|idx| {
158 add_column_desc(input_fields[*idx].data_type());
159 order_types.push(OrderType::ascending());
160 });
161
162 agg_calls.iter().for_each(|agg_call| {
163 add_column_desc(agg_call.return_type.clone());
164 });
165
166 StateTable::from_table_catalog_inconsistent_op(
167 &gen_pbtable(
168 table_id,
169 column_descs,
170 order_types,
171 (0..group_key_indices.len()).collect(),
172 0,
173 ),
174 store,
175 None,
176 )
177 .await
178}
179
180#[allow(clippy::too_many_arguments)]
182pub async fn new_boxed_hash_agg_executor<S: StateStore>(
183 store: S,
184 input: Executor,
185 is_append_only: bool,
186 agg_calls: Vec<AggCall>,
187 row_count_index: usize,
188 group_key_indices: Vec<usize>,
189 pk_indices: PkIndices,
190 extreme_cache_size: usize,
191 emit_on_window_close: bool,
192 executor_id: u64,
193) -> Executor {
194 let mut storages = Vec::with_capacity(agg_calls.iter().len());
195 for (idx, agg_call) in agg_calls.iter().enumerate() {
196 storages.push(
197 create_agg_state_storage(
198 store.clone(),
199 TableId::new(idx as u32),
200 agg_call,
201 &group_key_indices,
202 &pk_indices,
203 input.info.schema.fields.clone(),
204 is_append_only,
205 )
206 .await,
207 )
208 }
209
210 let intermediate_state_table = create_intermediate_state_table(
211 store,
212 TableId::new(agg_calls.len() as u32),
213 &agg_calls,
214 &group_key_indices,
215 input.info.schema.fields.clone(),
216 )
217 .await;
218
219 let schema = generate_agg_schema(&input, &agg_calls, Some(&group_key_indices));
220 let info = ExecutorInfo::new(
221 schema,
222 pk_indices,
223 "HashAggExecutor".to_owned(),
224 executor_id,
225 );
226
227 let exec = HashAggExecutor::<SerializedKey, S>::new(AggExecutorArgs {
228 version: PbAggNodeVersion::LATEST,
229
230 input,
231 actor_ctx: ActorContext::for_test(123),
232 info: info.clone(),
233
234 extreme_cache_size,
235
236 agg_calls,
237 row_count_index,
238 storages,
239 intermediate_state_table,
240 distinct_dedup_tables: Default::default(),
241 watermark_epoch: Arc::new(AtomicU64::new(0)),
242
243 extra: HashAggExecutorExtraArgs {
244 group_key_indices,
245 chunk_size: 1024,
246 max_dirty_groups_heap_size: 64 << 20,
247 emit_on_window_close,
248 },
249 })
250 .unwrap();
251 (info, exec).into()
252}
253
254#[allow(clippy::too_many_arguments)]
255pub async fn new_boxed_simple_agg_executor<S: StateStore>(
256 actor_ctx: ActorContextRef,
257 store: S,
258 input: Executor,
259 is_append_only: bool,
260 agg_calls: Vec<AggCall>,
261 row_count_index: usize,
262 pk_indices: PkIndices,
263 executor_id: u64,
264 must_output_per_barrier: bool,
265) -> Executor {
266 let storages = future::join_all(agg_calls.iter().enumerate().map(|(idx, agg_call)| {
267 create_agg_state_storage(
268 store.clone(),
269 TableId::new(idx as u32),
270 agg_call,
271 &[],
272 &pk_indices,
273 input.info.schema.fields.clone(),
274 is_append_only,
275 )
276 }))
277 .await;
278
279 let intermediate_state_table = create_intermediate_state_table(
280 store,
281 TableId::new(agg_calls.len() as u32),
282 &agg_calls,
283 &[],
284 input.info.schema.fields.clone(),
285 )
286 .await;
287
288 let schema = generate_agg_schema(&input, &agg_calls, None);
289 let info = ExecutorInfo::new(
290 schema,
291 pk_indices,
292 "SimpleAggExecutor".to_owned(),
293 executor_id,
294 );
295
296 let exec = SimpleAggExecutor::new(AggExecutorArgs {
297 version: PbAggNodeVersion::LATEST,
298
299 input,
300 actor_ctx,
301 info: info.clone(),
302
303 extreme_cache_size: 1024,
304
305 agg_calls,
306 row_count_index,
307 storages,
308 intermediate_state_table,
309 distinct_dedup_tables: Default::default(),
310 watermark_epoch: Arc::new(AtomicU64::new(0)),
311 extra: SimpleAggExecutorExtraArgs {
312 must_output_per_barrier,
313 },
314 })
315 .unwrap();
316 (info, exec).into()
317}