risingwave_common/array/arrow/
arrow_iceberg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::{Div, Mul};
18use std::sync::Arc;
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_54::{
24    FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25    is_parquet_schema_match_source_schema,
26};
27use crate::array::{
28    Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray, IntervalArray,
29    VECTOR_ITEM_TYPE,
30};
31use crate::types::StructType;
32
33pub struct IcebergArrowConvert;
34
35pub const ICEBERG_DECIMAL_PRECISION: u8 = 28;
36pub const ICEBERG_DECIMAL_SCALE: i8 = 10;
37
38impl IcebergArrowConvert {
39    pub fn to_record_batch(
40        &self,
41        schema: arrow_schema::SchemaRef,
42        chunk: &DataChunk,
43    ) -> Result<arrow_array::RecordBatch, ArrayError> {
44        ToArrow::to_record_batch(self, schema, chunk)
45    }
46
47    pub fn chunk_from_record_batch(
48        &self,
49        batch: &arrow_array::RecordBatch,
50    ) -> Result<DataChunk, ArrayError> {
51        FromArrow::from_record_batch(self, batch)
52    }
53
54    pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
55        FromArrow::from_field(self, field)
56    }
57
58    pub fn to_arrow_field(
59        &self,
60        name: &str,
61        data_type: &DataType,
62    ) -> Result<arrow_schema::Field, ArrayError> {
63        ToArrow::to_arrow_field(self, name, data_type)
64    }
65
66    pub fn struct_from_fields(
67        &self,
68        fields: &arrow_schema::Fields,
69    ) -> Result<StructType, ArrayError> {
70        FromArrow::from_fields(self, fields)
71    }
72
73    pub fn to_arrow_array(
74        &self,
75        data_type: &arrow_schema::DataType,
76        array: &ArrayImpl,
77    ) -> Result<arrow_array::ArrayRef, ArrayError> {
78        ToArrow::to_array(self, data_type, array)
79    }
80
81    pub fn array_from_arrow_array(
82        &self,
83        field: &arrow_schema::Field,
84        array: &arrow_array::ArrayRef,
85    ) -> Result<ArrayImpl, ArrayError> {
86        FromArrow::from_array(self, field, array)
87    }
88}
89
90impl ToArrow for IcebergArrowConvert {
91    fn to_arrow_field(
92        &self,
93        name: &str,
94        data_type: &DataType,
95    ) -> Result<arrow_schema::Field, ArrayError> {
96        let data_type = match data_type {
97            DataType::Boolean => self.bool_type_to_arrow(),
98            DataType::Int16 => self.int32_type_to_arrow(),
99            DataType::Int32 => self.int32_type_to_arrow(),
100            DataType::Int64 => self.int64_type_to_arrow(),
101            DataType::Int256 => self.int256_type_to_arrow(),
102            DataType::Float32 => self.float32_type_to_arrow(),
103            DataType::Float64 => self.float64_type_to_arrow(),
104            DataType::Date => self.date_type_to_arrow(),
105            DataType::Time => self.time_type_to_arrow(),
106            DataType::Timestamp => self.timestamp_type_to_arrow(),
107            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
108            DataType::Interval => self.interval_type_to_arrow(),
109            DataType::Varchar => self.varchar_type_to_arrow(),
110            DataType::Bytea => self.bytea_type_to_arrow(),
111            DataType::Serial => self.serial_type_to_arrow(),
112            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
113            DataType::Jsonb => self.varchar_type_to_arrow(),
114            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
115            DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
116            DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
117            DataType::Vector(_) => self.list_type_to_arrow(&VECTOR_ITEM_TYPE)?,
118        };
119        Ok(arrow_schema::Field::new(name, data_type, true))
120    }
121
122    #[inline]
123    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
124        arrow_schema::DataType::Utf8
125    }
126
127    #[inline]
128    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
129        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
130        let data_type =
131            arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
132        arrow_schema::Field::new(name, data_type, true)
133    }
134
135    fn decimal_to_arrow(
136        &self,
137        data_type: &arrow_schema::DataType,
138        array: &DecimalArray,
139    ) -> Result<arrow_array::ArrayRef, ArrayError> {
140        let (precision, max_scale) = match data_type {
141            arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
142            _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
143        };
144
145        // Convert Decimal to i128:
146        let values: Vec<Option<i128>> = array
147            .iter()
148            .map(|e| {
149                e.and_then(|e| match e {
150                    crate::array::Decimal::Normalized(e) => {
151                        let value = e.mantissa();
152                        let scale = e.scale() as i8;
153                        let diff_scale = abs(max_scale - scale);
154                        let value = match scale {
155                            _ if scale < max_scale => value.mul(10_i128.pow(diff_scale as u32)),
156                            _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
157                            _ => value,
158                        };
159                        Some(value)
160                    }
161                    // For Inf, we replace them with the max/min value within the precision.
162                    crate::array::Decimal::PositiveInf => {
163                        let max_value = 10_i128.pow(precision as u32) - 1;
164                        Some(max_value)
165                    }
166                    crate::array::Decimal::NegativeInf => {
167                        let max_value = 10_i128.pow(precision as u32) - 1;
168                        Some(-max_value)
169                    }
170                    crate::array::Decimal::NaN => None,
171                })
172            })
173            .collect();
174
175        let array = arrow_array::Decimal128Array::from(values)
176            .with_precision_and_scale(precision, max_scale)
177            .map_err(ArrayError::from_arrow)?;
178        Ok(Arc::new(array) as ArrayRef)
179    }
180
181    fn interval_to_arrow(
182        &self,
183        array: &IntervalArray,
184    ) -> Result<arrow_array::ArrayRef, ArrayError> {
185        Ok(Arc::new(arrow_array::StringArray::from(array)))
186    }
187}
188
189impl FromArrow for IcebergArrowConvert {}
190
191/// Iceberg sink with `create_table_if_not_exists` option will use this struct to convert the
192/// iceberg data type to arrow data type.
193///
194/// Specifically, it will add the field id to the arrow field metadata, because iceberg-rust need the field id to be set.
195///
196/// Note: this is different from [`IcebergArrowConvert`], which is used to read from/write to
197/// an _existing_ iceberg table. In that case, we just need to make sure the data is compatible to the existing schema.
198/// But to _create a new table_, we need to meet more requirements of iceberg.
199#[derive(Default)]
200pub struct IcebergCreateTableArrowConvert {
201    next_field_id: RefCell<u32>,
202}
203
204impl IcebergCreateTableArrowConvert {
205    pub fn to_arrow_field(
206        &self,
207        name: &str,
208        data_type: &DataType,
209    ) -> Result<arrow_schema::Field, ArrayError> {
210        ToArrow::to_arrow_field(self, name, data_type)
211    }
212
213    fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
214        *self.next_field_id.borrow_mut() += 1;
215        let field_id = *self.next_field_id.borrow();
216
217        let mut metadata = HashMap::new();
218        // for iceberg-rust
219        metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
220        arrow_field.set_metadata(metadata);
221    }
222}
223
224impl ToArrow for IcebergCreateTableArrowConvert {
225    #[inline]
226    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
227        // To create a iceberg table, we need a decimal type with precision and scale to be set
228        // We choose 28 here
229        // The decimal type finally will be converted to an iceberg decimal type.
230        // Iceberg decimal(P,S)
231        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
232        let data_type =
233            arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
234
235        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
236        self.add_field_id(&mut arrow_field);
237        arrow_field
238    }
239
240    #[inline]
241    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
242        arrow_schema::DataType::Utf8
243    }
244
245    fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
246        let data_type = arrow_schema::DataType::Utf8;
247
248        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
249        self.add_field_id(&mut arrow_field);
250        arrow_field
251    }
252
253    /// Convert RisingWave data type to Arrow data type.
254    ///
255    /// This function returns a `Field` instead of `DataType` because some may be converted to
256    /// extension types which require additional metadata in the field.
257    fn to_arrow_field(
258        &self,
259        name: &str,
260        value: &DataType,
261    ) -> Result<arrow_schema::Field, ArrayError> {
262        let data_type = match value {
263            // using the inline function
264            DataType::Boolean => self.bool_type_to_arrow(),
265            DataType::Int16 => self.int32_type_to_arrow(),
266            DataType::Int32 => self.int32_type_to_arrow(),
267            DataType::Int64 => self.int64_type_to_arrow(),
268            DataType::Int256 => self.varchar_type_to_arrow(),
269            DataType::Float32 => self.float32_type_to_arrow(),
270            DataType::Float64 => self.float64_type_to_arrow(),
271            DataType::Date => self.date_type_to_arrow(),
272            DataType::Time => self.time_type_to_arrow(),
273            DataType::Timestamp => self.timestamp_type_to_arrow(),
274            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
275            DataType::Interval => self.interval_type_to_arrow(),
276            DataType::Varchar => self.varchar_type_to_arrow(),
277            DataType::Bytea => self.bytea_type_to_arrow(),
278            DataType::Serial => self.serial_type_to_arrow(),
279            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
280            DataType::Jsonb => self.varchar_type_to_arrow(),
281            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
282            DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
283            DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
284            DataType::Vector(_) => self.list_type_to_arrow(&VECTOR_ITEM_TYPE)?,
285        };
286
287        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
288        self.add_field_id(&mut arrow_field);
289        Ok(arrow_field)
290    }
291}
292
293#[cfg(test)]
294mod test {
295    use std::sync::Arc;
296
297    use super::arrow_array::{ArrayRef, Decimal128Array};
298    use super::arrow_schema::DataType;
299    use super::*;
300    use crate::array::{Decimal, DecimalArray};
301
302    #[test]
303    fn decimal() {
304        let array = DecimalArray::from_iter([
305            None,
306            Some(Decimal::NaN),
307            Some(Decimal::PositiveInf),
308            Some(Decimal::NegativeInf),
309            Some(Decimal::Normalized("123.4".parse().unwrap())),
310            Some(Decimal::Normalized("123.456".parse().unwrap())),
311        ]);
312        let ty = DataType::Decimal128(6, 3);
313        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
314        let expect_array = Arc::new(
315            Decimal128Array::from(vec![
316                None,
317                None,
318                Some(999999),
319                Some(-999999),
320                Some(123400),
321                Some(123456),
322            ])
323            .with_data_type(ty),
324        ) as ArrayRef;
325        assert_eq!(&arrow_array, &expect_array);
326    }
327
328    #[test]
329    fn decimal_with_large_scale() {
330        let array = DecimalArray::from_iter([
331            None,
332            Some(Decimal::NaN),
333            Some(Decimal::PositiveInf),
334            Some(Decimal::NegativeInf),
335            Some(Decimal::Normalized("123.4".parse().unwrap())),
336            Some(Decimal::Normalized("123.456".parse().unwrap())),
337        ]);
338        let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
339        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
340        let expect_array = Arc::new(
341            Decimal128Array::from(vec![
342                None,
343                None,
344                Some(9999999999999999999999999999),
345                Some(-9999999999999999999999999999),
346                Some(1234000000000),
347                Some(1234560000000),
348            ])
349            .with_data_type(ty),
350        ) as ArrayRef;
351        assert_eq!(&arrow_array, &expect_array);
352    }
353}