risingwave_sqlsmith/sql_gen/
types.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module contains datatypes and functions which can be generated by sqlsmith.
16
17use std::collections::{HashMap, HashSet};
18use std::sync::LazyLock;
19
20use itertools::Itertools;
21use risingwave_common::types::{DataType, DataTypeName};
22use risingwave_expr::aggregate::PbAggKind;
23use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncSign};
24use risingwave_frontend::expr::{CastContext, CastSig as RwCastSig, ExprType, cast_sigs};
25use risingwave_sqlparser::ast::{BinaryOperator, DataType as AstDataType, StructField};
26
27pub(super) fn data_type_to_ast_data_type(data_type: &DataType) -> AstDataType {
28    match data_type {
29        DataType::Boolean => AstDataType::Boolean,
30        DataType::Int16 => AstDataType::SmallInt,
31        DataType::Int32 => AstDataType::Int,
32        DataType::Int64 => AstDataType::BigInt,
33        DataType::Int256 => AstDataType::Custom(vec!["rw_int256".into()].into()),
34        DataType::Serial => unreachable!("serial should not be generated"),
35        DataType::Decimal => AstDataType::Decimal(None, None),
36        DataType::Float32 => AstDataType::Real,
37        DataType::Float64 => AstDataType::Double,
38        DataType::Varchar => AstDataType::Varchar,
39        DataType::Bytea => AstDataType::Bytea,
40        DataType::Date => AstDataType::Date,
41        DataType::Timestamp => AstDataType::Timestamp(false),
42        DataType::Timestamptz => AstDataType::Timestamp(true),
43        DataType::Time => AstDataType::Time(false),
44        DataType::Interval => AstDataType::Interval,
45        DataType::Jsonb => AstDataType::Custom(vec!["JSONB".into()].into()),
46        DataType::Struct(inner) => AstDataType::Struct(
47            inner
48                .iter()
49                .map(|(name, typ)| StructField {
50                    name: name.into(),
51                    data_type: data_type_to_ast_data_type(typ),
52                })
53                .collect(),
54        ),
55        DataType::List(typ) => AstDataType::Array(Box::new(data_type_to_ast_data_type(typ))),
56        DataType::Vector(n) => AstDataType::Vector(*n as _),
57        DataType::Map(_) => todo!(),
58    }
59}
60
61fn data_type_name_to_ast_data_type(data_type_name: &DataTypeName) -> Option<DataType> {
62    use DataTypeName as T;
63    match data_type_name {
64        T::Boolean => Some(DataType::Boolean),
65        T::Int16 => Some(DataType::Int16),
66        T::Int32 => Some(DataType::Int32),
67        T::Int64 => Some(DataType::Int64),
68        T::Decimal => Some(DataType::Decimal),
69        T::Float32 => Some(DataType::Float32),
70        T::Float64 => Some(DataType::Float64),
71        T::Varchar => Some(DataType::Varchar),
72        T::Date => Some(DataType::Date),
73        T::Timestamp => Some(DataType::Timestamp),
74        T::Timestamptz => Some(DataType::Timestamptz),
75        T::Time => Some(DataType::Time),
76        T::Interval => Some(DataType::Interval),
77        _ => None,
78    }
79}
80
81/// Provide internal `CastSig` which can be used for `struct` and `list`.
82#[derive(Clone)]
83pub struct CastSig {
84    pub from_type: DataType,
85    pub to_type: DataType,
86    pub context: CastContext,
87}
88
89impl TryFrom<RwCastSig> for CastSig {
90    type Error = String;
91
92    fn try_from(value: RwCastSig) -> Result<Self, Self::Error> {
93        if let Some(from_type) = data_type_name_to_ast_data_type(&value.from_type)
94            && let Some(to_type) = data_type_name_to_ast_data_type(&value.to_type)
95        {
96            Ok(CastSig {
97                from_type,
98                to_type,
99                context: value.context,
100            })
101        } else {
102            Err(format!("unsupported cast sig: {:?}", value))
103        }
104    }
105}
106
107/// Function ban list.
108/// These functions should be generated eventually, by adding expression constraints.
109/// If we naively generate arguments for these functions, it will affect sqlsmith
110/// effectiveness, e.g. cause it to crash.
111static FUNC_BAN_LIST: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
112    [
113        // FIXME: https://github.com/risingwavelabs/risingwave/issues/8003
114        ExprType::Repeat,
115        // The format argument needs to be handled specially. It is still generated in `gen_special_func`.
116        ExprType::Decode,
117        // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
118        ExprType::Sqrt,
119        // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
120        ExprType::Pow,
121        // ENABLE: https://github.com/risingwavelabs/risingwave/issues/7328
122        ExprType::Position,
123        // ENABLE: https://github.com/risingwavelabs/risingwave/issues/7328
124        ExprType::Strpos,
125    ]
126    .into_iter()
127    .collect()
128});
129
130/// Table which maps functions' return types to possible function signatures.
131// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
132// TODO: Create a `SPECIAL_FUNC` table.
133// Otherwise when we dump the function table, we won't include those functions in
134// gen_special_func.
135pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
136    LazyLock::new(|| {
137        let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
138        FUNCTION_REGISTRY
139            .iter_scalars()
140            .filter(|func| {
141                func.inputs_type.iter().all(|t| {
142                    t.is_exact()
143                        && t.as_exact() != &DataType::Timestamptz
144                        && t.as_exact() != &DataType::Serial
145                }) && func.ret_type.is_exact()
146                    && !FUNC_BAN_LIST.contains(&func.name.as_scalar())
147                    && !func.deprecated // deprecated functions are not accepted by frontend
148            })
149            .for_each(|func| {
150                funcs
151                    .entry(func.ret_type.as_exact().clone())
152                    .or_default()
153                    .push(func)
154            });
155        funcs
156    });
157
158/// Set of invariant functions
159// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
160pub(crate) static INVARIANT_FUNC_SET: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
161    FUNCTION_REGISTRY
162        .iter_scalars()
163        .map(|sig| sig.name.as_scalar())
164        .counts()
165        .into_iter()
166        .filter(|(_key, count)| *count == 1)
167        .map(|(key, _)| key)
168        .collect()
169});
170
171/// Table which maps aggregate functions' return types to possible function signatures.
172// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
173pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
174    LazyLock::new(|| {
175        let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
176        FUNCTION_REGISTRY
177            .iter_aggregates()
178            .filter(|func| {
179                func.inputs_type
180                    .iter()
181                    .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz && t.as_exact() != &DataType::Serial)
182                    && func.ret_type.is_exact()
183                    // Ignored functions
184                    && ![
185                        PbAggKind::InternalLastSeenValue, // Use internally
186                        PbAggKind::Sum0, // Used internally
187                        PbAggKind::ApproxCountDistinct,
188                        PbAggKind::BitAnd,
189                        PbAggKind::BitOr,
190                        PbAggKind::BoolAnd,
191                        PbAggKind::BoolOr,
192                        PbAggKind::PercentileCont,
193                        PbAggKind::PercentileDisc,
194                        PbAggKind::Mode,
195                        PbAggKind::ApproxPercentile, // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
196                        PbAggKind::JsonbObjectAgg, // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
197                        PbAggKind::StddevSamp, // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
198                        PbAggKind::VarSamp, // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
199                    ]
200                    .contains(&func.name.as_aggregate())
201                    // Exclude 2 phase agg global sum.
202                    // Sum(Int64) -> Int64.
203                    // Otherwise it conflicts with normal aggregation:
204                    // Sum(Int64) -> Decimal.
205                    // And sqlsmith will generate expressions with wrong types.
206                    && if func.name.as_aggregate() == PbAggKind::Sum {
207                       !(func.inputs_type[0].as_exact() == &DataType::Int64 && func.ret_type.as_exact() == &DataType::Int64)
208                    } else {
209                       true
210                    }
211            })
212            .for_each(|func| {
213                funcs.entry(func.ret_type.as_exact().clone()).or_default().push(func)
214            });
215        funcs
216    });
217
218/// Build a cast map from return types to viable cast-signatures.
219/// NOTE: We avoid cast from varchar to other datatypes apart from itself.
220/// This is because arbitrary strings may not be able to cast,
221/// creating large number of invalid queries.
222pub(crate) static EXPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
223    LazyLock::new(|| {
224        let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
225        cast_sigs()
226            .filter_map(|cast| cast.try_into().ok())
227            .filter(|cast: &CastSig| cast.context == CastContext::Explicit)
228            .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
229            .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
230        casts
231    });
232
233/// Build a cast map from return types to viable cast-signatures.
234/// NOTE: We avoid cast from varchar to other datatypes apart from itself.
235/// This is because arbitrary strings may not be able to cast,
236/// creating large number of invalid queries.
237pub(crate) static IMPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
238    LazyLock::new(|| {
239        let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
240        cast_sigs()
241            .filter_map(|cast| cast.try_into().ok())
242            .filter(|cast: &CastSig| cast.context == CastContext::Implicit)
243            .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
244            .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
245        casts
246    });
247
248fn expr_type_to_inequality_op(typ: ExprType) -> Option<BinaryOperator> {
249    match typ {
250        ExprType::GreaterThan => Some(BinaryOperator::Gt),
251        ExprType::GreaterThanOrEqual => Some(BinaryOperator::GtEq),
252        ExprType::LessThan => Some(BinaryOperator::Lt),
253        ExprType::LessThanOrEqual => Some(BinaryOperator::LtEq),
254        ExprType::NotEqual => Some(BinaryOperator::NotEq),
255        _ => None,
256    }
257}
258
259/// Build set of binary inequality functions like `>`, `<`, etc...
260/// Maps from LHS, RHS argument to Inequality Operation
261/// For instance:
262/// GreaterThanOrEqual(Int16, Int64) -> Boolean
263/// Will store an entry of:
264/// Key: Int16, Int64
265/// Value: `BinaryOp::GreaterThanOrEqual`
266/// in the table.
267pub(crate) static BINARY_INEQUALITY_OP_TABLE: LazyLock<
268    HashMap<(DataType, DataType), Vec<BinaryOperator>>,
269> = LazyLock::new(|| {
270    let mut funcs = HashMap::<(DataType, DataType), Vec<BinaryOperator>>::new();
271    FUNCTION_REGISTRY
272        .iter_scalars()
273        .filter(|func| {
274            !FUNC_BAN_LIST.contains(&func.name.as_scalar())
275                && func.ret_type == DataType::Boolean.into()
276                && func.inputs_type.len() == 2
277                && func
278                    .inputs_type
279                    .iter()
280                    .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz)
281        })
282        .filter_map(|func| {
283            let lhs = func.inputs_type[0].as_exact().clone();
284            let rhs = func.inputs_type[1].as_exact().clone();
285            let op = expr_type_to_inequality_op(func.name.as_scalar())?;
286            Some(((lhs, rhs), op))
287        })
288        .for_each(|(args, op)| funcs.entry(args).or_default().push(op));
289    funcs
290});