1use std::cell::RefCell;
16use std::collections::HashMap;
17use std::ops::{Div, Mul};
18use std::sync::Arc;
19
20use arrow_array::ArrayRef;
21use num_traits::abs;
22
23pub use super::arrow_54::{
24 FromArrow, ToArrow, arrow_array, arrow_buffer, arrow_cast, arrow_schema,
25 is_parquet_schema_match_source_schema,
26};
27use crate::array::{
28 Array, ArrayError, ArrayImpl, DataChunk, DataType, DecimalArray, IntervalArray,
29};
30use crate::types::StructType;
31
32pub struct IcebergArrowConvert;
33
34pub const ICEBERG_DECIMAL_PRECISION: u8 = 28;
35pub const ICEBERG_DECIMAL_SCALE: i8 = 10;
36
37impl IcebergArrowConvert {
38 pub fn to_record_batch(
39 &self,
40 schema: arrow_schema::SchemaRef,
41 chunk: &DataChunk,
42 ) -> Result<arrow_array::RecordBatch, ArrayError> {
43 ToArrow::to_record_batch(self, schema, chunk)
44 }
45
46 pub fn chunk_from_record_batch(
47 &self,
48 batch: &arrow_array::RecordBatch,
49 ) -> Result<DataChunk, ArrayError> {
50 FromArrow::from_record_batch(self, batch)
51 }
52
53 pub fn type_from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
54 FromArrow::from_field(self, field)
55 }
56
57 pub fn to_arrow_field(
58 &self,
59 name: &str,
60 data_type: &DataType,
61 ) -> Result<arrow_schema::Field, ArrayError> {
62 ToArrow::to_arrow_field(self, name, data_type)
63 }
64
65 pub fn struct_from_fields(
66 &self,
67 fields: &arrow_schema::Fields,
68 ) -> Result<StructType, ArrayError> {
69 FromArrow::from_fields(self, fields)
70 }
71
72 pub fn to_arrow_array(
73 &self,
74 data_type: &arrow_schema::DataType,
75 array: &ArrayImpl,
76 ) -> Result<arrow_array::ArrayRef, ArrayError> {
77 ToArrow::to_array(self, data_type, array)
78 }
79
80 pub fn array_from_arrow_array(
81 &self,
82 field: &arrow_schema::Field,
83 array: &arrow_array::ArrayRef,
84 ) -> Result<ArrayImpl, ArrayError> {
85 FromArrow::from_array(self, field, array)
86 }
87}
88
89impl ToArrow for IcebergArrowConvert {
90 fn to_arrow_field(
91 &self,
92 name: &str,
93 data_type: &DataType,
94 ) -> Result<arrow_schema::Field, ArrayError> {
95 let data_type = match data_type {
96 DataType::Boolean => self.bool_type_to_arrow(),
97 DataType::Int16 => self.int32_type_to_arrow(),
98 DataType::Int32 => self.int32_type_to_arrow(),
99 DataType::Int64 => self.int64_type_to_arrow(),
100 DataType::Int256 => self.int256_type_to_arrow(),
101 DataType::Float32 => self.float32_type_to_arrow(),
102 DataType::Float64 => self.float64_type_to_arrow(),
103 DataType::Date => self.date_type_to_arrow(),
104 DataType::Time => self.time_type_to_arrow(),
105 DataType::Timestamp => self.timestamp_type_to_arrow(),
106 DataType::Timestamptz => self.timestamptz_type_to_arrow(),
107 DataType::Interval => self.interval_type_to_arrow(),
108 DataType::Varchar => self.varchar_type_to_arrow(),
109 DataType::Bytea => self.bytea_type_to_arrow(),
110 DataType::Serial => self.serial_type_to_arrow(),
111 DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
112 DataType::Jsonb => self.varchar_type_to_arrow(),
113 DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
114 DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
115 DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
116 DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
117 };
118 Ok(arrow_schema::Field::new(name, data_type, true))
119 }
120
121 #[inline]
122 fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
123 arrow_schema::DataType::Utf8
124 }
125
126 #[inline]
127 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
128 let data_type =
130 arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
131 arrow_schema::Field::new(name, data_type, true)
132 }
133
134 fn decimal_to_arrow(
135 &self,
136 data_type: &arrow_schema::DataType,
137 array: &DecimalArray,
138 ) -> Result<arrow_array::ArrayRef, ArrayError> {
139 let (precision, max_scale) = match data_type {
140 arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
141 _ => return Err(ArrayError::to_arrow("Invalid decimal type")),
142 };
143
144 let values: Vec<Option<i128>> = array
146 .iter()
147 .map(|e| {
148 e.and_then(|e| match e {
149 crate::array::Decimal::Normalized(e) => {
150 let value = e.mantissa();
151 let scale = e.scale() as i8;
152 let diff_scale = abs(max_scale - scale);
153 let value = match scale {
154 _ if scale < max_scale => value.mul(10_i128.pow(diff_scale as u32)),
155 _ if scale > max_scale => value.div(10_i128.pow(diff_scale as u32)),
156 _ => value,
157 };
158 Some(value)
159 }
160 crate::array::Decimal::PositiveInf => {
162 let max_value = 10_i128.pow(precision as u32) - 1;
163 Some(max_value)
164 }
165 crate::array::Decimal::NegativeInf => {
166 let max_value = 10_i128.pow(precision as u32) - 1;
167 Some(-max_value)
168 }
169 crate::array::Decimal::NaN => None,
170 })
171 })
172 .collect();
173
174 let array = arrow_array::Decimal128Array::from(values)
175 .with_precision_and_scale(precision, max_scale)
176 .map_err(ArrayError::from_arrow)?;
177 Ok(Arc::new(array) as ArrayRef)
178 }
179
180 fn interval_to_arrow(
181 &self,
182 array: &IntervalArray,
183 ) -> Result<arrow_array::ArrayRef, ArrayError> {
184 Ok(Arc::new(arrow_array::StringArray::from(array)))
185 }
186}
187
188impl FromArrow for IcebergArrowConvert {}
189
190#[derive(Default)]
199pub struct IcebergCreateTableArrowConvert {
200 next_field_id: RefCell<u32>,
201}
202
203impl IcebergCreateTableArrowConvert {
204 pub fn to_arrow_field(
205 &self,
206 name: &str,
207 data_type: &DataType,
208 ) -> Result<arrow_schema::Field, ArrayError> {
209 ToArrow::to_arrow_field(self, name, data_type)
210 }
211
212 fn add_field_id(&self, arrow_field: &mut arrow_schema::Field) {
213 *self.next_field_id.borrow_mut() += 1;
214 let field_id = *self.next_field_id.borrow();
215
216 let mut metadata = HashMap::new();
217 metadata.insert("PARQUET:field_id".to_owned(), field_id.to_string());
219 arrow_field.set_metadata(metadata);
220 }
221}
222
223impl ToArrow for IcebergCreateTableArrowConvert {
224 #[inline]
225 fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
226 let data_type =
232 arrow_schema::DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
233
234 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
235 self.add_field_id(&mut arrow_field);
236 arrow_field
237 }
238
239 #[inline]
240 fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
241 arrow_schema::DataType::Utf8
242 }
243
244 fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
245 let data_type = arrow_schema::DataType::Utf8;
246
247 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
248 self.add_field_id(&mut arrow_field);
249 arrow_field
250 }
251
252 fn to_arrow_field(
257 &self,
258 name: &str,
259 value: &DataType,
260 ) -> Result<arrow_schema::Field, ArrayError> {
261 let data_type = match value {
262 DataType::Boolean => self.bool_type_to_arrow(),
264 DataType::Int16 => self.int32_type_to_arrow(),
265 DataType::Int32 => self.int32_type_to_arrow(),
266 DataType::Int64 => self.int64_type_to_arrow(),
267 DataType::Int256 => self.varchar_type_to_arrow(),
268 DataType::Float32 => self.float32_type_to_arrow(),
269 DataType::Float64 => self.float64_type_to_arrow(),
270 DataType::Date => self.date_type_to_arrow(),
271 DataType::Time => self.time_type_to_arrow(),
272 DataType::Timestamp => self.timestamp_type_to_arrow(),
273 DataType::Timestamptz => self.timestamptz_type_to_arrow(),
274 DataType::Interval => self.interval_type_to_arrow(),
275 DataType::Varchar => self.varchar_type_to_arrow(),
276 DataType::Bytea => self.bytea_type_to_arrow(),
277 DataType::Serial => self.serial_type_to_arrow(),
278 DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
279 DataType::Jsonb => self.varchar_type_to_arrow(),
280 DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
281 DataType::List(datatype) => self.list_type_to_arrow(datatype)?,
282 DataType::Map(datatype) => self.map_type_to_arrow(datatype)?,
283 DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
284 };
285
286 let mut arrow_field = arrow_schema::Field::new(name, data_type, true);
287 self.add_field_id(&mut arrow_field);
288 Ok(arrow_field)
289 }
290}
291
292#[cfg(test)]
293mod test {
294 use std::sync::Arc;
295
296 use super::arrow_array::{ArrayRef, Decimal128Array};
297 use super::arrow_schema::DataType;
298 use super::*;
299 use crate::array::{Decimal, DecimalArray};
300
301 #[test]
302 fn decimal() {
303 let array = DecimalArray::from_iter([
304 None,
305 Some(Decimal::NaN),
306 Some(Decimal::PositiveInf),
307 Some(Decimal::NegativeInf),
308 Some(Decimal::Normalized("123.4".parse().unwrap())),
309 Some(Decimal::Normalized("123.456".parse().unwrap())),
310 ]);
311 let ty = DataType::Decimal128(6, 3);
312 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
313 let expect_array = Arc::new(
314 Decimal128Array::from(vec![
315 None,
316 None,
317 Some(999999),
318 Some(-999999),
319 Some(123400),
320 Some(123456),
321 ])
322 .with_data_type(ty),
323 ) as ArrayRef;
324 assert_eq!(&arrow_array, &expect_array);
325 }
326
327 #[test]
328 fn decimal_with_large_scale() {
329 let array = DecimalArray::from_iter([
330 None,
331 Some(Decimal::NaN),
332 Some(Decimal::PositiveInf),
333 Some(Decimal::NegativeInf),
334 Some(Decimal::Normalized("123.4".parse().unwrap())),
335 Some(Decimal::Normalized("123.456".parse().unwrap())),
336 ]);
337 let ty = DataType::Decimal128(ICEBERG_DECIMAL_PRECISION, ICEBERG_DECIMAL_SCALE);
338 let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
339 let expect_array = Arc::new(
340 Decimal128Array::from(vec![
341 None,
342 None,
343 Some(9999999999999999999999999999),
344 Some(-9999999999999999999999999999),
345 Some(1234000000000),
346 Some(1234560000000),
347 ])
348 .with_data_type(ty),
349 ) as ArrayRef;
350 assert_eq!(&arrow_array, &expect_array);
351 }
352}