risingwave_stream/executor/test_utils/
agg_executor.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::sync::atomic::AtomicU64;
17
18use futures::future;
19use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
20use risingwave_common::hash::SerializedKey;
21use risingwave_common::types::DataType;
22use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
23use risingwave_expr::aggregate::{AggCall, AggType, PbAggKind};
24use risingwave_pb::stream_plan::PbAggNodeVersion;
25use risingwave_storage::StateStore;
26
27use crate::common::StateTableColumnMapping;
28use crate::common::table::state_table::StateTable;
29use crate::common::table::test_utils::gen_pbtable;
30use crate::executor::aggregate::{
31    AggExecutorArgs, AggStateStorage, HashAggExecutor, HashAggExecutorExtraArgs, SimpleAggExecutor,
32    SimpleAggExecutorExtraArgs,
33};
34use crate::executor::{ActorContext, ActorContextRef, Executor, ExecutorInfo, PkIndices};
35
36/// Generate aggExecuter's schema from `input`, `agg_calls` and `group_key_indices`.
37/// For [`crate::executor::aggregate::HashAggExecutor`], the group key indices should be provided.
38pub fn generate_agg_schema(
39    input_ref: &Executor,
40    agg_calls: &[AggCall],
41    group_key_indices: Option<&[usize]>,
42) -> Schema {
43    let aggs = agg_calls
44        .iter()
45        .map(|agg| Field::unnamed(agg.return_type.clone()));
46
47    let fields = if let Some(key_indices) = group_key_indices {
48        let keys = key_indices
49            .iter()
50            .map(|idx| input_ref.schema().fields[*idx].clone());
51
52        keys.chain(aggs).collect()
53    } else {
54        aggs.collect()
55    };
56
57    Schema { fields }
58}
59
60/// Create state storage for the given agg call.
61/// Should infer the schema in the same way as `LogicalAgg::infer_stream_agg_state`.
62pub async fn create_agg_state_storage<S: StateStore>(
63    store: S,
64    table_id: TableId,
65    agg_call: &AggCall,
66    group_key_indices: &[usize],
67    pk_indices: &[usize],
68    input_fields: Vec<Field>,
69    is_append_only: bool,
70) -> AggStateStorage<S> {
71    match agg_call.agg_type {
72            AggType::Builtin(PbAggKind::Min | PbAggKind::Max) if !is_append_only => {
73                let mut column_descs = Vec::new();
74                let mut order_types = Vec::new();
75                let mut upstream_columns = Vec::new();
76                let mut order_columns = Vec::new();
77
78                let mut next_column_id = 0;
79                let mut add_column = |upstream_idx: usize, data_type: DataType, order_type: Option<OrderType>| {
80                    upstream_columns.push(upstream_idx);
81                    column_descs.push(ColumnDesc::unnamed(
82                        ColumnId::new(next_column_id),
83                        data_type,
84                    ));
85                    if let Some(order_type) = order_type {
86                        order_columns.push(ColumnOrder::new(upstream_idx as _, order_type));
87                        order_types.push(order_type);
88                    }
89                    next_column_id += 1;
90                };
91
92                for idx in group_key_indices {
93                    add_column(*idx, input_fields[*idx].data_type(), None);
94                }
95
96                add_column(agg_call.args.val_indices()[0], agg_call.args.arg_types()[0].clone(), if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Max)) {
97                    Some(OrderType::descending())
98                } else {
99                    Some(OrderType::ascending())
100                });
101
102                for idx in pk_indices {
103                    add_column(*idx, input_fields[*idx].data_type(), Some(OrderType::ascending()));
104                }
105
106                let state_table = StateTable::from_table_catalog(
107                    &gen_pbtable(
108                        table_id,
109                        column_descs,
110                        order_types.clone(),
111                        (0..order_types.len()).collect(),
112                        0,
113                    ),
114                    store,
115                    None,
116                ).await;
117
118                AggStateStorage::MaterializedInput { table: state_table, mapping: StateTableColumnMapping::new(upstream_columns, None), order_columns }
119            }
120            AggType::Builtin(
121                PbAggKind::Min /* append only */
122                | PbAggKind::Max /* append only */
123                | PbAggKind::Sum
124                | PbAggKind::Sum0
125                | PbAggKind::Count
126                | PbAggKind::Avg
127                | PbAggKind::ApproxCountDistinct
128            ) => {
129                AggStateStorage::Value
130            }
131            _ => {
132                panic!("no need to mock other agg kinds here");
133            }
134        }
135}
136
137/// Create intermediate state table for agg executor.
138pub async fn create_intermediate_state_table<S: StateStore>(
139    store: S,
140    table_id: TableId,
141    agg_calls: &[AggCall],
142    group_key_indices: &[usize],
143    input_fields: Vec<Field>,
144) -> StateTable<S> {
145    let mut column_descs = Vec::new();
146    let mut order_types = Vec::new();
147
148    let mut next_column_id = 0;
149    let mut add_column_desc = |data_type: DataType| {
150        column_descs.push(ColumnDesc::unnamed(
151            ColumnId::new(next_column_id),
152            data_type,
153        ));
154        next_column_id += 1;
155    };
156
157    group_key_indices.iter().for_each(|idx| {
158        add_column_desc(input_fields[*idx].data_type());
159        order_types.push(OrderType::ascending());
160    });
161
162    agg_calls.iter().for_each(|agg_call| {
163        add_column_desc(agg_call.return_type.clone());
164    });
165
166    StateTable::from_table_catalog_inconsistent_op(
167        &gen_pbtable(
168            table_id,
169            column_descs,
170            order_types,
171            (0..group_key_indices.len()).collect(),
172            0,
173        ),
174        store,
175        None,
176    )
177    .await
178}
179
180/// NOTE(kwannoel): This should only be used by `test` or `bench`.
181#[allow(clippy::too_many_arguments)]
182pub async fn new_boxed_hash_agg_executor<S: StateStore>(
183    store: S,
184    input: Executor,
185    is_append_only: bool,
186    agg_calls: Vec<AggCall>,
187    row_count_index: usize,
188    group_key_indices: Vec<usize>,
189    pk_indices: PkIndices,
190    extreme_cache_size: usize,
191    emit_on_window_close: bool,
192    executor_id: u64,
193) -> Executor {
194    let mut storages = Vec::with_capacity(agg_calls.iter().len());
195    for (idx, agg_call) in agg_calls.iter().enumerate() {
196        storages.push(
197            create_agg_state_storage(
198                store.clone(),
199                TableId::new(idx as u32),
200                agg_call,
201                &group_key_indices,
202                &pk_indices,
203                input.info.schema.fields.clone(),
204                is_append_only,
205            )
206            .await,
207        )
208    }
209
210    let intermediate_state_table = create_intermediate_state_table(
211        store,
212        TableId::new(agg_calls.len() as u32),
213        &agg_calls,
214        &group_key_indices,
215        input.info.schema.fields.clone(),
216    )
217    .await;
218
219    let schema = generate_agg_schema(&input, &agg_calls, Some(&group_key_indices));
220    let info = ExecutorInfo::new(
221        schema,
222        pk_indices,
223        "HashAggExecutor".to_owned(),
224        executor_id,
225    );
226
227    let exec = HashAggExecutor::<SerializedKey, S>::new(AggExecutorArgs {
228        version: PbAggNodeVersion::LATEST,
229
230        input,
231        actor_ctx: ActorContext::for_test(123),
232        info: info.clone(),
233
234        extreme_cache_size,
235
236        agg_calls,
237        row_count_index,
238        storages,
239        intermediate_state_table,
240        distinct_dedup_tables: Default::default(),
241        watermark_epoch: Arc::new(AtomicU64::new(0)),
242
243        extra: HashAggExecutorExtraArgs {
244            group_key_indices,
245            chunk_size: 1024,
246            max_dirty_groups_heap_size: 64 << 20,
247            emit_on_window_close,
248        },
249    })
250    .unwrap();
251    (info, exec).into()
252}
253
254#[allow(clippy::too_many_arguments)]
255pub async fn new_boxed_simple_agg_executor<S: StateStore>(
256    actor_ctx: ActorContextRef,
257    store: S,
258    input: Executor,
259    is_append_only: bool,
260    agg_calls: Vec<AggCall>,
261    row_count_index: usize,
262    pk_indices: PkIndices,
263    executor_id: u64,
264    must_output_per_barrier: bool,
265) -> Executor {
266    let storages = future::join_all(agg_calls.iter().enumerate().map(|(idx, agg_call)| {
267        create_agg_state_storage(
268            store.clone(),
269            TableId::new(idx as u32),
270            agg_call,
271            &[],
272            &pk_indices,
273            input.info.schema.fields.clone(),
274            is_append_only,
275        )
276    }))
277    .await;
278
279    let intermediate_state_table = create_intermediate_state_table(
280        store,
281        TableId::new(agg_calls.len() as u32),
282        &agg_calls,
283        &[],
284        input.info.schema.fields.clone(),
285    )
286    .await;
287
288    let schema = generate_agg_schema(&input, &agg_calls, None);
289    let info = ExecutorInfo::new(
290        schema,
291        pk_indices,
292        "SimpleAggExecutor".to_owned(),
293        executor_id,
294    );
295
296    let exec = SimpleAggExecutor::new(AggExecutorArgs {
297        version: PbAggNodeVersion::LATEST,
298
299        input,
300        actor_ctx,
301        info: info.clone(),
302
303        extreme_cache_size: 1024,
304
305        agg_calls,
306        row_count_index,
307        storages,
308        intermediate_state_table,
309        distinct_dedup_tables: Default::default(),
310        watermark_epoch: Arc::new(AtomicU64::new(0)),
311        extra: SimpleAggExecutorExtraArgs {
312            must_output_per_barrier,
313        },
314    })
315    .unwrap();
316    (info, exec).into()
317}