1use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::{Div, Mul};
18use std::sync::Arc;
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_54::{
24 FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25 is_parquet_schema_match_source_schema,
26};
27use crate::array::{
28 Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray, IntervalArray,
29 VECTOR_ITEM_TYPE,
30};
31use crate::types::StructType;
32
33pub struct IcebergArrowConvert;
34
35pub const ICEBERG_DECIMAL_PRECISION: u8 = 28;
36pub const ICEBERG_DECIMAL_SCALE: i8 = 10;
37
38impl IcebergArrowConvert {
39 pub fn to_record_batch(
40 &self,
41 schema: arrow_schema::SchemaRef,
42 chunk: &DataChunk,
43 ) -> Result<arrow_array::RecordBatch, ArrayError> {
44 ToArrow::to_record_batch(self, schema, chunk)
45 }
46
47 pub fn chunk_from_record_batch(
48 &self,
49 batch: &arrow_array::RecordBatch,
50 ) -> Result<DataChunk, ArrayError> {
51 FromArrow::from_record_batch(self, batch)
52 }
53
54 pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
55 FromArrow::from_field(self, field)
56 }
57
58 pub fn to_arrow_field(
59 &self,
60 name: &str,
61 data_type: &DataType,
62 ) -> Result<arrow_schema::Field, ArrayError> {
63 ToArrow::to_arrow_field(self, name, data_type)
64 }
65
66 pub fn struct_from_fields(
67 &self,
68 fields: &arrow_schema::Fields,
69 ) -> Result<StructType, ArrayError> {
70 FromArrow::from_fields(self, fields)
71 }
72
73 pub fn to_arrow_array(
74 &self,
75 data_type: &arrow_schema::DataType,
76 array: &ArrayImpl,
77 ) -> Result<arrow_array::ArrayRef, ArrayError> {
78 ToArrow::to_array(self, data_type, array)
79 }
80
81 pub fn array_from_arrow_array(
82 &self,
83 field: &arrow_schema::Field,
84 array: &arrow_array::ArrayRef,
85 ) -> Result<ArrayImpl, ArrayError> {
86 FromArrow::from_array(self, field, array)
87 }
88}
89
90impl ToArrow for IcebergArrowConvert {
91 fn to_arrow_field(
92 &self,
93 name: &str,
94 data_type: &DataType,
95 ) -> Result<arrow_schema::Field, ArrayError> {
96 let data_type = match data_type {
97 DataType::Boolean => self.bool_type_to_arrow(),
98 DataType::Int16 => self.int32_type_to_arrow(),
99 DataType::Int32 => self.int32_type_to_arrow(),
100 DataType::Int64 => self.int64_type_to_arrow(),
101 DataType::Int256 => self.int256_type_to_arrow(),
102 DataType::Float32 => self.float32_type_to_arrow(),
103 DataType::Float64 => self.float64_type_to_arrow(),
104 DataType::Date => self.date_type_to_arrow(),
105 DataType::Time => self.time_type_to_arrow(),
106 DataType::Timestamp => self.timestamp_type_to_arrow(),
107 DataType::Timestamptz => self.timestamptz_type_to_arrow(),
108 DataType::Interval => self.interval_type_to_arrow(),
109 DataType::Varchar => self.varchar_type_to_arrow(),
110 DataType::Bytea => self.bytea_type_to_arrow(),
111 DataType::Serial => self.serial_type_to_arrow(),
112 DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
113 DataType::Jsonb => self.varchar_type_to_arrow(),
114 DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
115 DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
116 DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
117 DataType::Vector(_) => self.list_type_to_arrow(&VECTOR_ITEM_TYPE)?,
118 };
119 Ok(arrow_schema::Field::new(name, data_type, true))
120 }
121
122 #[inline]
123 fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
124 arrow_schema::DataType::Utf8
125 }
126
127 #[inline]
128 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
129 let data_type =
131 arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
132 arrow_schema::Field::new(name, data_type, true)
133 }
134
135 fn decimal_to_arrow(
136 &self,
137 data_type: &arrow_schema::DataType,
138 array: &DecimalArray,
139 ) -> Result<arrow_array::ArrayRef, ArrayError> {
140 let (precision, max_scale) = match data_type {
141 arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
142 _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
143 };
144
145 let values: Vec<Option<i128>> = array
147 .iter()
148 .map(|e| {
149 e.and_then(|e| match e {
150 crate::array::Decimal::Normalized(e) => {
151 let value = e.mantissa();
152 let scale = e.scale() as i8;
153 let diff_scale = abs(max_scale - scale);
154 let value = match scale {
155 _ if scale < max_scale => value.mul(10_i128.pow(diff_scale as u32)),
156 _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
157 _ => value,
158 };
159 Some(value)
160 }
161 crate::array::Decimal::PositiveInf => {
163 let max_value = 10_i128.pow(precision as u32) - 1;
164 Some(max_value)
165 }
166 crate::array::Decimal::NegativeInf => {
167 let max_value = 10_i128.pow(precision as u32) - 1;
168 Some(-max_value)
169 }
170 crate::array::Decimal::NaN => None,
171 })
172 })
173 .collect();
174
175 let array = arrow_array::Decimal128Array::from(values)
176 .with_precision_and_scale(precision, max_scale)
177 .map_err(ArrayError::from_arrow)?;
178 Ok(Arc::new(array) as ArrayRef)
179 }
180
181 fn interval_to_arrow(
182 &self,
183 array: &IntervalArray,
184 ) -> Result<arrow_array::ArrayRef, ArrayError> {
185 Ok(Arc::new(arrow_array::StringArray::from(array)))
186 }
187}
188
189impl FromArrow for IcebergArrowConvert {}
190
191#[derive(Default)]
200pub struct IcebergCreateTableArrowConvert {
201 next_field_id: RefCell<u32>,
202}
203
204impl IcebergCreateTableArrowConvert {
205 pub fn to_arrow_field(
206 &self,
207 name: &str,
208 data_type: &DataType,
209 ) -> Result<arrow_schema::Field, ArrayError> {
210 ToArrow::to_arrow_field(self, name, data_type)
211 }
212
213 fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
214 *self.next_field_id.borrow_mut() += 1;
215 let field_id = *self.next_field_id.borrow();
216
217 let mut metadata = HashMap::new();
218 metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
220 arrow_field.set_metadata(metadata);
221 }
222}
223
224impl ToArrow for IcebergCreateTableArrowConvert {
225 #[inline]
226 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
227 let data_type =
233 arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
234
235 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
236 self.add_field_id(&mut arrow_field);
237 arrow_field
238 }
239
240 #[inline]
241 fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
242 arrow_schema::DataType::Utf8
243 }
244
245 fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
246 let data_type = arrow_schema::DataType::Utf8;
247
248 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
249 self.add_field_id(&mut arrow_field);
250 arrow_field
251 }
252
253 fn to_arrow_field(
258 &self,
259 name: &str,
260 value: &DataType,
261 ) -> Result<arrow_schema::Field, ArrayError> {
262 let data_type = match value {
263 DataType::Boolean => self.bool_type_to_arrow(),
265 DataType::Int16 => self.int32_type_to_arrow(),
266 DataType::Int32 => self.int32_type_to_arrow(),
267 DataType::Int64 => self.int64_type_to_arrow(),
268 DataType::Int256 => self.varchar_type_to_arrow(),
269 DataType::Float32 => self.float32_type_to_arrow(),
270 DataType::Float64 => self.float64_type_to_arrow(),
271 DataType::Date => self.date_type_to_arrow(),
272 DataType::Time => self.time_type_to_arrow(),
273 DataType::Timestamp => self.timestamp_type_to_arrow(),
274 DataType::Timestamptz => self.timestamptz_type_to_arrow(),
275 DataType::Interval => self.interval_type_to_arrow(),
276 DataType::Varchar => self.varchar_type_to_arrow(),
277 DataType::Bytea => self.bytea_type_to_arrow(),
278 DataType::Serial => self.serial_type_to_arrow(),
279 DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
280 DataType::Jsonb => self.varchar_type_to_arrow(),
281 DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
282 DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
283 DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
284 DataType::Vector(_) => self.list_type_to_arrow(&VECTOR_ITEM_TYPE)?,
285 };
286
287 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
288 self.add_field_id(&mut arrow_field);
289 Ok(arrow_field)
290 }
291}
292
293#[cfg(test)]
294mod test {
295 use std::sync::Arc;
296
297 use super::arrow_array::{ArrayRef, Decimal128Array};
298 use super::arrow_schema::DataType;
299 use super::*;
300 use crate::array::{Decimal, DecimalArray};
301
302 #[test]
303 fn decimal() {
304 let array = DecimalArray::from_iter([
305 None,
306 Some(Decimal::NaN),
307 Some(Decimal::PositiveInf),
308 Some(Decimal::NegativeInf),
309 Some(Decimal::Normalized("123.4".parse().unwrap())),
310 Some(Decimal::Normalized("123.456".parse().unwrap())),
311 ]);
312 let ty = DataType::Decimal128(6, 3);
313 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
314 let expect_array = Arc::new(
315 Decimal128Array::from(vec![
316 None,
317 None,
318 Some(999999),
319 Some(-999999),
320 Some(123400),
321 Some(123456),
322 ])
323 .with_data_type(ty),
324 ) as ArrayRef;
325 assert_eq!(&arrow_array, &expect_array);
326 }
327
328 #[test]
329 fn decimal_with_large_scale() {
330 let array = DecimalArray::from_iter([
331 None,
332 Some(Decimal::NaN),
333 Some(Decimal::PositiveInf),
334 Some(Decimal::NegativeInf),
335 Some(Decimal::Normalized("123.4".parse().unwrap())),
336 Some(Decimal::Normalized("123.456".parse().unwrap())),
337 ]);
338 let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
339 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
340 let expect_array = Arc::new(
341 Decimal128Array::from(vec![
342 None,
343 None,
344 Some(9999999999999999999999999999),
345 Some(-9999999999999999999999999999),
346 Some(1234000000000),
347 Some(1234560000000),
348 ])
349 .with_data_type(ty),
350 ) as ArrayRef;
351 assert_eq!(&arrow_array, &expect_array);
352 }
353}