risingwave_common/array/arrow/
arrow_udf.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This is for arrow dependency named `arrow-xxx` such as `arrow-array` in the cargo workspace.
16//!
17//! This should the default arrow version to be used in our system.
18//!
19//! The corresponding version of arrow is currently used by `udf` and `iceberg` sink.
20
21use std::sync::Arc;
22
23pub use super::arrow_54::{
24    FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25};
26use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray};
27
28/// Arrow conversion for UDF.
29#[derive(Default, Debug)]
30pub struct UdfArrowConvert {
31    /// Whether the UDF talks in legacy mode.
32    ///
33    /// If true, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types.
34    /// Otherwise, they are mapped to Arrow extension types.
35    /// See <https://github.com/risingwavelabs/arrow-udf/tree/main#extension-types>.
36    pub legacy: bool,
37}
38
39impl ToArrow for UdfArrowConvert {
40    fn decimal_to_arrow(
41        &self,
42        _data_type: &arrow_schema::DataType,
43        array: &DecimalArray,
44    ) -> Result<arrow_array::ArrayRef, ArrayError> {
45        if self.legacy {
46            // Decimal values are stored as ASCII text representation in a large binary array.
47            Ok(Arc::new(arrow_array::LargeBinaryArray::from(array)))
48        } else {
49            Ok(Arc::new(arrow_array::StringArray::from(array)))
50        }
51    }
52
53    fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
54        if self.legacy {
55            // JSON values are stored as text representation in a large string array.
56            Ok(Arc::new(arrow_array::LargeStringArray::from(array)))
57        } else {
58            Ok(Arc::new(arrow_array::StringArray::from(array)))
59        }
60    }
61
62    fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
63        if self.legacy {
64            arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true)
65        } else {
66            arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
67                .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into())
68        }
69    }
70
71    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
72        if self.legacy {
73            arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true)
74        } else {
75            arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
76                .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into())
77        }
78    }
79}
80
81impl FromArrow for UdfArrowConvert {
82    fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
83        if self.legacy {
84            Ok(DataType::Jsonb)
85        } else {
86            Ok(DataType::Varchar)
87        }
88    }
89
90    fn from_large_binary(&self) -> Result<DataType, ArrayError> {
91        if self.legacy {
92            Ok(DataType::Decimal)
93        } else {
94            Ok(DataType::Bytea)
95        }
96    }
97
98    fn from_large_utf8_array(
99        &self,
100        array: &arrow_array::LargeStringArray,
101    ) -> Result<ArrayImpl, ArrayError> {
102        if self.legacy {
103            Ok(ArrayImpl::Jsonb(array.try_into()?))
104        } else {
105            Ok(ArrayImpl::Utf8(array.into()))
106        }
107    }
108
109    fn from_large_binary_array(
110        &self,
111        array: &arrow_array::LargeBinaryArray,
112    ) -> Result<ArrayImpl, ArrayError> {
113        if self.legacy {
114            Ok(ArrayImpl::Decimal(array.try_into()?))
115        } else {
116            Ok(ArrayImpl::Bytea(array.into()))
117        }
118    }
119}
120
121#[cfg(test)]
122mod tests {
123
124    use super::*;
125    use crate::array::*;
126
127    #[test]
128    fn struct_array() {
129        // Empty array - risingwave to arrow conversion.
130        let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0));
131        assert_eq!(
132            UdfArrowConvert::default()
133                .struct_to_arrow(
134                    &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
135                    &test_arr
136                )
137                .unwrap()
138                .len(),
139            0
140        );
141
142        // Empty array - arrow to risingwave conversion.
143        let test_arr_2 = arrow_array::StructArray::from(vec![]);
144        assert_eq!(
145            UdfArrowConvert::default()
146                .from_struct_array(&test_arr_2)
147                .unwrap()
148                .len(),
149            0
150        );
151
152        // Struct array with primitive types. arrow to risingwave conversion.
153        let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![
154            (
155                "a",
156                Arc::new(arrow_array::BooleanArray::from(vec![
157                    Some(false),
158                    Some(false),
159                    Some(true),
160                    None,
161                ])) as arrow_array::ArrayRef,
162            ),
163            (
164                "b",
165                Arc::new(arrow_array::Int32Array::from(vec![
166                    Some(42),
167                    Some(28),
168                    Some(19),
169                    None,
170                ])) as arrow_array::ArrayRef,
171            ),
172        ])
173        .unwrap();
174        let actual_risingwave_struct_array = UdfArrowConvert::default()
175            .from_struct_array(&test_arrow_struct_array)
176            .unwrap()
177            .into_struct();
178        let expected_risingwave_struct_array = StructArray::new(
179            StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]),
180            vec![
181                BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(),
182                I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(),
183            ],
184            [true, true, true, true].into_iter().collect(),
185        );
186        assert_eq!(
187            expected_risingwave_struct_array,
188            actual_risingwave_struct_array
189        );
190    }
191
192    #[test]
193    fn list() {
194        let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]);
195        let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true);
196        let arrow = UdfArrowConvert::default()
197            .list_to_arrow(&data_type, &array)
198            .unwrap();
199        let rw_array = UdfArrowConvert::default()
200            .from_list_array(arrow.as_any().downcast_ref().unwrap())
201            .unwrap();
202        assert_eq!(rw_array.as_list(), &array);
203    }
204
205    #[test]
206    fn map() {
207        let map_type = MapType::from_kv(DataType::Varchar, DataType::Int32);
208        let rw_map_type = DataType::Map(map_type.clone());
209        let mut builder = MapArrayBuilder::with_type(3, rw_map_type.clone());
210        builder.append_owned(Some(
211            MapValue::try_from_kv(
212                ListValue::from_str("{a,b,c}", &DataType::Varchar.list()).unwrap(),
213                ListValue::from_str("{1,2,3}", &DataType::Int32.list()).unwrap(),
214            )
215            .unwrap(),
216        ));
217        builder.append_owned(None);
218        builder.append_owned(Some(
219            MapValue::try_from_kv(
220                ListValue::from_str("{a,c}", &DataType::Varchar.list()).unwrap(),
221                ListValue::from_str("{1,3}", &DataType::Int32.list()).unwrap(),
222            )
223            .unwrap(),
224        ));
225        let rw_array = builder.finish();
226
227        let arrow_map_type = UdfArrowConvert::default()
228            .map_type_to_arrow(&map_type)
229            .unwrap();
230        expect_test::expect![[r#"
231            Map(
232                Field {
233                    name: "entries",
234                    data_type: Struct(
235                        [
236                            Field {
237                                name: "key",
238                                data_type: Utf8,
239                                nullable: false,
240                                dict_id: 0,
241                                dict_is_ordered: false,
242                                metadata: {},
243                            },
244                            Field {
245                                name: "value",
246                                data_type: Int32,
247                                nullable: true,
248                                dict_id: 0,
249                                dict_is_ordered: false,
250                                metadata: {},
251                            },
252                        ],
253                    ),
254                    nullable: false,
255                    dict_id: 0,
256                    dict_is_ordered: false,
257                    metadata: {},
258                },
259                false,
260            )
261        "#]]
262        .assert_debug_eq(&arrow_map_type);
263        let rw_map_type_new = UdfArrowConvert::default()
264            .from_field(&arrow_schema::Field::new(
265                "map",
266                arrow_map_type.clone(),
267                true,
268            ))
269            .unwrap();
270        assert_eq!(rw_map_type, rw_map_type_new);
271        let arrow = UdfArrowConvert::default()
272            .map_to_arrow(&arrow_map_type, &rw_array)
273            .unwrap();
274        expect_test::expect![[r#"
275            MapArray
276            [
277              StructArray
278            -- validity:
279            [
280              valid,
281              valid,
282              valid,
283            ]
284            [
285            -- child 0: "key" (Utf8)
286            StringArray
287            [
288              "a",
289              "b",
290              "c",
291            ]
292            -- child 1: "value" (Int32)
293            PrimitiveArray<Int32>
294            [
295              1,
296              2,
297              3,
298            ]
299            ],
300              null,
301              StructArray
302            -- validity:
303            [
304              valid,
305              valid,
306            ]
307            [
308            -- child 0: "key" (Utf8)
309            StringArray
310            [
311              "a",
312              "c",
313            ]
314            -- child 1: "value" (Int32)
315            PrimitiveArray<Int32>
316            [
317              1,
318              3,
319            ]
320            ],
321            ]"#]]
322        .assert_eq(
323            &format!("{:#?}", arrow)
324                .lines()
325                .map(|s| s.trim_end())
326                .collect::<Vec<_>>()
327                .join("\n"),
328        );
329
330        let rw_array_new = UdfArrowConvert::default()
331            .from_map_array(arrow.as_any().downcast_ref().unwrap())
332            .unwrap();
333        assert_eq!(&rw_array, rw_array_new.as_map());
334    }
335}