risingwave_common/array/arrow/
arrow_iceberg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::{Div, Mul};
18use std::sync::Arc;
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_54::{
24    FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25    is_parquet_schema_match_source_schema,
26};
27use crate::array::{Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray};
28use crate::types::StructType;
29
30pub struct IcebergArrowConvert;
31
32impl IcebergArrowConvert {
33    pub fn to_record_batch(
34        &self,
35        schema: arrow_schema::SchemaRef,
36        chunk: &DataChunk,
37    ) -> Result<arrow_array::RecordBatch, ArrayError> {
38        ToArrow::to_record_batch(self, schema, chunk)
39    }
40
41    pub fn chunk_from_record_batch(
42        &self,
43        batch: &arrow_array::RecordBatch,
44    ) -> Result<DataChunk, ArrayError> {
45        FromArrow::from_record_batch(self, batch)
46    }
47
48    pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
49        FromArrow::from_field(self, field)
50    }
51
52    pub fn to_arrow_field(
53        &self,
54        name: &str,
55        data_type: &DataType,
56    ) -> Result<arrow_schema::Field, ArrayError> {
57        ToArrow::to_arrow_field(self, name, data_type)
58    }
59
60    pub fn struct_from_fields(
61        &self,
62        fields: &arrow_schema::Fields,
63    ) -> Result<StructType, ArrayError> {
64        FromArrow::from_fields(self, fields)
65    }
66
67    pub fn to_arrow_array(
68        &self,
69        data_type: &arrow_schema::DataType,
70        array: &ArrayImpl,
71    ) -> Result<arrow_array::ArrayRef, ArrayError> {
72        ToArrow::to_array(self, data_type, array)
73    }
74
75    pub fn array_from_arrow_array(
76        &self,
77        field: &arrow_schema::Field,
78        array: &arrow_array::ArrayRef,
79    ) -> Result<ArrayImpl, ArrayError> {
80        FromArrow::from_array(self, field, array)
81    }
82}
83
84impl ToArrow for IcebergArrowConvert {
85    fn to_arrow_field(
86        &self,
87        name: &str,
88        data_type: &DataType,
89    ) -> Result<arrow_schema::Field, ArrayError> {
90        let data_type = match data_type {
91            DataType::Boolean => self.bool_type_to_arrow(),
92            DataType::Int16 => self.int32_type_to_arrow(),
93            DataType::Int32 => self.int32_type_to_arrow(),
94            DataType::Int64 => self.int64_type_to_arrow(),
95            DataType::Int256 => self.int256_type_to_arrow(),
96            DataType::Float32 => self.float32_type_to_arrow(),
97            DataType::Float64 => self.float64_type_to_arrow(),
98            DataType::Date => self.date_type_to_arrow(),
99            DataType::Time => self.time_type_to_arrow(),
100            DataType::Timestamp => self.timestamp_type_to_arrow(),
101            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
102            DataType::Interval => self.interval_type_to_arrow(),
103            DataType::Varchar => self.varchar_type_to_arrow(),
104            DataType::Bytea => self.bytea_type_to_arrow(),
105            DataType::Serial => self.serial_type_to_arrow(),
106            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
107            DataType::Jsonb => self.varchar_type_to_arrow(),
108            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
109            DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
110            DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
111        };
112        Ok(arrow_schema::Field::new(name, data_type, true))
113    }
114
115    #[inline]
116    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
117        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
118        let data_type = arrow_schema::DataType::Decimal128(28, 10);
119        arrow_schema::Field::new(name, data_type, true)
120    }
121
122    fn decimal_to_arrow(
123        &self,
124        data_type: &arrow_schema::DataType,
125        array: &DecimalArray,
126    ) -> Result<arrow_array::ArrayRef, ArrayError> {
127        let (precision, max_scale) = match data_type {
128            arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
129            _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
130        };
131
132        // Convert Decimal to i128:
133        let values: Vec<Option<i128>> = array
134            .iter()
135            .map(|e| {
136                e.and_then(|e| match e {
137                    crate::array::Decimal::Normalized(e) => {
138                        let value = e.mantissa();
139                        let scale = e.scale() as i8;
140                        let diff_scale = abs(max_scale - scale);
141                        let value = match scale {
142                            _ if scale < max_scale => value.mul(10_i128.pow(diff_scale as u32)),
143                            _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
144                            _ => value,
145                        };
146                        Some(value)
147                    }
148                    // For Inf, we replace them with the max/min value within the precision.
149                    crate::array::Decimal::PositiveInf => {
150                        let max_value = 10_i128.pow(precision as u32) - 1;
151                        Some(max_value)
152                    }
153                    crate::array::Decimal::NegativeInf => {
154                        let max_value = 10_i128.pow(precision as u32) - 1;
155                        Some(-max_value)
156                    }
157                    crate::array::Decimal::NaN => None,
158                })
159            })
160            .collect();
161
162        let array = arrow_array::Decimal128Array::from(values)
163            .with_precision_and_scale(precision, max_scale)
164            .map_err(ArrayError::from_arrow)?;
165        Ok(Arc::new(array) as ArrayRef)
166    }
167}
168
169impl FromArrow for IcebergArrowConvert {}
170
171/// Iceberg sink with `create_table_if_not_exists` option will use this struct to convert the
172/// iceberg data type to arrow data type.
173///
174/// Specifically, it will add the field id to the arrow field metadata, because iceberg-rust need the field id to be set.
175///
176/// Note: this is different from [`IcebergArrowConvert`], which is used to read from/write to
177/// an _existing_ iceberg table. In that case, we just need to make sure the data is compatible to the existing schema.
178/// But to _create a new table_, we need to meet more requirements of iceberg.
179#[derive(Default)]
180pub struct IcebergCreateTableArrowConvert {
181    next_field_id: RefCell<u32>,
182}
183
184impl IcebergCreateTableArrowConvert {
185    pub fn to_arrow_field(
186        &self,
187        name: &str,
188        data_type: &DataType,
189    ) -> Result<arrow_schema::Field, ArrayError> {
190        ToArrow::to_arrow_field(self, name, data_type)
191    }
192
193    fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
194        *self.next_field_id.borrow_mut() += 1;
195        let field_id = *self.next_field_id.borrow();
196
197        let mut metadata = HashMap::new();
198        // for iceberg-rust
199        metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
200        arrow_field.set_metadata(metadata);
201    }
202}
203
204impl ToArrow for IcebergCreateTableArrowConvert {
205    #[inline]
206    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
207        // To create a iceberg table, we need a decimal type with precision and scale to be set
208        // We choose 28 here
209        // The decimal type finally will be converted to an iceberg decimal type.
210        // Iceberg decimal(P,S)
211        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
212        let data_type = arrow_schema::DataType::Decimal128(28, 10);
213
214        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
215        self.add_field_id(&mut arrow_field);
216        arrow_field
217    }
218
219    fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
220        let data_type = arrow_schema::DataType::Utf8;
221
222        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
223        self.add_field_id(&mut arrow_field);
224        arrow_field
225    }
226
227    /// Convert RisingWave data type to Arrow data type.
228    ///
229    /// This function returns a `Field` instead of `DataType` because some may be converted to
230    /// extension types which require additional metadata in the field.
231    fn to_arrow_field(
232        &self,
233        name: &str,
234        value: &DataType,
235    ) -> Result<arrow_schema::Field, ArrayError> {
236        let data_type = match value {
237            // using the inline function
238            DataType::Boolean => self.bool_type_to_arrow(),
239            DataType::Int16 => self.int32_type_to_arrow(),
240            DataType::Int32 => self.int32_type_to_arrow(),
241            DataType::Int64 => self.int64_type_to_arrow(),
242            DataType::Int256 => self.varchar_type_to_arrow(),
243            DataType::Float32 => self.float32_type_to_arrow(),
244            DataType::Float64 => self.float64_type_to_arrow(),
245            DataType::Date => self.date_type_to_arrow(),
246            DataType::Time => self.time_type_to_arrow(),
247            DataType::Timestamp => self.timestamp_type_to_arrow(),
248            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
249            DataType::Interval => self.interval_type_to_arrow(),
250            DataType::Varchar => self.varchar_type_to_arrow(),
251            DataType::Bytea => self.bytea_type_to_arrow(),
252            DataType::Serial => self.serial_type_to_arrow(),
253            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
254            DataType::Jsonb => self.varchar_type_to_arrow(),
255            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
256            DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
257            DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
258        };
259
260        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
261        self.add_field_id(&mut arrow_field);
262        Ok(arrow_field)
263    }
264}
265
266#[cfg(test)]
267mod test {
268    use std::sync::Arc;
269
270    use super::arrow_array::{ArrayRef, Decimal128Array};
271    use super::arrow_schema::DataType;
272    use super::*;
273    use crate::array::{Decimal, DecimalArray};
274
275    #[test]
276    fn decimal() {
277        let array = DecimalArray::from_iter([
278            None,
279            Some(Decimal::NaN),
280            Some(Decimal::PositiveInf),
281            Some(Decimal::NegativeInf),
282            Some(Decimal::Normalized("123.4".parse().unwrap())),
283            Some(Decimal::Normalized("123.456".parse().unwrap())),
284        ]);
285        let ty = DataType::Decimal128(6, 3);
286        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
287        let expect_array = Arc::new(
288            Decimal128Array::from(vec![
289                None,
290                None,
291                Some(999999),
292                Some(-999999),
293                Some(123400),
294                Some(123456),
295            ])
296            .with_data_type(ty),
297        ) as ArrayRef;
298        assert_eq!(&arrow_array, &expect_array);
299    }
300
301    #[test]
302    fn decimal_with_large_scale() {
303        let array = DecimalArray::from_iter([
304            None,
305            Some(Decimal::NaN),
306            Some(Decimal::PositiveInf),
307            Some(Decimal::NegativeInf),
308            Some(Decimal::Normalized("123.4".parse().unwrap())),
309            Some(Decimal::Normalized("123.456".parse().unwrap())),
310        ]);
311        let ty = DataType::Decimal128(28, 10);
312        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
313        let expect_array = Arc::new(
314            Decimal128Array::from(vec![
315                None,
316                None,
317                Some(9999999999999999999999999999),
318                Some(-9999999999999999999999999999),
319                Some(1234000000000),
320                Some(1234560000000),
321            ])
322            .with_data_type(ty),
323        ) as ArrayRef;
324        assert_eq!(&arrow_array, &expect_array);
325    }
326}