risingwave_common/array/arrow/
arrow_iceberg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::{Div, Mul};
18use std::sync::{Arc, LazyLock};
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_56::{
24    FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25    is_parquet_schema_match_source_schema,
26};
27use crate::array::{
28    Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray, IntervalArray,
29};
30use crate::types::StructType;
31
32pub struct IcebergArrowConvert;
33
34pub const ICEBERG_DECIMAL_PRECISION: u8 = 28;
35pub const ICEBERG_DECIMAL_SCALE: i8 = 10;
36
37impl IcebergArrowConvert {
38    pub fn to_record_batch(
39        &self,
40        schema: arrow_schema::SchemaRef,
41        chunk: &DataChunk,
42    ) -> Result<arrow_array::RecordBatch, ArrayError> {
43        ToArrow::to_record_batch(self, schema, chunk)
44    }
45
46    pub fn chunk_from_record_batch(
47        &self,
48        batch: &arrow_array::RecordBatch,
49    ) -> Result<DataChunk, ArrayError> {
50        FromArrow::from_record_batch(self, batch)
51    }
52
53    pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
54        FromArrow::from_field(self, field)
55    }
56
57    pub fn to_arrow_field(
58        &self,
59        name: &str,
60        data_type: &DataType,
61    ) -> Result<arrow_schema::Field, ArrayError> {
62        ToArrow::to_arrow_field(self, name, data_type)
63    }
64
65    pub fn struct_from_fields(
66        &self,
67        fields: &arrow_schema::Fields,
68    ) -> Result<StructType, ArrayError> {
69        FromArrow::from_fields(self, fields)
70    }
71
72    pub fn to_arrow_array(
73        &self,
74        data_type: &arrow_schema::DataType,
75        array: &ArrayImpl,
76    ) -> Result<arrow_array::ArrayRef, ArrayError> {
77        ToArrow::to_array(self, data_type, array)
78    }
79
80    pub fn array_from_arrow_array(
81        &self,
82        field: &arrow_schema::Field,
83        array: &arrow_array::ArrayRef,
84    ) -> Result<ArrayImpl, ArrayError> {
85        FromArrow::from_array(self, field, array)
86    }
87
88    /// A helper function to convert an Arrow array to RisingWave array without knowing the field.
89    /// It will use the datatype from arrow array to infer the RisingWave data type.
90    ///
91    /// The difference between this function and `array_from_arrow_array` is that `array_from_arrow_array` will try using `ARROW:extension:name` field metadata to determine the RisingWave data type for extension types.
92    pub fn array_from_arrow_array_raw(
93        &self,
94        array: &arrow_array::ArrayRef,
95    ) -> Result<ArrayImpl, ArrayError> {
96        static FIELD_DUMMY: LazyLock<arrow_schema::Field> =
97            LazyLock::new(|| arrow_schema::Field::new("dummy", arrow_schema::DataType::Null, true));
98        FromArrow::from_array(self, &FIELD_DUMMY, array)
99    }
100}
101
102impl ToArrow for IcebergArrowConvert {
103    fn to_arrow_field(
104        &self,
105        name: &str,
106        data_type: &DataType,
107    ) -> Result<arrow_schema::Field, ArrayError> {
108        let data_type = match data_type {
109            DataType::Boolean => self.bool_type_to_arrow(),
110            DataType::Int16 => self.int32_type_to_arrow(),
111            DataType::Int32 => self.int32_type_to_arrow(),
112            DataType::Int64 => self.int64_type_to_arrow(),
113            DataType::Int256 => self.int256_type_to_arrow(),
114            DataType::Float32 => self.float32_type_to_arrow(),
115            DataType::Float64 => self.float64_type_to_arrow(),
116            DataType::Date => self.date_type_to_arrow(),
117            DataType::Time => self.time_type_to_arrow(),
118            DataType::Timestamp => self.timestamp_type_to_arrow(),
119            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
120            DataType::Interval => self.interval_type_to_arrow(),
121            DataType::Varchar => self.varchar_type_to_arrow(),
122            DataType::Bytea => self.bytea_type_to_arrow(),
123            DataType::Serial => self.serial_type_to_arrow(),
124            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
125            DataType::Jsonb => self.varchar_type_to_arrow(),
126            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
127            DataType::List(list) => self.list_type_to_arrow(list)?,
128            DataType::Map(map) => self.map_type_to_arrow(map)?,
129            DataType::Vector(_) => self.vector_type_to_arrow()?,
130        };
131        Ok(arrow_schema::Field::new(name, data_type, true))
132    }
133
134    #[inline]
135    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
136        arrow_schema::DataType::Utf8
137    }
138
139    #[inline]
140    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
141        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
142        let data_type =
143            arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
144        arrow_schema::Field::new(name, data_type, true)
145    }
146
147    fn decimal_to_arrow(
148        &self,
149        data_type: &arrow_schema::DataType,
150        array: &DecimalArray,
151    ) -> Result<arrow_array::ArrayRef, ArrayError> {
152        let (precision, max_scale) = match data_type {
153            arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
154            _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
155        };
156
157        // Convert Decimal to i128:
158        let values: Vec<Option<i128>> = array
159            .iter()
160            .map(|e| {
161                e.and_then(|e| match e {
162                    crate::array::Decimal::Normalized(e) => {
163                        let value = e.mantissa();
164                        let scale = e.scale() as i8;
165                        let diff_scale = abs(max_scale - scale);
166                        let value = match scale {
167                            _ if scale < max_scale => value.mul(10_i128.pow(diff_scale as u32)),
168                            _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
169                            _ => value,
170                        };
171                        Some(value)
172                    }
173                    // For Inf, we replace them with the max/min value within the precision.
174                    crate::array::Decimal::PositiveInf => {
175                        let max_value = 10_i128.pow(precision as u32) - 1;
176                        Some(max_value)
177                    }
178                    crate::array::Decimal::NegativeInf => {
179                        let max_value = 10_i128.pow(precision as u32) - 1;
180                        Some(-max_value)
181                    }
182                    crate::array::Decimal::NaN => None,
183                })
184            })
185            .collect();
186
187        let array = arrow_array::Decimal128Array::from(values)
188            .with_precision_and_scale(precision, max_scale)
189            .map_err(ArrayError::from_arrow)?;
190        Ok(Arc::new(array) as ArrayRef)
191    }
192
193    fn interval_to_arrow(
194        &self,
195        array: &IntervalArray,
196    ) -> Result<arrow_array::ArrayRef, ArrayError> {
197        Ok(Arc::new(arrow_array::StringArray::from(array)))
198    }
199}
200
201impl FromArrow for IcebergArrowConvert {}
202
203/// Iceberg sink with `create_table_if_not_exists` option will use this struct to convert the
204/// iceberg data type to arrow data type.
205///
206/// Specifically, it will add the field id to the arrow field metadata, because iceberg-rust need the field id to be set.
207///
208/// Note: this is different from [`IcebergArrowConvert`], which is used to read from/write to
209/// an _existing_ iceberg table. In that case, we just need to make sure the data is compatible to the existing schema.
210/// But to _create a new table_, we need to meet more requirements of iceberg.
211#[derive(Default)]
212pub struct IcebergCreateTableArrowConvert {
213    next_field_id: RefCell<u32>,
214}
215
216impl IcebergCreateTableArrowConvert {
217    pub fn to_arrow_field(
218        &self,
219        name: &str,
220        data_type: &DataType,
221    ) -> Result<arrow_schema::Field, ArrayError> {
222        ToArrow::to_arrow_field(self, name, data_type)
223    }
224
225    fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
226        *self.next_field_id.borrow_mut() += 1;
227        let field_id = *self.next_field_id.borrow();
228
229        let mut metadata = HashMap::new();
230        // for iceberg-rust
231        metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
232        arrow_field.set_metadata(metadata);
233    }
234}
235
236impl ToArrow for IcebergCreateTableArrowConvert {
237    #[inline]
238    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
239        // To create a iceberg table, we need a decimal type with precision and scale to be set
240        // We choose 28 here
241        // The decimal type finally will be converted to an iceberg decimal type.
242        // Iceberg decimal(P,S)
243        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
244        let data_type =
245            arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
246
247        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
248        self.add_field_id(&mut arrow_field);
249        arrow_field
250    }
251
252    #[inline]
253    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
254        arrow_schema::DataType::Utf8
255    }
256
257    fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
258        let data_type = arrow_schema::DataType::Utf8;
259
260        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
261        self.add_field_id(&mut arrow_field);
262        arrow_field
263    }
264
265    /// Convert RisingWave data type to Arrow data type.
266    ///
267    /// This function returns a `Field` instead of `DataType` because some may be converted to
268    /// extension types which require additional metadata in the field.
269    fn to_arrow_field(
270        &self,
271        name: &str,
272        value: &DataType,
273    ) -> Result<arrow_schema::Field, ArrayError> {
274        let data_type = match value {
275            // using the inline function
276            DataType::Boolean => self.bool_type_to_arrow(),
277            DataType::Int16 => self.int32_type_to_arrow(),
278            DataType::Int32 => self.int32_type_to_arrow(),
279            DataType::Int64 => self.int64_type_to_arrow(),
280            DataType::Int256 => self.varchar_type_to_arrow(),
281            DataType::Float32 => self.float32_type_to_arrow(),
282            DataType::Float64 => self.float64_type_to_arrow(),
283            DataType::Date => self.date_type_to_arrow(),
284            DataType::Time => self.time_type_to_arrow(),
285            DataType::Timestamp => self.timestamp_type_to_arrow(),
286            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
287            DataType::Interval => self.interval_type_to_arrow(),
288            DataType::Varchar => self.varchar_type_to_arrow(),
289            DataType::Bytea => self.bytea_type_to_arrow(),
290            DataType::Serial => self.serial_type_to_arrow(),
291            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
292            DataType::Jsonb => self.varchar_type_to_arrow(),
293            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
294            DataType::List(list) => self.list_type_to_arrow(list)?,
295            DataType::Map(map) => self.map_type_to_arrow(map)?,
296            DataType::Vector(_) => self.vector_type_to_arrow()?,
297        };
298
299        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
300        self.add_field_id(&mut arrow_field);
301        Ok(arrow_field)
302    }
303}
304
305#[cfg(test)]
306mod test {
307    use std::sync::Arc;
308
309    use super::arrow_array::{ArrayRef, Decimal128Array};
310    use super::arrow_schema::DataType;
311    use super::*;
312    use crate::array::{Decimal, DecimalArray};
313
314    #[test]
315    fn decimal() {
316        let array = DecimalArray::from_iter([
317            None,
318            Some(Decimal::NaN),
319            Some(Decimal::PositiveInf),
320            Some(Decimal::NegativeInf),
321            Some(Decimal::Normalized("123.4".parse().unwrap())),
322            Some(Decimal::Normalized("123.456".parse().unwrap())),
323        ]);
324        let ty = DataType::Decimal128(6, 3);
325        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
326        let expect_array = Arc::new(
327            Decimal128Array::from(vec![
328                None,
329                None,
330                Some(999999),
331                Some(-999999),
332                Some(123400),
333                Some(123456),
334            ])
335            .with_data_type(ty),
336        ) as ArrayRef;
337        assert_eq!(&arrow_array, &expect_array);
338    }
339
340    #[test]
341    fn decimal_with_large_scale() {
342        let array = DecimalArray::from_iter([
343            None,
344            Some(Decimal::NaN),
345            Some(Decimal::PositiveInf),
346            Some(Decimal::NegativeInf),
347            Some(Decimal::Normalized("123.4".parse().unwrap())),
348            Some(Decimal::Normalized("123.456".parse().unwrap())),
349        ]);
350        let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
351        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
352        let expect_array = Arc::new(
353            Decimal128Array::from(vec![
354                None,
355                None,
356                Some(9999999999999999999999999999),
357                Some(-9999999999999999999999999999),
358                Some(1234000000000),
359                Some(1234560000000),
360            ])
361            .with_data_type(ty),
362        ) as ArrayRef;
363        assert_eq!(&arrow_array, &expect_array);
364    }
365}