risingwave_sqlsmith/sql_gen/
functions.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use rand::Rng;
17use rand::seq::IndexedRandom;
18use risingwave_common::types::DataType;
19use risingwave_frontend::expr::ExprType;
20use risingwave_sqlparser::ast::{
21    BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgList, Ident,
22    ObjectName, TrimWhereField, UnaryOperator, Value,
23};
24
25use crate::sql_gen::types::{FUNC_TABLE, IMPLICIT_CAST_TABLE, INVARIANT_FUNC_SET};
26use crate::sql_gen::{SqlGenerator, SqlGeneratorContext};
27
28impl<R: Rng> SqlGenerator<'_, R> {
29    pub fn gen_func(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
30        match self.rng.random_bool(0.1) {
31            true => self.gen_special_func(ret, context),
32            false => self.gen_fixed_func(ret, context),
33        }
34    }
35
36    /// Generates functions with special properties, e.g.
37    /// `CASE`, `COALESCE`, `CONCAT`, `CONCAT_WS`, `OVERLAY`.
38    /// These require custom logic for arguments.
39    /// For instance, `OVERLAY` requires a positive length argument,
40    /// and `CONCAT` and `CONCAT_WS` require variable number of arguments.
41    fn gen_special_func(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
42        use DataType as T;
43        match ret {
44            T::Varchar => match self.rng.random_range(0..=4) {
45                0 => self.gen_case(ret, context),
46                1 => self.gen_coalesce(ret, context),
47                2 => self.gen_concat(context),
48                3 => self.gen_concat_ws(context),
49                4 => self.gen_overlay(context),
50                _ => unreachable!(),
51            },
52            T::Bytea => self.gen_decode(context),
53            _ => match self.rng.random_bool(0.5) {
54                true => self.gen_case(ret, context),
55                false => self.gen_coalesce(ret, context),
56            },
57            // TODO: gen_regexpr
58            // TODO: gen functions which return list, struct
59        }
60    }
61
62    /// We do custom generation for the `OVERLAY` function call.
63    /// See: [`https://github.com/risingwavelabs/risingwave/issues/10695`] for rationale.
64    fn gen_overlay(&mut self, context: SqlGeneratorContext) -> Expr {
65        let expr = Box::new(self.gen_expr(&DataType::Varchar, context));
66        let new_substring = Box::new(self.gen_expr(&DataType::Varchar, context));
67        let start = Box::new(self.gen_range_scalar(&DataType::Int32, 1, 10).unwrap());
68        let count = if self.flip_coin() {
69            None
70        } else {
71            Some(Box::new(
72                self.gen_range_scalar(&DataType::Int32, 1, 10).unwrap(),
73            ))
74        };
75        Expr::Overlay {
76            expr,
77            new_substring,
78            start,
79            count,
80        }
81    }
82
83    fn gen_case(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
84        let n = self.rng.random_range(1..4);
85        Expr::Case {
86            operand: None,
87            conditions: self.gen_n_exprs_with_type(n, &DataType::Boolean, context),
88            results: self.gen_n_exprs_with_type(n, ret, context),
89            else_result: Some(Box::new(self.gen_expr(ret, context))),
90        }
91    }
92
93    fn gen_coalesce(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
94        let non_null = self.gen_expr(ret, context);
95        let position = self.rng.random_range(0..10);
96        let mut args = (0..10).map(|_| Expr::Value(Value::Null)).collect_vec();
97        args[position] = non_null;
98        Expr::Function(make_simple_func("coalesce", &args))
99    }
100
101    fn gen_concat(&mut self, context: SqlGeneratorContext) -> Expr {
102        Expr::Function(make_simple_func("concat", &self.gen_concat_args(context)))
103    }
104
105    fn gen_concat_ws(&mut self, context: SqlGeneratorContext) -> Expr {
106        let sep = self.gen_expr(&DataType::Varchar, context);
107        let mut args = self.gen_concat_args(context);
108        args.insert(0, sep);
109        Expr::Function(make_simple_func("concat_ws", &args))
110    }
111
112    fn gen_concat_args(&mut self, context: SqlGeneratorContext) -> Vec<Expr> {
113        let n = self.rng.random_range(1..4);
114        (0..n)
115            .map(|_| {
116                if self.rng.random_bool(0.1) {
117                    self.gen_explicit_cast(&DataType::Varchar, context)
118                } else {
119                    self.gen_expr(&DataType::Varchar, context)
120                }
121            })
122            .collect()
123    }
124
125    fn gen_decode(&mut self, context: SqlGeneratorContext) -> Expr {
126        let input_string = self.gen_expr(&DataType::Bytea, context);
127        let encoding = &["base64", "hex", "escape"].choose(&mut self.rng).unwrap();
128        let args = vec![
129            input_string,
130            Expr::Value(Value::SingleQuotedString(encoding.to_string())),
131        ];
132        let encoded_string = Expr::Function(make_simple_func("encode", &args));
133        let args = vec![
134            encoded_string,
135            Expr::Value(Value::SingleQuotedString(encoding.to_string())),
136        ];
137        Expr::Function(make_simple_func("decode", &args))
138    }
139
140    fn gen_fixed_func(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
141        let funcs = match FUNC_TABLE.get(ret) {
142            None => return self.gen_simple_scalar(ret),
143            Some(funcs) => funcs,
144        };
145        let func = funcs.choose(&mut self.rng).unwrap();
146        let can_implicit_cast = INVARIANT_FUNC_SET.contains(&func.name.as_scalar());
147        let exprs: Vec<Expr> = func
148            .inputs_type
149            .iter()
150            .map(|t| {
151                if let Some(from_tys) = IMPLICIT_CAST_TABLE.get(t.as_exact())
152                    && can_implicit_cast
153                    && self.flip_coin()
154                {
155                    let from_ty = &from_tys.choose(&mut self.rng).unwrap().from_type;
156                    self.gen_implicit_cast(from_ty, context)
157                } else {
158                    self.gen_expr(t.as_exact(), context)
159                }
160            })
161            .collect();
162        let expr = if exprs.len() == 1 {
163            make_unary_op(func.name.as_scalar(), &exprs[0])
164        } else if exprs.len() == 2 {
165            make_bin_op(func.name.as_scalar(), &exprs)
166        } else {
167            None
168        };
169        expr.or_else(|| make_general_expr(func.name.as_scalar(), exprs))
170            .unwrap_or_else(|| self.gen_simple_scalar(ret))
171    }
172}
173
174fn make_unary_op(func: ExprType, expr: &Expr) -> Option<Expr> {
175    use {ExprType as E, UnaryOperator as U};
176    let unary_op = match func {
177        E::Neg => U::Minus,
178        E::Not => U::Not,
179        E::BitwiseNot => U::PGBitwiseNot,
180        _ => return None,
181    };
182    Some(Expr::UnaryOp {
183        op: unary_op,
184        expr: Box::new(expr.clone()),
185    })
186}
187
188/// General expressions do not fall under unary / binary op, so they are constructed differently.
189fn make_general_expr(func: ExprType, exprs: Vec<Expr>) -> Option<Expr> {
190    use ExprType as E;
191
192    match func {
193        E::Trim | E::Ltrim | E::Rtrim => Some(make_trim(func, exprs)),
194        E::IsNull => Some(Expr::IsNull(Box::new(exprs[0].clone()))),
195        E::IsNotNull => Some(Expr::IsNotNull(Box::new(exprs[0].clone()))),
196        E::IsTrue => Some(Expr::IsTrue(Box::new(exprs[0].clone()))),
197        E::IsNotTrue => Some(Expr::IsNotTrue(Box::new(exprs[0].clone()))),
198        E::IsFalse => Some(Expr::IsFalse(Box::new(exprs[0].clone()))),
199        E::IsNotFalse => Some(Expr::IsNotFalse(Box::new(exprs[0].clone()))),
200        E::Position => Some(Expr::Function(make_simple_func("strpos", &exprs))),
201        E::RoundDigit => Some(Expr::Function(make_simple_func("round", &exprs))),
202        E::Pow => Some(Expr::Function(make_simple_func("pow", &exprs))),
203        E::Repeat => Some(Expr::Function(make_simple_func("repeat", &exprs))),
204        E::CharLength => Some(Expr::Function(make_simple_func("char_length", &exprs))),
205        E::Substr => Some(Expr::Function(make_simple_func("substr", &exprs))),
206        E::Length => Some(Expr::Function(make_simple_func("length", &exprs))),
207        E::Upper => Some(Expr::Function(make_simple_func("upper", &exprs))),
208        E::Lower => Some(Expr::Function(make_simple_func("lower", &exprs))),
209        E::Replace => Some(Expr::Function(make_simple_func("replace", &exprs))),
210        E::Md5 => Some(Expr::Function(make_simple_func("md5", &exprs))),
211        E::ToChar => Some(Expr::Function(make_simple_func("to_char", &exprs))),
212        E::SplitPart => Some(Expr::Function(make_simple_func("split_part", &exprs))),
213        E::Encode => Some(Expr::Function(make_simple_func("encode", &exprs))),
214        E::Decode => Some(Expr::Function(make_simple_func("decode", &exprs))),
215        E::Sha1 => Some(Expr::Function(make_simple_func("sha1", &exprs))),
216        E::Sha224 => Some(Expr::Function(make_simple_func("sha224", &exprs))),
217        E::Sha256 => Some(Expr::Function(make_simple_func("sha256", &exprs))),
218        E::Sha384 => Some(Expr::Function(make_simple_func("sha384", &exprs))),
219        E::Sha512 => Some(Expr::Function(make_simple_func("sha512", &exprs))),
220        // TODO: Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112
221        // E::Translate => Some(Expr::Function(make_simple_func("translate", &exprs))),
222        // NOTE(kwannoel): I disabled `Overlay`, its arguments require special handling.
223        // We generate it in `gen_special_func` instead.
224        // E::Overlay => Some(make_overlay(exprs)),
225        _ => None,
226    }
227}
228
229fn make_trim(func: ExprType, exprs: Vec<Expr>) -> Expr {
230    use ExprType as E;
231
232    let trim_type = match func {
233        E::Trim => TrimWhereField::Both,
234        E::Ltrim => TrimWhereField::Leading,
235        E::Rtrim => TrimWhereField::Trailing,
236        _ => unreachable!(),
237    };
238    let trim_what = if exprs.len() > 1 {
239        Some(Box::new(exprs[1].clone()))
240    } else {
241        None
242    };
243    Expr::Trim {
244        expr: Box::new(exprs[0].clone()),
245        trim_where: Some(trim_type),
246        trim_what,
247    }
248}
249
250/// Generates simple functions such as `length`, `round`, `to_char`. These operate on datums instead
251/// of columns / rows.
252pub fn make_simple_func(func_name: &str, exprs: &[Expr]) -> Function {
253    let args = exprs
254        .iter()
255        .map(|e| FunctionArg::Unnamed(FunctionArgExpr::Expr(e.clone())))
256        .collect();
257
258    Function {
259        scalar_as_agg: false,
260        name: ObjectName(vec![Ident::new_unchecked(func_name)]),
261        arg_list: FunctionArgList::args_only(args),
262        over: None,
263        filter: None,
264        within_group: None,
265    }
266}
267
268fn make_bin_op(func: ExprType, exprs: &[Expr]) -> Option<Expr> {
269    use {BinaryOperator as B, ExprType as E};
270    let bin_op = match func {
271        E::Add => B::Plus,
272        E::Subtract => B::Minus,
273        E::Multiply => B::Multiply,
274        E::Divide => B::Divide,
275        E::Modulus => B::Modulo,
276        E::GreaterThan => B::Gt,
277        E::GreaterThanOrEqual => B::GtEq,
278        E::LessThan => B::Lt,
279        E::LessThanOrEqual => B::LtEq,
280        E::Equal => B::Eq,
281        E::NotEqual => B::NotEq,
282        E::And => B::And,
283        E::Or => B::Or,
284        E::Like => B::PGLikeMatch,
285        E::BitwiseAnd => B::BitwiseAnd,
286        E::BitwiseOr => B::BitwiseOr,
287        E::BitwiseXor => B::PGBitwiseXor,
288        E::BitwiseShiftLeft => B::PGBitwiseShiftLeft,
289        E::BitwiseShiftRight => B::PGBitwiseShiftRight,
290        _ => return None,
291    };
292    Some(Expr::BinaryOp {
293        left: Box::new(exprs[0].clone()),
294        op: bin_op,
295        right: Box::new(exprs[1].clone()),
296    })
297}