risingwave_common/array/arrow/
arrow_udf.rsuse std::sync::Arc;
pub use super::arrow_53::{
arrow_array, arrow_buffer, arrow_cast, arrow_schema, FromArrow, ToArrow,
};
use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray};
#[derive(Default, Debug)]
pub struct UdfArrowConvert {
pub legacy: bool,
}
impl ToArrow for UdfArrowConvert {
fn decimal_to_arrow(
&self,
_data_type: &arrow_schema::DataType,
array: &DecimalArray,
) -> Result<arrow_array::ArrayRef, ArrayError> {
if self.legacy {
Ok(Arc::new(arrow_array::LargeBinaryArray::from(array)))
} else {
Ok(Arc::new(arrow_array::StringArray::from(array)))
}
}
fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
if self.legacy {
Ok(Arc::new(arrow_array::LargeStringArray::from(array)))
} else {
Ok(Arc::new(arrow_array::StringArray::from(array)))
}
}
fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
if self.legacy {
arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true)
} else {
arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
.with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into())
}
}
fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
if self.legacy {
arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true)
} else {
arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
.with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into())
}
}
}
impl FromArrow for UdfArrowConvert {
fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
if self.legacy {
Ok(DataType::Jsonb)
} else {
Ok(DataType::Varchar)
}
}
fn from_large_binary(&self) -> Result<DataType, ArrayError> {
if self.legacy {
Ok(DataType::Decimal)
} else {
Ok(DataType::Bytea)
}
}
fn from_large_utf8_array(
&self,
array: &arrow_array::LargeStringArray,
) -> Result<ArrayImpl, ArrayError> {
if self.legacy {
Ok(ArrayImpl::Jsonb(array.try_into()?))
} else {
Ok(ArrayImpl::Utf8(array.into()))
}
}
fn from_large_binary_array(
&self,
array: &arrow_array::LargeBinaryArray,
) -> Result<ArrayImpl, ArrayError> {
if self.legacy {
Ok(ArrayImpl::Decimal(array.try_into()?))
} else {
Ok(ArrayImpl::Bytea(array.into()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::*;
#[test]
fn struct_array() {
let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0));
assert_eq!(
UdfArrowConvert::default()
.struct_to_arrow(
&arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
&test_arr
)
.unwrap()
.len(),
0
);
let test_arr_2 = arrow_array::StructArray::from(vec![]);
assert_eq!(
UdfArrowConvert::default()
.from_struct_array(&test_arr_2)
.unwrap()
.len(),
0
);
let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![
(
"a",
Arc::new(arrow_array::BooleanArray::from(vec![
Some(false),
Some(false),
Some(true),
None,
])) as arrow_array::ArrayRef,
),
(
"b",
Arc::new(arrow_array::Int32Array::from(vec![
Some(42),
Some(28),
Some(19),
None,
])) as arrow_array::ArrayRef,
),
])
.unwrap();
let actual_risingwave_struct_array = UdfArrowConvert::default()
.from_struct_array(&test_arrow_struct_array)
.unwrap()
.into_struct();
let expected_risingwave_struct_array = StructArray::new(
StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]),
vec![
BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(),
I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(),
],
[true, true, true, true].into_iter().collect(),
);
assert_eq!(
expected_risingwave_struct_array,
actual_risingwave_struct_array
);
}
#[test]
fn list() {
let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]);
let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true);
let arrow = UdfArrowConvert::default()
.list_to_arrow(&data_type, &array)
.unwrap();
let rw_array = UdfArrowConvert::default()
.from_list_array(arrow.as_any().downcast_ref().unwrap())
.unwrap();
assert_eq!(rw_array.as_list(), &array);
}
#[test]
fn map() {
let map_type = MapType::from_kv(DataType::Varchar, DataType::Int32);
let rw_map_type = DataType::Map(map_type.clone());
let mut builder = MapArrayBuilder::with_type(3, rw_map_type.clone());
builder.append_owned(Some(
MapValue::try_from_kv(
ListValue::from_str("{a,b,c}", &DataType::List(Box::new(DataType::Varchar)))
.unwrap(),
ListValue::from_str("{1,2,3}", &DataType::List(Box::new(DataType::Int32))).unwrap(),
)
.unwrap(),
));
builder.append_owned(None);
builder.append_owned(Some(
MapValue::try_from_kv(
ListValue::from_str("{a,c}", &DataType::List(Box::new(DataType::Varchar))).unwrap(),
ListValue::from_str("{1,3}", &DataType::List(Box::new(DataType::Int32))).unwrap(),
)
.unwrap(),
));
let rw_array = builder.finish();
let arrow_map_type = UdfArrowConvert::default()
.map_type_to_arrow(&map_type)
.unwrap();
expect_test::expect![[r#"
Map(
Field {
name: "entries",
data_type: Struct(
[
Field {
name: "key",
data_type: Utf8,
nullable: false,
dict_id: 0,
dict_is_ordered: false,
metadata: {},
},
Field {
name: "value",
data_type: Int32,
nullable: true,
dict_id: 0,
dict_is_ordered: false,
metadata: {},
},
],
),
nullable: false,
dict_id: 0,
dict_is_ordered: false,
metadata: {},
},
false,
)
"#]]
.assert_debug_eq(&arrow_map_type);
let rw_map_type_new = UdfArrowConvert::default()
.from_field(&arrow_schema::Field::new(
"map",
arrow_map_type.clone(),
true,
))
.unwrap();
assert_eq!(rw_map_type, rw_map_type_new);
let arrow = UdfArrowConvert::default()
.map_to_arrow(&arrow_map_type, &rw_array)
.unwrap();
expect_test::expect![[r#"
MapArray
[
StructArray
[
-- child 0: "key" (Utf8)
StringArray
[
"a",
"b",
"c",
]
-- child 1: "value" (Int32)
PrimitiveArray<Int32>
[
1,
2,
3,
]
],
null,
StructArray
[
-- child 0: "key" (Utf8)
StringArray
[
"a",
"c",
]
-- child 1: "value" (Int32)
PrimitiveArray<Int32>
[
1,
3,
]
],
]
"#]]
.assert_debug_eq(&arrow);
let rw_array_new = UdfArrowConvert::default()
.from_map_array(arrow.as_any().downcast_ref().unwrap())
.unwrap();
assert_eq!(&rw_array, rw_array_new.as_map());
}
}