risingwave_sqlsmith/sql_gen/
types.rs1use std::collections::{HashMap, HashSet};
18use std::sync::LazyLock;
19
20use itertools::Itertools;
21use risingwave_common::types::{DataType, DataTypeName};
22use risingwave_expr::aggregate::PbAggKind;
23use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncSign};
24use risingwave_frontend::expr::{CastContext, CastSig as RwCastSig, ExprType, cast_sigs};
25use risingwave_sqlparser::ast::{BinaryOperator, DataType as AstDataType};
26
27pub(super) fn data_type_to_ast_data_type(data_type: &DataType) -> AstDataType {
28 use risingwave_frontend::DataTypeToAst as _;
29 data_type.to_ast()
30}
31
32fn data_type_name_to_ast_data_type(data_type_name: &DataTypeName) -> Option<DataType> {
33 use DataTypeName as T;
34 match data_type_name {
35 T::Boolean => Some(DataType::Boolean),
36 T::Int16 => Some(DataType::Int16),
37 T::Int32 => Some(DataType::Int32),
38 T::Int64 => Some(DataType::Int64),
39 T::Decimal => Some(DataType::Decimal),
40 T::Float32 => Some(DataType::Float32),
41 T::Float64 => Some(DataType::Float64),
42 T::Varchar => Some(DataType::Varchar),
43 T::Date => Some(DataType::Date),
44 T::Timestamp => Some(DataType::Timestamp),
45 T::Timestamptz => Some(DataType::Timestamptz),
46 T::Time => Some(DataType::Time),
47 T::Interval => Some(DataType::Interval),
48 _ => None,
49 }
50}
51
52#[derive(Clone)]
54pub struct CastSig {
55 pub from_type: DataType,
56 pub to_type: DataType,
57 pub context: CastContext,
58}
59
60impl TryFrom<RwCastSig> for CastSig {
61 type Error = String;
62
63 fn try_from(value: RwCastSig) -> Result<Self, Self::Error> {
64 if let Some(from_type) = data_type_name_to_ast_data_type(&value.from_type)
65 && let Some(to_type) = data_type_name_to_ast_data_type(&value.to_type)
66 {
67 Ok(CastSig {
68 from_type,
69 to_type,
70 context: value.context,
71 })
72 } else {
73 Err(format!("unsupported cast sig: {:?}", value))
74 }
75 }
76}
77
78static FUNC_BAN_LIST: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
83 [
84 ExprType::Repeat,
86 ExprType::Decode,
88 ExprType::Sqrt,
90 ExprType::Pow,
92 ExprType::Position,
94 #[expect(deprecated)]
96 ExprType::Strpos,
97 ]
98 .into_iter()
99 .collect()
100});
101
102pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
108 LazyLock::new(|| {
109 let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
110 FUNCTION_REGISTRY
111 .iter_scalars()
112 .filter(|func| {
113 func.inputs_type.iter().all(|t| {
114 t.is_exact()
115 && t.as_exact() != &DataType::Timestamptz
116 && t.as_exact() != &DataType::Serial
117 }) && func.ret_type.is_exact()
118 && !FUNC_BAN_LIST.contains(&func.name.as_scalar())
119 && !func.deprecated })
121 .for_each(|func| {
122 funcs
123 .entry(func.ret_type.as_exact().clone())
124 .or_default()
125 .push(func)
126 });
127 funcs
128 });
129
130pub(crate) static INVARIANT_FUNC_SET: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
133 FUNCTION_REGISTRY
134 .iter_scalars()
135 .map(|sig| sig.name.as_scalar())
136 .counts()
137 .into_iter()
138 .filter(|(_key, count)| *count == 1)
139 .map(|(key, _)| key)
140 .collect()
141});
142
143pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
146 LazyLock::new(|| {
147 let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
148 FUNCTION_REGISTRY
149 .iter_aggregates()
150 .filter(|func| {
151 func.inputs_type
152 .iter()
153 .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz && t.as_exact() != &DataType::Serial)
154 && func.ret_type.is_exact()
155 && ![
157 PbAggKind::InternalLastSeenValue, PbAggKind::Sum0, PbAggKind::ApproxCountDistinct,
160 PbAggKind::BitAnd,
161 PbAggKind::BitOr,
162 PbAggKind::BoolAnd,
163 PbAggKind::BoolOr,
164 PbAggKind::PercentileCont,
165 PbAggKind::PercentileDisc,
166 PbAggKind::Mode,
167 PbAggKind::ApproxPercentile, PbAggKind::JsonbObjectAgg, PbAggKind::StddevSamp, PbAggKind::VarSamp, ]
172 .contains(&func.name.as_aggregate())
173 && if func.name.as_aggregate() == PbAggKind::Sum {
179 !(func.inputs_type[0].as_exact() == &DataType::Int64 && func.ret_type.as_exact() == &DataType::Int64)
180 } else {
181 true
182 }
183 })
184 .for_each(|func| {
185 funcs.entry(func.ret_type.as_exact().clone()).or_default().push(func)
186 });
187 funcs
188 });
189
190pub(crate) static EXPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
195 LazyLock::new(|| {
196 let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
197 cast_sigs()
198 .filter_map(|cast| cast.try_into().ok())
199 .filter(|cast: &CastSig| cast.context == CastContext::Explicit)
200 .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
201 .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
202 casts
203 });
204
205pub(crate) static IMPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
210 LazyLock::new(|| {
211 let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
212 cast_sigs()
213 .filter_map(|cast| cast.try_into().ok())
214 .filter(|cast: &CastSig| cast.context == CastContext::Implicit)
215 .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
216 .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
217 casts
218 });
219
220fn expr_type_to_inequality_op(typ: ExprType) -> Option<BinaryOperator> {
221 match typ {
222 ExprType::GreaterThan => Some(BinaryOperator::Gt),
223 ExprType::GreaterThanOrEqual => Some(BinaryOperator::GtEq),
224 ExprType::LessThan => Some(BinaryOperator::Lt),
225 ExprType::LessThanOrEqual => Some(BinaryOperator::LtEq),
226 ExprType::NotEqual => Some(BinaryOperator::NotEq),
227 _ => None,
228 }
229}
230
231pub(crate) static BINARY_INEQUALITY_OP_TABLE: LazyLock<
240 HashMap<(DataType, DataType), Vec<BinaryOperator>>,
241> = LazyLock::new(|| {
242 let mut funcs = HashMap::<(DataType, DataType), Vec<BinaryOperator>>::new();
243 FUNCTION_REGISTRY
244 .iter_scalars()
245 .filter(|func| {
246 !FUNC_BAN_LIST.contains(&func.name.as_scalar())
247 && func.ret_type == DataType::Boolean.into()
248 && func.inputs_type.len() == 2
249 && func
250 .inputs_type
251 .iter()
252 .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz)
253 })
254 .filter_map(|func| {
255 let lhs = func.inputs_type[0].as_exact().clone();
256 let rhs = func.inputs_type[1].as_exact().clone();
257 let op = expr_type_to_inequality_op(func.name.as_scalar())?;
258 Some(((lhs, rhs), op))
259 })
260 .for_each(|(args, op)| funcs.entry(args).or_default().push(op));
261 funcs
262});