risingwave_common/array/arrow/
arrow_udf.rs1use std::sync::Arc;
22
23pub use super::arrow_54::{
24 FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25};
26use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray};
27
28#[derive(Default, Debug)]
30pub struct UdfArrowConvert {
31 pub legacy: bool,
37}
38
39impl ToArrow for UdfArrowConvert {
40 fn decimal_to_arrow(
41 &self,
42 _data_type: &arrow_schema::DataType,
43 array: &DecimalArray,
44 ) -> Result<arrow_array::ArrayRef, ArrayError> {
45 if self.legacy {
46 Ok(Arc::new(arrow_array::LargeBinaryArray::from(array)))
48 } else {
49 Ok(Arc::new(arrow_array::StringArray::from(array)))
50 }
51 }
52
53 fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
54 if self.legacy {
55 Ok(Arc::new(arrow_array::LargeStringArray::from(array)))
57 } else {
58 Ok(Arc::new(arrow_array::StringArray::from(array)))
59 }
60 }
61
62 fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
63 if self.legacy {
64 arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true)
65 } else {
66 arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
67 .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into())
68 }
69 }
70
71 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
72 if self.legacy {
73 arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true)
74 } else {
75 arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
76 .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into())
77 }
78 }
79}
80
81impl FromArrow for UdfArrowConvert {
82 fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
83 if self.legacy {
84 Ok(DataType::Jsonb)
85 } else {
86 Ok(DataType::Varchar)
87 }
88 }
89
90 fn from_large_binary(&self) -> Result<DataType, ArrayError> {
91 if self.legacy {
92 Ok(DataType::Decimal)
93 } else {
94 Ok(DataType::Bytea)
95 }
96 }
97
98 fn from_large_utf8_array(
99 &self,
100 array: &arrow_array::LargeStringArray,
101 ) -> Result<ArrayImpl, ArrayError> {
102 if self.legacy {
103 Ok(ArrayImpl::Jsonb(array.try_into()?))
104 } else {
105 Ok(ArrayImpl::Utf8(array.into()))
106 }
107 }
108
109 fn from_large_binary_array(
110 &self,
111 array: &arrow_array::LargeBinaryArray,
112 ) -> Result<ArrayImpl, ArrayError> {
113 if self.legacy {
114 Ok(ArrayImpl::Decimal(array.try_into()?))
115 } else {
116 Ok(ArrayImpl::Bytea(array.into()))
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123
124 use super::*;
125 use crate::array::*;
126
127 #[test]
128 fn struct_array() {
129 let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0));
131 assert_eq!(
132 UdfArrowConvert::default()
133 .struct_to_arrow(
134 &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
135 &test_arr
136 )
137 .unwrap()
138 .len(),
139 0
140 );
141
142 let test_arr_2 = arrow_array::StructArray::from(vec![]);
144 assert_eq!(
145 UdfArrowConvert::default()
146 .from_struct_array(&test_arr_2)
147 .unwrap()
148 .len(),
149 0
150 );
151
152 let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![
154 (
155 "a",
156 Arc::new(arrow_array::BooleanArray::from(vec![
157 Some(false),
158 Some(false),
159 Some(true),
160 None,
161 ])) as arrow_array::ArrayRef,
162 ),
163 (
164 "b",
165 Arc::new(arrow_array::Int32Array::from(vec![
166 Some(42),
167 Some(28),
168 Some(19),
169 None,
170 ])) as arrow_array::ArrayRef,
171 ),
172 ])
173 .unwrap();
174 let actual_risingwave_struct_array = UdfArrowConvert::default()
175 .from_struct_array(&test_arrow_struct_array)
176 .unwrap()
177 .into_struct();
178 let expected_risingwave_struct_array = StructArray::new(
179 StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]),
180 vec![
181 BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(),
182 I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(),
183 ],
184 [true, true, true, true].into_iter().collect(),
185 );
186 assert_eq!(
187 expected_risingwave_struct_array,
188 actual_risingwave_struct_array
189 );
190 }
191
192 #[test]
193 fn list() {
194 let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]);
195 let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true);
196 let arrow = UdfArrowConvert::default()
197 .list_to_arrow(&data_type, &array)
198 .unwrap();
199 let rw_array = UdfArrowConvert::default()
200 .from_list_array(arrow.as_any().downcast_ref().unwrap())
201 .unwrap();
202 assert_eq!(rw_array.as_list(), &array);
203 }
204
205 #[test]
206 fn map() {
207 let map_type = MapType::from_kv(DataType::Varchar, DataType::Int32);
208 let rw_map_type = DataType::Map(map_type.clone());
209 let mut builder = MapArrayBuilder::with_type(3, rw_map_type.clone());
210 builder.append_owned(Some(
211 MapValue::try_from_kv(
212 ListValue::from_str("{a,b,c}", &DataType::Varchar.list()).unwrap(),
213 ListValue::from_str("{1,2,3}", &DataType::Int32.list()).unwrap(),
214 )
215 .unwrap(),
216 ));
217 builder.append_owned(None);
218 builder.append_owned(Some(
219 MapValue::try_from_kv(
220 ListValue::from_str("{a,c}", &DataType::Varchar.list()).unwrap(),
221 ListValue::from_str("{1,3}", &DataType::Int32.list()).unwrap(),
222 )
223 .unwrap(),
224 ));
225 let rw_array = builder.finish();
226
227 let arrow_map_type = UdfArrowConvert::default()
228 .map_type_to_arrow(&map_type)
229 .unwrap();
230 expect_test::expect![[r#"
231 Map(
232 Field {
233 name: "entries",
234 data_type: Struct(
235 [
236 Field {
237 name: "key",
238 data_type: Utf8,
239 nullable: false,
240 dict_id: 0,
241 dict_is_ordered: false,
242 metadata: {},
243 },
244 Field {
245 name: "value",
246 data_type: Int32,
247 nullable: true,
248 dict_id: 0,
249 dict_is_ordered: false,
250 metadata: {},
251 },
252 ],
253 ),
254 nullable: false,
255 dict_id: 0,
256 dict_is_ordered: false,
257 metadata: {},
258 },
259 false,
260 )
261 "#]]
262 .assert_debug_eq(&arrow_map_type);
263 let rw_map_type_new = UdfArrowConvert::default()
264 .from_field(&arrow_schema::Field::new(
265 "map",
266 arrow_map_type.clone(),
267 true,
268 ))
269 .unwrap();
270 assert_eq!(rw_map_type, rw_map_type_new);
271 let arrow = UdfArrowConvert::default()
272 .map_to_arrow(&arrow_map_type, &rw_array)
273 .unwrap();
274 expect_test::expect![[r#"
275 MapArray
276 [
277 StructArray
278 -- validity:
279 [
280 valid,
281 valid,
282 valid,
283 ]
284 [
285 -- child 0: "key" (Utf8)
286 StringArray
287 [
288 "a",
289 "b",
290 "c",
291 ]
292 -- child 1: "value" (Int32)
293 PrimitiveArray<Int32>
294 [
295 1,
296 2,
297 3,
298 ]
299 ],
300 null,
301 StructArray
302 -- validity:
303 [
304 valid,
305 valid,
306 ]
307 [
308 -- child 0: "key" (Utf8)
309 StringArray
310 [
311 "a",
312 "c",
313 ]
314 -- child 1: "value" (Int32)
315 PrimitiveArray<Int32>
316 [
317 1,
318 3,
319 ]
320 ],
321 ]"#]]
322 .assert_eq(
323 &format!("{:#?}", arrow)
324 .lines()
325 .map(|s| s.trim_end())
326 .collect::<Vec<_>>()
327 .join("\n"),
328 );
329
330 let rw_array_new = UdfArrowConvert::default()
331 .from_map_array(arrow.as_any().downcast_ref().unwrap())
332 .unwrap();
333 assert_eq!(&rw_array, rw_array_new.as_map());
334 }
335}