1use std::str::FromStr;
16use std::sync::Arc;
17
18use futures_util::FutureExt;
19use itertools::Itertools;
20use risingwave_common::array::{DataChunk, ListRef, ListValue, StructRef, StructValue, VectorVal};
21use risingwave_common::cast;
22use risingwave_common::row::OwnedRow;
23use risingwave_common::types::{
24 DataType, F64, Int256, JsonbRef, MapRef, MapValue, ScalarRef as _, Serial, Timestamptz, ToText,
25};
26use risingwave_common::util::iter_util::ZipEqFast;
27use risingwave_common::util::row_id::row_id_to_unix_millis;
28use risingwave_expr::expr::{Context, ExpressionBoxExt, InputRefExpression, build_func};
29use risingwave_expr::{ExprError, Result, function};
30use risingwave_pb::expr::expr_node::PbType;
31use thiserror_ext::AsReport;
32
33#[function("cast(varchar) -> *int")]
34#[function("cast(varchar) -> decimal")]
35#[function("cast(varchar) -> *float")]
36#[function("cast(varchar) -> int256")]
37#[function("cast(varchar) -> date")]
38#[function("cast(varchar) -> time")]
39#[function("cast(varchar) -> timestamp")]
40#[function("cast(varchar) -> interval")]
41#[function("cast(varchar) -> jsonb")]
42pub fn str_parse<T>(elem: &str, ctx: &Context) -> Result<T>
43where
44 T: FromStr,
45 <T as FromStr>::Err: std::fmt::Display,
46{
47 elem.trim().parse().map_err(|err: <T as FromStr>::Err| {
48 ExprError::Parse(format!("{} {}", ctx.return_type, err).into())
49 })
50}
51
52#[function("pgwire_recv(bytea) -> int8")]
54pub fn pgwire_recv(elem: &[u8]) -> Result<i64> {
55 let fixed_length =
56 <[u8; 8]>::try_from(elem).map_err(|e| ExprError::Parse(e.to_report_string().into()))?;
57 Ok(i64::from_be_bytes(fixed_length))
58}
59
60#[function("cast(int2) -> int256")]
61#[function("cast(int4) -> int256")]
62#[function("cast(int8) -> int256")]
63pub fn to_int256<T: TryInto<Int256>>(elem: T) -> Result<Int256> {
64 elem.try_into()
65 .map_err(|_| ExprError::CastOutOfRange("int256"))
66}
67
68#[function("cast(jsonb) -> boolean")]
69pub fn jsonb_to_bool(v: JsonbRef<'_>) -> Result<bool> {
70 v.as_bool().map_err(|e| ExprError::Parse(e.into()))
71}
72
73#[function("cast(jsonb) -> int2")]
76#[function("cast(jsonb) -> int4")]
77#[function("cast(jsonb) -> int8")]
78#[function("cast(jsonb) -> decimal")]
79#[function("cast(jsonb) -> float4")]
80#[function("cast(jsonb) -> float8")]
81pub fn jsonb_to_number<T: TryFrom<F64>>(v: JsonbRef<'_>) -> Result<T> {
82 v.as_number()
83 .map_err(|e| ExprError::Parse(e.into()))?
84 .try_into()
85 .map_err(|_| ExprError::NumericOutOfRange)
86}
87
88#[function("cast(int4) -> int2")]
89#[function("cast(int8) -> int2")]
90#[function("cast(int8) -> int4")]
91#[function("cast(int8) -> serial")]
92#[function("cast(serial) -> int8")]
93#[function("cast(float4) -> int2")]
94#[function("cast(float8) -> int2")]
95#[function("cast(float4) -> int4")]
96#[function("cast(float8) -> int4")]
97#[function("cast(float4) -> int8")]
98#[function("cast(float8) -> int8")]
99#[function("cast(float8) -> float4")]
100#[function("cast(decimal) -> int2")]
101#[function("cast(decimal) -> int4")]
102#[function("cast(decimal) -> int8")]
103#[function("cast(decimal) -> float4")]
104#[function("cast(decimal) -> float8")]
105#[function("cast(float4) -> decimal")]
106#[function("cast(float8) -> decimal")]
107pub fn try_cast<T1, T2>(elem: T1) -> Result<T2>
108where
109 T1: TryInto<T2> + std::fmt::Debug + Copy,
110{
111 elem.try_into()
112 .map_err(|_| ExprError::CastOutOfRange(std::any::type_name::<T2>()))
113}
114
115#[function("cast(boolean) -> int4")]
116#[function("cast(int2) -> int4")]
117#[function("cast(int2) -> int8")]
118#[function("cast(int2) -> float4")]
119#[function("cast(int2) -> float8")]
120#[function("cast(int2) -> decimal")]
121#[function("cast(int4) -> int8")]
122#[function("cast(int4) -> float4")]
123#[function("cast(int4) -> float8")]
124#[function("cast(int4) -> decimal")]
125#[function("cast(int8) -> float4")]
126#[function("cast(int8) -> float8")]
127#[function("cast(int8) -> decimal")]
128#[function("cast(float4) -> float8")]
129#[function("cast(date) -> timestamp")]
130#[function("cast(time) -> interval")]
131#[function("cast(timestamp) -> date")]
132#[function("cast(timestamp) -> time")]
133#[function("cast(interval) -> time")]
134#[function("cast(varchar) -> varchar")]
135#[function("cast(int256) -> float8")]
136pub fn cast<T1, T2>(elem: T1) -> T2
137where
138 T1: Into<T2>,
139{
140 elem.into()
141}
142
143#[function("cast(serial) -> timestamptz")]
145pub fn serial_to_timestamptz(elem: Serial) -> Result<Timestamptz> {
146 let unix_ms = row_id_to_unix_millis(elem.as_row_id()).ok_or(ExprError::NumericOutOfRange)?;
147 Timestamptz::from_millis(unix_ms).ok_or(ExprError::NumericOutOfRange)
148}
149
150#[function("cast(varchar) -> boolean")]
151pub fn str_to_bool(input: &str) -> Result<bool> {
152 cast::str_to_bool(input).map_err(|err| ExprError::Parse(err.into()))
153}
154
155#[function("cast(int4) -> boolean")]
156pub fn int_to_bool(input: i32) -> bool {
157 input != 0
158}
159
160#[function("cast(*int) -> varchar")]
163#[function("cast(decimal) -> varchar")]
164#[function("cast(*float) -> varchar")]
165#[function("cast(int256) -> varchar")]
166#[function("cast(time) -> varchar")]
167#[function("cast(date) -> varchar")]
168#[function("cast(interval) -> varchar")]
169#[function("cast(timestamp) -> varchar")]
170#[function("cast(jsonb) -> varchar")]
171#[function("cast(bytea) -> varchar")]
172#[function("cast(anyarray) -> varchar")]
173#[function("cast(vector) -> varchar")]
174pub fn general_to_text(elem: impl ToText, mut writer: &mut impl std::fmt::Write) {
175 elem.write(&mut writer).unwrap();
176}
177
178#[function("pgwire_send(int8) -> bytea")]
180fn pgwire_send(elem: i64, writer: &mut impl std::io::Write) {
181 writer.write_all(&elem.to_be_bytes()).unwrap();
182}
183
184#[function("cast(boolean) -> varchar")]
185pub fn bool_to_varchar(input: bool, writer: &mut impl std::fmt::Write) {
186 writer
187 .write_str(if input { "true" } else { "false" })
188 .unwrap();
189}
190
191#[function("bool_out(boolean) -> varchar")]
194pub fn bool_out(input: bool, writer: &mut impl std::fmt::Write) {
195 writer.write_str(if input { "t" } else { "f" }).unwrap();
196}
197
198#[function("cast(varchar) -> bytea")]
199pub fn str_to_bytea(elem: &str, writer: &mut impl std::io::Write) -> Result<()> {
200 cast::str_to_bytea(elem, writer).map_err(|err| ExprError::Parse(err.into()))
201}
202
203#[function("cast(varchar) -> anyarray", type_infer = "unreachable")]
204fn str_to_list(input: &str, ctx: &Context) -> Result<ListValue> {
205 ListValue::from_str(input, &ctx.return_type).map_err(|err| ExprError::Parse(err.into()))
206}
207
208#[function("cast(varchar) -> vector", type_infer = "unreachable")]
209fn str_to_vector(input: &str, ctx: &Context) -> Result<VectorVal> {
210 let DataType::Vector(size) = &ctx.return_type else {
211 unreachable!()
212 };
213 VectorVal::from_text(input, *size).map_err(|err| ExprError::Parse(err.into()))
214}
215
216#[function("cast(anyarray) -> anyarray", type_infer = "unreachable")]
218fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result<ListValue> {
219 let cast = build_func(
220 PbType::Cast,
221 ctx.return_type.as_list_elem().clone(),
222 vec![InputRefExpression::new(ctx.arg_types[0].as_list_elem().clone(), 0).boxed()],
223 )
224 .unwrap();
225 let items = Arc::new(input.to_owned_scalar().into_array());
226 let len = items.len();
227 let list = cast
228 .eval(&DataChunk::new(vec![items], len))
229 .now_or_never()
230 .unwrap()?;
231 Ok(ListValue::new(Arc::try_unwrap(list).unwrap()))
232}
233
234#[function("cast(struct) -> struct", type_infer = "unreachable")]
236fn struct_cast(input: StructRef<'_>, ctx: &Context) -> Result<StructValue> {
237 let fields = (input.iter_fields_ref())
238 .zip_eq_fast(ctx.arg_types[0].as_struct().types())
239 .zip_eq_fast(ctx.return_type.as_struct().types())
240 .map(|((datum_ref, source_field_type), target_field_type)| {
241 if source_field_type == target_field_type {
242 return Ok(datum_ref.map(|scalar_ref| scalar_ref.into_scalar_impl()));
243 }
244 let cast = build_func(
245 PbType::Cast,
246 target_field_type.clone(),
247 vec![InputRefExpression::new(source_field_type.clone(), 0).boxed()],
248 )
249 .unwrap();
250 let value = match datum_ref {
251 Some(scalar_ref) => cast
252 .eval_row(&OwnedRow::new(vec![Some(scalar_ref.into_scalar_impl())]))
253 .now_or_never()
254 .unwrap()?,
255 None => None,
256 };
257 Ok(value) as Result<_>
258 })
259 .try_collect()?;
260 Ok(StructValue::new(fields))
261}
262
263#[function("cast(anymap) -> anymap", type_infer = "unreachable")]
265fn map_cast(map: MapRef<'_>, ctx: &Context) -> Result<MapValue> {
266 let new_ctx = Context {
267 arg_types: vec![ctx.arg_types[0].clone().as_map().clone().into_list()],
268 return_type: ctx.return_type.as_map().clone().into_list(),
269 variadic: ctx.variadic,
270 };
271 list_cast(map.into_inner(), &new_ctx).map(MapValue::from_entries)
272}
273
274#[cfg(test)]
275mod tests {
276 use chrono::NaiveDateTime;
277 use risingwave_common::array::*;
278 use risingwave_common::types::*;
279 use risingwave_expr::expr::build_from_pretty;
280
281 use super::*;
282
283 #[test]
284 fn integer_cast_to_bool() {
285 assert!(int_to_bool(32));
286 assert!(int_to_bool(-32));
287 assert!(!int_to_bool(0));
288 }
289
290 #[test]
291 fn number_to_string() {
292 macro_rules! test {
293 ($fn:ident($value:expr), $right:literal) => {
294 let mut writer = String::new();
295 $fn($value, &mut writer);
296 assert_eq!(writer, $right);
297 };
298 }
299
300 test!(bool_to_varchar(true), "true");
301 test!(bool_to_varchar(true), "true");
302 test!(bool_to_varchar(false), "false");
303
304 test!(general_to_text(32), "32");
305 test!(general_to_text(-32), "-32");
306 test!(general_to_text(i32::MIN), "-2147483648");
307 test!(general_to_text(i32::MAX), "2147483647");
308
309 test!(general_to_text(i16::MIN), "-32768");
310 test!(general_to_text(i16::MAX), "32767");
311
312 test!(general_to_text(i64::MIN), "-9223372036854775808");
313 test!(general_to_text(i64::MAX), "9223372036854775807");
314
315 test!(general_to_text(F64::from(32.12)), "32.12");
316 test!(general_to_text(F64::from(-32.14)), "-32.14");
317
318 test!(general_to_text(F32::from(32.12_f32)), "32.12");
319 test!(general_to_text(F32::from(-32.14_f32)), "-32.14");
320
321 test!(general_to_text(Decimal::try_from(1.222).unwrap()), "1.222");
322
323 test!(general_to_text(Decimal::NaN), "NaN");
324 }
325
326 #[test]
327 fn test_str_to_list() {
328 let ctx = Context {
330 arg_types: vec![DataType::Varchar],
331 return_type: DataType::from_str("int[]").unwrap(),
332 variadic: false,
333 };
334 assert_eq!(
335 str_to_list("{}", &ctx).unwrap(),
336 ListValue::empty(&DataType::Varchar)
337 );
338
339 let list123 = ListValue::from_iter([1, 2, 3]);
340
341 let ctx = Context {
343 arg_types: vec![DataType::Varchar],
344 return_type: DataType::from_str("int[]").unwrap(),
345 variadic: false,
346 };
347 assert_eq!(str_to_list("{1, 2, 3}", &ctx).unwrap(), list123);
348
349 let nested_list123 = ListValue::from_iter([list123]);
351 let ctx = Context {
352 arg_types: vec![DataType::Varchar],
353 return_type: DataType::from_str("int[][]").unwrap(),
354 variadic: false,
355 };
356 assert_eq!(str_to_list("{{1, 2, 3}}", &ctx).unwrap(), nested_list123);
357
358 let nested_list445566 = ListValue::from_iter([ListValue::from_iter([44, 55, 66])]);
359
360 let double_nested_list123_445566 =
361 ListValue::from_iter([nested_list123.clone(), nested_list445566.clone()]);
362
363 let ctx = Context {
365 arg_types: vec![DataType::Varchar],
366 return_type: DataType::from_str("int[][][]").unwrap(),
367 variadic: false,
368 };
369 assert_eq!(
370 str_to_list("{{{1, 2, 3}}, {{44, 55, 66}}}", &ctx).unwrap(),
371 double_nested_list123_445566
372 );
373
374 let ctx = Context {
376 arg_types: vec![DataType::from_str("int[][]").unwrap()],
377 return_type: DataType::from_str("varchar[][]").unwrap(),
378 variadic: false,
379 };
380 let double_nested_varchar_list123_445566 = ListValue::from_iter([
381 list_cast(nested_list123.as_scalar_ref(), &ctx).unwrap(),
382 list_cast(nested_list445566.as_scalar_ref(), &ctx).unwrap(),
383 ]);
384
385 let ctx = Context {
387 arg_types: vec![DataType::Varchar],
388 return_type: DataType::from_str("varchar[][][]").unwrap(),
389 variadic: false,
390 };
391 assert_eq!(
392 str_to_list("{{{1, 2, 3}}, {{44, 55, 66}}}", &ctx).unwrap(),
393 double_nested_varchar_list123_445566
394 );
395 }
396
397 #[test]
398 fn test_invalid_str_to_list() {
399 let ctx = Context {
401 arg_types: vec![DataType::Varchar],
402 return_type: DataType::from_str("int[]").unwrap(),
403 variadic: false,
404 };
405 assert!(str_to_list("{{}", &ctx).is_err());
406 assert!(str_to_list("{}}", &ctx).is_err());
407 assert!(str_to_list("{{1, 2, 3}, {4, 5, 6}", &ctx).is_err());
408 assert!(str_to_list("{{1, 2, 3}, 4, 5, 6}}", &ctx).is_err());
409 }
410
411 #[test]
412 fn test_struct_cast() {
413 let ctx = Context {
414 arg_types: vec![DataType::Struct(StructType::new(vec![
415 ("a", DataType::Varchar),
416 ("b", DataType::Float32),
417 ]))],
418 return_type: DataType::Struct(StructType::new(vec![
419 ("a", DataType::Int32),
420 ("b", DataType::Int32),
421 ])),
422 variadic: false,
423 };
424 assert_eq!(
425 struct_cast(
426 StructValue::new(vec![
427 Some("1".into()),
428 Some(F32::from(0.0).to_scalar_value()),
429 ])
430 .as_scalar_ref(),
431 &ctx,
432 )
433 .unwrap(),
434 StructValue::new(vec![
435 Some(1i32.to_scalar_value()),
436 Some(0i32.to_scalar_value()),
437 ])
438 );
439 }
440
441 #[test]
442 fn test_timestamp() {
443 assert_eq!(
444 try_cast::<_, Timestamp>(Date::from_ymd_uncheck(1994, 1, 1)).unwrap(),
445 Timestamp::new(
446 NaiveDateTime::parse_from_str("1994-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
447 )
448 )
449 }
450
451 #[tokio::test]
452 async fn test_unary() {
453 test_unary_bool::<BoolArray, _>(|x| !x, PbType::Not).await;
454 test_unary_date::<TimestampArray, _>(|x| try_cast(x).unwrap(), PbType::Cast).await;
455 let ctx_str_to_int16 = Context {
456 arg_types: vec![DataType::Varchar],
457 return_type: DataType::Int16,
458 variadic: false,
459 };
460 test_str_to_int16::<I16Array, _>(|x| str_parse(x, &ctx_str_to_int16).unwrap()).await;
461 }
462
463 #[tokio::test]
464 async fn test_i16_to_i32() {
465 let mut input = Vec::<Option<i16>>::new();
466 let mut target = Vec::<Option<i32>>::new();
467 for i in 0..100i16 {
468 if i % 2 == 0 {
469 target.push(Some(i as i32));
470 input.push(Some(i));
471 } else {
472 input.push(None);
473 target.push(None);
474 }
475 }
476 let col1 = I16Array::from_iter(&input).into_ref();
477 let data_chunk = DataChunk::new(vec![col1], 100);
478 let expr = build_from_pretty("(cast:int4 $0:int2)");
479 let res = expr.eval(&data_chunk).await.unwrap();
480 let arr: &I32Array = res.as_ref().into();
481 for (idx, item) in arr.iter().enumerate() {
482 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
483 assert_eq!(x, item);
484 }
485
486 for i in 0..input.len() {
487 let row = OwnedRow::new(vec![input[i].map(|int| int.to_scalar_value())]);
488 let result = expr.eval_row(&row).await.unwrap();
489 let expected = target[i].map(|int| int.to_scalar_value());
490 assert_eq!(result, expected);
491 }
492 }
493
494 #[tokio::test]
495 async fn test_neg() {
496 let input = [Some(1), Some(0), Some(-1)];
497 let target = [Some(-1), Some(0), Some(1)];
498
499 let col1 = I32Array::from_iter(&input).into_ref();
500 let data_chunk = DataChunk::new(vec![col1], 3);
501 let expr = build_from_pretty("(neg:int4 $0:int4)");
502 let res = expr.eval(&data_chunk).await.unwrap();
503 let arr: &I32Array = res.as_ref().into();
504 for (idx, item) in arr.iter().enumerate() {
505 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
506 assert_eq!(x, item);
507 }
508
509 for i in 0..input.len() {
510 let row = OwnedRow::new(vec![input[i].map(|int| int.to_scalar_value())]);
511 let result = expr.eval_row(&row).await.unwrap();
512 let expected = target[i].map(|int| int.to_scalar_value());
513 assert_eq!(result, expected);
514 }
515 }
516
517 async fn test_str_to_int16<A, F>(f: F)
518 where
519 A: Array,
520 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
521 for<'a> <A as Array>::RefItem<'a>: PartialEq,
522 F: Fn(&str) -> <A as Array>::OwnedItem,
523 {
524 let mut input = Vec::<Option<Box<str>>>::new();
525 let mut target = Vec::<Option<<A as Array>::OwnedItem>>::new();
526 for i in 0..1u32 {
527 if i % 2 == 0 {
528 let s = i.to_string().into_boxed_str();
529 target.push(Some(f(&s)));
530 input.push(Some(s));
531 } else {
532 input.push(None);
533 target.push(None);
534 }
535 }
536 let col1_data = &input.iter().map(|x| x.as_ref().map(|x| &**x)).collect_vec();
537 let col1 = Utf8Array::from_iter(col1_data).into_ref();
538 let data_chunk = DataChunk::new(vec![col1], 1);
539 let expr = build_from_pretty("(cast:int2 $0:varchar)");
540 let res = expr.eval(&data_chunk).await.unwrap();
541 let arr: &A = res.as_ref().into();
542 for (idx, item) in arr.iter().enumerate() {
543 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
544 assert_eq!(x, item);
545 }
546
547 for i in 0..input.len() {
548 let row = OwnedRow::new(vec![
549 input[i].as_ref().cloned().map(|str| str.to_scalar_value()),
550 ]);
551 let result = expr.eval_row(&row).await.unwrap();
552 let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
553 assert_eq!(result, expected);
554 }
555 }
556
557 async fn test_unary_bool<A, F>(f: F, kind: PbType)
558 where
559 A: Array,
560 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
561 for<'a> <A as Array>::RefItem<'a>: PartialEq,
562 F: Fn(bool) -> <A as Array>::OwnedItem,
563 {
564 let mut input = Vec::<Option<bool>>::new();
565 let mut target = Vec::<Option<<A as Array>::OwnedItem>>::new();
566 for i in 0..100 {
567 if i % 2 == 0 {
568 input.push(Some(true));
569 target.push(Some(f(true)));
570 } else if i % 3 == 0 {
571 input.push(Some(false));
572 target.push(Some(f(false)));
573 } else {
574 input.push(None);
575 target.push(None);
576 }
577 }
578
579 let col1 = BoolArray::from_iter(&input).into_ref();
580 let data_chunk = DataChunk::new(vec![col1], 100);
581 let expr = build_from_pretty(format!("({kind:?}:boolean $0:boolean)"));
582 let res = expr.eval(&data_chunk).await.unwrap();
583 let arr: &A = res.as_ref().into();
584 for (idx, item) in arr.iter().enumerate() {
585 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
586 assert_eq!(x, item);
587 }
588
589 for i in 0..input.len() {
590 let row = OwnedRow::new(vec![input[i].map(|b| b.to_scalar_value())]);
591 let result = expr.eval_row(&row).await.unwrap();
592 let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
593 assert_eq!(result, expected);
594 }
595 }
596
597 async fn test_unary_date<A, F>(f: F, kind: PbType)
598 where
599 A: Array,
600 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
601 for<'a> <A as Array>::RefItem<'a>: PartialEq,
602 F: Fn(Date) -> <A as Array>::OwnedItem,
603 {
604 let mut input = Vec::<Option<Date>>::new();
605 let mut target = Vec::<Option<<A as Array>::OwnedItem>>::new();
606 for i in 0..100 {
607 if i % 2 == 0 {
608 let date = Date::from_num_days_from_ce_uncheck(i);
609 input.push(Some(date));
610 target.push(Some(f(date)));
611 } else {
612 input.push(None);
613 target.push(None);
614 }
615 }
616
617 let col1 = DateArray::from_iter(&input).into_ref();
618 let data_chunk = DataChunk::new(vec![col1], 100);
619 let expr = build_from_pretty(format!("({kind:?}:timestamp $0:date)"));
620 let res = expr.eval(&data_chunk).await.unwrap();
621 let arr: &A = res.as_ref().into();
622 for (idx, item) in arr.iter().enumerate() {
623 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
624 assert_eq!(x, item);
625 }
626
627 for i in 0..input.len() {
628 let row = OwnedRow::new(vec![input[i].map(|d| d.to_scalar_value())]);
629 let result = expr.eval_row(&row).await.unwrap();
630 let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
631 assert_eq!(result, expected);
632 }
633 }
634}