1use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::Div;
18use std::sync::{Arc, LazyLock};
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_56::{
24 FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25 is_parquet_schema_match_source_schema,
26};
27use crate::array::{
28 Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray, IntervalArray,
29};
30use crate::types::StructType;
31
32pub struct IcebergArrowConvert;
33
34pub const ICEBERG_DECIMAL_PRECISION: u8 = 38;
43pub const ICEBERG_DECIMAL_SCALE: i8 = 10;
44
45impl IcebergArrowConvert {
46 pub fn to_record_batch(
47 &self,
48 schema: arrow_schema::SchemaRef,
49 chunk: &DataChunk,
50 ) -> Result<arrow_array::RecordBatch, ArrayError> {
51 ToArrow::to_record_batch(self, schema, chunk)
52 }
53
54 pub fn chunk_from_record_batch(
55 &self,
56 batch: &arrow_array::RecordBatch,
57 ) -> Result<DataChunk, ArrayError> {
58 FromArrow::from_record_batch(self, batch)
59 }
60
61 pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
62 FromArrow::from_field(self, field)
63 }
64
65 pub fn to_arrow_field(
66 &self,
67 name: &str,
68 data_type: &DataType,
69 ) -> Result<arrow_schema::Field, ArrayError> {
70 ToArrow::to_arrow_field(self, name, data_type)
71 }
72
73 pub fn struct_from_fields(
74 &self,
75 fields: &arrow_schema::Fields,
76 ) -> Result<StructType, ArrayError> {
77 FromArrow::from_fields(self, fields)
78 }
79
80 pub fn to_arrow_array(
81 &self,
82 data_type: &arrow_schema::DataType,
83 array: &ArrayImpl,
84 ) -> Result<arrow_array::ArrayRef, ArrayError> {
85 ToArrow::to_array(self, data_type, array)
86 }
87
88 pub fn array_from_arrow_array(
89 &self,
90 field: &arrow_schema::Field,
91 array: &arrow_array::ArrayRef,
92 ) -> Result<ArrayImpl, ArrayError> {
93 FromArrow::from_array(self, field, array)
94 }
95
96 pub fn array_from_arrow_array_raw(
101 &self,
102 array: &arrow_array::ArrayRef,
103 ) -> Result<ArrayImpl, ArrayError> {
104 static FIELD_DUMMY: LazyLock<arrow_schema::Field> =
105 LazyLock::new(|| arrow_schema::Field::new("dummy", arrow_schema::DataType::Null, true));
106 FromArrow::from_array(self, &FIELD_DUMMY, array)
107 }
108}
109
110impl ToArrow for IcebergArrowConvert {
111 fn to_arrow_field(
112 &self,
113 name: &str,
114 data_type: &DataType,
115 ) -> Result<arrow_schema::Field, ArrayError> {
116 let data_type = match data_type {
117 DataType::Boolean => self.bool_type_to_arrow(),
118 DataType::Int16 => self.int32_type_to_arrow(),
119 DataType::Int32 => self.int32_type_to_arrow(),
120 DataType::Int64 => self.int64_type_to_arrow(),
121 DataType::Int256 => self.int256_type_to_arrow(),
122 DataType::Float32 => self.float32_type_to_arrow(),
123 DataType::Float64 => self.float64_type_to_arrow(),
124 DataType::Date => self.date_type_to_arrow(),
125 DataType::Time => self.time_type_to_arrow(),
126 DataType::Timestamp => self.timestamp_type_to_arrow(),
127 DataType::Timestamptz => self.timestamptz_type_to_arrow(),
128 DataType::Interval => self.interval_type_to_arrow(),
129 DataType::Varchar => self.varchar_type_to_arrow(),
130 DataType::Bytea => self.bytea_type_to_arrow(),
131 DataType::Serial => self.serial_type_to_arrow(),
132 DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
133 DataType::Jsonb => self.varchar_type_to_arrow(),
134 DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
135 DataType::List(list) => self.list_type_to_arrow(list)?,
136 DataType::Map(map) => self.map_type_to_arrow(map)?,
137 DataType::Vector(_) => self.vector_type_to_arrow()?,
138 };
139 Ok(arrow_schema::Field::new(name, data_type, true))
140 }
141
142 #[inline]
143 fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
144 arrow_schema::DataType::Utf8
145 }
146
147 #[inline]
148 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
149 let data_type =
151 arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
152 arrow_schema::Field::new(name, data_type, true)
153 }
154
155 fn decimal_to_arrow(
156 &self,
157 data_type: &arrow_schema::DataType,
158 array: &DecimalArray,
159 ) -> Result<arrow_array::ArrayRef, ArrayError> {
160 let (precision, max_scale) = match data_type {
161 arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
162 _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
163 };
164
165 let max_value = 10_i128.pow(precision as u32) - 1;
167 let values: Vec<Option<i128>> = array
168 .iter()
169 .map(|e| {
170 e.and_then(|e| match e {
171 crate::array::Decimal::Normalized(e) => {
172 let value = e.mantissa();
173 let scale = e.scale() as i8;
174 let diff_scale = abs(max_scale - scale);
175 let value = match scale {
176 _ if scale < max_scale => value
177 .checked_mul(10_i128.pow(diff_scale as u32))
178 .and_then(|v| if abs(v) <= max_value { Some(v) } else { None })
179 .unwrap_or_else(|| {
180 tracing::warn!(
181 "Decimal overflow when converting to arrow decimal with precision {} and scale {}. It will be replaced with inf/-inf.",
182 precision, max_scale
183 );
184 if value >= 0 { max_value } else { -max_value }
185 }),
186 _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
187 _ => value,
188 };
189 Some(value)
190 }
191 crate::array::Decimal::PositiveInf => {
193 Some(max_value)
194 }
195 crate::array::Decimal::NegativeInf => {
196 Some(-max_value)
197 }
198 crate::array::Decimal::NaN => None,
199 })
200 })
201 .collect();
202
203 let array = arrow_array::Decimal128Array::from(values)
204 .with_precision_and_scale(precision, max_scale)
205 .map_err(ArrayError::from_arrow)?;
206 Ok(Arc::new(array) as ArrayRef)
207 }
208
209 fn interval_to_arrow(
210 &self,
211 array: &IntervalArray,
212 ) -> Result<arrow_array::ArrayRef, ArrayError> {
213 Ok(Arc::new(arrow_array::StringArray::from(array)))
214 }
215}
216
217impl FromArrow for IcebergArrowConvert {}
218
219#[derive(Default)]
228pub struct IcebergCreateTableArrowConvert {
229 next_field_id: RefCell<u32>,
230}
231
232impl IcebergCreateTableArrowConvert {
233 pub fn to_arrow_field(
234 &self,
235 name: &str,
236 data_type: &DataType,
237 ) -> Result<arrow_schema::Field, ArrayError> {
238 ToArrow::to_arrow_field(self, name, data_type)
239 }
240
241 fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
242 *self.next_field_id.borrow_mut() += 1;
243 let field_id = *self.next_field_id.borrow();
244
245 let mut metadata = HashMap::new();
246 metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
248 arrow_field.set_metadata(metadata);
249 }
250}
251
252impl ToArrow for IcebergCreateTableArrowConvert {
253 #[inline]
254 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
255 let data_type =
261 arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
262
263 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
264 self.add_field_id(&mut arrow_field);
265 arrow_field
266 }
267
268 #[inline]
269 fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
270 arrow_schema::DataType::Utf8
271 }
272
273 fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
274 let data_type = arrow_schema::DataType::Utf8;
275
276 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
277 self.add_field_id(&mut arrow_field);
278 arrow_field
279 }
280
281 fn to_arrow_field(
286 &self,
287 name: &str,
288 value: &DataType,
289 ) -> Result<arrow_schema::Field, ArrayError> {
290 let data_type = match value {
291 DataType::Boolean => self.bool_type_to_arrow(),
293 DataType::Int16 => self.int32_type_to_arrow(),
294 DataType::Int32 => self.int32_type_to_arrow(),
295 DataType::Int64 => self.int64_type_to_arrow(),
296 DataType::Int256 => self.varchar_type_to_arrow(),
297 DataType::Float32 => self.float32_type_to_arrow(),
298 DataType::Float64 => self.float64_type_to_arrow(),
299 DataType::Date => self.date_type_to_arrow(),
300 DataType::Time => self.time_type_to_arrow(),
301 DataType::Timestamp => self.timestamp_type_to_arrow(),
302 DataType::Timestamptz => self.timestamptz_type_to_arrow(),
303 DataType::Interval => self.interval_type_to_arrow(),
304 DataType::Varchar => self.varchar_type_to_arrow(),
305 DataType::Bytea => self.bytea_type_to_arrow(),
306 DataType::Serial => self.serial_type_to_arrow(),
307 DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
308 DataType::Jsonb => self.varchar_type_to_arrow(),
309 DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
310 DataType::List(list) => self.list_type_to_arrow(list)?,
311 DataType::Map(map) => self.map_type_to_arrow(map)?,
312 DataType::Vector(_) => self.vector_type_to_arrow()?,
313 };
314
315 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
316 self.add_field_id(&mut arrow_field);
317 Ok(arrow_field)
318 }
319}
320
321#[cfg(test)]
322mod test {
323 use std::sync::Arc;
324
325 use super::arrow_array::{ArrayRef, Decimal128Array};
326 use super::arrow_schema::DataType;
327 use super::*;
328 use crate::array::{Decimal, DecimalArray};
329
330 #[test]
331 fn decimal() {
332 let array = DecimalArray::from_iter([
333 None,
334 Some(Decimal::NaN),
335 Some(Decimal::PositiveInf),
336 Some(Decimal::NegativeInf),
337 Some(Decimal::Normalized("123.4".parse().unwrap())),
338 Some(Decimal::Normalized("123.456".parse().unwrap())),
339 ]);
340 let ty = DataType::Decimal128(6, 3);
341 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
342 let expect_array = Arc::new(
343 Decimal128Array::from(vec![
344 None,
345 None,
346 Some(999999),
347 Some(-999999),
348 Some(123400),
349 Some(123456),
350 ])
351 .with_data_type(ty),
352 ) as ArrayRef;
353 assert_eq!(&arrow_array, &expect_array);
354 }
355
356 #[test]
357 fn decimal_with_large_scale() {
358 let array = DecimalArray::from_iter([
359 None,
360 Some(Decimal::NaN),
361 Some(Decimal::PositiveInf),
362 Some(Decimal::NegativeInf),
363 Some(Decimal::Normalized("123.4".parse().unwrap())),
364 Some(Decimal::Normalized("123.456".parse().unwrap())),
365 ]);
366 let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
367 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
368 let expect_array = Arc::new(
369 Decimal128Array::from(vec![
370 None,
371 None,
372 Some(99999999999999999999999999999999999999),
374 Some(-99999999999999999999999999999999999999),
375 Some(1234000000000),
376 Some(1234560000000),
377 ])
378 .with_data_type(ty),
379 ) as ArrayRef;
380 assert_eq!(&arrow_array, &expect_array);
381 }
382
383 #[test]
384 fn decimal_edge_cases_risingwave_precision() {
385 let array = DecimalArray::from_iter([
387 Some(Decimal::Normalized(
389 "999999999999999999999999999".parse().unwrap(),
390 )),
391 Some(Decimal::Normalized(
393 "9999999999999999999999999999".parse().unwrap(),
394 )),
395 Some(Decimal::Normalized(
397 "999999999999999999.9999999999".parse().unwrap(),
398 )),
399 Some(Decimal::Normalized(
401 "0.9999999999999999999999999999".parse().unwrap(),
402 )),
403 Some(Decimal::Normalized(
405 "-999999999999999999999999999".parse().unwrap(),
406 )),
407 Some(Decimal::Normalized("1000000000000000000".parse().unwrap())),
409 Some(Decimal::Normalized("0.0000000001".parse().unwrap())),
411 Some(Decimal::Normalized("0.0000000000".parse().unwrap())),
413 ]);
414
415 let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
416 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
417
418 let expect_array = Arc::new(
419 Decimal128Array::from(vec![
420 Some(9999999999999999999999999990000000000),
422 Some(99999999999999999999999999990000000000),
424 Some(9999999999999999999999999999),
426 Some(9999999999),
428 Some(-9999999999999999999999999990000000000),
430 Some(10000000000000000000000000000),
432 Some(1),
434 Some(0),
436 ])
437 .with_data_type(ty),
438 ) as ArrayRef;
439
440 assert_eq!(&arrow_array, &expect_array);
441 }
442
443 #[test]
444 fn decimal_special_values_roundtrip() {
445 use crate::array::Array;
447
448 let original_array = DecimalArray::from_iter([
449 Some(Decimal::PositiveInf),
450 Some(Decimal::NegativeInf),
451 Some(Decimal::NaN),
452 Some(Decimal::Normalized("123.45".parse().unwrap())),
453 None,
454 ]);
455
456 let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
458 let arrow_array = IcebergArrowConvert
459 .decimal_to_arrow(&ty, &original_array)
460 .unwrap();
461
462 let arrow_decimal: &arrow_array::Decimal128Array = arrow_array
464 .as_any()
465 .downcast_ref()
466 .expect("should be Decimal128Array");
467
468 let roundtrip_array: DecimalArray = arrow_decimal.try_into().unwrap();
469
470 assert_eq!(original_array.len(), roundtrip_array.len());
472
473 assert_eq!(roundtrip_array.value_at(0), Some(Decimal::PositiveInf));
475
476 assert_eq!(roundtrip_array.value_at(1), Some(Decimal::NegativeInf));
478
479 assert_eq!(roundtrip_array.value_at(2), None);
481
482 assert!(matches!(
484 roundtrip_array.value_at(3),
485 Some(Decimal::Normalized(_))
486 ));
487
488 assert_eq!(roundtrip_array.value_at(4), None);
490 }
491}