risingwave_common/array/arrow/
arrow_udf.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This is for arrow dependency named `arrow-xxx` such as `arrow-array` in the cargo workspace.
16//!
17//! This should the default arrow version to be used in our system.
18//!
19//! The corresponding version of arrow is currently used by `udf` and `iceberg` sink.
20
21use std::sync::Arc;
22
23pub use super::arrow_54::{
24    FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25};
26use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray};
27
28/// Arrow conversion for UDF.
29#[derive(Default, Debug)]
30pub struct UdfArrowConvert {
31    /// Whether the UDF talks in legacy mode.
32    ///
33    /// If true, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types.
34    /// Otherwise, they are mapped to Arrow extension types.
35    /// See <https://github.com/risingwavelabs/arrow-udf/tree/main#extension-types>.
36    pub legacy: bool,
37}
38
39impl ToArrow for UdfArrowConvert {
40    fn decimal_to_arrow(
41        &self,
42        _data_type: &arrow_schema::DataType,
43        array: &DecimalArray,
44    ) -> Result<arrow_array::ArrayRef, ArrayError> {
45        if self.legacy {
46            // Decimal values are stored as ASCII text representation in a large binary array.
47            Ok(Arc::new(arrow_array::LargeBinaryArray::from(array)))
48        } else {
49            Ok(Arc::new(arrow_array::StringArray::from(array)))
50        }
51    }
52
53    fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
54        if self.legacy {
55            // JSON values are stored as text representation in a large string array.
56            Ok(Arc::new(arrow_array::LargeStringArray::from(array)))
57        } else {
58            Ok(Arc::new(arrow_array::StringArray::from(array)))
59        }
60    }
61
62    fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
63        if self.legacy {
64            arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true)
65        } else {
66            arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
67                .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into())
68        }
69    }
70
71    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
72        if self.legacy {
73            arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true)
74        } else {
75            arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
76                .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into())
77        }
78    }
79}
80
81impl FromArrow for UdfArrowConvert {
82    fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
83        if self.legacy {
84            Ok(DataType::Jsonb)
85        } else {
86            Ok(DataType::Varchar)
87        }
88    }
89
90    fn from_large_binary(&self) -> Result<DataType, ArrayError> {
91        if self.legacy {
92            Ok(DataType::Decimal)
93        } else {
94            Ok(DataType::Bytea)
95        }
96    }
97
98    fn from_large_utf8_array(
99        &self,
100        array: &arrow_array::LargeStringArray,
101    ) -> Result<ArrayImpl, ArrayError> {
102        if self.legacy {
103            Ok(ArrayImpl::Jsonb(array.try_into()?))
104        } else {
105            Ok(ArrayImpl::Utf8(array.into()))
106        }
107    }
108
109    fn from_large_binary_array(
110        &self,
111        array: &arrow_array::LargeBinaryArray,
112    ) -> Result<ArrayImpl, ArrayError> {
113        if self.legacy {
114            Ok(ArrayImpl::Decimal(array.try_into()?))
115        } else {
116            Ok(ArrayImpl::Bytea(array.into()))
117        }
118    }
119}
120
121#[cfg(test)]
122mod tests {
123
124    use super::*;
125    use crate::array::*;
126
127    #[test]
128    fn struct_array() {
129        // Empty array - risingwave to arrow conversion.
130        let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0));
131        assert_eq!(
132            UdfArrowConvert::default()
133                .struct_to_arrow(
134                    &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
135                    &test_arr
136                )
137                .unwrap()
138                .len(),
139            0
140        );
141
142        // Empty array - arrow to risingwave conversion.
143        let test_arr_2 = arrow_array::StructArray::from(vec![]);
144        assert_eq!(
145            UdfArrowConvert::default()
146                .from_struct_array(&test_arr_2)
147                .unwrap()
148                .len(),
149            0
150        );
151
152        // Struct array with primitive types. arrow to risingwave conversion.
153        let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![
154            (
155                "a",
156                Arc::new(arrow_array::BooleanArray::from(vec![
157                    Some(false),
158                    Some(false),
159                    Some(true),
160                    None,
161                ])) as arrow_array::ArrayRef,
162            ),
163            (
164                "b",
165                Arc::new(arrow_array::Int32Array::from(vec![
166                    Some(42),
167                    Some(28),
168                    Some(19),
169                    None,
170                ])) as arrow_array::ArrayRef,
171            ),
172        ])
173        .unwrap();
174        let actual_risingwave_struct_array = UdfArrowConvert::default()
175            .from_struct_array(&test_arrow_struct_array)
176            .unwrap()
177            .into_struct();
178        let expected_risingwave_struct_array = StructArray::new(
179            StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]),
180            vec![
181                BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(),
182                I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(),
183            ],
184            [true, true, true, true].into_iter().collect(),
185        );
186        assert_eq!(
187            expected_risingwave_struct_array,
188            actual_risingwave_struct_array
189        );
190    }
191
192    #[test]
193    fn list() {
194        let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]);
195        let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true);
196        let arrow = UdfArrowConvert::default()
197            .list_to_arrow(&data_type, &array)
198            .unwrap();
199        let rw_array = UdfArrowConvert::default()
200            .from_list_array(arrow.as_any().downcast_ref().unwrap())
201            .unwrap();
202        assert_eq!(rw_array.as_list(), &array);
203    }
204
205    #[test]
206    fn map() {
207        let map_type = MapType::from_kv(DataType::Varchar, DataType::Int32);
208        let rw_map_type = DataType::Map(map_type.clone());
209        let mut builder = MapArrayBuilder::with_type(3, rw_map_type.clone());
210        builder.append_owned(Some(
211            MapValue::try_from_kv(
212                ListValue::from_str("{a,b,c}", &DataType::List(Box::new(DataType::Varchar)))
213                    .unwrap(),
214                ListValue::from_str("{1,2,3}", &DataType::List(Box::new(DataType::Int32))).unwrap(),
215            )
216            .unwrap(),
217        ));
218        builder.append_owned(None);
219        builder.append_owned(Some(
220            MapValue::try_from_kv(
221                ListValue::from_str("{a,c}", &DataType::List(Box::new(DataType::Varchar))).unwrap(),
222                ListValue::from_str("{1,3}", &DataType::List(Box::new(DataType::Int32))).unwrap(),
223            )
224            .unwrap(),
225        ));
226        let rw_array = builder.finish();
227
228        let arrow_map_type = UdfArrowConvert::default()
229            .map_type_to_arrow(&map_type)
230            .unwrap();
231        expect_test::expect![[r#"
232            Map(
233                Field {
234                    name: "entries",
235                    data_type: Struct(
236                        [
237                            Field {
238                                name: "key",
239                                data_type: Utf8,
240                                nullable: false,
241                                dict_id: 0,
242                                dict_is_ordered: false,
243                                metadata: {},
244                            },
245                            Field {
246                                name: "value",
247                                data_type: Int32,
248                                nullable: true,
249                                dict_id: 0,
250                                dict_is_ordered: false,
251                                metadata: {},
252                            },
253                        ],
254                    ),
255                    nullable: false,
256                    dict_id: 0,
257                    dict_is_ordered: false,
258                    metadata: {},
259                },
260                false,
261            )
262        "#]]
263        .assert_debug_eq(&arrow_map_type);
264        let rw_map_type_new = UdfArrowConvert::default()
265            .from_field(&arrow_schema::Field::new(
266                "map",
267                arrow_map_type.clone(),
268                true,
269            ))
270            .unwrap();
271        assert_eq!(rw_map_type, rw_map_type_new);
272        let arrow = UdfArrowConvert::default()
273            .map_to_arrow(&arrow_map_type, &rw_array)
274            .unwrap();
275        expect_test::expect![[r#"
276            MapArray
277            [
278              StructArray
279            -- validity:
280            [
281              valid,
282              valid,
283              valid,
284            ]
285            [
286            -- child 0: "key" (Utf8)
287            StringArray
288            [
289              "a",
290              "b",
291              "c",
292            ]
293            -- child 1: "value" (Int32)
294            PrimitiveArray<Int32>
295            [
296              1,
297              2,
298              3,
299            ]
300            ],
301              null,
302              StructArray
303            -- validity:
304            [
305              valid,
306              valid,
307            ]
308            [
309            -- child 0: "key" (Utf8)
310            StringArray
311            [
312              "a",
313              "c",
314            ]
315            -- child 1: "value" (Int32)
316            PrimitiveArray<Int32>
317            [
318              1,
319              3,
320            ]
321            ],
322            ]"#]]
323        .assert_eq(
324            &format!("{:#?}", arrow)
325                .lines()
326                .map(|s| s.trim_end())
327                .collect::<Vec<_>>()
328                .join("\n"),
329        );
330
331        let rw_array_new = UdfArrowConvert::default()
332            .from_map_array(arrow.as_any().downcast_ref().unwrap())
333            .unwrap();
334        assert_eq!(&rw_array, rw_array_new.as_map());
335    }
336}