risingwave_sqlsmith/sql_gen/
agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use rand::Rng;
16use rand::seq::IndexedRandom;
17use risingwave_common::types::DataType;
18use risingwave_expr::aggregate::PbAggKind;
19use risingwave_expr::sig::SigDataType;
20use risingwave_sqlparser::ast::{
21    Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgList, Ident, ObjectName, OrderByExpr,
22};
23
24use crate::sql_gen::types::AGG_FUNC_TABLE;
25use crate::sql_gen::{SqlGenerator, SqlGeneratorContext};
26
27impl<R: Rng> SqlGenerator<'_, R> {
28    pub fn gen_agg(&mut self, ret: &DataType) -> Expr {
29        let funcs = match AGG_FUNC_TABLE.get(ret) {
30            None => return self.gen_simple_scalar(ret),
31            Some(funcs) => funcs,
32        };
33        let func = funcs.choose(&mut self.rng).unwrap();
34        if matches!(func.name.as_aggregate(), PbAggKind::Min | PbAggKind::Max)
35            && matches!(
36                func.ret_type,
37                SigDataType::Exact(DataType::Boolean | DataType::Jsonb)
38            )
39        {
40            return self.gen_simple_scalar(ret);
41        }
42
43        let context = SqlGeneratorContext::new();
44        let context = context.set_inside_agg();
45        let exprs: Vec<Expr> = func
46            .inputs_type
47            .iter()
48            .map(|t| self.gen_expr(t.as_exact(), context))
49            .collect();
50
51        // DISTINCT now only works with agg kinds except `ApproxCountDistinct`, and with at least
52        // one argument and only the first being non-constant. See `Binder::bind_normal_agg`
53        // for more details.
54        let distinct_allowed = func.name.as_aggregate() != PbAggKind::ApproxCountDistinct
55            && !exprs.is_empty()
56            && exprs.iter().skip(1).all(|e| matches!(e, Expr::Value(_)));
57        let distinct = distinct_allowed && self.flip_coin();
58
59        let filter = if self.flip_coin() {
60            let context = SqlGeneratorContext::new_with_can_agg(false);
61            // ENABLE: https://github.com/risingwavelabs/risingwave/issues/4762
62            // Prevent correlated query with `FILTER`
63            let old_ctxt = self.new_local_context();
64            let expr = Some(Box::new(self.gen_expr(&DataType::Boolean, context)));
65            self.restore_context(old_ctxt);
66            expr
67        } else {
68            None
69        };
70
71        let order_by = if self.flip_coin() {
72            if distinct {
73                // can only generate order by clause with exprs in argument list, see
74                // `Binder::bind_normal_agg`
75                self.gen_order_by_within(&exprs)
76            } else {
77                self.gen_order_by()
78            }
79        } else {
80            vec![]
81        };
82        self.make_agg_expr(func.name.as_aggregate(), &exprs, distinct, filter, order_by)
83            .unwrap_or_else(|| self.gen_simple_scalar(ret))
84    }
85
86    /// Generates aggregate expressions. For internal / unsupported aggregators, we return `None`.
87    fn make_agg_expr(
88        &mut self,
89        func: PbAggKind,
90        exprs: &[Expr],
91        distinct: bool,
92        filter: Option<Box<Expr>>,
93        order_by: Vec<OrderByExpr>,
94    ) -> Option<Expr> {
95        use PbAggKind as A;
96        match func {
97            kind @ (A::FirstValue | A::LastValue) => {
98                if order_by.is_empty() {
99                    // `first/last_value` only works when ORDER BY is provided
100                    None
101                } else {
102                    Some(Expr::Function(make_agg_func(
103                        &kind.as_str_name().to_lowercase(),
104                        exprs,
105                        distinct,
106                        filter,
107                        order_by,
108                    )))
109                }
110            }
111            other => Some(Expr::Function(make_agg_func(
112                &other.as_str_name().to_lowercase(),
113                exprs,
114                distinct,
115                filter,
116                order_by,
117            ))),
118        }
119    }
120}
121
122/// This is the function that generate aggregate function.
123/// DISTINCT, ORDER BY or FILTER is allowed in aggregation functions。
124fn make_agg_func(
125    func_name: &str,
126    exprs: &[Expr],
127    distinct: bool,
128    filter: Option<Box<Expr>>,
129    order_by: Vec<OrderByExpr>,
130) -> Function {
131    let args = if exprs.is_empty() {
132        // The only agg without args is `count`.
133        // `select proname from pg_proc where array_length(proargtypes, 1) = 0 and prokind = 'a';`
134        vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard(None))]
135    } else {
136        exprs
137            .iter()
138            .map(|e| FunctionArg::Unnamed(FunctionArgExpr::Expr(e.clone())))
139            .collect()
140    };
141
142    Function {
143        scalar_as_agg: false,
144        name: ObjectName(vec![Ident::new_unchecked(func_name)]),
145        arg_list: FunctionArgList::for_agg(distinct, args, order_by),
146        over: None,
147        filter,
148        within_group: None,
149    }
150}