1use std::str::FromStr;
16use std::sync::Arc;
17
18use futures_util::FutureExt;
19use itertools::Itertools;
20use risingwave_common::array::{DataChunk, ListRef, ListValue, StructRef, StructValue, VectorVal};
21use risingwave_common::cast;
22use risingwave_common::row::OwnedRow;
23use risingwave_common::types::{
24 DataType, F64, Int256, JsonbRef, MapRef, MapValue, ScalarRef as _, Serial, Timestamptz, ToText,
25};
26use risingwave_common::util::iter_util::ZipEqFast;
27use risingwave_common::util::row_id::row_id_to_unix_millis;
28use risingwave_expr::expr::{Context, ExpressionBoxExt, InputRefExpression, build_func};
29use risingwave_expr::{ExprError, Result, function};
30use risingwave_pb::expr::expr_node::PbType;
31use thiserror_ext::AsReport;
32
33#[function("cast(varchar) -> *int")]
34#[function("cast(varchar) -> decimal")]
35#[function("cast(varchar) -> *float")]
36#[function("cast(varchar) -> int256")]
37#[function("cast(varchar) -> date")]
38#[function("cast(varchar) -> time")]
39#[function("cast(varchar) -> timestamp")]
40#[function("cast(varchar) -> interval")]
41#[function("cast(varchar) -> jsonb")]
42pub fn str_parse<T>(elem: &str, ctx: &Context) -> Result<T>
43where
44 T: FromStr,
45 <T as FromStr>::Err: std::fmt::Display,
46{
47 elem.trim().parse().map_err(|err: <T as FromStr>::Err| {
48 ExprError::Parse(format!("{} {}", ctx.return_type, err).into())
49 })
50}
51
52#[function("pgwire_recv(bytea) -> int8")]
54pub fn pgwire_recv(elem: &[u8]) -> Result<i64> {
55 let fixed_length =
56 <[u8; 8]>::try_from(elem).map_err(|e| ExprError::Parse(e.to_report_string().into()))?;
57 Ok(i64::from_be_bytes(fixed_length))
58}
59
60#[function("cast(int2) -> int256")]
61#[function("cast(int4) -> int256")]
62#[function("cast(int8) -> int256")]
63pub fn to_int256<T: TryInto<Int256>>(elem: T) -> Result<Int256> {
64 elem.try_into()
65 .map_err(|_| ExprError::CastOutOfRange("int256"))
66}
67
68#[function("cast(jsonb) -> boolean")]
69pub fn jsonb_to_bool(v: JsonbRef<'_>) -> Result<Option<bool>> {
70 if v.is_jsonb_null() {
71 Ok(None)
72 } else {
73 v.as_bool()
74 .map(Some)
75 .map_err(|e| ExprError::Parse(e.into()))
76 }
77}
78
79#[function("cast(jsonb) -> int2")]
82#[function("cast(jsonb) -> int4")]
83#[function("cast(jsonb) -> int8")]
84#[function("cast(jsonb) -> decimal")]
85#[function("cast(jsonb) -> float4")]
86#[function("cast(jsonb) -> float8")]
87pub fn jsonb_to_number<T: TryFrom<F64>>(v: JsonbRef<'_>) -> Result<Option<T>> {
88 if v.is_jsonb_null() {
89 Ok(None)
90 } else {
91 v.as_number()
92 .map_err(|e| ExprError::Parse(e.into()))?
93 .try_into()
94 .map(Some)
95 .map_err(|_| ExprError::NumericOutOfRange)
96 }
97}
98
99#[function("cast(int4) -> int2")]
100#[function("cast(int8) -> int2")]
101#[function("cast(int8) -> int4")]
102#[function("cast(int8) -> serial")]
103#[function("cast(serial) -> int8")]
104#[function("cast(float4) -> int2")]
105#[function("cast(float8) -> int2")]
106#[function("cast(float4) -> int4")]
107#[function("cast(float8) -> int4")]
108#[function("cast(float4) -> int8")]
109#[function("cast(float8) -> int8")]
110#[function("cast(float8) -> float4")]
111#[function("cast(decimal) -> int2")]
112#[function("cast(decimal) -> int4")]
113#[function("cast(decimal) -> int8")]
114#[function("cast(decimal) -> float4")]
115#[function("cast(decimal) -> float8")]
116#[function("cast(float4) -> decimal")]
117#[function("cast(float8) -> decimal")]
118pub fn try_cast<T1, T2>(elem: T1) -> Result<T2>
119where
120 T1: TryInto<T2> + std::fmt::Debug + Copy,
121{
122 elem.try_into()
123 .map_err(|_| ExprError::CastOutOfRange(std::any::type_name::<T2>()))
124}
125
126#[function("cast(boolean) -> int4")]
127#[function("cast(int2) -> int4")]
128#[function("cast(int2) -> int8")]
129#[function("cast(int2) -> float4")]
130#[function("cast(int2) -> float8")]
131#[function("cast(int2) -> decimal")]
132#[function("cast(int4) -> int8")]
133#[function("cast(int4) -> float4")]
134#[function("cast(int4) -> float8")]
135#[function("cast(int4) -> decimal")]
136#[function("cast(int8) -> float4")]
137#[function("cast(int8) -> float8")]
138#[function("cast(int8) -> decimal")]
139#[function("cast(float4) -> float8")]
140#[function("cast(date) -> timestamp")]
141#[function("cast(time) -> interval")]
142#[function("cast(timestamp) -> date")]
143#[function("cast(timestamp) -> time")]
144#[function("cast(interval) -> time")]
145#[function("cast(varchar) -> varchar")]
146#[function("cast(int256) -> float8")]
147pub fn cast<T1, T2>(elem: T1) -> T2
148where
149 T1: Into<T2>,
150{
151 elem.into()
152}
153
154#[function("cast(serial) -> timestamptz")]
156pub fn serial_to_timestamptz(elem: Serial) -> Result<Timestamptz> {
157 let unix_ms = row_id_to_unix_millis(elem.as_row_id()).ok_or(ExprError::NumericOutOfRange)?;
158 Timestamptz::from_millis(unix_ms).ok_or(ExprError::NumericOutOfRange)
159}
160
161#[function("cast(varchar) -> boolean")]
162pub fn str_to_bool(input: &str) -> Result<bool> {
163 cast::str_to_bool(input).map_err(|err| ExprError::Parse(err.into()))
164}
165
166#[function("cast(int4) -> boolean")]
167pub fn int_to_bool(input: i32) -> bool {
168 input != 0
169}
170
171#[function("cast(*int) -> varchar")]
174#[function("cast(decimal) -> varchar")]
175#[function("cast(*float) -> varchar")]
176#[function("cast(int256) -> varchar")]
177#[function("cast(time) -> varchar")]
178#[function("cast(date) -> varchar")]
179#[function("cast(interval) -> varchar")]
180#[function("cast(timestamp) -> varchar")]
181#[function("cast(jsonb) -> varchar")]
182#[function("cast(bytea) -> varchar")]
183#[function("cast(anyarray) -> varchar")]
184#[function("cast(vector) -> varchar")]
185pub fn general_to_text(elem: impl ToText, mut writer: &mut impl std::fmt::Write) {
186 elem.write(&mut writer).unwrap();
187}
188
189#[function("pgwire_send(int8) -> bytea")]
191fn pgwire_send(elem: i64, writer: &mut impl std::io::Write) {
192 writer.write_all(&elem.to_be_bytes()).unwrap();
193}
194
195#[function("cast(boolean) -> varchar")]
196pub fn bool_to_varchar(input: bool, writer: &mut impl std::fmt::Write) {
197 writer
198 .write_str(if input { "true" } else { "false" })
199 .unwrap();
200}
201
202#[function("bool_out(boolean) -> varchar")]
205pub fn bool_out(input: bool, writer: &mut impl std::fmt::Write) {
206 writer.write_str(if input { "t" } else { "f" }).unwrap();
207}
208
209#[function("cast(varchar) -> bytea")]
210pub fn str_to_bytea(elem: &str, writer: &mut impl std::io::Write) -> Result<()> {
211 cast::str_to_bytea(elem, writer).map_err(|err| ExprError::Parse(err.into()))
212}
213
214#[function("cast(varchar) -> anyarray", type_infer = "unreachable")]
215fn str_to_list(input: &str, ctx: &Context) -> Result<ListValue> {
216 ListValue::from_str(input, &ctx.return_type).map_err(|err| ExprError::Parse(err.into()))
217}
218
219#[function("cast(varchar) -> vector", type_infer = "unreachable")]
220fn str_to_vector(input: &str, ctx: &Context) -> Result<VectorVal> {
221 let DataType::Vector(size) = &ctx.return_type else {
222 unreachable!()
223 };
224 VectorVal::from_text(input, *size).map_err(|err| ExprError::Parse(err.into()))
225}
226
227#[function("cast(anyarray) -> anyarray", type_infer = "unreachable")]
229fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result<ListValue> {
230 let cast = build_func(
231 PbType::Cast,
232 ctx.return_type.as_list_elem().clone(),
233 vec![InputRefExpression::new(ctx.arg_types[0].as_list_elem().clone(), 0).boxed()],
234 )
235 .unwrap();
236 let items = Arc::new(input.to_owned_scalar().into_array());
237 let len = items.len();
238 let list = cast
239 .eval(&DataChunk::new(vec![items], len))
240 .now_or_never()
241 .unwrap()?;
242 Ok(ListValue::new(Arc::try_unwrap(list).unwrap()))
243}
244
245#[function("cast(struct) -> struct", type_infer = "unreachable")]
247fn struct_cast(input: StructRef<'_>, ctx: &Context) -> Result<StructValue> {
248 let fields = (input.iter_fields_ref())
249 .zip_eq_fast(ctx.arg_types[0].as_struct().types())
250 .zip_eq_fast(ctx.return_type.as_struct().types())
251 .map(|((datum_ref, source_field_type), target_field_type)| {
252 if source_field_type == target_field_type {
253 return Ok(datum_ref.map(|scalar_ref| scalar_ref.into_scalar_impl()));
254 }
255 let cast = build_func(
256 PbType::Cast,
257 target_field_type.clone(),
258 vec![InputRefExpression::new(source_field_type.clone(), 0).boxed()],
259 )
260 .unwrap();
261 let value = match datum_ref {
262 Some(scalar_ref) => cast
263 .eval_row(&OwnedRow::new(vec![Some(scalar_ref.into_scalar_impl())]))
264 .now_or_never()
265 .unwrap()?,
266 None => None,
267 };
268 Ok(value) as Result<_>
269 })
270 .try_collect()?;
271 Ok(StructValue::new(fields))
272}
273
274#[function("cast(anymap) -> anymap", type_infer = "unreachable")]
276fn map_cast(map: MapRef<'_>, ctx: &Context) -> Result<MapValue> {
277 let new_ctx = Context {
278 arg_types: vec![ctx.arg_types[0].clone().as_map().clone().into_list()],
279 return_type: ctx.return_type.as_map().clone().into_list(),
280 variadic: ctx.variadic,
281 };
282 list_cast(map.into_inner(), &new_ctx).map(MapValue::from_entries)
283}
284
285#[cfg(test)]
286mod tests {
287 use chrono::NaiveDateTime;
288 use risingwave_common::array::*;
289 use risingwave_common::types::*;
290 use risingwave_expr::expr::build_from_pretty;
291
292 use super::*;
293
294 #[test]
295 fn integer_cast_to_bool() {
296 assert!(int_to_bool(32));
297 assert!(int_to_bool(-32));
298 assert!(!int_to_bool(0));
299 }
300
301 #[test]
302 fn number_to_string() {
303 macro_rules! test {
304 ($fn:ident($value:expr), $right:literal) => {
305 let mut writer = String::new();
306 $fn($value, &mut writer);
307 assert_eq!(writer, $right);
308 };
309 }
310
311 test!(bool_to_varchar(true), "true");
312 test!(bool_to_varchar(true), "true");
313 test!(bool_to_varchar(false), "false");
314
315 test!(general_to_text(32), "32");
316 test!(general_to_text(-32), "-32");
317 test!(general_to_text(i32::MIN), "-2147483648");
318 test!(general_to_text(i32::MAX), "2147483647");
319
320 test!(general_to_text(i16::MIN), "-32768");
321 test!(general_to_text(i16::MAX), "32767");
322
323 test!(general_to_text(i64::MIN), "-9223372036854775808");
324 test!(general_to_text(i64::MAX), "9223372036854775807");
325
326 test!(general_to_text(F64::from(32.12)), "32.12");
327 test!(general_to_text(F64::from(-32.14)), "-32.14");
328
329 test!(general_to_text(F32::from(32.12_f32)), "32.12");
330 test!(general_to_text(F32::from(-32.14_f32)), "-32.14");
331
332 test!(general_to_text(Decimal::try_from(1.222).unwrap()), "1.222");
333
334 test!(general_to_text(Decimal::NaN), "NaN");
335 }
336
337 #[test]
338 fn test_str_to_list() {
339 let ctx = Context {
341 arg_types: vec![DataType::Varchar],
342 return_type: DataType::from_str("int[]").unwrap(),
343 variadic: false,
344 };
345 assert_eq!(
346 str_to_list("{}", &ctx).unwrap(),
347 ListValue::empty(&DataType::Varchar)
348 );
349
350 let list123 = ListValue::from_iter([1, 2, 3]);
351
352 let ctx = Context {
354 arg_types: vec![DataType::Varchar],
355 return_type: DataType::from_str("int[]").unwrap(),
356 variadic: false,
357 };
358 assert_eq!(str_to_list("{1, 2, 3}", &ctx).unwrap(), list123);
359
360 let nested_list123 = ListValue::from_iter([list123]);
362 let ctx = Context {
363 arg_types: vec![DataType::Varchar],
364 return_type: DataType::from_str("int[][]").unwrap(),
365 variadic: false,
366 };
367 assert_eq!(str_to_list("{{1, 2, 3}}", &ctx).unwrap(), nested_list123);
368
369 let nested_list445566 = ListValue::from_iter([ListValue::from_iter([44, 55, 66])]);
370
371 let double_nested_list123_445566 =
372 ListValue::from_iter([nested_list123.clone(), nested_list445566.clone()]);
373
374 let ctx = Context {
376 arg_types: vec![DataType::Varchar],
377 return_type: DataType::from_str("int[][][]").unwrap(),
378 variadic: false,
379 };
380 assert_eq!(
381 str_to_list("{{{1, 2, 3}}, {{44, 55, 66}}}", &ctx).unwrap(),
382 double_nested_list123_445566
383 );
384
385 let ctx = Context {
387 arg_types: vec![DataType::from_str("int[][]").unwrap()],
388 return_type: DataType::from_str("varchar[][]").unwrap(),
389 variadic: false,
390 };
391 let double_nested_varchar_list123_445566 = ListValue::from_iter([
392 list_cast(nested_list123.as_scalar_ref(), &ctx).unwrap(),
393 list_cast(nested_list445566.as_scalar_ref(), &ctx).unwrap(),
394 ]);
395
396 let ctx = Context {
398 arg_types: vec![DataType::Varchar],
399 return_type: DataType::from_str("varchar[][][]").unwrap(),
400 variadic: false,
401 };
402 assert_eq!(
403 str_to_list("{{{1, 2, 3}}, {{44, 55, 66}}}", &ctx).unwrap(),
404 double_nested_varchar_list123_445566
405 );
406 }
407
408 #[test]
409 fn test_invalid_str_to_list() {
410 let ctx = Context {
412 arg_types: vec![DataType::Varchar],
413 return_type: DataType::from_str("int[]").unwrap(),
414 variadic: false,
415 };
416 assert!(str_to_list("{{}", &ctx).is_err());
417 assert!(str_to_list("{}}", &ctx).is_err());
418 assert!(str_to_list("{{1, 2, 3}, {4, 5, 6}", &ctx).is_err());
419 assert!(str_to_list("{{1, 2, 3}, 4, 5, 6}}", &ctx).is_err());
420 }
421
422 #[test]
423 fn test_struct_cast() {
424 let ctx = Context {
425 arg_types: vec![DataType::Struct(StructType::new(vec![
426 ("a", DataType::Varchar),
427 ("b", DataType::Float32),
428 ]))],
429 return_type: DataType::Struct(StructType::new(vec![
430 ("a", DataType::Int32),
431 ("b", DataType::Int32),
432 ])),
433 variadic: false,
434 };
435 assert_eq!(
436 struct_cast(
437 StructValue::new(vec![
438 Some("1".into()),
439 Some(F32::from(0.0).to_scalar_value()),
440 ])
441 .as_scalar_ref(),
442 &ctx,
443 )
444 .unwrap(),
445 StructValue::new(vec![
446 Some(1i32.to_scalar_value()),
447 Some(0i32.to_scalar_value()),
448 ])
449 );
450 }
451
452 #[test]
453 fn test_timestamp() {
454 assert_eq!(
455 try_cast::<_, Timestamp>(Date::from_ymd_uncheck(1994, 1, 1)).unwrap(),
456 Timestamp::new(
457 NaiveDateTime::parse_from_str("1994-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
458 )
459 )
460 }
461
462 #[tokio::test]
463 async fn test_unary() {
464 test_unary_bool::<BoolArray, _>(|x| !x, PbType::Not).await;
465 test_unary_date::<TimestampArray, _>(|x| try_cast(x).unwrap(), PbType::Cast).await;
466 let ctx_str_to_int16 = Context {
467 arg_types: vec![DataType::Varchar],
468 return_type: DataType::Int16,
469 variadic: false,
470 };
471 test_str_to_int16::<I16Array, _>(|x| str_parse(x, &ctx_str_to_int16).unwrap()).await;
472 }
473
474 #[tokio::test]
475 async fn test_i16_to_i32() {
476 let mut input = Vec::<Option<i16>>::new();
477 let mut target = Vec::<Option<i32>>::new();
478 for i in 0..100i16 {
479 if i % 2 == 0 {
480 target.push(Some(i as i32));
481 input.push(Some(i));
482 } else {
483 input.push(None);
484 target.push(None);
485 }
486 }
487 let col1 = I16Array::from_iter(&input).into_ref();
488 let data_chunk = DataChunk::new(vec![col1], 100);
489 let expr = build_from_pretty("(cast:int4 $0:int2)");
490 let res = expr.eval(&data_chunk).await.unwrap();
491 let arr: &I32Array = res.as_ref().into();
492 for (idx, item) in arr.iter().enumerate() {
493 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
494 assert_eq!(x, item);
495 }
496
497 for i in 0..input.len() {
498 let row = OwnedRow::new(vec![input[i].map(|int| int.to_scalar_value())]);
499 let result = expr.eval_row(&row).await.unwrap();
500 let expected = target[i].map(|int| int.to_scalar_value());
501 assert_eq!(result, expected);
502 }
503 }
504
505 #[tokio::test]
506 async fn test_neg() {
507 let input = [Some(1), Some(0), Some(-1)];
508 let target = [Some(-1), Some(0), Some(1)];
509
510 let col1 = I32Array::from_iter(&input).into_ref();
511 let data_chunk = DataChunk::new(vec![col1], 3);
512 let expr = build_from_pretty("(neg:int4 $0:int4)");
513 let res = expr.eval(&data_chunk).await.unwrap();
514 let arr: &I32Array = res.as_ref().into();
515 for (idx, item) in arr.iter().enumerate() {
516 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
517 assert_eq!(x, item);
518 }
519
520 for i in 0..input.len() {
521 let row = OwnedRow::new(vec![input[i].map(|int| int.to_scalar_value())]);
522 let result = expr.eval_row(&row).await.unwrap();
523 let expected = target[i].map(|int| int.to_scalar_value());
524 assert_eq!(result, expected);
525 }
526 }
527
528 async fn test_str_to_int16<A, F>(f: F)
529 where
530 A: Array,
531 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
532 for<'a> <A as Array>::RefItem<'a>: PartialEq,
533 F: Fn(&str) -> <A as Array>::OwnedItem,
534 {
535 let mut input = Vec::<Option<Box<str>>>::new();
536 let mut target = Vec::<Option<<A as Array>::OwnedItem>>::new();
537 for i in 0..1u32 {
538 if i % 2 == 0 {
539 let s = i.to_string().into_boxed_str();
540 target.push(Some(f(&s)));
541 input.push(Some(s));
542 } else {
543 input.push(None);
544 target.push(None);
545 }
546 }
547 let col1_data = &input.iter().map(|x| x.as_ref().map(|x| &**x)).collect_vec();
548 let col1 = Utf8Array::from_iter(col1_data).into_ref();
549 let data_chunk = DataChunk::new(vec![col1], 1);
550 let expr = build_from_pretty("(cast:int2 $0:varchar)");
551 let res = expr.eval(&data_chunk).await.unwrap();
552 let arr: &A = res.as_ref().into();
553 for (idx, item) in arr.iter().enumerate() {
554 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
555 assert_eq!(x, item);
556 }
557
558 for i in 0..input.len() {
559 let row = OwnedRow::new(vec![
560 input[i].as_ref().cloned().map(|str| str.to_scalar_value()),
561 ]);
562 let result = expr.eval_row(&row).await.unwrap();
563 let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
564 assert_eq!(result, expected);
565 }
566 }
567
568 async fn test_unary_bool<A, F>(f: F, kind: PbType)
569 where
570 A: Array,
571 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
572 for<'a> <A as Array>::RefItem<'a>: PartialEq,
573 F: Fn(bool) -> <A as Array>::OwnedItem,
574 {
575 let mut input = Vec::<Option<bool>>::new();
576 let mut target = Vec::<Option<<A as Array>::OwnedItem>>::new();
577 for i in 0..100 {
578 if i % 2 == 0 {
579 input.push(Some(true));
580 target.push(Some(f(true)));
581 } else if i % 3 == 0 {
582 input.push(Some(false));
583 target.push(Some(f(false)));
584 } else {
585 input.push(None);
586 target.push(None);
587 }
588 }
589
590 let col1 = BoolArray::from_iter(&input).into_ref();
591 let data_chunk = DataChunk::new(vec![col1], 100);
592 let expr = build_from_pretty(format!("({kind:?}:boolean $0:boolean)"));
593 let res = expr.eval(&data_chunk).await.unwrap();
594 let arr: &A = res.as_ref().into();
595 for (idx, item) in arr.iter().enumerate() {
596 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
597 assert_eq!(x, item);
598 }
599
600 for i in 0..input.len() {
601 let row = OwnedRow::new(vec![input[i].map(|b| b.to_scalar_value())]);
602 let result = expr.eval_row(&row).await.unwrap();
603 let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
604 assert_eq!(result, expected);
605 }
606 }
607
608 async fn test_unary_date<A, F>(f: F, kind: PbType)
609 where
610 A: Array,
611 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
612 for<'a> <A as Array>::RefItem<'a>: PartialEq,
613 F: Fn(Date) -> <A as Array>::OwnedItem,
614 {
615 let mut input = Vec::<Option<Date>>::new();
616 let mut target = Vec::<Option<<A as Array>::OwnedItem>>::new();
617 for i in 0..100 {
618 if i % 2 == 0 {
619 let date = Date::from_num_days_from_ce_uncheck(i);
620 input.push(Some(date));
621 target.push(Some(f(date)));
622 } else {
623 input.push(None);
624 target.push(None);
625 }
626 }
627
628 let col1 = DateArray::from_iter(&input).into_ref();
629 let data_chunk = DataChunk::new(vec![col1], 100);
630 let expr = build_from_pretty(format!("({kind:?}:timestamp $0:date)"));
631 let res = expr.eval(&data_chunk).await.unwrap();
632 let arr: &A = res.as_ref().into();
633 for (idx, item) in arr.iter().enumerate() {
634 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
635 assert_eq!(x, item);
636 }
637
638 for i in 0..input.len() {
639 let row = OwnedRow::new(vec![input[i].map(|d| d.to_scalar_value())]);
640 let result = expr.eval_row(&row).await.unwrap();
641 let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
642 assert_eq!(result, expected);
643 }
644 }
645}