risingwave_common/array/arrow/
arrow_udf.rs1use std::sync::Arc;
22
23pub use super::arrow_54::{
24 FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25};
26use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray};
27
28#[derive(Default, Debug)]
30pub struct UdfArrowConvert {
31 pub legacy: bool,
37}
38
39impl ToArrow for UdfArrowConvert {
40 fn decimal_to_arrow(
41 &self,
42 _data_type: &arrow_schema::DataType,
43 array: &DecimalArray,
44 ) -> Result<arrow_array::ArrayRef, ArrayError> {
45 if self.legacy {
46 Ok(Arc::new(arrow_array::LargeBinaryArray::from(array)))
48 } else {
49 Ok(Arc::new(arrow_array::StringArray::from(array)))
50 }
51 }
52
53 fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
54 if self.legacy {
55 Ok(Arc::new(arrow_array::LargeStringArray::from(array)))
57 } else {
58 Ok(Arc::new(arrow_array::StringArray::from(array)))
59 }
60 }
61
62 fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
63 if self.legacy {
64 arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true)
65 } else {
66 arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
67 .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into())
68 }
69 }
70
71 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
72 if self.legacy {
73 arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true)
74 } else {
75 arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
76 .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into())
77 }
78 }
79}
80
81impl FromArrow for UdfArrowConvert {
82 fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
83 if self.legacy {
84 Ok(DataType::Jsonb)
85 } else {
86 Ok(DataType::Varchar)
87 }
88 }
89
90 fn from_large_binary(&self) -> Result<DataType, ArrayError> {
91 if self.legacy {
92 Ok(DataType::Decimal)
93 } else {
94 Ok(DataType::Bytea)
95 }
96 }
97
98 fn from_large_utf8_array(
99 &self,
100 array: &arrow_array::LargeStringArray,
101 ) -> Result<ArrayImpl, ArrayError> {
102 if self.legacy {
103 Ok(ArrayImpl::Jsonb(array.try_into()?))
104 } else {
105 Ok(ArrayImpl::Utf8(array.into()))
106 }
107 }
108
109 fn from_large_binary_array(
110 &self,
111 array: &arrow_array::LargeBinaryArray,
112 ) -> Result<ArrayImpl, ArrayError> {
113 if self.legacy {
114 Ok(ArrayImpl::Decimal(array.try_into()?))
115 } else {
116 Ok(ArrayImpl::Bytea(array.into()))
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123
124 use super::*;
125 use crate::array::*;
126
127 #[test]
128 fn struct_array() {
129 let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0));
131 assert_eq!(
132 UdfArrowConvert::default()
133 .struct_to_arrow(
134 &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
135 &test_arr
136 )
137 .unwrap()
138 .len(),
139 0
140 );
141
142 let test_arr_2 = arrow_array::StructArray::from(vec![]);
144 assert_eq!(
145 UdfArrowConvert::default()
146 .from_struct_array(&test_arr_2)
147 .unwrap()
148 .len(),
149 0
150 );
151
152 let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![
154 (
155 "a",
156 Arc::new(arrow_array::BooleanArray::from(vec![
157 Some(false),
158 Some(false),
159 Some(true),
160 None,
161 ])) as arrow_array::ArrayRef,
162 ),
163 (
164 "b",
165 Arc::new(arrow_array::Int32Array::from(vec![
166 Some(42),
167 Some(28),
168 Some(19),
169 None,
170 ])) as arrow_array::ArrayRef,
171 ),
172 ])
173 .unwrap();
174 let actual_risingwave_struct_array = UdfArrowConvert::default()
175 .from_struct_array(&test_arrow_struct_array)
176 .unwrap()
177 .into_struct();
178 let expected_risingwave_struct_array = StructArray::new(
179 StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]),
180 vec![
181 BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(),
182 I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(),
183 ],
184 [true, true, true, true].into_iter().collect(),
185 );
186 assert_eq!(
187 expected_risingwave_struct_array,
188 actual_risingwave_struct_array
189 );
190 }
191
192 #[test]
193 fn list() {
194 let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]);
195 let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true);
196 let arrow = UdfArrowConvert::default()
197 .list_to_arrow(&data_type, &array)
198 .unwrap();
199 let rw_array = UdfArrowConvert::default()
200 .from_list_array(arrow.as_any().downcast_ref().unwrap())
201 .unwrap();
202 assert_eq!(rw_array.as_list(), &array);
203 }
204
205 #[test]
206 fn map() {
207 let map_type = MapType::from_kv(DataType::Varchar, DataType::Int32);
208 let rw_map_type = DataType::Map(map_type.clone());
209 let mut builder = MapArrayBuilder::with_type(3, rw_map_type.clone());
210 builder.append_owned(Some(
211 MapValue::try_from_kv(
212 ListValue::from_str("{a,b,c}", &DataType::List(Box::new(DataType::Varchar)))
213 .unwrap(),
214 ListValue::from_str("{1,2,3}", &DataType::List(Box::new(DataType::Int32))).unwrap(),
215 )
216 .unwrap(),
217 ));
218 builder.append_owned(None);
219 builder.append_owned(Some(
220 MapValue::try_from_kv(
221 ListValue::from_str("{a,c}", &DataType::List(Box::new(DataType::Varchar))).unwrap(),
222 ListValue::from_str("{1,3}", &DataType::List(Box::new(DataType::Int32))).unwrap(),
223 )
224 .unwrap(),
225 ));
226 let rw_array = builder.finish();
227
228 let arrow_map_type = UdfArrowConvert::default()
229 .map_type_to_arrow(&map_type)
230 .unwrap();
231 expect_test::expect![[r#"
232 Map(
233 Field {
234 name: "entries",
235 data_type: Struct(
236 [
237 Field {
238 name: "key",
239 data_type: Utf8,
240 nullable: false,
241 dict_id: 0,
242 dict_is_ordered: false,
243 metadata: {},
244 },
245 Field {
246 name: "value",
247 data_type: Int32,
248 nullable: true,
249 dict_id: 0,
250 dict_is_ordered: false,
251 metadata: {},
252 },
253 ],
254 ),
255 nullable: false,
256 dict_id: 0,
257 dict_is_ordered: false,
258 metadata: {},
259 },
260 false,
261 )
262 "#]]
263 .assert_debug_eq(&arrow_map_type);
264 let rw_map_type_new = UdfArrowConvert::default()
265 .from_field(&arrow_schema::Field::new(
266 "map",
267 arrow_map_type.clone(),
268 true,
269 ))
270 .unwrap();
271 assert_eq!(rw_map_type, rw_map_type_new);
272 let arrow = UdfArrowConvert::default()
273 .map_to_arrow(&arrow_map_type, &rw_array)
274 .unwrap();
275 expect_test::expect![[r#"
276 MapArray
277 [
278 StructArray
279 -- validity:
280 [
281 valid,
282 valid,
283 valid,
284 ]
285 [
286 -- child 0: "key" (Utf8)
287 StringArray
288 [
289 "a",
290 "b",
291 "c",
292 ]
293 -- child 1: "value" (Int32)
294 PrimitiveArray<Int32>
295 [
296 1,
297 2,
298 3,
299 ]
300 ],
301 null,
302 StructArray
303 -- validity:
304 [
305 valid,
306 valid,
307 ]
308 [
309 -- child 0: "key" (Utf8)
310 StringArray
311 [
312 "a",
313 "c",
314 ]
315 -- child 1: "value" (Int32)
316 PrimitiveArray<Int32>
317 [
318 1,
319 3,
320 ]
321 ],
322 ]"#]]
323 .assert_eq(
324 &format!("{:#?}", arrow)
325 .lines()
326 .map(|s| s.trim_end())
327 .collect::<Vec<_>>()
328 .join("\n"),
329 );
330
331 let rw_array_new = UdfArrowConvert::default()
332 .from_map_array(arrow.as_any().downcast_ref().unwrap())
333 .unwrap();
334 assert_eq!(&rw_array, rw_array_new.as_map());
335 }
336}