risingwave_sqlsmith/sql_gen/
types.rs1use std::collections::{HashMap, HashSet};
18use std::sync::LazyLock;
19
20use itertools::Itertools;
21use risingwave_common::types::{DataType, DataTypeName};
22use risingwave_expr::aggregate::PbAggKind;
23use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncSign};
24use risingwave_frontend::expr::{CastContext, CastSig as RwCastSig, ExprType, cast_sigs};
25use risingwave_sqlparser::ast::{BinaryOperator, DataType as AstDataType, StructField};
26
27pub(super) fn data_type_to_ast_data_type(data_type: &DataType) -> AstDataType {
28 match data_type {
29 DataType::Boolean => AstDataType::Boolean,
30 DataType::Int16 => AstDataType::SmallInt,
31 DataType::Int32 => AstDataType::Int,
32 DataType::Int64 => AstDataType::BigInt,
33 DataType::Int256 => AstDataType::Custom(vec!["rw_int256".into()].into()),
34 DataType::Serial => unreachable!("serial should not be generated"),
35 DataType::Decimal => AstDataType::Decimal(None, None),
36 DataType::Float32 => AstDataType::Real,
37 DataType::Float64 => AstDataType::Double,
38 DataType::Varchar => AstDataType::Varchar,
39 DataType::Bytea => AstDataType::Bytea,
40 DataType::Date => AstDataType::Date,
41 DataType::Timestamp => AstDataType::Timestamp(false),
42 DataType::Timestamptz => AstDataType::Timestamp(true),
43 DataType::Time => AstDataType::Time(false),
44 DataType::Interval => AstDataType::Interval,
45 DataType::Jsonb => AstDataType::Custom(vec!["JSONB".into()].into()),
46 DataType::Struct(inner) => AstDataType::Struct(
47 inner
48 .iter()
49 .map(|(name, typ)| StructField {
50 name: name.into(),
51 data_type: data_type_to_ast_data_type(typ),
52 })
53 .collect(),
54 ),
55 DataType::List(list) => {
56 AstDataType::Array(Box::new(data_type_to_ast_data_type(list.elem())))
57 }
58 DataType::Vector(n) => AstDataType::Vector(*n as _),
59 DataType::Map(_) => todo!(),
60 }
61}
62
63fn data_type_name_to_ast_data_type(data_type_name: &DataTypeName) -> Option<DataType> {
64 use DataTypeName as T;
65 match data_type_name {
66 T::Boolean => Some(DataType::Boolean),
67 T::Int16 => Some(DataType::Int16),
68 T::Int32 => Some(DataType::Int32),
69 T::Int64 => Some(DataType::Int64),
70 T::Decimal => Some(DataType::Decimal),
71 T::Float32 => Some(DataType::Float32),
72 T::Float64 => Some(DataType::Float64),
73 T::Varchar => Some(DataType::Varchar),
74 T::Date => Some(DataType::Date),
75 T::Timestamp => Some(DataType::Timestamp),
76 T::Timestamptz => Some(DataType::Timestamptz),
77 T::Time => Some(DataType::Time),
78 T::Interval => Some(DataType::Interval),
79 _ => None,
80 }
81}
82
83#[derive(Clone)]
85pub struct CastSig {
86 pub from_type: DataType,
87 pub to_type: DataType,
88 pub context: CastContext,
89}
90
91impl TryFrom<RwCastSig> for CastSig {
92 type Error = String;
93
94 fn try_from(value: RwCastSig) -> Result<Self, Self::Error> {
95 if let Some(from_type) = data_type_name_to_ast_data_type(&value.from_type)
96 && let Some(to_type) = data_type_name_to_ast_data_type(&value.to_type)
97 {
98 Ok(CastSig {
99 from_type,
100 to_type,
101 context: value.context,
102 })
103 } else {
104 Err(format!("unsupported cast sig: {:?}", value))
105 }
106 }
107}
108
109static FUNC_BAN_LIST: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
114 [
115 ExprType::Repeat,
117 ExprType::Decode,
119 ExprType::Sqrt,
121 ExprType::Pow,
123 ExprType::Position,
125 ExprType::Strpos,
127 ]
128 .into_iter()
129 .collect()
130});
131
132pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
138 LazyLock::new(|| {
139 let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
140 FUNCTION_REGISTRY
141 .iter_scalars()
142 .filter(|func| {
143 func.inputs_type.iter().all(|t| {
144 t.is_exact()
145 && t.as_exact() != &DataType::Timestamptz
146 && t.as_exact() != &DataType::Serial
147 }) && func.ret_type.is_exact()
148 && !FUNC_BAN_LIST.contains(&func.name.as_scalar())
149 && !func.deprecated })
151 .for_each(|func| {
152 funcs
153 .entry(func.ret_type.as_exact().clone())
154 .or_default()
155 .push(func)
156 });
157 funcs
158 });
159
160pub(crate) static INVARIANT_FUNC_SET: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
163 FUNCTION_REGISTRY
164 .iter_scalars()
165 .map(|sig| sig.name.as_scalar())
166 .counts()
167 .into_iter()
168 .filter(|(_key, count)| *count == 1)
169 .map(|(key, _)| key)
170 .collect()
171});
172
173pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
176 LazyLock::new(|| {
177 let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
178 FUNCTION_REGISTRY
179 .iter_aggregates()
180 .filter(|func| {
181 func.inputs_type
182 .iter()
183 .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz && t.as_exact() != &DataType::Serial)
184 && func.ret_type.is_exact()
185 && ![
187 PbAggKind::InternalLastSeenValue, PbAggKind::Sum0, PbAggKind::ApproxCountDistinct,
190 PbAggKind::BitAnd,
191 PbAggKind::BitOr,
192 PbAggKind::BoolAnd,
193 PbAggKind::BoolOr,
194 PbAggKind::PercentileCont,
195 PbAggKind::PercentileDisc,
196 PbAggKind::Mode,
197 PbAggKind::ApproxPercentile, PbAggKind::JsonbObjectAgg, PbAggKind::StddevSamp, PbAggKind::VarSamp, ]
202 .contains(&func.name.as_aggregate())
203 && if func.name.as_aggregate() == PbAggKind::Sum {
209 !(func.inputs_type[0].as_exact() == &DataType::Int64 && func.ret_type.as_exact() == &DataType::Int64)
210 } else {
211 true
212 }
213 })
214 .for_each(|func| {
215 funcs.entry(func.ret_type.as_exact().clone()).or_default().push(func)
216 });
217 funcs
218 });
219
220pub(crate) static EXPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
225 LazyLock::new(|| {
226 let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
227 cast_sigs()
228 .filter_map(|cast| cast.try_into().ok())
229 .filter(|cast: &CastSig| cast.context == CastContext::Explicit)
230 .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
231 .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
232 casts
233 });
234
235pub(crate) static IMPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
240 LazyLock::new(|| {
241 let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
242 cast_sigs()
243 .filter_map(|cast| cast.try_into().ok())
244 .filter(|cast: &CastSig| cast.context == CastContext::Implicit)
245 .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
246 .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
247 casts
248 });
249
250fn expr_type_to_inequality_op(typ: ExprType) -> Option<BinaryOperator> {
251 match typ {
252 ExprType::GreaterThan => Some(BinaryOperator::Gt),
253 ExprType::GreaterThanOrEqual => Some(BinaryOperator::GtEq),
254 ExprType::LessThan => Some(BinaryOperator::Lt),
255 ExprType::LessThanOrEqual => Some(BinaryOperator::LtEq),
256 ExprType::NotEqual => Some(BinaryOperator::NotEq),
257 _ => None,
258 }
259}
260
261pub(crate) static BINARY_INEQUALITY_OP_TABLE: LazyLock<
270 HashMap<(DataType, DataType), Vec<BinaryOperator>>,
271> = LazyLock::new(|| {
272 let mut funcs = HashMap::<(DataType, DataType), Vec<BinaryOperator>>::new();
273 FUNCTION_REGISTRY
274 .iter_scalars()
275 .filter(|func| {
276 !FUNC_BAN_LIST.contains(&func.name.as_scalar())
277 && func.ret_type == DataType::Boolean.into()
278 && func.inputs_type.len() == 2
279 && func
280 .inputs_type
281 .iter()
282 .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz)
283 })
284 .filter_map(|func| {
285 let lhs = func.inputs_type[0].as_exact().clone();
286 let rhs = func.inputs_type[1].as_exact().clone();
287 let op = expr_type_to_inequality_op(func.name.as_scalar())?;
288 Some(((lhs, rhs), op))
289 })
290 .for_each(|(args, op)| funcs.entry(args).or_default().push(op));
291 funcs
292});