risingwave_sqlsmith/sql_gen/
types.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module contains datatypes and functions which can be generated by sqlsmith.
16
17use std::collections::{HashMap, HashSet};
18use std::sync::LazyLock;
19
20use itertools::Itertools;
21use risingwave_common::types::{DataType, DataTypeName};
22use risingwave_expr::aggregate::PbAggKind;
23use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncSign};
24use risingwave_frontend::expr::{CastContext, CastSig as RwCastSig, ExprType, cast_sigs};
25use risingwave_sqlparser::ast::{BinaryOperator, DataType as AstDataType};
26
27pub(super) fn data_type_to_ast_data_type(data_type: &DataType) -> AstDataType {
28    use risingwave_frontend::DataTypeToAst as _;
29    data_type.to_ast()
30}
31
32fn data_type_name_to_ast_data_type(data_type_name: &DataTypeName) -> Option<DataType> {
33    use DataTypeName as T;
34    match data_type_name {
35        T::Boolean => Some(DataType::Boolean),
36        T::Int16 => Some(DataType::Int16),
37        T::Int32 => Some(DataType::Int32),
38        T::Int64 => Some(DataType::Int64),
39        T::Decimal => Some(DataType::Decimal),
40        T::Float32 => Some(DataType::Float32),
41        T::Float64 => Some(DataType::Float64),
42        T::Varchar => Some(DataType::Varchar),
43        T::Date => Some(DataType::Date),
44        T::Timestamp => Some(DataType::Timestamp),
45        T::Timestamptz => Some(DataType::Timestamptz),
46        T::Time => Some(DataType::Time),
47        T::Interval => Some(DataType::Interval),
48        _ => None,
49    }
50}
51
52/// Provide internal `CastSig` which can be used for `struct` and `list`.
53#[derive(Clone)]
54pub struct CastSig {
55    pub from_type: DataType,
56    pub to_type: DataType,
57    pub context: CastContext,
58}
59
60impl TryFrom<RwCastSig> for CastSig {
61    type Error = String;
62
63    fn try_from(value: RwCastSig) -> Result<Self, Self::Error> {
64        if let Some(from_type) = data_type_name_to_ast_data_type(&value.from_type)
65            && let Some(to_type) = data_type_name_to_ast_data_type(&value.to_type)
66        {
67            Ok(CastSig {
68                from_type,
69                to_type,
70                context: value.context,
71            })
72        } else {
73            Err(format!("unsupported cast sig: {:?}", value))
74        }
75    }
76}
77
78/// Function ban list.
79/// These functions should be generated eventually, by adding expression constraints.
80/// If we naively generate arguments for these functions, it will affect sqlsmith
81/// effectiveness, e.g. cause it to crash.
82static FUNC_BAN_LIST: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
83    [
84        // FIXME: https://github.com/risingwavelabs/risingwave/issues/8003
85        ExprType::Repeat,
86        // The format argument needs to be handled specially. It is still generated in `gen_special_func`.
87        ExprType::Decode,
88        // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
89        ExprType::Sqrt,
90        // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
91        ExprType::Pow,
92        // ENABLE: https://github.com/risingwavelabs/risingwave/issues/7328
93        ExprType::Position,
94        // ENABLE: https://github.com/risingwavelabs/risingwave/issues/7328
95        #[expect(deprecated)]
96        ExprType::Strpos,
97    ]
98    .into_iter()
99    .collect()
100});
101
102/// Table which maps functions' return types to possible function signatures.
103// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
104// TODO: Create a `SPECIAL_FUNC` table.
105// Otherwise when we dump the function table, we won't include those functions in
106// gen_special_func.
107pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
108    LazyLock::new(|| {
109        let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
110        FUNCTION_REGISTRY
111            .iter_scalars()
112            .filter(|func| {
113                func.inputs_type.iter().all(|t| {
114                    t.is_exact()
115                        && t.as_exact() != &DataType::Timestamptz
116                        && t.as_exact() != &DataType::Serial
117                }) && func.ret_type.is_exact()
118                    && !FUNC_BAN_LIST.contains(&func.name.as_scalar())
119                    && !func.deprecated // deprecated functions are not accepted by frontend
120            })
121            .for_each(|func| {
122                funcs
123                    .entry(func.ret_type.as_exact().clone())
124                    .or_default()
125                    .push(func)
126            });
127        funcs
128    });
129
130/// Set of invariant functions
131// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
132pub(crate) static INVARIANT_FUNC_SET: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
133    FUNCTION_REGISTRY
134        .iter_scalars()
135        .map(|sig| sig.name.as_scalar())
136        .counts()
137        .into_iter()
138        .filter(|(_key, count)| *count == 1)
139        .map(|(key, _)| key)
140        .collect()
141});
142
143/// Table which maps aggregate functions' return types to possible function signatures.
144// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
145pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
146    LazyLock::new(|| {
147        let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
148        FUNCTION_REGISTRY
149            .iter_aggregates()
150            .filter(|func| {
151                func.inputs_type
152                    .iter()
153                    .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz && t.as_exact() != &DataType::Serial)
154                    && func.ret_type.is_exact()
155                    // Ignored functions
156                    && ![
157                        PbAggKind::InternalLastSeenValue, // Use internally
158                        PbAggKind::Sum0, // Used internally
159                        PbAggKind::ApproxCountDistinct,
160                        PbAggKind::BitAnd,
161                        PbAggKind::BitOr,
162                        PbAggKind::BoolAnd,
163                        PbAggKind::BoolOr,
164                        PbAggKind::PercentileCont,
165                        PbAggKind::PercentileDisc,
166                        PbAggKind::Mode,
167                        PbAggKind::ApproxPercentile, // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
168                        PbAggKind::JsonbObjectAgg, // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
169                        PbAggKind::StddevSamp, // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
170                        PbAggKind::VarSamp, // ENABLE: https://github.com/risingwavelabs/risingwave/issues/16293
171                    ]
172                    .contains(&func.name.as_aggregate())
173                    // Exclude 2 phase agg global sum.
174                    // Sum(Int64) -> Int64.
175                    // Otherwise it conflicts with normal aggregation:
176                    // Sum(Int64) -> Decimal.
177                    // And sqlsmith will generate expressions with wrong types.
178                    && if func.name.as_aggregate() == PbAggKind::Sum {
179                       !(func.inputs_type[0].as_exact() == &DataType::Int64 && func.ret_type.as_exact() == &DataType::Int64)
180                    } else {
181                       true
182                    }
183            })
184            .for_each(|func| {
185                funcs.entry(func.ret_type.as_exact().clone()).or_default().push(func)
186            });
187        funcs
188    });
189
190/// Build a cast map from return types to viable cast-signatures.
191/// NOTE: We avoid cast from varchar to other datatypes apart from itself.
192/// This is because arbitrary strings may not be able to cast,
193/// creating large number of invalid queries.
194pub(crate) static EXPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
195    LazyLock::new(|| {
196        let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
197        cast_sigs()
198            .filter_map(|cast| cast.try_into().ok())
199            .filter(|cast: &CastSig| cast.context == CastContext::Explicit)
200            .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
201            .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
202        casts
203    });
204
205/// Build a cast map from return types to viable cast-signatures.
206/// NOTE: We avoid cast from varchar to other datatypes apart from itself.
207/// This is because arbitrary strings may not be able to cast,
208/// creating large number of invalid queries.
209pub(crate) static IMPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
210    LazyLock::new(|| {
211        let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
212        cast_sigs()
213            .filter_map(|cast| cast.try_into().ok())
214            .filter(|cast: &CastSig| cast.context == CastContext::Implicit)
215            .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
216            .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
217        casts
218    });
219
220fn expr_type_to_inequality_op(typ: ExprType) -> Option<BinaryOperator> {
221    match typ {
222        ExprType::GreaterThan => Some(BinaryOperator::Gt),
223        ExprType::GreaterThanOrEqual => Some(BinaryOperator::GtEq),
224        ExprType::LessThan => Some(BinaryOperator::Lt),
225        ExprType::LessThanOrEqual => Some(BinaryOperator::LtEq),
226        ExprType::NotEqual => Some(BinaryOperator::NotEq),
227        _ => None,
228    }
229}
230
231/// Build set of binary inequality functions like `>`, `<`, etc...
232/// Maps from LHS, RHS argument to Inequality Operation
233/// For instance:
234/// GreaterThanOrEqual(Int16, Int64) -> Boolean
235/// Will store an entry of:
236/// Key: Int16, Int64
237/// Value: `BinaryOp::GreaterThanOrEqual`
238/// in the table.
239pub(crate) static BINARY_INEQUALITY_OP_TABLE: LazyLock<
240    HashMap<(DataType, DataType), Vec<BinaryOperator>>,
241> = LazyLock::new(|| {
242    let mut funcs = HashMap::<(DataType, DataType), Vec<BinaryOperator>>::new();
243    FUNCTION_REGISTRY
244        .iter_scalars()
245        .filter(|func| {
246            !FUNC_BAN_LIST.contains(&func.name.as_scalar())
247                && func.ret_type == DataType::Boolean.into()
248                && func.inputs_type.len() == 2
249                && func
250                    .inputs_type
251                    .iter()
252                    .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz)
253        })
254        .filter_map(|func| {
255            let lhs = func.inputs_type[0].as_exact().clone();
256            let rhs = func.inputs_type[1].as_exact().clone();
257            let op = expr_type_to_inequality_op(func.name.as_scalar())?;
258            Some(((lhs, rhs), op))
259        })
260        .for_each(|(args, op)| funcs.entry(args).or_default().push(op));
261    funcs
262});