risingwave_common/array/arrow/
arrow_iceberg.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::Div;
18use std::sync::{Arc, LazyLock};
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_56::{
24    FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25    is_parquet_schema_match_source_schema,
26};
27use crate::array::{
28    Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray, IntervalArray,
29};
30use crate::types::StructType;
31
32pub struct IcebergArrowConvert;
33
34pub const ICEBERG_DECIMAL_PRECISION: u8 = 28;
35pub const ICEBERG_DECIMAL_SCALE: i8 = 10;
36
37impl IcebergArrowConvert {
38    pub fn to_record_batch(
39        &self,
40        schema: arrow_schema::SchemaRef,
41        chunk: &DataChunk,
42    ) -> Result<arrow_array::RecordBatch, ArrayError> {
43        ToArrow::to_record_batch(self, schema, chunk)
44    }
45
46    pub fn chunk_from_record_batch(
47        &self,
48        batch: &arrow_array::RecordBatch,
49    ) -> Result<DataChunk, ArrayError> {
50        FromArrow::from_record_batch(self, batch)
51    }
52
53    pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
54        FromArrow::from_field(self, field)
55    }
56
57    pub fn to_arrow_field(
58        &self,
59        name: &str,
60        data_type: &DataType,
61    ) -> Result<arrow_schema::Field, ArrayError> {
62        ToArrow::to_arrow_field(self, name, data_type)
63    }
64
65    pub fn struct_from_fields(
66        &self,
67        fields: &arrow_schema::Fields,
68    ) -> Result<StructType, ArrayError> {
69        FromArrow::from_fields(self, fields)
70    }
71
72    pub fn to_arrow_array(
73        &self,
74        data_type: &arrow_schema::DataType,
75        array: &ArrayImpl,
76    ) -> Result<arrow_array::ArrayRef, ArrayError> {
77        ToArrow::to_array(self, data_type, array)
78    }
79
80    pub fn array_from_arrow_array(
81        &self,
82        field: &arrow_schema::Field,
83        array: &arrow_array::ArrayRef,
84    ) -> Result<ArrayImpl, ArrayError> {
85        FromArrow::from_array(self, field, array)
86    }
87
88    /// A helper function to convert an Arrow array to RisingWave array without knowing the field.
89    /// It will use the datatype from arrow array to infer the RisingWave data type.
90    ///
91    /// The difference between this function and `array_from_arrow_array` is that `array_from_arrow_array` will try using `ARROW:extension:name` field metadata to determine the RisingWave data type for extension types.
92    pub fn array_from_arrow_array_raw(
93        &self,
94        array: &arrow_array::ArrayRef,
95    ) -> Result<ArrayImpl, ArrayError> {
96        static FIELD_DUMMY: LazyLock<arrow_schema::Field> =
97            LazyLock::new(|| arrow_schema::Field::new("dummy", arrow_schema::DataType::Null, true));
98        FromArrow::from_array(self, &FIELD_DUMMY, array)
99    }
100}
101
102impl ToArrow for IcebergArrowConvert {
103    fn to_arrow_field(
104        &self,
105        name: &str,
106        data_type: &DataType,
107    ) -> Result<arrow_schema::Field, ArrayError> {
108        let data_type = match data_type {
109            DataType::Boolean => self.bool_type_to_arrow(),
110            DataType::Int16 => self.int32_type_to_arrow(),
111            DataType::Int32 => self.int32_type_to_arrow(),
112            DataType::Int64 => self.int64_type_to_arrow(),
113            DataType::Int256 => self.int256_type_to_arrow(),
114            DataType::Float32 => self.float32_type_to_arrow(),
115            DataType::Float64 => self.float64_type_to_arrow(),
116            DataType::Date => self.date_type_to_arrow(),
117            DataType::Time => self.time_type_to_arrow(),
118            DataType::Timestamp => self.timestamp_type_to_arrow(),
119            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
120            DataType::Interval => self.interval_type_to_arrow(),
121            DataType::Varchar => self.varchar_type_to_arrow(),
122            DataType::Bytea => self.bytea_type_to_arrow(),
123            DataType::Serial => self.serial_type_to_arrow(),
124            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
125            DataType::Jsonb => self.varchar_type_to_arrow(),
126            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
127            DataType::List(list) => self.list_type_to_arrow(list)?,
128            DataType::Map(map) => self.map_type_to_arrow(map)?,
129            DataType::Vector(_) => self.vector_type_to_arrow()?,
130        };
131        Ok(arrow_schema::Field::new(name, data_type, true))
132    }
133
134    #[inline]
135    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
136        arrow_schema::DataType::Utf8
137    }
138
139    #[inline]
140    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
141        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
142        let data_type =
143            arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
144        arrow_schema::Field::new(name, data_type, true)
145    }
146
147    fn decimal_to_arrow(
148        &self,
149        data_type: &arrow_schema::DataType,
150        array: &DecimalArray,
151    ) -> Result<arrow_array::ArrayRef, ArrayError> {
152        let (precision, max_scale) = match data_type {
153            arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
154            _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
155        };
156
157        // Convert Decimal to i128:
158        let max_value = 10_i128.pow(precision as u32) - 1;
159        let values: Vec<Option<i128>> = array
160            .iter()
161            .map(|e| {
162                e.and_then(|e| match e {
163                    crate::array::Decimal::Normalized(e) => {
164                        let value = e.mantissa();
165                        let scale = e.scale() as i8;
166                        let diff_scale = abs(max_scale - scale);
167                        let value = match scale {
168                            _ if scale < max_scale => value
169                                .checked_mul(10_i128.pow(diff_scale as u32))
170                                .and_then(|v| if abs(v) <= max_value { Some(v) } else { None })
171                                .unwrap_or_else(|| {
172                                    tracing::warn!(
173                                        "Decimal overflow when converting to arrow decimal with precision {} and scale {}. It will be replaced with inf/-inf.",
174                                        precision, max_scale
175                                    );
176                                    if value >= 0 { max_value } else { -max_value }
177                                }),
178                            _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
179                            _ => value,
180                        };
181                        Some(value)
182                    }
183                    // For Inf, we replace them with the max/min value within the precision.
184                    crate::array::Decimal::PositiveInf => {
185                        Some(max_value)
186                    }
187                    crate::array::Decimal::NegativeInf => {
188                        Some(-max_value)
189                    }
190                    crate::array::Decimal::NaN => None,
191                })
192            })
193            .collect();
194
195        let array = arrow_array::Decimal128Array::from(values)
196            .with_precision_and_scale(precision, max_scale)
197            .map_err(ArrayError::from_arrow)?;
198        Ok(Arc::new(array) as ArrayRef)
199    }
200
201    fn interval_to_arrow(
202        &self,
203        array: &IntervalArray,
204    ) -> Result<arrow_array::ArrayRef, ArrayError> {
205        Ok(Arc::new(arrow_array::StringArray::from(array)))
206    }
207}
208
209impl FromArrow for IcebergArrowConvert {}
210
211/// Iceberg sink with `create_table_if_not_exists` option will use this struct to convert the
212/// iceberg data type to arrow data type.
213///
214/// Specifically, it will add the field id to the arrow field metadata, because iceberg-rust need the field id to be set.
215///
216/// Note: this is different from [`IcebergArrowConvert`], which is used to read from/write to
217/// an _existing_ iceberg table. In that case, we just need to make sure the data is compatible to the existing schema.
218/// But to _create a new table_, we need to meet more requirements of iceberg.
219#[derive(Default)]
220pub struct IcebergCreateTableArrowConvert {
221    next_field_id: RefCell<u32>,
222}
223
224impl IcebergCreateTableArrowConvert {
225    pub fn to_arrow_field(
226        &self,
227        name: &str,
228        data_type: &DataType,
229    ) -> Result<arrow_schema::Field, ArrayError> {
230        ToArrow::to_arrow_field(self, name, data_type)
231    }
232
233    fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
234        *self.next_field_id.borrow_mut() += 1;
235        let field_id = *self.next_field_id.borrow();
236
237        let mut metadata = HashMap::new();
238        // for iceberg-rust
239        metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
240        arrow_field.set_metadata(metadata);
241    }
242}
243
244impl ToArrow for IcebergCreateTableArrowConvert {
245    #[inline]
246    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
247        // To create a iceberg table, we need a decimal type with precision and scale to be set
248        // We choose 28 here
249        // The decimal type finally will be converted to an iceberg decimal type.
250        // Iceberg decimal(P,S)
251        // Fixed-point decimal; precision P, scale S Scale is fixed, precision must be less than 38.
252        let data_type =
253            arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
254
255        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
256        self.add_field_id(&mut arrow_field);
257        arrow_field
258    }
259
260    #[inline]
261    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
262        arrow_schema::DataType::Utf8
263    }
264
265    fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
266        let data_type = arrow_schema::DataType::Utf8;
267
268        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
269        self.add_field_id(&mut arrow_field);
270        arrow_field
271    }
272
273    /// Convert RisingWave data type to Arrow data type.
274    ///
275    /// This function returns a `Field` instead of `DataType` because some may be converted to
276    /// extension types which require additional metadata in the field.
277    fn to_arrow_field(
278        &self,
279        name: &str,
280        value: &DataType,
281    ) -> Result<arrow_schema::Field, ArrayError> {
282        let data_type = match value {
283            // using the inline function
284            DataType::Boolean => self.bool_type_to_arrow(),
285            DataType::Int16 => self.int32_type_to_arrow(),
286            DataType::Int32 => self.int32_type_to_arrow(),
287            DataType::Int64 => self.int64_type_to_arrow(),
288            DataType::Int256 => self.varchar_type_to_arrow(),
289            DataType::Float32 => self.float32_type_to_arrow(),
290            DataType::Float64 => self.float64_type_to_arrow(),
291            DataType::Date => self.date_type_to_arrow(),
292            DataType::Time => self.time_type_to_arrow(),
293            DataType::Timestamp => self.timestamp_type_to_arrow(),
294            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
295            DataType::Interval => self.interval_type_to_arrow(),
296            DataType::Varchar => self.varchar_type_to_arrow(),
297            DataType::Bytea => self.bytea_type_to_arrow(),
298            DataType::Serial => self.serial_type_to_arrow(),
299            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
300            DataType::Jsonb => self.varchar_type_to_arrow(),
301            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
302            DataType::List(list) => self.list_type_to_arrow(list)?,
303            DataType::Map(map) => self.map_type_to_arrow(map)?,
304            DataType::Vector(_) => self.vector_type_to_arrow()?,
305        };
306
307        let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
308        self.add_field_id(&mut arrow_field);
309        Ok(arrow_field)
310    }
311}
312
313#[cfg(test)]
314mod test {
315    use std::sync::Arc;
316
317    use super::arrow_array::{ArrayRef, Decimal128Array};
318    use super::arrow_schema::DataType;
319    use super::*;
320    use crate::array::{Decimal, DecimalArray};
321
322    #[test]
323    fn decimal() {
324        let array = DecimalArray::from_iter([
325            None,
326            Some(Decimal::NaN),
327            Some(Decimal::PositiveInf),
328            Some(Decimal::NegativeInf),
329            Some(Decimal::Normalized("123.4".parse().unwrap())),
330            Some(Decimal::Normalized("123.456".parse().unwrap())),
331        ]);
332        let ty = DataType::Decimal128(6, 3);
333        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
334        let expect_array = Arc::new(
335            Decimal128Array::from(vec![
336                None,
337                None,
338                Some(999999),
339                Some(-999999),
340                Some(123400),
341                Some(123456),
342            ])
343            .with_data_type(ty),
344        ) as ArrayRef;
345        assert_eq!(&arrow_array, &expect_array);
346    }
347
348    #[test]
349    fn decimal_with_large_scale() {
350        let array = DecimalArray::from_iter([
351            None,
352            Some(Decimal::NaN),
353            Some(Decimal::PositiveInf),
354            Some(Decimal::NegativeInf),
355            Some(Decimal::Normalized("123.4".parse().unwrap())),
356            Some(Decimal::Normalized("123.456".parse().unwrap())),
357        ]);
358        let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
359        let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
360        let expect_array = Arc::new(
361            Decimal128Array::from(vec![
362                None,
363                None,
364                Some(9999999999999999999999999999),
365                Some(-9999999999999999999999999999),
366                Some(1234000000000),
367                Some(1234560000000),
368            ])
369            .with_data_type(ty),
370        ) as ArrayRef;
371        assert_eq!(&arrow_array, &expect_array);
372    }
373}