1use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::{Div, Mul};
18use std::sync::Arc;
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_54::{
24 FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25 is_parquet_schema_match_source_schema,
26};
27use crate::array::{
28 Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray, IntervalArray,
29};
30use crate::types::StructType;
31
32pub struct IcebergArrowConvert;
33
34pub const ICEBERG_DECIMAL_PRECISION: u8 = 28;
35pub const ICEBERG_DECIMAL_SCALE: i8 = 10;
36
37impl IcebergArrowConvert {
38 pub fn to_record_batch(
39 &self,
40 schema: arrow_schema::SchemaRef,
41 chunk: &DataChunk,
42 ) -> Result<arrow_array::RecordBatch, ArrayError> {
43 ToArrow::to_record_batch(self, schema, chunk)
44 }
45
46 pub fn chunk_from_record_batch(
47 &self,
48 batch: &arrow_array::RecordBatch,
49 ) -> Result<DataChunk, ArrayError> {
50 FromArrow::from_record_batch(self, batch)
51 }
52
53 pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
54 FromArrow::from_field(self, field)
55 }
56
57 pub fn to_arrow_field(
58 &self,
59 name: &str,
60 data_type: &DataType,
61 ) -> Result<arrow_schema::Field, ArrayError> {
62 ToArrow::to_arrow_field(self, name, data_type)
63 }
64
65 pub fn struct_from_fields(
66 &self,
67 fields: &arrow_schema::Fields,
68 ) -> Result<StructType, ArrayError> {
69 FromArrow::from_fields(self, fields)
70 }
71
72 pub fn to_arrow_array(
73 &self,
74 data_type: &arrow_schema::DataType,
75 array: &ArrayImpl,
76 ) -> Result<arrow_array::ArrayRef, ArrayError> {
77 ToArrow::to_array(self, data_type, array)
78 }
79
80 pub fn array_from_arrow_array(
81 &self,
82 field: &arrow_schema::Field,
83 array: &arrow_array::ArrayRef,
84 ) -> Result<ArrayImpl, ArrayError> {
85 FromArrow::from_array(self, field, array)
86 }
87}
88
89impl ToArrow for IcebergArrowConvert {
90 fn to_arrow_field(
91 &self,
92 name: &str,
93 data_type: &DataType,
94 ) -> Result<arrow_schema::Field, ArrayError> {
95 let data_type = match data_type {
96 DataType::Boolean => self.bool_type_to_arrow(),
97 DataType::Int16 => self.int32_type_to_arrow(),
98 DataType::Int32 => self.int32_type_to_arrow(),
99 DataType::Int64 => self.int64_type_to_arrow(),
100 DataType::Int256 => self.int256_type_to_arrow(),
101 DataType::Float32 => self.float32_type_to_arrow(),
102 DataType::Float64 => self.float64_type_to_arrow(),
103 DataType::Date => self.date_type_to_arrow(),
104 DataType::Time => self.time_type_to_arrow(),
105 DataType::Timestamp => self.timestamp_type_to_arrow(),
106 DataType::Timestamptz => self.timestamptz_type_to_arrow(),
107 DataType::Interval => self.interval_type_to_arrow(),
108 DataType::Varchar => self.varchar_type_to_arrow(),
109 DataType::Bytea => self.bytea_type_to_arrow(),
110 DataType::Serial => self.serial_type_to_arrow(),
111 DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
112 DataType::Jsonb => self.varchar_type_to_arrow(),
113 DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
114 DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
115 DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
116 };
117 Ok(arrow_schema::Field::new(name, data_type, true))
118 }
119
120 #[inline]
121 fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
122 arrow_schema::DataType::Utf8
123 }
124
125 #[inline]
126 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
127 let data_type =
129 arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
130 arrow_schema::Field::new(name, data_type, true)
131 }
132
133 fn decimal_to_arrow(
134 &self,
135 data_type: &arrow_schema::DataType,
136 array: &DecimalArray,
137 ) -> Result<arrow_array::ArrayRef, ArrayError> {
138 let (precision, max_scale) = match data_type {
139 arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
140 _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
141 };
142
143 let values: Vec<Option<i128>> = array
145 .iter()
146 .map(|e| {
147 e.and_then(|e| match e {
148 crate::array::Decimal::Normalized(e) => {
149 let value = e.mantissa();
150 let scale = e.scale() as i8;
151 let diff_scale = abs(max_scale - scale);
152 let value = match scale {
153 _ if scale < max_scale => value.mul(10_i128.pow(diff_scale as u32)),
154 _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
155 _ => value,
156 };
157 Some(value)
158 }
159 crate::array::Decimal::PositiveInf => {
161 let max_value = 10_i128.pow(precision as u32) - 1;
162 Some(max_value)
163 }
164 crate::array::Decimal::NegativeInf => {
165 let max_value = 10_i128.pow(precision as u32) - 1;
166 Some(-max_value)
167 }
168 crate::array::Decimal::NaN => None,
169 })
170 })
171 .collect();
172
173 let array = arrow_array::Decimal128Array::from(values)
174 .with_precision_and_scale(precision, max_scale)
175 .map_err(ArrayError::from_arrow)?;
176 Ok(Arc::new(array) as ArrayRef)
177 }
178
179 fn interval_to_arrow(
180 &self,
181 array: &IntervalArray,
182 ) -> Result<arrow_array::ArrayRef, ArrayError> {
183 Ok(Arc::new(arrow_array::StringArray::from(array)))
184 }
185}
186
187impl FromArrow for IcebergArrowConvert {}
188
189#[derive(Default)]
198pub struct IcebergCreateTableArrowConvert {
199 next_field_id: RefCell<u32>,
200}
201
202impl IcebergCreateTableArrowConvert {
203 pub fn to_arrow_field(
204 &self,
205 name: &str,
206 data_type: &DataType,
207 ) -> Result<arrow_schema::Field, ArrayError> {
208 ToArrow::to_arrow_field(self, name, data_type)
209 }
210
211 fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
212 *self.next_field_id.borrow_mut() += 1;
213 let field_id = *self.next_field_id.borrow();
214
215 let mut metadata = HashMap::new();
216 metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
218 arrow_field.set_metadata(metadata);
219 }
220}
221
222impl ToArrow for IcebergCreateTableArrowConvert {
223 #[inline]
224 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
225 let data_type =
231 arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
232
233 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
234 self.add_field_id(&mut arrow_field);
235 arrow_field
236 }
237
238 #[inline]
239 fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
240 arrow_schema::DataType::Utf8
241 }
242
243 fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
244 let data_type = arrow_schema::DataType::Utf8;
245
246 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
247 self.add_field_id(&mut arrow_field);
248 arrow_field
249 }
250
251 fn to_arrow_field(
256 &self,
257 name: &str,
258 value: &DataType,
259 ) -> Result<arrow_schema::Field, ArrayError> {
260 let data_type = match value {
261 DataType::Boolean => self.bool_type_to_arrow(),
263 DataType::Int16 => self.int32_type_to_arrow(),
264 DataType::Int32 => self.int32_type_to_arrow(),
265 DataType::Int64 => self.int64_type_to_arrow(),
266 DataType::Int256 => self.varchar_type_to_arrow(),
267 DataType::Float32 => self.float32_type_to_arrow(),
268 DataType::Float64 => self.float64_type_to_arrow(),
269 DataType::Date => self.date_type_to_arrow(),
270 DataType::Time => self.time_type_to_arrow(),
271 DataType::Timestamp => self.timestamp_type_to_arrow(),
272 DataType::Timestamptz => self.timestamptz_type_to_arrow(),
273 DataType::Interval => self.interval_type_to_arrow(),
274 DataType::Varchar => self.varchar_type_to_arrow(),
275 DataType::Bytea => self.bytea_type_to_arrow(),
276 DataType::Serial => self.serial_type_to_arrow(),
277 DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
278 DataType::Jsonb => self.varchar_type_to_arrow(),
279 DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
280 DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
281 DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
282 };
283
284 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
285 self.add_field_id(&mut arrow_field);
286 Ok(arrow_field)
287 }
288}
289
290#[cfg(test)]
291mod test {
292 use std::sync::Arc;
293
294 use super::arrow_array::{ArrayRef, Decimal128Array};
295 use super::arrow_schema::DataType;
296 use super::*;
297 use crate::array::{Decimal, DecimalArray};
298
299 #[test]
300 fn decimal() {
301 let array = DecimalArray::from_iter([
302 None,
303 Some(Decimal::NaN),
304 Some(Decimal::PositiveInf),
305 Some(Decimal::NegativeInf),
306 Some(Decimal::Normalized("123.4".parse().unwrap())),
307 Some(Decimal::Normalized("123.456".parse().unwrap())),
308 ]);
309 let ty = DataType::Decimal128(6, 3);
310 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
311 let expect_array = Arc::new(
312 Decimal128Array::from(vec![
313 None,
314 None,
315 Some(999999),
316 Some(-999999),
317 Some(123400),
318 Some(123456),
319 ])
320 .with_data_type(ty),
321 ) as ArrayRef;
322 assert_eq!(&arrow_array, &expect_array);
323 }
324
325 #[test]
326 fn decimal_with_large_scale() {
327 let array = DecimalArray::from_iter([
328 None,
329 Some(Decimal::NaN),
330 Some(Decimal::PositiveInf),
331 Some(Decimal::NegativeInf),
332 Some(Decimal::Normalized("123.4".parse().unwrap())),
333 Some(Decimal::Normalized("123.456".parse().unwrap())),
334 ]);
335 let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
336 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
337 let expect_array = Arc::new(
338 Decimal128Array::from(vec![
339 None,
340 None,
341 Some(9999999999999999999999999999),
342 Some(-9999999999999999999999999999),
343 Some(1234000000000),
344 Some(1234560000000),
345 ])
346 .with_data_type(ty),
347 ) as ArrayRef;
348 assert_eq!(&arrow_array, &expect_array);
349 }
350}