risingwave_expr_impl/aggregate/
general.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use num_traits::{CheckedAdd, CheckedSub};
16use risingwave_expr::{ExprError, Result, aggregate};
17
18#[aggregate("sum(int2) -> int8")]
19#[aggregate("sum(int4) -> int8")]
20#[aggregate("sum(int8) -> decimal")]
21#[aggregate("sum(float4) -> float4")]
22#[aggregate("sum(float8) -> float8")]
23#[aggregate("sum(decimal) -> decimal")]
24#[aggregate("sum(interval) -> interval")]
25#[aggregate("sum(int256) -> int256")]
26#[aggregate("sum(int8) -> int8", internal)] // used internally for 2-phase sum(int2) and sum(int4)
27#[aggregate("sum0(int8) -> int8", internal, init_state = "0i64")] // used internally for 2-phase count
28fn sum<S, T>(state: S, input: T, retract: bool) -> Result<S>
29where
30    S: Default + From<T> + CheckedAdd<Output = S> + CheckedSub<Output = S>,
31{
32    if retract {
33        state
34            .checked_sub(&S::from(input))
35            .ok_or_else(|| ExprError::NumericOutOfRange)
36    } else {
37        state
38            .checked_add(&S::from(input))
39            .ok_or_else(|| ExprError::NumericOutOfRange)
40    }
41}
42
43#[aggregate("avg(int2) -> decimal", rewritten)]
44#[aggregate("avg(int4) -> decimal", rewritten)]
45#[aggregate("avg(int8) -> decimal", rewritten)]
46#[aggregate("avg(decimal) -> decimal", rewritten)]
47#[aggregate("avg(float4) -> float8", rewritten)]
48#[aggregate("avg(float8) -> float8", rewritten)]
49#[aggregate("avg(int256) -> float8", rewritten)]
50#[aggregate("avg(interval) -> interval", rewritten)]
51fn _avg() {}
52
53#[aggregate("stddev_pop(int2) -> decimal", rewritten)]
54#[aggregate("stddev_pop(int4) -> decimal", rewritten)]
55#[aggregate("stddev_pop(int8) -> decimal", rewritten)]
56#[aggregate("stddev_pop(decimal) -> decimal", rewritten)]
57#[aggregate("stddev_pop(float4) -> float8", rewritten)]
58#[aggregate("stddev_pop(float8) -> float8", rewritten)]
59#[aggregate("stddev_pop(int256) -> float8", rewritten)]
60fn _stddev_pop() {}
61
62#[aggregate("stddev_samp(int2) -> decimal", rewritten)]
63#[aggregate("stddev_samp(int4) -> decimal", rewritten)]
64#[aggregate("stddev_samp(int8) -> decimal", rewritten)]
65#[aggregate("stddev_samp(decimal) -> decimal", rewritten)]
66#[aggregate("stddev_samp(float4) -> float8", rewritten)]
67#[aggregate("stddev_samp(float8) -> float8", rewritten)]
68#[aggregate("stddev_samp(int256) -> float8", rewritten)]
69fn _stddev_samp() {}
70
71#[aggregate("var_pop(int2) -> decimal", rewritten)]
72#[aggregate("var_pop(int4) -> decimal", rewritten)]
73#[aggregate("var_pop(int8) -> decimal", rewritten)]
74#[aggregate("var_pop(decimal) -> decimal", rewritten)]
75#[aggregate("var_pop(float4) -> float8", rewritten)]
76#[aggregate("var_pop(float8) -> float8", rewritten)]
77#[aggregate("var_pop(int256) -> float8", rewritten)]
78fn _var_pop() {}
79
80#[aggregate("var_samp(int2) -> decimal", rewritten)]
81#[aggregate("var_samp(int4) -> decimal", rewritten)]
82#[aggregate("var_samp(int8) -> decimal", rewritten)]
83#[aggregate("var_samp(decimal) -> decimal", rewritten)]
84#[aggregate("var_samp(float4) -> float8", rewritten)]
85#[aggregate("var_samp(float8) -> float8", rewritten)]
86#[aggregate("var_samp(int256) -> float8", rewritten)]
87fn _var_samp() {}
88
89// no `min(boolean)` and `min(jsonb)`
90#[aggregate("min(*int) -> auto", state = "ref")]
91#[aggregate("min(*float) -> auto", state = "ref")]
92#[aggregate("min(decimal) -> auto", state = "ref")]
93#[aggregate("min(int256) -> auto", state = "ref")]
94#[aggregate("min(serial) -> auto", state = "ref")]
95#[aggregate("min(date) -> auto", state = "ref")]
96#[aggregate("min(time) -> auto", state = "ref")]
97#[aggregate("min(interval) -> auto", state = "ref")]
98#[aggregate("min(timestamp) -> auto", state = "ref")]
99#[aggregate("min(timestamptz) -> auto", state = "ref")]
100#[aggregate("min(varchar) -> auto", state = "ref")]
101#[aggregate("min(bytea) -> auto", state = "ref")]
102#[aggregate("min(anyarray) -> auto", state = "ref")]
103#[aggregate("min(struct) -> auto", state = "ref")]
104fn min<T: Ord>(state: T, input: T) -> T {
105    state.min(input)
106}
107
108// no `max(boolean)` and `max(jsonb)`
109#[aggregate("max(*int) -> auto", state = "ref")]
110#[aggregate("max(*float) -> auto", state = "ref")]
111#[aggregate("max(decimal) -> auto", state = "ref")]
112#[aggregate("max(int256) -> auto", state = "ref")]
113#[aggregate("max(serial) -> auto", state = "ref")]
114#[aggregate("max(date) -> auto", state = "ref")]
115#[aggregate("max(time) -> auto", state = "ref")]
116#[aggregate("max(interval) -> auto", state = "ref")]
117#[aggregate("max(timestamp) -> auto", state = "ref")]
118#[aggregate("max(timestamptz) -> auto", state = "ref")]
119#[aggregate("max(varchar) -> auto", state = "ref")]
120#[aggregate("max(bytea) -> auto", state = "ref")]
121#[aggregate("max(anyarray) -> auto", state = "ref")]
122#[aggregate("max(struct) -> auto", state = "ref")]
123fn max<T: Ord>(state: T, input: T) -> T {
124    state.max(input)
125}
126
127/// Note the following corner cases:
128///
129/// ```slt
130/// statement ok
131/// create table t(v1 int);
132///
133/// statement ok
134/// insert into t values (null);
135///
136/// query I
137/// select count(*) from t;
138/// ----
139/// 1
140///
141/// query I
142/// select count(v1) from t;
143/// ----
144/// 0
145///
146/// query I
147/// select sum(v1) from t;
148/// ----
149/// NULL
150///
151/// statement ok
152/// drop table t;
153/// ```
154#[aggregate("count(*) -> int8", init_state = "0i64")]
155fn count<T>(state: i64, _: T, retract: bool) -> i64 {
156    if retract { state - 1 } else { state + 1 }
157}
158
159#[aggregate("count() -> int8", init_state = "0i64")]
160fn count_star(state: i64, retract: bool) -> i64 {
161    if retract { state - 1 } else { state + 1 }
162}
163
164#[cfg(test)]
165mod tests {
166    extern crate test;
167
168    use std::sync::Arc;
169
170    use futures_util::FutureExt;
171    use risingwave_common::array::*;
172    use risingwave_common::test_utils::{rand_bitmap, rand_stream_chunk};
173    use risingwave_common::types::{Datum, Decimal};
174    use risingwave_expr::aggregate::{AggCall, build_append_only};
175    use test::Bencher;
176
177    fn test_agg(pretty: &str, input: StreamChunk, expected: Datum) {
178        let agg = build_append_only(&AggCall::from_pretty(pretty)).unwrap();
179        let mut state = agg.create_state().unwrap();
180        agg.update(&mut state, &input)
181            .now_or_never()
182            .unwrap()
183            .unwrap();
184        let actual = agg.get_result(&state).now_or_never().unwrap().unwrap();
185        assert_eq!(actual, expected);
186    }
187
188    #[test]
189    fn sum_int4() {
190        let input = StreamChunk::from_pretty(
191            " i
192            + 3
193            - 1
194            - 3 D
195            + 1 D",
196        );
197        test_agg("(sum:int8 $0:int4)", input, Some(2i64.into()));
198    }
199
200    #[test]
201    fn sum_int8() {
202        let input = StreamChunk::from_pretty(
203            " I
204            + 3
205            - 1
206            - 3 D
207            + 1 D",
208        );
209        test_agg(
210            "(sum:decimal $0:int8)",
211            input,
212            Some(Decimal::from(2).into()),
213        );
214    }
215
216    #[test]
217    fn sum_float8() {
218        let input = StreamChunk::from_pretty(
219            " F
220            + 1.0
221            + 2.0
222            + 3.0
223            - 4.0",
224        );
225        test_agg("(sum:float8 $0:float8)", input, Some(2.0f64.into()));
226
227        let input = StreamChunk::from_pretty(
228            " F
229            + 1.0
230            + inf
231            + 3.0
232            - 3.0",
233        );
234        test_agg("(sum:float8 $0:float8)", input, Some(f64::INFINITY.into()));
235
236        let input = StreamChunk::from_pretty(
237            " F
238            + 0.0
239            - -inf",
240        );
241        test_agg("(sum:float8 $0:float8)", input, Some(f64::INFINITY.into()));
242
243        let input = StreamChunk::from_pretty(
244            " F
245            + 1.0
246            + nan
247            + 1926.0",
248        );
249        test_agg("(sum:float8 $0:float8)", input, Some(f64::NAN.into()));
250    }
251
252    /// Even if there is no element after some insertions and equal number of deletion operations,
253    /// sum `AggregateFunction` should output `0` instead of `None`.
254    #[test]
255    fn sum_no_none() {
256        test_agg("(sum:int8 $0:int8)", StreamChunk::from_pretty("I"), None);
257
258        let input = StreamChunk::from_pretty(
259            " I
260            + 2
261            - 1
262            + 1
263            - 2",
264        );
265        test_agg("(sum:int8 $0:int8)", input, Some(0i64.into()));
266
267        let input = StreamChunk::from_pretty(
268            " I
269            - 3 D
270            + 1
271            - 3 D
272            - 1",
273        );
274        test_agg("(sum:int8 $0:int8)", input, Some(0i64.into()));
275    }
276
277    #[test]
278    fn min_int8() {
279        let input = StreamChunk::from_pretty(
280            " I
281            + 1  D
282            + 10
283            + .
284            + 5",
285        );
286        test_agg("(min:int8 $0:int8)", input, Some(5i64.into()));
287    }
288
289    #[test]
290    fn min_float4() {
291        let input = StreamChunk::from_pretty(
292            " f
293            + 1.0  D
294            + 10.0
295            + .
296            + 5.0",
297        );
298        test_agg("(min:float4 $0:float4)", input, Some(5.0f32.into()));
299    }
300
301    #[test]
302    fn min_char() {
303        let input = StreamChunk::from_pretty(
304            " T
305            + b
306            + aa",
307        );
308        test_agg("(min:varchar $0:varchar)", input, Some("aa".into()));
309    }
310
311    #[test]
312    fn min_list() {
313        let input = StreamChunk::from_pretty(
314            " i[]
315            + {0}
316            + {1}
317            + {2}",
318        );
319        test_agg(
320            "(min:int4[] $0:int4[])",
321            input,
322            Some(ListValue::from_iter([0]).into()),
323        );
324    }
325
326    #[test]
327    fn max_int8() {
328        let input = StreamChunk::from_pretty(
329            " I
330            + 1
331            + 10 D
332            + .
333            + 5",
334        );
335        test_agg("(max:int8 $0:int8)", input, Some(5i64.into()));
336    }
337
338    #[test]
339    fn max_char() {
340        let input = StreamChunk::from_pretty(
341            " T
342            + b
343            + aa",
344        );
345        test_agg("(max:varchar $0:varchar)", input, Some("b".into()));
346    }
347
348    #[test]
349    fn count_int4() {
350        let input = StreamChunk::from_pretty(
351            " i
352            + 1
353            + 2
354            + 3",
355        );
356        test_agg("(count:int8 $0:int4)", input, Some(3i64.into()));
357
358        let input = StreamChunk::from_pretty(
359            " i
360            + 1
361            + .
362            + 3
363            - 1",
364        );
365        test_agg("(count:int8 $0:int4)", input, Some(1i64.into()));
366
367        let input = StreamChunk::from_pretty(
368            " i
369            - 1 D
370            - .
371            - 3 D
372            - 1 D",
373        );
374        test_agg("(count:int8 $0:int4)", input, Some(0i64.into()));
375
376        let input = StreamChunk::from_pretty("i");
377        test_agg("(count:int8 $0:int4)", input, Some(0i64.into()));
378
379        let input = StreamChunk::from_pretty(
380            " i
381            + .",
382        );
383        test_agg("(count:int8 $0:int4)", input, Some(0i64.into()));
384    }
385
386    #[test]
387    fn count_star() {
388        // when there is no element, output should be `0`.
389        let input = StreamChunk::from_pretty("i");
390        test_agg("(count:int8)", input, Some(0i64.into()));
391
392        // insert one element to state
393        let input = StreamChunk::from_pretty(
394            " i
395            + 0",
396        );
397        test_agg("(count:int8)", input, Some(1i64.into()));
398
399        // delete one element from state
400        let input = StreamChunk::from_pretty(
401            " i
402            + 0
403            - 0",
404        );
405        test_agg("(count:int8)", input, Some(0i64.into()));
406
407        let input = StreamChunk::from_pretty(
408            " i
409            - 0
410            - 0 D
411            + 1
412            - 1",
413        );
414        test_agg("(count:int8)", input, Some((-1i64).into()));
415    }
416
417    #[test]
418    fn bitxor_int8() {
419        let input = StreamChunk::from_pretty(
420            " I
421            + 1
422            - 10 D
423            + .
424            - 5",
425        );
426        test_agg("(bit_xor:int8 $0:int8)", input, Some(4i64.into()));
427    }
428
429    fn bench_i64(
430        b: &mut Bencher,
431        agg_desc: &str,
432        chunk_size: usize,
433        vis_rate: f64,
434        append_only: bool,
435    ) {
436        println!(
437            "benching {} agg, chunk_size={}, vis_rate={}",
438            agg_desc, chunk_size, vis_rate
439        );
440        let vis =
441            rand_bitmap::gen_rand_bitmap(chunk_size, (chunk_size as f64 * vis_rate) as usize, 666);
442        let (ops, data) =
443            rand_stream_chunk::gen_legal_stream_chunk(&vis, chunk_size, append_only, 666);
444        let chunk = StreamChunk::from_parts(ops, DataChunk::new(vec![Arc::new(data)], vis));
445        let pretty = format!("({agg_desc}:int8 $0:int8)");
446        let agg = build_append_only(&AggCall::from_pretty(pretty)).unwrap();
447        let mut state = agg.create_state().unwrap();
448        b.iter(|| {
449            agg.update(&mut state, &chunk)
450                .now_or_never()
451                .unwrap()
452                .unwrap();
453        });
454    }
455
456    #[bench]
457    fn sum_agg_without_vis(b: &mut Bencher) {
458        bench_i64(b, "sum", 1024, 1.0, false);
459    }
460
461    #[bench]
462    fn sum_agg_vis_rate_0_75(b: &mut Bencher) {
463        bench_i64(b, "sum", 1024, 0.75, false);
464    }
465
466    #[bench]
467    fn sum_agg_vis_rate_0_5(b: &mut Bencher) {
468        bench_i64(b, "sum", 1024, 0.5, false);
469    }
470
471    #[bench]
472    fn sum_agg_vis_rate_0_25(b: &mut Bencher) {
473        bench_i64(b, "sum", 1024, 0.25, false);
474    }
475
476    #[bench]
477    fn sum_agg_vis_rate_0_05(b: &mut Bencher) {
478        bench_i64(b, "sum", 1024, 0.05, false);
479    }
480
481    #[bench]
482    fn count_agg_without_vis(b: &mut Bencher) {
483        bench_i64(b, "count", 1024, 1.0, false);
484    }
485
486    #[bench]
487    fn count_agg_vis_rate_0_75(b: &mut Bencher) {
488        bench_i64(b, "count", 1024, 0.75, false);
489    }
490
491    #[bench]
492    fn count_agg_vis_rate_0_5(b: &mut Bencher) {
493        bench_i64(b, "count", 1024, 0.5, false);
494    }
495
496    #[bench]
497    fn count_agg_vis_rate_0_25(b: &mut Bencher) {
498        bench_i64(b, "count", 1024, 0.25, false);
499    }
500
501    #[bench]
502    fn count_agg_vis_rate_0_05(b: &mut Bencher) {
503        bench_i64(b, "count", 1024, 0.05, false);
504    }
505
506    #[bench]
507    fn min_agg_without_vis(b: &mut Bencher) {
508        bench_i64(b, "min", 1024, 1.0, true);
509    }
510
511    #[bench]
512    fn min_agg_vis_rate_0_75(b: &mut Bencher) {
513        bench_i64(b, "min", 1024, 0.75, true);
514    }
515
516    #[bench]
517    fn min_agg_vis_rate_0_5(b: &mut Bencher) {
518        bench_i64(b, "min", 1024, 0.5, true);
519    }
520
521    #[bench]
522    fn min_agg_vis_rate_0_25(b: &mut Bencher) {
523        bench_i64(b, "min", 1024, 0.25, true);
524    }
525
526    #[bench]
527    fn min_agg_vis_rate_0_05(b: &mut Bencher) {
528        bench_i64(b, "min", 1024, 0.05, true);
529    }
530
531    #[bench]
532    fn max_agg_without_vis(b: &mut Bencher) {
533        bench_i64(b, "max", 1024, 1.0, true);
534    }
535
536    #[bench]
537    fn max_agg_vis_rate_0_75(b: &mut Bencher) {
538        bench_i64(b, "max", 1024, 0.75, true);
539    }
540
541    #[bench]
542    fn max_agg_vis_rate_0_5(b: &mut Bencher) {
543        bench_i64(b, "max", 1024, 0.5, true);
544    }
545
546    #[bench]
547    fn max_agg_vis_rate_0_25(b: &mut Bencher) {
548        bench_i64(b, "max", 1024, 0.25, true);
549    }
550
551    #[bench]
552    fn max_agg_vis_rate_0_05(b: &mut Bencher) {
553        bench_i64(b, "max", 1024, 0.05, true);
554    }
555}