risingwave_common/array/arrow/
arrow_iceberg.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::Div;
18use std::sync::{Arc, LazyLock};
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_56::{
24    FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25    is_parquet_schema_match_source_schema,
26};
27use crate::array::{
28    Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray, IntervalArray,
29};
30use crate::types::StructType;
31
32pub struct IcebergArrowConvert;
33
34// Arrow Decimal128 supports up to 38 decimal digits. We use precision=38, scale=10:
35// - Integer range: up to 10^28 - 1 (28 digits)
36// - Fractional precision: 10 digits
37// - Covers all RisingWave decimal values (MAX_PRECISION=28)
38//
39// Note: When reading Arrow decimals that exceed RisingWave's 96-bit / 28-digit
40// storage limit, the conversion code in arrow_impl.rs will reduce scale and
41// truncate the mantissa (via truncated_i128_and_scale) to make them fit.
42pub const ICEBERG_DECIMAL_PRECISION: u8 = 38;
43pub const ICEBERG_DECIMAL_SCALE: i8 = 10;
44
45impl IcebergArrowConvert {
46    pub fn to_record_batch(
47        &self,
48        schema: arrow_schema::SchemaRef,
49        chunk: &DataChunk,
50    ) -> Result<arrow_array::RecordBatch, ArrayError> {
51        ToArrow::to_record_batch(self, schema, chunk)
52    }
53
54    pub fn chunk_from_record_batch(
55        &self,
56        batch: &arrow_array::RecordBatch,
57    ) -> Result<DataChunk, ArrayError> {
58        FromArrow::from_record_batch(self, batch)
59    }
60
61    pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
62        FromArrow::from_field(self, field)
63    }
64
65    pub fn to_arrow_field(
66        &self,
67        name: &str,
68        data_type: &DataType,
69    ) -> Result<arrow_schema::Field, ArrayError> {
70        ToArrow::to_arrow_field(self, name, data_type)
71    }
72
73    pub fn struct_from_fields(
74        &self,
75        fields: &arrow_schema::Fields,
76    ) -> Result<StructType, ArrayError> {
77        FromArrow::from_fields(self, fields)
78    }
79
80    pub fn to_arrow_array(
81        &self,
82        data_type: &arrow_schema::DataType,
83        array: &ArrayImpl,
84    ) -> Result<arrow_array::ArrayRef, ArrayError> {
85        ToArrow::to_array(self, data_type, array)
86    }
87
88    pub fn array_from_arrow_array(
89        &self,
90        field: &arrow_schema::Field,
91        array: &arrow_array::ArrayRef,
92    ) -> Result<ArrayImpl, ArrayError> {
93        FromArrow::from_array(self, field, array)
94    }
95
96    /// A helper function to convert an Arrow array to RisingWave array without knowing the field.
97    /// It will use the datatype from arrow array to infer the RisingWave data type.
98    ///
99    /// The difference between this function and `array_from_arrow_array` is that `array_from_arrow_array` will try using `ARROW:extension:name` field metadata to determine the RisingWave data type for extension types.
100    pub fn array_from_arrow_array_raw(
101        &self,
102        array: &arrow_array::ArrayRef,
103    ) -> Result<ArrayImpl, ArrayError> {
104        static FIELD_DUMMY: LazyLock<arrow_schema::Field> =
105            LazyLock::new(|| arrow_schema::Field::new("dummy", arrow_schema::DataType::Null, true));
106        FromArrow::from_array(self, &FIELD_DUMMY, array)
107    }
108}
109
110impl ToArrow for IcebergArrowConvert {
111    fn to_arrow_field(
112        &self,
113        name: &str,
114        data_type: &DataType,
115    ) -> Result<arrow_schema::Field, ArrayError> {
116        let data_type = match data_type {
117            DataType::Boolean => self.bool_type_to_arrow(),
118            DataType::Int16 => self.int32_type_to_arrow(),
119            DataType::Int32 => self.int32_type_to_arrow(),
120            DataType::Int64 => self.int64_type_to_arrow(),
121            DataType::Int256 => self.int256_type_to_arrow(),
122            DataType::Float32 => self.float32_type_to_arrow(),
123            DataType::Float64 => self.float64_type_to_arrow(),
124            DataType::Date => self.date_type_to_arrow(),
125            DataType::Time => self.time_type_to_arrow(),
126            DataType::Timestamp => self.timestamp_type_to_arrow(),
127            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
128            DataType::Interval => self.interval_type_to_arrow(),
129            DataType::Varchar => self.varchar_type_to_arrow(),
130            DataType::Bytea => self.bytea_type_to_arrow(),
131            DataType::Serial => self.serial_type_to_arrow(),
132            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
133            DataType::Jsonb => self.varchar_type_to_arrow(),
134            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
135            DataType::List(list) => self.list_type_to_arrow(list)?,
136            DataType::Map(map) => self.map_type_to_arrow(map)?,
137            DataType::Vector(_) => self.vector_type_to_arrow()?,
138        };
139        Ok(arrow_schema::Field::new(name, data_type, true))
140    }
141
142    #[inline]
143    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
144        arrow_schema::DataType::Utf8
145    }
146
147    #[inline]
148    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
149        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
150        let data_type =
151            arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
152        arrow_schema::Field::new(name, data_type, true)
153    }
154
155    fn decimal_to_arrow(
156        &self,
157        data_type: &arrow_schema::DataType,
158        array: &DecimalArray,
159    ) -> Result<arrow_array::ArrayRef, ArrayError> {
160        let (precision, max_scale) = match data_type {
161            arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
162            _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
163        };
164
165        // Convert Decimal to i128:
166        let max_value = 10_i128.pow(precision as u32) - 1;
167        let values: Vec<Option<i128>> = array
168            .iter()
169            .map(|e| {
170                e.and_then(|e| match e {
171                    crate::array::Decimal::Normalized(e) => {
172                        let value = e.mantissa();
173                        let scale = e.scale() as i8;
174                        let diff_scale = abs(max_scale - scale);
175                        let value = match scale {
176                            _ if scale < max_scale => value
177                                .checked_mul(10_i128.pow(diff_scale as u32))
178                                .and_then(|v| if abs(v) <= max_value { Some(v) } else { None })
179                                .unwrap_or_else(|| {
180                                    tracing::warn!(
181                                        "Decimal overflow when converting to arrow decimal with precision {} and scale {}. It will be replaced with inf/-inf.",
182                                        precision, max_scale
183                                    );
184                                    if value >= 0 { max_value } else { -max_value }
185                                }),
186                            _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
187                            _ => value,
188                        };
189                        Some(value)
190                    }
191                    // For Inf, we replace them with the max/min value within the precision.
192                    crate::array::Decimal::PositiveInf => {
193                        Some(max_value)
194                    }
195                    crate::array::Decimal::NegativeInf => {
196                        Some(-max_value)
197                    }
198                    crate::array::Decimal::NaN => None,
199                })
200            })
201            .collect();
202
203        let array = arrow_array::Decimal128Array::from(values)
204            .with_precision_and_scale(precision, max_scale)
205            .map_err(ArrayError::from_arrow)?;
206        Ok(Arc::new(array) as ArrayRef)
207    }
208
209    fn interval_to_arrow(
210        &self,
211        array: &IntervalArray,
212    ) -> Result<arrow_array::ArrayRef, ArrayError> {
213        Ok(Arc::new(arrow_array::StringArray::from(array)))
214    }
215}
216
217impl FromArrow for IcebergArrowConvert {}
218
219/// Iceberg sink with `create_table_if_not_exists` option will use this struct to convert the
220/// iceberg data type to arrow data type.
221///
222/// Specifically, it will add the field id to the arrow field metadata, because iceberg-rust need the field id to be set.
223///
224/// Note: this is different from [`IcebergArrowConvert`], which is used to read from/write to
225/// an _existing_ iceberg table. In that case, we just need to make sure the data is compatible to the existing schema.
226/// But to _create a new table_, we need to meet more requirements of iceberg.
227#[derive(Default)]
228pub struct IcebergCreateTableArrowConvert {
229    next_field_id: RefCell<u32>,
230}
231
232impl IcebergCreateTableArrowConvert {
233    pub fn to_arrow_field(
234        &self,
235        name: &str,
236        data_type: &DataType,
237    ) -> Result<arrow_schema::Field, ArrayError> {
238        ToArrow::to_arrow_field(self, name, data_type)
239    }
240
241    fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
242        *self.next_field_id.borrow_mut() += 1;
243        let field_id = *self.next_field_id.borrow();
244
245        let mut metadata = HashMap::new();
246        // for iceberg-rust
247        metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
248        arrow_field.set_metadata(metadata);
249    }
250}
251
252impl ToArrow for IcebergCreateTableArrowConvert {
253    #[inline]
254    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
255        // To create a iceberg table, we need a decimal type with precision and scale to be set
256        // We choose 28 here
257        // The decimal type finally will be converted to an iceberg decimal type.
258        // Iceberg decimal(P,S)
259        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
260        let data_type =
261            arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
262
263        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
264        self.add_field_id(&mut arrow_field);
265        arrow_field
266    }
267
268    #[inline]
269    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
270        arrow_schema::DataType::Utf8
271    }
272
273    fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
274        let data_type = arrow_schema::DataType::Utf8;
275
276        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
277        self.add_field_id(&mut arrow_field);
278        arrow_field
279    }
280
281    /// Convert RisingWave data type to Arrow data type.
282    ///
283    /// This function returns a `Field` instead of `DataType` because some may be converted to
284    /// extension types which require additional metadata in the field.
285    fn to_arrow_field(
286        &self,
287        name: &str,
288        value: &DataType,
289    ) -> Result<arrow_schema::Field, ArrayError> {
290        let data_type = match value {
291            // using the inline function
292            DataType::Boolean => self.bool_type_to_arrow(),
293            DataType::Int16 => self.int32_type_to_arrow(),
294            DataType::Int32 => self.int32_type_to_arrow(),
295            DataType::Int64 => self.int64_type_to_arrow(),
296            DataType::Int256 => self.varchar_type_to_arrow(),
297            DataType::Float32 => self.float32_type_to_arrow(),
298            DataType::Float64 => self.float64_type_to_arrow(),
299            DataType::Date => self.date_type_to_arrow(),
300            DataType::Time => self.time_type_to_arrow(),
301            DataType::Timestamp => self.timestamp_type_to_arrow(),
302            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
303            DataType::Interval => self.interval_type_to_arrow(),
304            DataType::Varchar => self.varchar_type_to_arrow(),
305            DataType::Bytea => self.bytea_type_to_arrow(),
306            DataType::Serial => self.serial_type_to_arrow(),
307            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
308            DataType::Jsonb => self.varchar_type_to_arrow(),
309            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
310            DataType::List(list) => self.list_type_to_arrow(list)?,
311            DataType::Map(map) => self.map_type_to_arrow(map)?,
312            DataType::Vector(_) => self.vector_type_to_arrow()?,
313        };
314
315        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
316        self.add_field_id(&mut arrow_field);
317        Ok(arrow_field)
318    }
319}
320
321#[cfg(test)]
322mod test {
323    use std::sync::Arc;
324
325    use super::arrow_array::{ArrayRef, Decimal128Array};
326    use super::arrow_schema::DataType;
327    use super::*;
328    use crate::array::{Decimal, DecimalArray};
329
330    #[test]
331    fn decimal() {
332        let array = DecimalArray::from_iter([
333            None,
334            Some(Decimal::NaN),
335            Some(Decimal::PositiveInf),
336            Some(Decimal::NegativeInf),
337            Some(Decimal::Normalized("123.4".parse().unwrap())),
338            Some(Decimal::Normalized("123.456".parse().unwrap())),
339        ]);
340        let ty = DataType::Decimal128(6, 3);
341        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
342        let expect_array = Arc::new(
343            Decimal128Array::from(vec![
344                None,
345                None,
346                Some(999999),
347                Some(-999999),
348                Some(123400),
349                Some(123456),
350            ])
351            .with_data_type(ty),
352        ) as ArrayRef;
353        assert_eq!(&arrow_array, &expect_array);
354    }
355
356    #[test]
357    fn decimal_with_large_scale() {
358        let array = DecimalArray::from_iter([
359            None,
360            Some(Decimal::NaN),
361            Some(Decimal::PositiveInf),
362            Some(Decimal::NegativeInf),
363            Some(Decimal::Normalized("123.4".parse().unwrap())),
364            Some(Decimal::Normalized("123.456".parse().unwrap())),
365        ]);
366        let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
367        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
368        let expect_array = Arc::new(
369            Decimal128Array::from(vec![
370                None,
371                None,
372                // With precision=38, max value is 10^38 - 1
373                Some(99999999999999999999999999999999999999),
374                Some(-99999999999999999999999999999999999999),
375                Some(1234000000000),
376                Some(1234560000000),
377            ])
378            .with_data_type(ty),
379        ) as ArrayRef;
380        assert_eq!(&arrow_array, &expect_array);
381    }
382
383    #[test]
384    fn decimal_edge_cases_risingwave_precision() {
385        // Test edge cases between RisingWave decimal precision (28 digits) and Arrow Decimal128(38,10)
386        let array = DecimalArray::from_iter([
387            // Large 27-digit integer (previously would overflow with precision=28, scale=10)
388            Some(Decimal::Normalized(
389                "999999999999999999999999999".parse().unwrap(),
390            )),
391            // RisingWave MAX_PRECISION: 28-digit integer
392            Some(Decimal::Normalized(
393                "9999999999999999999999999999".parse().unwrap(),
394            )),
395            // Large integer with fractional part
396            Some(Decimal::Normalized(
397                "999999999999999999.9999999999".parse().unwrap(),
398            )),
399            // Small value with maximum fractional digits
400            Some(Decimal::Normalized(
401                "0.9999999999999999999999999999".parse().unwrap(),
402            )),
403            // Negative large integer
404            Some(Decimal::Normalized(
405                "-999999999999999999999999999".parse().unwrap(),
406            )),
407            // Edge case: exactly 10^18 (18 digits) - boundary for old precision=28,scale=10
408            Some(Decimal::Normalized("1000000000000000000".parse().unwrap())),
409            // Very small decimal
410            Some(Decimal::Normalized("0.0000000001".parse().unwrap())),
411            // Zero with fractional representation
412            Some(Decimal::Normalized("0.0000000000".parse().unwrap())),
413        ]);
414
415        let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
416        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
417
418        let expect_array = Arc::new(
419            Decimal128Array::from(vec![
420                // 999999999999999999999999999 * 10^10 (scale 0 → 10)
421                Some(9999999999999999999999999990000000000),
422                // 9999999999999999999999999999 * 10^10
423                Some(99999999999999999999999999990000000000),
424                // 999999999999999999.9999999999 already at scale 10
425                Some(9999999999999999999999999999),
426                // 0.9999999999999999999999999999: scale 28 → 10, truncates to 0.9999999999
427                Some(9999999999),
428                // -999999999999999999999999999 * 10^10
429                Some(-9999999999999999999999999990000000000),
430                // 1000000000000000000 * 10^10
431                Some(10000000000000000000000000000),
432                // 0.0000000001 already at scale 10
433                Some(1),
434                // 0.0000000000 (scale 10)
435                Some(0),
436            ])
437            .with_data_type(ty),
438        ) as ArrayRef;
439
440        assert_eq!(&arrow_array, &expect_array);
441    }
442
443    #[test]
444    fn decimal_special_values_roundtrip() {
445        // Test that special decimal values (inf, -inf, nan) can be written and read back correctly
446        use crate::array::Array;
447
448        let original_array = DecimalArray::from_iter([
449            Some(Decimal::PositiveInf),
450            Some(Decimal::NegativeInf),
451            Some(Decimal::NaN),
452            Some(Decimal::Normalized("123.45".parse().unwrap())),
453            None,
454        ]);
455
456        // Convert to Arrow
457        let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
458        let arrow_array = IcebergArrowConvert
459            .decimal_to_arrow(&ty, &original_array)
460            .unwrap();
461
462        // Convert back to RisingWave
463        let arrow_decimal: &arrow_array::Decimal128Array = arrow_array
464            .as_any()
465            .downcast_ref()
466            .expect("should be Decimal128Array");
467
468        let roundtrip_array: DecimalArray = arrow_decimal.try_into().unwrap();
469
470        // Verify special values roundtrip correctly
471        assert_eq!(original_array.len(), roundtrip_array.len());
472
473        // PositiveInf -> max value -> PositiveInf
474        assert_eq!(roundtrip_array.value_at(0), Some(Decimal::PositiveInf));
475
476        // NegativeInf -> min value -> NegativeInf
477        assert_eq!(roundtrip_array.value_at(1), Some(Decimal::NegativeInf));
478
479        // NaN -> NULL -> None (NaN cannot roundtrip, becomes NULL in Arrow)
480        assert_eq!(roundtrip_array.value_at(2), None);
481
482        // Normal value roundtrips correctly (scale may be adjusted)
483        assert!(matches!(
484            roundtrip_array.value_at(3),
485            Some(Decimal::Normalized(_))
486        ));
487
488        // NULL -> NULL -> None
489        assert_eq!(roundtrip_array.value_at(4), None);
490    }
491}