risingwave_common/array/arrow/
arrow_udf.rs1use std::sync::Arc;
22
23pub use super::arrow_58::{
24 FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25};
26use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray};
27
28#[derive(Default, Debug)]
30pub struct UdfArrowConvert {
31 pub legacy: bool,
37}
38
39impl ToArrow for UdfArrowConvert {
40 fn decimal_to_arrow(
41 &self,
42 _data_type: &arrow_schema::DataType,
43 array: &DecimalArray,
44 ) -> Result<arrow_array::ArrayRef, ArrayError> {
45 if self.legacy {
46 Ok(Arc::new(arrow_array::LargeBinaryArray::from(array)))
48 } else {
49 Ok(Arc::new(arrow_array::StringArray::from(array)))
50 }
51 }
52
53 fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
54 if self.legacy {
55 Ok(Arc::new(arrow_array::LargeStringArray::from(array)))
57 } else {
58 Ok(Arc::new(arrow_array::StringArray::from(array)))
59 }
60 }
61
62 fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
63 if self.legacy {
64 arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true)
65 } else {
66 arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
67 .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into())
68 }
69 }
70
71 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
72 if self.legacy {
73 arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true)
74 } else {
75 arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
76 .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into())
77 }
78 }
79}
80
81impl FromArrow for UdfArrowConvert {
82 fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
83 if self.legacy {
84 Ok(DataType::Jsonb)
85 } else {
86 Ok(DataType::Varchar)
87 }
88 }
89
90 fn from_large_binary(&self) -> Result<DataType, ArrayError> {
91 if self.legacy {
92 Ok(DataType::Decimal)
93 } else {
94 Ok(DataType::Bytea)
95 }
96 }
97
98 fn from_large_utf8_array(
99 &self,
100 array: &arrow_array::LargeStringArray,
101 ) -> Result<ArrayImpl, ArrayError> {
102 if self.legacy {
103 Ok(ArrayImpl::Jsonb(array.try_into()?))
104 } else {
105 Ok(ArrayImpl::Utf8(array.into()))
106 }
107 }
108
109 fn from_large_binary_array(
110 &self,
111 array: &arrow_array::LargeBinaryArray,
112 ) -> Result<ArrayImpl, ArrayError> {
113 if self.legacy {
114 Ok(ArrayImpl::Decimal(array.try_into()?))
115 } else {
116 Ok(ArrayImpl::Bytea(array.into()))
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123
124 use super::*;
125 use crate::array::*;
126
127 #[test]
128 fn struct_array() {
129 let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0));
131 assert_eq!(
132 UdfArrowConvert::default()
133 .struct_to_arrow(
134 &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
135 &test_arr
136 )
137 .unwrap()
138 .len(),
139 0
140 );
141
142 let test_arr_2 = arrow_array::StructArray::new_empty_fields(0, None);
144 assert_eq!(
145 UdfArrowConvert::default()
146 .from_struct_array(&test_arr_2)
147 .unwrap()
148 .len(),
149 0
150 );
151
152 let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![
154 (
155 "a",
156 Arc::new(arrow_array::BooleanArray::from(vec![
157 Some(false),
158 Some(false),
159 Some(true),
160 None,
161 ])) as arrow_array::ArrayRef,
162 ),
163 (
164 "b",
165 Arc::new(arrow_array::Int32Array::from(vec![
166 Some(42),
167 Some(28),
168 Some(19),
169 None,
170 ])) as arrow_array::ArrayRef,
171 ),
172 ])
173 .unwrap();
174 let actual_risingwave_struct_array = UdfArrowConvert::default()
175 .from_struct_array(&test_arrow_struct_array)
176 .unwrap()
177 .into_struct();
178 let expected_risingwave_struct_array = StructArray::new(
179 StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]),
180 vec![
181 BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(),
182 I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(),
183 ],
184 [true, true, true, true].into_iter().collect(),
185 );
186 assert_eq!(
187 expected_risingwave_struct_array,
188 actual_risingwave_struct_array
189 );
190 }
191
192 #[test]
193 fn list() {
194 let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]);
195 let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true);
196 let arrow = UdfArrowConvert::default()
197 .list_to_arrow(&data_type, &array)
198 .unwrap();
199 let rw_array = UdfArrowConvert::default()
200 .from_list_array(arrow.as_any().downcast_ref().unwrap())
201 .unwrap();
202 assert_eq!(rw_array.as_list(), &array);
203 }
204
205 #[test]
206 fn map() {
207 let map_type = MapType::from_kv(DataType::Varchar, DataType::Int32);
208 let rw_map_type = DataType::Map(map_type.clone());
209 let mut builder = MapArrayBuilder::with_type(3, rw_map_type.clone());
210 builder.append_owned(Some(
211 MapValue::try_from_kv(
212 ListValue::from_str("{a,b,c}", &DataType::Varchar.list()).unwrap(),
213 ListValue::from_str("{1,2,3}", &DataType::Int32.list()).unwrap(),
214 )
215 .unwrap(),
216 ));
217 builder.append_owned(None);
218 builder.append_owned(Some(
219 MapValue::try_from_kv(
220 ListValue::from_str("{a,c}", &DataType::Varchar.list()).unwrap(),
221 ListValue::from_str("{1,3}", &DataType::Int32.list()).unwrap(),
222 )
223 .unwrap(),
224 ));
225 let rw_array = builder.finish();
226
227 let arrow_map_type = UdfArrowConvert::default()
228 .map_type_to_arrow(&map_type)
229 .unwrap();
230 expect_test::expect![[r#"
231 Map(
232 Field {
233 name: "entries",
234 data_type: Struct(
235 [
236 Field {
237 name: "key",
238 data_type: Utf8,
239 },
240 Field {
241 name: "value",
242 data_type: Int32,
243 nullable: true,
244 },
245 ],
246 ),
247 },
248 false,
249 )
250 "#]]
251 .assert_debug_eq(&arrow_map_type);
252 let rw_map_type_new = UdfArrowConvert::default()
253 .from_field(&arrow_schema::Field::new(
254 "map",
255 arrow_map_type.clone(),
256 true,
257 ))
258 .unwrap();
259 assert_eq!(rw_map_type, rw_map_type_new);
260 let arrow = UdfArrowConvert::default()
261 .map_to_arrow(&arrow_map_type, &rw_array)
262 .unwrap();
263 expect_test::expect![[r#"
264 MapArray
265 [
266 StructArray
267 -- validity:
268 [
269 valid,
270 valid,
271 valid,
272 ]
273 [
274 -- child 0: "key" (Utf8)
275 StringArray
276 [
277 "a",
278 "b",
279 "c",
280 ]
281 -- child 1: "value" (Int32)
282 PrimitiveArray<Int32>
283 [
284 1,
285 2,
286 3,
287 ]
288 ],
289 null,
290 StructArray
291 -- validity:
292 [
293 valid,
294 valid,
295 ]
296 [
297 -- child 0: "key" (Utf8)
298 StringArray
299 [
300 "a",
301 "c",
302 ]
303 -- child 1: "value" (Int32)
304 PrimitiveArray<Int32>
305 [
306 1,
307 3,
308 ]
309 ],
310 ]"#]]
311 .assert_eq(
312 &format!("{:#?}", arrow)
313 .lines()
314 .map(|s| s.trim_end())
315 .collect::<Vec<_>>()
316 .join("\n"),
317 );
318
319 let rw_array_new = UdfArrowConvert::default()
320 .from_map_array(arrow.as_any().downcast_ref().unwrap())
321 .unwrap();
322 assert_eq!(&rw_array, rw_array_new.as_map());
323 }
324}