1use num_traits::{CheckedAdd, CheckedSub};
16use risingwave_expr::{ExprError, Result, aggregate};
17
18#[aggregate("sum(int2) -> int8")]
19#[aggregate("sum(int4) -> int8")]
20#[aggregate("sum(int8) -> decimal")]
21#[aggregate("sum(float4) -> float4")]
22#[aggregate("sum(float8) -> float8")]
23#[aggregate("sum(decimal) -> decimal")]
24#[aggregate("sum(interval) -> interval")]
25#[aggregate("sum(int256) -> int256")]
26#[aggregate("sum(int8) -> int8", internal)] #[aggregate("sum0(int8) -> int8", internal, init_state = "0i64")] fn sum<S, T>(state: S, input: T, retract: bool) -> Result<S>
29where
30 S: Default + From<T> + CheckedAdd<Output = S> + CheckedSub<Output = S>,
31{
32 if retract {
33 state
34 .checked_sub(&S::from(input))
35 .ok_or_else(|| ExprError::NumericOutOfRange)
36 } else {
37 state
38 .checked_add(&S::from(input))
39 .ok_or_else(|| ExprError::NumericOutOfRange)
40 }
41}
42
43#[aggregate("avg(int2) -> decimal", rewritten)]
44#[aggregate("avg(int4) -> decimal", rewritten)]
45#[aggregate("avg(int8) -> decimal", rewritten)]
46#[aggregate("avg(decimal) -> decimal", rewritten)]
47#[aggregate("avg(float4) -> float8", rewritten)]
48#[aggregate("avg(float8) -> float8", rewritten)]
49#[aggregate("avg(int256) -> float8", rewritten)]
50#[aggregate("avg(interval) -> interval", rewritten)]
51fn _avg() {}
52
53#[aggregate("stddev_pop(int2) -> decimal", rewritten)]
54#[aggregate("stddev_pop(int4) -> decimal", rewritten)]
55#[aggregate("stddev_pop(int8) -> decimal", rewritten)]
56#[aggregate("stddev_pop(decimal) -> decimal", rewritten)]
57#[aggregate("stddev_pop(float4) -> float8", rewritten)]
58#[aggregate("stddev_pop(float8) -> float8", rewritten)]
59#[aggregate("stddev_pop(int256) -> float8", rewritten)]
60fn _stddev_pop() {}
61
62#[aggregate("stddev_samp(int2) -> decimal", rewritten)]
63#[aggregate("stddev_samp(int4) -> decimal", rewritten)]
64#[aggregate("stddev_samp(int8) -> decimal", rewritten)]
65#[aggregate("stddev_samp(decimal) -> decimal", rewritten)]
66#[aggregate("stddev_samp(float4) -> float8", rewritten)]
67#[aggregate("stddev_samp(float8) -> float8", rewritten)]
68#[aggregate("stddev_samp(int256) -> float8", rewritten)]
69fn _stddev_samp() {}
70
71#[aggregate("var_pop(int2) -> decimal", rewritten)]
72#[aggregate("var_pop(int4) -> decimal", rewritten)]
73#[aggregate("var_pop(int8) -> decimal", rewritten)]
74#[aggregate("var_pop(decimal) -> decimal", rewritten)]
75#[aggregate("var_pop(float4) -> float8", rewritten)]
76#[aggregate("var_pop(float8) -> float8", rewritten)]
77#[aggregate("var_pop(int256) -> float8", rewritten)]
78fn _var_pop() {}
79
80#[aggregate("var_samp(int2) -> decimal", rewritten)]
81#[aggregate("var_samp(int4) -> decimal", rewritten)]
82#[aggregate("var_samp(int8) -> decimal", rewritten)]
83#[aggregate("var_samp(decimal) -> decimal", rewritten)]
84#[aggregate("var_samp(float4) -> float8", rewritten)]
85#[aggregate("var_samp(float8) -> float8", rewritten)]
86#[aggregate("var_samp(int256) -> float8", rewritten)]
87fn _var_samp() {}
88
89#[aggregate("min(*int) -> auto", state = "ref")]
91#[aggregate("min(*float) -> auto", state = "ref")]
92#[aggregate("min(decimal) -> auto", state = "ref")]
93#[aggregate("min(int256) -> auto", state = "ref")]
94#[aggregate("min(serial) -> auto", state = "ref")]
95#[aggregate("min(date) -> auto", state = "ref")]
96#[aggregate("min(time) -> auto", state = "ref")]
97#[aggregate("min(interval) -> auto", state = "ref")]
98#[aggregate("min(timestamp) -> auto", state = "ref")]
99#[aggregate("min(timestamptz) -> auto", state = "ref")]
100#[aggregate("min(varchar) -> auto", state = "ref")]
101#[aggregate("min(bytea) -> auto", state = "ref")]
102#[aggregate("min(anyarray) -> auto", state = "ref")]
103#[aggregate("min(struct) -> auto", state = "ref")]
104fn min<T: Ord>(state: T, input: T) -> T {
105 state.min(input)
106}
107
108#[aggregate("max(*int) -> auto", state = "ref")]
110#[aggregate("max(*float) -> auto", state = "ref")]
111#[aggregate("max(decimal) -> auto", state = "ref")]
112#[aggregate("max(int256) -> auto", state = "ref")]
113#[aggregate("max(serial) -> auto", state = "ref")]
114#[aggregate("max(date) -> auto", state = "ref")]
115#[aggregate("max(time) -> auto", state = "ref")]
116#[aggregate("max(interval) -> auto", state = "ref")]
117#[aggregate("max(timestamp) -> auto", state = "ref")]
118#[aggregate("max(timestamptz) -> auto", state = "ref")]
119#[aggregate("max(varchar) -> auto", state = "ref")]
120#[aggregate("max(bytea) -> auto", state = "ref")]
121#[aggregate("max(anyarray) -> auto", state = "ref")]
122#[aggregate("max(struct) -> auto", state = "ref")]
123fn max<T: Ord>(state: T, input: T) -> T {
124 state.max(input)
125}
126
127#[aggregate("count(*) -> int8", init_state = "0i64")]
155fn count<T>(state: i64, _: T, retract: bool) -> i64 {
156 if retract { state - 1 } else { state + 1 }
157}
158
159#[aggregate("count() -> int8", init_state = "0i64")]
160fn count_star(state: i64, retract: bool) -> i64 {
161 if retract { state - 1 } else { state + 1 }
162}
163
164#[cfg(test)]
165mod tests {
166 extern crate test;
167
168 use std::sync::Arc;
169
170 use futures_util::FutureExt;
171 use risingwave_common::array::*;
172 use risingwave_common::test_utils::{rand_bitmap, rand_stream_chunk};
173 use risingwave_common::types::{Datum, Decimal};
174 use risingwave_expr::aggregate::{AggCall, build_append_only};
175 use test::Bencher;
176
177 fn test_agg(pretty: &str, input: StreamChunk, expected: Datum) {
178 let agg = build_append_only(&AggCall::from_pretty(pretty)).unwrap();
179 let mut state = agg.create_state().unwrap();
180 agg.update(&mut state, &input)
181 .now_or_never()
182 .unwrap()
183 .unwrap();
184 let actual = agg.get_result(&state).now_or_never().unwrap().unwrap();
185 assert_eq!(actual, expected);
186 }
187
188 #[test]
189 fn sum_int4() {
190 let input = StreamChunk::from_pretty(
191 " i
192 + 3
193 - 1
194 - 3 D
195 + 1 D",
196 );
197 test_agg("(sum:int8 $0:int4)", input, Some(2i64.into()));
198 }
199
200 #[test]
201 fn sum_int8() {
202 let input = StreamChunk::from_pretty(
203 " I
204 + 3
205 - 1
206 - 3 D
207 + 1 D",
208 );
209 test_agg(
210 "(sum:decimal $0:int8)",
211 input,
212 Some(Decimal::from(2).into()),
213 );
214 }
215
216 #[test]
217 fn sum_float8() {
218 let input = StreamChunk::from_pretty(
219 " F
220 + 1.0
221 + 2.0
222 + 3.0
223 - 4.0",
224 );
225 test_agg("(sum:float8 $0:float8)", input, Some(2.0f64.into()));
226
227 let input = StreamChunk::from_pretty(
228 " F
229 + 1.0
230 + inf
231 + 3.0
232 - 3.0",
233 );
234 test_agg("(sum:float8 $0:float8)", input, Some(f64::INFINITY.into()));
235
236 let input = StreamChunk::from_pretty(
237 " F
238 + 0.0
239 - -inf",
240 );
241 test_agg("(sum:float8 $0:float8)", input, Some(f64::INFINITY.into()));
242
243 let input = StreamChunk::from_pretty(
244 " F
245 + 1.0
246 + nan
247 + 1926.0",
248 );
249 test_agg("(sum:float8 $0:float8)", input, Some(f64::NAN.into()));
250 }
251
252 #[test]
255 fn sum_no_none() {
256 test_agg("(sum:int8 $0:int8)", StreamChunk::from_pretty("I"), None);
257
258 let input = StreamChunk::from_pretty(
259 " I
260 + 2
261 - 1
262 + 1
263 - 2",
264 );
265 test_agg("(sum:int8 $0:int8)", input, Some(0i64.into()));
266
267 let input = StreamChunk::from_pretty(
268 " I
269 - 3 D
270 + 1
271 - 3 D
272 - 1",
273 );
274 test_agg("(sum:int8 $0:int8)", input, Some(0i64.into()));
275 }
276
277 #[test]
278 fn min_int8() {
279 let input = StreamChunk::from_pretty(
280 " I
281 + 1 D
282 + 10
283 + .
284 + 5",
285 );
286 test_agg("(min:int8 $0:int8)", input, Some(5i64.into()));
287 }
288
289 #[test]
290 fn min_float4() {
291 let input = StreamChunk::from_pretty(
292 " f
293 + 1.0 D
294 + 10.0
295 + .
296 + 5.0",
297 );
298 test_agg("(min:float4 $0:float4)", input, Some(5.0f32.into()));
299 }
300
301 #[test]
302 fn min_char() {
303 let input = StreamChunk::from_pretty(
304 " T
305 + b
306 + aa",
307 );
308 test_agg("(min:varchar $0:varchar)", input, Some("aa".into()));
309 }
310
311 #[test]
312 fn min_list() {
313 let input = StreamChunk::from_pretty(
314 " i[]
315 + {0}
316 + {1}
317 + {2}",
318 );
319 test_agg(
320 "(min:int4[] $0:int4[])",
321 input,
322 Some(ListValue::from_iter([0]).into()),
323 );
324 }
325
326 #[test]
327 fn max_int8() {
328 let input = StreamChunk::from_pretty(
329 " I
330 + 1
331 + 10 D
332 + .
333 + 5",
334 );
335 test_agg("(max:int8 $0:int8)", input, Some(5i64.into()));
336 }
337
338 #[test]
339 fn max_char() {
340 let input = StreamChunk::from_pretty(
341 " T
342 + b
343 + aa",
344 );
345 test_agg("(max:varchar $0:varchar)", input, Some("b".into()));
346 }
347
348 #[test]
349 fn count_int4() {
350 let input = StreamChunk::from_pretty(
351 " i
352 + 1
353 + 2
354 + 3",
355 );
356 test_agg("(count:int8 $0:int4)", input, Some(3i64.into()));
357
358 let input = StreamChunk::from_pretty(
359 " i
360 + 1
361 + .
362 + 3
363 - 1",
364 );
365 test_agg("(count:int8 $0:int4)", input, Some(1i64.into()));
366
367 let input = StreamChunk::from_pretty(
368 " i
369 - 1 D
370 - .
371 - 3 D
372 - 1 D",
373 );
374 test_agg("(count:int8 $0:int4)", input, Some(0i64.into()));
375
376 let input = StreamChunk::from_pretty("i");
377 test_agg("(count:int8 $0:int4)", input, Some(0i64.into()));
378
379 let input = StreamChunk::from_pretty(
380 " i
381 + .",
382 );
383 test_agg("(count:int8 $0:int4)", input, Some(0i64.into()));
384 }
385
386 #[test]
387 fn count_star() {
388 let input = StreamChunk::from_pretty("i");
390 test_agg("(count:int8)", input, Some(0i64.into()));
391
392 let input = StreamChunk::from_pretty(
394 " i
395 + 0",
396 );
397 test_agg("(count:int8)", input, Some(1i64.into()));
398
399 let input = StreamChunk::from_pretty(
401 " i
402 + 0
403 - 0",
404 );
405 test_agg("(count:int8)", input, Some(0i64.into()));
406
407 let input = StreamChunk::from_pretty(
408 " i
409 - 0
410 - 0 D
411 + 1
412 - 1",
413 );
414 test_agg("(count:int8)", input, Some((-1i64).into()));
415 }
416
417 #[test]
418 fn bitxor_int8() {
419 let input = StreamChunk::from_pretty(
420 " I
421 + 1
422 - 10 D
423 + .
424 - 5",
425 );
426 test_agg("(bit_xor:int8 $0:int8)", input, Some(4i64.into()));
427 }
428
429 fn bench_i64(
430 b: &mut Bencher,
431 agg_desc: &str,
432 chunk_size: usize,
433 vis_rate: f64,
434 append_only: bool,
435 ) {
436 println!(
437 "benching {} agg, chunk_size={}, vis_rate={}",
438 agg_desc, chunk_size, vis_rate
439 );
440 let vis =
441 rand_bitmap::gen_rand_bitmap(chunk_size, (chunk_size as f64 * vis_rate) as usize, 666);
442 let (ops, data) =
443 rand_stream_chunk::gen_legal_stream_chunk(&vis, chunk_size, append_only, 666);
444 let chunk = StreamChunk::from_parts(ops, DataChunk::new(vec![Arc::new(data)], vis));
445 let pretty = format!("({agg_desc}:int8 $0:int8)");
446 let agg = build_append_only(&AggCall::from_pretty(pretty)).unwrap();
447 let mut state = agg.create_state().unwrap();
448 b.iter(|| {
449 agg.update(&mut state, &chunk)
450 .now_or_never()
451 .unwrap()
452 .unwrap();
453 });
454 }
455
456 #[bench]
457 fn sum_agg_without_vis(b: &mut Bencher) {
458 bench_i64(b, "sum", 1024, 1.0, false);
459 }
460
461 #[bench]
462 fn sum_agg_vis_rate_0_75(b: &mut Bencher) {
463 bench_i64(b, "sum", 1024, 0.75, false);
464 }
465
466 #[bench]
467 fn sum_agg_vis_rate_0_5(b: &mut Bencher) {
468 bench_i64(b, "sum", 1024, 0.5, false);
469 }
470
471 #[bench]
472 fn sum_agg_vis_rate_0_25(b: &mut Bencher) {
473 bench_i64(b, "sum", 1024, 0.25, false);
474 }
475
476 #[bench]
477 fn sum_agg_vis_rate_0_05(b: &mut Bencher) {
478 bench_i64(b, "sum", 1024, 0.05, false);
479 }
480
481 #[bench]
482 fn count_agg_without_vis(b: &mut Bencher) {
483 bench_i64(b, "count", 1024, 1.0, false);
484 }
485
486 #[bench]
487 fn count_agg_vis_rate_0_75(b: &mut Bencher) {
488 bench_i64(b, "count", 1024, 0.75, false);
489 }
490
491 #[bench]
492 fn count_agg_vis_rate_0_5(b: &mut Bencher) {
493 bench_i64(b, "count", 1024, 0.5, false);
494 }
495
496 #[bench]
497 fn count_agg_vis_rate_0_25(b: &mut Bencher) {
498 bench_i64(b, "count", 1024, 0.25, false);
499 }
500
501 #[bench]
502 fn count_agg_vis_rate_0_05(b: &mut Bencher) {
503 bench_i64(b, "count", 1024, 0.05, false);
504 }
505
506 #[bench]
507 fn min_agg_without_vis(b: &mut Bencher) {
508 bench_i64(b, "min", 1024, 1.0, true);
509 }
510
511 #[bench]
512 fn min_agg_vis_rate_0_75(b: &mut Bencher) {
513 bench_i64(b, "min", 1024, 0.75, true);
514 }
515
516 #[bench]
517 fn min_agg_vis_rate_0_5(b: &mut Bencher) {
518 bench_i64(b, "min", 1024, 0.5, true);
519 }
520
521 #[bench]
522 fn min_agg_vis_rate_0_25(b: &mut Bencher) {
523 bench_i64(b, "min", 1024, 0.25, true);
524 }
525
526 #[bench]
527 fn min_agg_vis_rate_0_05(b: &mut Bencher) {
528 bench_i64(b, "min", 1024, 0.05, true);
529 }
530
531 #[bench]
532 fn max_agg_without_vis(b: &mut Bencher) {
533 bench_i64(b, "max", 1024, 1.0, true);
534 }
535
536 #[bench]
537 fn max_agg_vis_rate_0_75(b: &mut Bencher) {
538 bench_i64(b, "max", 1024, 0.75, true);
539 }
540
541 #[bench]
542 fn max_agg_vis_rate_0_5(b: &mut Bencher) {
543 bench_i64(b, "max", 1024, 0.5, true);
544 }
545
546 #[bench]
547 fn max_agg_vis_rate_0_25(b: &mut Bencher) {
548 bench_i64(b, "max", 1024, 0.25, true);
549 }
550
551 #[bench]
552 fn max_agg_vis_rate_0_05(b: &mut Bencher) {
553 bench_i64(b, "max", 1024, 0.05, true);
554 }
555}