risingwave_common/array/arrow/
arrow_impl.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Converts between arrays and Apache Arrow arrays.
16//!
17//! This file acts as a template file for conversion code between
18//! arrays and different version of Apache Arrow.
19//!
20//! The conversion logic will be implemented for the arrow version specified in the outer mod by
21//! `super::arrow_xxx`, such as `super::arrow_array`.
22//!
23//! When we want to implement the conversion logic for an arrow version, we first
24//! create a new mod file, and rename the corresponding arrow package name to `arrow_xxx`
25//! using the `use` clause, and then declare a sub-mod and set its file path with attribute
26//! `#[path = "./arrow_impl.rs"]` so that the code in this template file can be embedded to
27//! the new mod file, and the conversion logic can be implemented for the corresponding arrow
28//! version.
29//!
30//! Example can be seen in `arrow_default.rs`, which is also as followed:
31//! ```ignore
32//! use {arrow_array, arrow_buffer, arrow_cast, arrow_schema};
33//!
34//! #[allow(clippy::duplicate_mod)]
35//! #[path = "./arrow_impl.rs"]
36//! mod arrow_impl;
37//! ```
38
39// Is this a bug? Why do we have these lints?
40#![allow(unused_imports)]
41#![allow(dead_code)]
42
43use std::fmt::Write;
44
45use arrow_array::array;
46use arrow_array::cast::AsArray;
47use arrow_buffer::OffsetBuffer;
48use arrow_schema::TimeUnit;
49use chrono::{DateTime, NaiveDateTime, NaiveTime};
50use itertools::Itertools;
51
52use super::arrow_schema::IntervalUnit;
53// This is important because we want to use the arrow version specified by the outer mod.
54use super::{ArrowIntervalType, arrow_array, arrow_buffer, arrow_cast, arrow_schema};
55// Other import should always use the absolute path.
56use crate::array::*;
57use crate::types::{DataType as RwDataType, Scalar, *};
58use crate::util::iter_util::ZipEqFast;
59
60/// Defines how to convert RisingWave arrays to Arrow arrays.
61///
62/// This trait allows for customized conversion logic for different external systems using Arrow.
63/// The default implementation is based on the `From` implemented in this mod.
64pub trait ToArrow {
65    /// Converts RisingWave `DataChunk` to Arrow `RecordBatch` with specified schema.
66    ///
67    /// This function will try to convert the array if the type is not same with the schema.
68    fn to_record_batch(
69        &self,
70        schema: arrow_schema::SchemaRef,
71        chunk: &DataChunk,
72    ) -> Result<arrow_array::RecordBatch, ArrayError> {
73        // compact the chunk if it's not compacted
74        if !chunk.is_vis_compacted() {
75            let c = chunk.clone();
76            return self.to_record_batch(schema, &c.compact_vis());
77        }
78
79        // convert each column to arrow array
80        let columns: Vec<_> = chunk
81            .columns()
82            .iter()
83            .zip_eq_fast(schema.fields().iter())
84            .map(|(column, field)| self.to_array(field.data_type(), column))
85            .try_collect()?;
86
87        // create record batch
88        let opts =
89            arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity()));
90        arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts)
91            .map_err(ArrayError::to_arrow)
92    }
93
94    /// Converts RisingWave array to Arrow array.
95    fn to_array(
96        &self,
97        data_type: &arrow_schema::DataType,
98        array: &ArrayImpl,
99    ) -> Result<arrow_array::ArrayRef, ArrayError> {
100        let arrow_array = match array {
101            ArrayImpl::Bool(array) => self.bool_to_arrow(array),
102            ArrayImpl::Int16(array) => self.int16_to_arrow(array),
103            ArrayImpl::Int32(array) => self.int32_to_arrow(array),
104            ArrayImpl::Int64(array) => self.int64_to_arrow(array),
105            ArrayImpl::Int256(array) => self.int256_to_arrow(array),
106            ArrayImpl::Float32(array) => self.float32_to_arrow(array),
107            ArrayImpl::Float64(array) => self.float64_to_arrow(array),
108            ArrayImpl::Date(array) => self.date_to_arrow(array),
109            ArrayImpl::Time(array) => self.time_to_arrow(array),
110            ArrayImpl::Timestamp(array) => self.timestamp_to_arrow(array),
111            ArrayImpl::Timestamptz(array) => self.timestamptz_to_arrow(array),
112            ArrayImpl::Interval(array) => self.interval_to_arrow(array),
113            ArrayImpl::Utf8(array) => self.utf8_to_arrow(array),
114            ArrayImpl::Bytea(array) => self.bytea_to_arrow(array),
115            ArrayImpl::Decimal(array) => self.decimal_to_arrow(data_type, array),
116            ArrayImpl::Jsonb(array) => self.jsonb_to_arrow(array),
117            ArrayImpl::Serial(array) => self.serial_to_arrow(array),
118            ArrayImpl::List(array) => self.list_to_arrow(data_type, array),
119            ArrayImpl::Struct(array) => self.struct_to_arrow(data_type, array),
120            ArrayImpl::Map(array) => self.map_to_arrow(data_type, array),
121            ArrayImpl::Vector(inner) => self.vector_to_arrow(data_type, inner),
122        }?;
123        if arrow_array.data_type() != data_type {
124            arrow_cast::cast(&arrow_array, data_type).map_err(ArrayError::to_arrow)
125        } else {
126            Ok(arrow_array)
127        }
128    }
129
130    #[inline]
131    fn bool_to_arrow(&self, array: &BoolArray) -> Result<arrow_array::ArrayRef, ArrayError> {
132        Ok(Arc::new(arrow_array::BooleanArray::from(array)))
133    }
134
135    #[inline]
136    fn int16_to_arrow(&self, array: &I16Array) -> Result<arrow_array::ArrayRef, ArrayError> {
137        Ok(Arc::new(arrow_array::Int16Array::from(array)))
138    }
139
140    #[inline]
141    fn int32_to_arrow(&self, array: &I32Array) -> Result<arrow_array::ArrayRef, ArrayError> {
142        Ok(Arc::new(arrow_array::Int32Array::from(array)))
143    }
144
145    #[inline]
146    fn int64_to_arrow(&self, array: &I64Array) -> Result<arrow_array::ArrayRef, ArrayError> {
147        Ok(Arc::new(arrow_array::Int64Array::from(array)))
148    }
149
150    #[inline]
151    fn float32_to_arrow(&self, array: &F32Array) -> Result<arrow_array::ArrayRef, ArrayError> {
152        Ok(Arc::new(arrow_array::Float32Array::from(array)))
153    }
154
155    #[inline]
156    fn float64_to_arrow(&self, array: &F64Array) -> Result<arrow_array::ArrayRef, ArrayError> {
157        Ok(Arc::new(arrow_array::Float64Array::from(array)))
158    }
159
160    #[inline]
161    fn utf8_to_arrow(&self, array: &Utf8Array) -> Result<arrow_array::ArrayRef, ArrayError> {
162        Ok(Arc::new(arrow_array::StringArray::from(array)))
163    }
164
165    #[inline]
166    fn int256_to_arrow(&self, array: &Int256Array) -> Result<arrow_array::ArrayRef, ArrayError> {
167        Ok(Arc::new(arrow_array::Decimal256Array::from(array)))
168    }
169
170    #[inline]
171    fn date_to_arrow(&self, array: &DateArray) -> Result<arrow_array::ArrayRef, ArrayError> {
172        Ok(Arc::new(arrow_array::Date32Array::from(array)))
173    }
174
175    #[inline]
176    fn timestamp_to_arrow(
177        &self,
178        array: &TimestampArray,
179    ) -> Result<arrow_array::ArrayRef, ArrayError> {
180        Ok(Arc::new(arrow_array::TimestampMicrosecondArray::from(
181            array,
182        )))
183    }
184
185    #[inline]
186    fn timestamptz_to_arrow(
187        &self,
188        array: &TimestamptzArray,
189    ) -> Result<arrow_array::ArrayRef, ArrayError> {
190        Ok(Arc::new(
191            arrow_array::TimestampMicrosecondArray::from(array).with_timezone_utc(),
192        ))
193    }
194
195    #[inline]
196    fn time_to_arrow(&self, array: &TimeArray) -> Result<arrow_array::ArrayRef, ArrayError> {
197        Ok(Arc::new(arrow_array::Time64MicrosecondArray::from(array)))
198    }
199
200    #[inline]
201    fn interval_to_arrow(
202        &self,
203        array: &IntervalArray,
204    ) -> Result<arrow_array::ArrayRef, ArrayError> {
205        Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from(
206            array,
207        )))
208    }
209
210    #[inline]
211    fn bytea_to_arrow(&self, array: &BytesArray) -> Result<arrow_array::ArrayRef, ArrayError> {
212        Ok(Arc::new(arrow_array::BinaryArray::from(array)))
213    }
214
215    // Decimal values are stored as ASCII text representation in a string array.
216    #[inline]
217    fn decimal_to_arrow(
218        &self,
219        _data_type: &arrow_schema::DataType,
220        array: &DecimalArray,
221    ) -> Result<arrow_array::ArrayRef, ArrayError> {
222        Ok(Arc::new(arrow_array::StringArray::from(array)))
223    }
224
225    // JSON values are stored as text representation in a string array.
226    #[inline]
227    fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
228        Ok(Arc::new(arrow_array::StringArray::from(array)))
229    }
230
231    #[inline]
232    fn serial_to_arrow(&self, array: &SerialArray) -> Result<arrow_array::ArrayRef, ArrayError> {
233        Ok(Arc::new(arrow_array::Int64Array::from(array)))
234    }
235
236    #[inline]
237    fn list_to_arrow(
238        &self,
239        data_type: &arrow_schema::DataType,
240        array: &ListArray,
241    ) -> Result<arrow_array::ArrayRef, ArrayError> {
242        let arrow_schema::DataType::List(field) = data_type else {
243            return Err(ArrayError::to_arrow("Invalid list type"));
244        };
245        let values = self.to_array(field.data_type(), array.values())?;
246        let offsets = OffsetBuffer::new(array.offsets().iter().map(|&o| o as i32).collect());
247        let nulls = (!array.null_bitmap().all()).then(|| array.null_bitmap().into());
248        Ok(Arc::new(arrow_array::ListArray::new(
249            field.clone(),
250            offsets,
251            values,
252            nulls,
253        )))
254    }
255
256    #[inline]
257    fn vector_to_arrow(
258        &self,
259        data_type: &arrow_schema::DataType,
260        array: &VectorArray,
261    ) -> Result<arrow_array::ArrayRef, ArrayError> {
262        let arrow_schema::DataType::List(field) = data_type else {
263            return Err(ArrayError::to_arrow("Invalid list type"));
264        };
265        if field.data_type() != &arrow_schema::DataType::Float32 {
266            return Err(ArrayError::to_arrow("Invalid list inner type for vector"));
267        }
268        let values = Arc::new(arrow_array::Float32Array::from(
269            array.as_raw_slice().to_vec(),
270        ));
271        let offsets = OffsetBuffer::new(array.offsets().iter().map(|&o| o as i32).collect());
272        let nulls = (!array.null_bitmap().all()).then(|| array.null_bitmap().into());
273        Ok(Arc::new(arrow_array::ListArray::new(
274            field.clone(),
275            offsets,
276            values,
277            nulls,
278        )))
279    }
280
281    #[inline]
282    fn struct_to_arrow(
283        &self,
284        data_type: &arrow_schema::DataType,
285        array: &StructArray,
286    ) -> Result<arrow_array::ArrayRef, ArrayError> {
287        let arrow_schema::DataType::Struct(fields) = data_type else {
288            return Err(ArrayError::to_arrow("Invalid struct type"));
289        };
290        // Use `try_new_with_length` so that empty-field structs keep their row count;
291        // `StructArray::new` panics for empty `fields` because it derives length from
292        // the child arrays.
293        let len = array.len();
294        let child_arrays = array
295            .fields()
296            .zip_eq_fast(fields)
297            .map(|(arr, field)| self.to_array(field.data_type(), arr))
298            .try_collect::<_, _, ArrayError>()?;
299        let nulls = Some(array.null_bitmap().into());
300        Ok(Arc::new(
301            arrow_array::StructArray::try_new_with_length(fields.clone(), child_arrays, nulls, len)
302                .map_err(ArrayError::from_arrow)?,
303        ))
304    }
305
306    #[inline]
307    fn map_to_arrow(
308        &self,
309        data_type: &arrow_schema::DataType,
310        array: &MapArray,
311    ) -> Result<arrow_array::ArrayRef, ArrayError> {
312        let arrow_schema::DataType::Map(field, ordered) = data_type else {
313            return Err(ArrayError::to_arrow("Invalid map type"));
314        };
315        if *ordered {
316            return Err(ArrayError::to_arrow("Sorted map is not supported"));
317        }
318        let values = self
319            .struct_to_arrow(field.data_type(), array.as_struct())?
320            .as_struct()
321            .clone();
322        let offsets = OffsetBuffer::new(array.offsets().iter().map(|&o| o as i32).collect());
323        let nulls = (!array.null_bitmap().all()).then(|| array.null_bitmap().into());
324        Ok(Arc::new(arrow_array::MapArray::new(
325            field.clone(),
326            offsets,
327            values,
328            nulls,
329            *ordered,
330        )))
331    }
332
333    /// Convert RisingWave data type to Arrow data type.
334    ///
335    /// This function returns a `Field` instead of `DataType` because some may be converted to
336    /// extension types which require additional metadata in the field.
337    fn to_arrow_field(
338        &self,
339        name: &str,
340        value: &DataType,
341    ) -> Result<arrow_schema::Field, ArrayError> {
342        let data_type = match value {
343            // using the inline function
344            DataType::Boolean => self.bool_type_to_arrow(),
345            DataType::Int16 => self.int16_type_to_arrow(),
346            DataType::Int32 => self.int32_type_to_arrow(),
347            DataType::Int64 => self.int64_type_to_arrow(),
348            DataType::Int256 => self.int256_type_to_arrow(),
349            DataType::Float32 => self.float32_type_to_arrow(),
350            DataType::Float64 => self.float64_type_to_arrow(),
351            DataType::Date => self.date_type_to_arrow(),
352            DataType::Time => self.time_type_to_arrow(),
353            DataType::Timestamp => self.timestamp_type_to_arrow(),
354            DataType::Timestamptz => self.timestamptz_type_to_arrow(),
355            DataType::Interval => self.interval_type_to_arrow(),
356            DataType::Varchar => self.varchar_type_to_arrow(),
357            DataType::Bytea => self.bytea_type_to_arrow(),
358            DataType::Serial => self.serial_type_to_arrow(),
359            DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)),
360            DataType::Jsonb => return Ok(self.jsonb_type_to_arrow(name)),
361            DataType::Struct(fields) => self.struct_type_to_arrow(fields)?,
362            DataType::List(list) => self.list_type_to_arrow(list)?,
363            DataType::Map(map) => self.map_type_to_arrow(map)?,
364            DataType::Vector(_) => self.vector_type_to_arrow()?,
365        };
366        Ok(arrow_schema::Field::new(name, data_type, true))
367    }
368
369    #[inline]
370    fn bool_type_to_arrow(&self) -> arrow_schema::DataType {
371        arrow_schema::DataType::Boolean
372    }
373
374    #[inline]
375    fn int16_type_to_arrow(&self) -> arrow_schema::DataType {
376        arrow_schema::DataType::Int16
377    }
378
379    #[inline]
380    fn int32_type_to_arrow(&self) -> arrow_schema::DataType {
381        arrow_schema::DataType::Int32
382    }
383
384    #[inline]
385    fn int64_type_to_arrow(&self) -> arrow_schema::DataType {
386        arrow_schema::DataType::Int64
387    }
388
389    #[inline]
390    fn int256_type_to_arrow(&self) -> arrow_schema::DataType {
391        arrow_schema::DataType::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0)
392    }
393
394    #[inline]
395    fn float32_type_to_arrow(&self) -> arrow_schema::DataType {
396        arrow_schema::DataType::Float32
397    }
398
399    #[inline]
400    fn float64_type_to_arrow(&self) -> arrow_schema::DataType {
401        arrow_schema::DataType::Float64
402    }
403
404    #[inline]
405    fn date_type_to_arrow(&self) -> arrow_schema::DataType {
406        arrow_schema::DataType::Date32
407    }
408
409    #[inline]
410    fn time_type_to_arrow(&self) -> arrow_schema::DataType {
411        arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Microsecond)
412    }
413
414    #[inline]
415    fn timestamp_type_to_arrow(&self) -> arrow_schema::DataType {
416        arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None)
417    }
418
419    #[inline]
420    fn timestamptz_type_to_arrow(&self) -> arrow_schema::DataType {
421        arrow_schema::DataType::Timestamp(
422            arrow_schema::TimeUnit::Microsecond,
423            Some("+00:00".into()),
424        )
425    }
426
427    #[inline]
428    fn interval_type_to_arrow(&self) -> arrow_schema::DataType {
429        arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano)
430    }
431
432    #[inline]
433    fn varchar_type_to_arrow(&self) -> arrow_schema::DataType {
434        arrow_schema::DataType::Utf8
435    }
436
437    #[inline]
438    fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
439        arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
440            .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into())
441    }
442
443    #[inline]
444    fn bytea_type_to_arrow(&self) -> arrow_schema::DataType {
445        arrow_schema::DataType::Binary
446    }
447
448    #[inline]
449    fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
450        arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
451            .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into())
452    }
453
454    #[inline]
455    fn serial_type_to_arrow(&self) -> arrow_schema::DataType {
456        arrow_schema::DataType::Int64
457    }
458
459    #[inline]
460    fn list_type_to_arrow(
461        &self,
462        list_type: &ListType,
463    ) -> Result<arrow_schema::DataType, ArrayError> {
464        Ok(arrow_schema::DataType::List(Arc::new(
465            self.to_arrow_field("item", list_type.elem())?,
466        )))
467    }
468
469    #[inline]
470    fn struct_type_to_arrow(
471        &self,
472        fields: &StructType,
473    ) -> Result<arrow_schema::DataType, ArrayError> {
474        Ok(arrow_schema::DataType::Struct(
475            fields
476                .iter()
477                .map(|(name, ty)| self.to_arrow_field(name, ty))
478                .try_collect::<_, _, ArrayError>()?,
479        ))
480    }
481
482    #[inline]
483    fn map_type_to_arrow(&self, map_type: &MapType) -> Result<arrow_schema::DataType, ArrayError> {
484        let sorted = false;
485        // "key" is always non-null
486        let key = self
487            .to_arrow_field("key", map_type.key())?
488            .with_nullable(false);
489        let value = self.to_arrow_field("value", map_type.value())?;
490        Ok(arrow_schema::DataType::Map(
491            Arc::new(arrow_schema::Field::new(
492                "entries",
493                arrow_schema::DataType::Struct([Arc::new(key), Arc::new(value)].into()),
494                // "entries" is always non-null
495                false,
496            )),
497            sorted,
498        ))
499    }
500
501    #[inline]
502    fn vector_type_to_arrow(&self) -> Result<arrow_schema::DataType, ArrayError> {
503        Ok(arrow_schema::DataType::List(Arc::new(
504            self.to_arrow_field("item", &VECTOR_ITEM_TYPE)?,
505        )))
506    }
507}
508
509/// Defines how to convert Arrow arrays to RisingWave arrays.
510#[allow(clippy::wrong_self_convention)]
511pub trait FromArrow {
512    /// Converts Arrow `RecordBatch` to RisingWave `DataChunk`.
513    fn from_record_batch(&self, batch: &arrow_array::RecordBatch) -> Result<DataChunk, ArrayError> {
514        let mut columns = Vec::with_capacity(batch.num_columns());
515        for (array, field) in batch.columns().iter().zip_eq_fast(batch.schema().fields()) {
516            let column = Arc::new(self.from_array(field, array)?);
517            columns.push(column);
518        }
519        Ok(DataChunk::new(columns, batch.num_rows()))
520    }
521
522    /// Converts Arrow `Fields` to RisingWave `StructType`.
523    fn from_fields(&self, fields: &arrow_schema::Fields) -> Result<StructType, ArrayError> {
524        Ok(StructType::new(
525            fields
526                .iter()
527                .map(|f| Ok((f.name().clone(), self.from_field(f)?)))
528                .try_collect::<_, Vec<_>, ArrayError>()?,
529        ))
530    }
531
532    /// Converts Arrow `Field` to RisingWave `DataType`.
533    fn from_field(&self, field: &arrow_schema::Field) -> Result<DataType, ArrayError> {
534        use arrow_schema::DataType::*;
535        use arrow_schema::IntervalUnit::*;
536        use arrow_schema::TimeUnit::*;
537
538        // extension type
539        if let Some(type_name) = field.metadata().get("ARROW:extension:name") {
540            return self.from_extension_type(type_name, field.data_type());
541        }
542
543        Ok(match field.data_type() {
544            Boolean => DataType::Boolean,
545            Int16 => DataType::Int16,
546            Int32 => DataType::Int32,
547            Int64 => DataType::Int64,
548            Int8 => DataType::Int16,
549            UInt8 => DataType::Int16,
550            UInt16 => DataType::Int32,
551            UInt32 => DataType::Int64,
552            UInt64 => DataType::Decimal,
553            Float16 => DataType::Float32,
554            Float32 => DataType::Float32,
555            Float64 => DataType::Float64,
556            Decimal128(_, _) => DataType::Decimal,
557            Decimal256(_, _) => DataType::Int256,
558            Date32 => DataType::Date,
559            Time64(Microsecond) => DataType::Time,
560            Timestamp(Microsecond, None) => DataType::Timestamp,
561            Timestamp(Microsecond, Some(_)) => DataType::Timestamptz,
562            Timestamp(Second, None) => DataType::Timestamp,
563            Timestamp(Second, Some(_)) => DataType::Timestamptz,
564            Timestamp(Millisecond, None) => DataType::Timestamp,
565            Timestamp(Millisecond, Some(_)) => DataType::Timestamptz,
566            Timestamp(Nanosecond, None) => DataType::Timestamp,
567            Timestamp(Nanosecond, Some(_)) => DataType::Timestamptz,
568            Interval(MonthDayNano) => DataType::Interval,
569            Utf8 => DataType::Varchar,
570            Utf8View => DataType::Varchar,
571            Binary => DataType::Bytea,
572            LargeUtf8 => self.from_large_utf8()?,
573            LargeBinary => self.from_large_binary()?,
574            List(field) => DataType::list(self.from_field(field)?),
575            Struct(fields) => DataType::Struct(self.from_fields(fields)?),
576            Map(field, _is_sorted) => {
577                let entries = self.from_field(field)?;
578                DataType::Map(MapType::try_from_entries(entries).map_err(|e| {
579                    ArrayError::from_arrow(format!("invalid arrow map field: {field:?}, err: {e}"))
580                })?)
581            }
582            t => {
583                return Err(ArrayError::from_arrow(format!(
584                    "unsupported arrow data type: {t:?}"
585                )));
586            }
587        })
588    }
589
590    /// Converts Arrow `LargeUtf8` type to RisingWave data type.
591    fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
592        Ok(DataType::Varchar)
593    }
594
595    /// Converts Arrow `LargeBinary` type to RisingWave data type.
596    fn from_large_binary(&self) -> Result<DataType, ArrayError> {
597        Ok(DataType::Bytea)
598    }
599
600    /// Converts Arrow extension type to RisingWave `DataType`.
601    fn from_extension_type(
602        &self,
603        type_name: &str,
604        physical_type: &arrow_schema::DataType,
605    ) -> Result<DataType, ArrayError> {
606        match (type_name, physical_type) {
607            ("arrowudf.decimal", arrow_schema::DataType::Utf8) => Ok(DataType::Decimal),
608            ("arrowudf.json", arrow_schema::DataType::Utf8) => Ok(DataType::Jsonb),
609            _ => Err(ArrayError::from_arrow(format!(
610                "unsupported extension type: {type_name:?}"
611            ))),
612        }
613    }
614
615    /// Converts Arrow `Array` to RisingWave `ArrayImpl`.
616    fn from_array(
617        &self,
618        field: &arrow_schema::Field,
619        array: &arrow_array::ArrayRef,
620    ) -> Result<ArrayImpl, ArrayError> {
621        use arrow_schema::DataType::*;
622        use arrow_schema::IntervalUnit::*;
623        use arrow_schema::TimeUnit::*;
624
625        // extension type
626        if let Some(type_name) = field.metadata().get("ARROW:extension:name") {
627            return self.from_extension_array(type_name, array);
628        }
629
630        // Struct projection for file source (Parquet): allow Arrow struct to be a superset of the
631        // expected struct fields. We align fields by name and ignore extra fields.
632        //
633        // Only use projection when Arrow struct differs from expected (superset, reordered, or
634        // different field names). If they match exactly, fall through to the normal path to
635        // avoid unnecessary overhead and potential issues with UDF/other paths.
636        if let (
637            arrow_schema::DataType::Struct(expected_fields),
638            arrow_schema::DataType::Struct(actual_fields),
639        ) = (field.data_type(), array.data_type())
640        {
641            let dominated = Self::struct_fields_dominated(expected_fields, actual_fields);
642            if dominated {
643                let struct_array: &arrow_array::StructArray =
644                    array.as_any().downcast_ref().unwrap();
645                return self.from_struct_array_projected(expected_fields, struct_array);
646            }
647            // else: fields match exactly, fall through to normal Struct(_) path below
648        }
649        match array.data_type() {
650            Boolean => self.from_bool_array(array.as_any().downcast_ref().unwrap()),
651            Int8 => self.from_int8_array(array.as_any().downcast_ref().unwrap()),
652            Int16 => self.from_int16_array(array.as_any().downcast_ref().unwrap()),
653            Int32 => self.from_int32_array(array.as_any().downcast_ref().unwrap()),
654            Int64 => self.from_int64_array(array.as_any().downcast_ref().unwrap()),
655            UInt8 => self.from_uint8_array(array.as_any().downcast_ref().unwrap()),
656            UInt16 => self.from_uint16_array(array.as_any().downcast_ref().unwrap()),
657            UInt32 => self.from_uint32_array(array.as_any().downcast_ref().unwrap()),
658
659            UInt64 => self.from_uint64_array(array.as_any().downcast_ref().unwrap()),
660            Decimal128(_, _) => self.from_decimal128_array(array.as_any().downcast_ref().unwrap()),
661            Decimal256(_, _) => self.from_int256_array(array.as_any().downcast_ref().unwrap()),
662            Float16 => self.from_float16_array(array.as_any().downcast_ref().unwrap()),
663            Float32 => self.from_float32_array(array.as_any().downcast_ref().unwrap()),
664            Float64 => self.from_float64_array(array.as_any().downcast_ref().unwrap()),
665            Date32 => self.from_date32_array(array.as_any().downcast_ref().unwrap()),
666            Time64(Microsecond) => self.from_time64us_array(array.as_any().downcast_ref().unwrap()),
667            Timestamp(Second, None) => {
668                self.from_timestampsecond_array(array.as_any().downcast_ref().unwrap())
669            }
670            Timestamp(Second, Some(_)) => {
671                self.from_timestampsecond_some_array(array.as_any().downcast_ref().unwrap())
672            }
673            Timestamp(Millisecond, None) => {
674                self.from_timestampms_array(array.as_any().downcast_ref().unwrap())
675            }
676            Timestamp(Millisecond, Some(_)) => {
677                self.from_timestampms_some_array(array.as_any().downcast_ref().unwrap())
678            }
679            Timestamp(Microsecond, None) => {
680                self.from_timestampus_array(array.as_any().downcast_ref().unwrap())
681            }
682            Timestamp(Microsecond, Some(_)) => {
683                self.from_timestampus_some_array(array.as_any().downcast_ref().unwrap())
684            }
685            Timestamp(Nanosecond, None) => {
686                self.from_timestampns_array(array.as_any().downcast_ref().unwrap())
687            }
688            Timestamp(Nanosecond, Some(_)) => {
689                self.from_timestampns_some_array(array.as_any().downcast_ref().unwrap())
690            }
691            Interval(MonthDayNano) => {
692                self.from_interval_array(array.as_any().downcast_ref().unwrap())
693            }
694            Utf8 => self.from_utf8_array(array.as_any().downcast_ref().unwrap()),
695            Utf8View => self.from_utf8_view_array(array.as_any().downcast_ref().unwrap()),
696            Binary => self.from_binary_array(array.as_any().downcast_ref().unwrap()),
697            LargeUtf8 => self.from_large_utf8_array(array.as_any().downcast_ref().unwrap()),
698            LargeBinary => self.from_large_binary_array(array.as_any().downcast_ref().unwrap()),
699            List(_) => self.from_list_array(array.as_any().downcast_ref().unwrap()),
700            Struct(_) => self.from_struct_array(array.as_any().downcast_ref().unwrap()),
701            Map(_, _) => self.from_map_array(array.as_any().downcast_ref().unwrap()),
702            t => Err(ArrayError::from_arrow(format!(
703                "unsupported arrow data type: {t:?}",
704            ))),
705        }
706    }
707
708    /// Converts Arrow extension array to RisingWave `ArrayImpl`.
709    fn from_extension_array(
710        &self,
711        type_name: &str,
712        array: &arrow_array::ArrayRef,
713    ) -> Result<ArrayImpl, ArrayError> {
714        match type_name {
715            "arrowudf.decimal" => {
716                let array: &arrow_array::StringArray =
717                    array.as_any().downcast_ref().ok_or_else(|| {
718                        ArrayError::from_arrow(
719                            "expected string array for `arrowudf.decimal`".to_owned(),
720                        )
721                    })?;
722                Ok(ArrayImpl::Decimal(array.try_into()?))
723            }
724            "arrowudf.json" => {
725                let array: &arrow_array::StringArray =
726                    array.as_any().downcast_ref().ok_or_else(|| {
727                        ArrayError::from_arrow(
728                            "expected string array for `arrowudf.json`".to_owned(),
729                        )
730                    })?;
731                Ok(ArrayImpl::Jsonb(array.try_into()?))
732            }
733            _ => Err(ArrayError::from_arrow(format!(
734                "unsupported extension type: {type_name:?}"
735            ))),
736        }
737    }
738
739    fn from_bool_array(&self, array: &arrow_array::BooleanArray) -> Result<ArrayImpl, ArrayError> {
740        Ok(ArrayImpl::Bool(array.into()))
741    }
742
743    fn from_int16_array(&self, array: &arrow_array::Int16Array) -> Result<ArrayImpl, ArrayError> {
744        Ok(ArrayImpl::Int16(array.into()))
745    }
746
747    fn from_int8_array(&self, array: &arrow_array::Int8Array) -> Result<ArrayImpl, ArrayError> {
748        Ok(ArrayImpl::Int16(array.into()))
749    }
750
751    fn from_uint8_array(&self, array: &arrow_array::UInt8Array) -> Result<ArrayImpl, ArrayError> {
752        Ok(ArrayImpl::Int16(array.into()))
753    }
754
755    fn from_uint16_array(&self, array: &arrow_array::UInt16Array) -> Result<ArrayImpl, ArrayError> {
756        Ok(ArrayImpl::Int32(array.into()))
757    }
758
759    fn from_uint32_array(&self, array: &arrow_array::UInt32Array) -> Result<ArrayImpl, ArrayError> {
760        Ok(ArrayImpl::Int64(array.into()))
761    }
762
763    fn from_int32_array(&self, array: &arrow_array::Int32Array) -> Result<ArrayImpl, ArrayError> {
764        Ok(ArrayImpl::Int32(array.into()))
765    }
766
767    fn from_int64_array(&self, array: &arrow_array::Int64Array) -> Result<ArrayImpl, ArrayError> {
768        Ok(ArrayImpl::Int64(array.into()))
769    }
770
771    fn from_int256_array(
772        &self,
773        array: &arrow_array::Decimal256Array,
774    ) -> Result<ArrayImpl, ArrayError> {
775        Ok(ArrayImpl::Int256(array.into()))
776    }
777
778    fn from_decimal128_array(
779        &self,
780        array: &arrow_array::Decimal128Array,
781    ) -> Result<ArrayImpl, ArrayError> {
782        Ok(ArrayImpl::Decimal(array.try_into()?))
783    }
784
785    fn from_uint64_array(&self, array: &arrow_array::UInt64Array) -> Result<ArrayImpl, ArrayError> {
786        Ok(ArrayImpl::Decimal(array.try_into()?))
787    }
788
789    fn from_float16_array(
790        &self,
791        array: &arrow_array::Float16Array,
792    ) -> Result<ArrayImpl, ArrayError> {
793        Ok(ArrayImpl::Float32(array.try_into()?))
794    }
795
796    fn from_float32_array(
797        &self,
798        array: &arrow_array::Float32Array,
799    ) -> Result<ArrayImpl, ArrayError> {
800        Ok(ArrayImpl::Float32(array.into()))
801    }
802
803    fn from_float64_array(
804        &self,
805        array: &arrow_array::Float64Array,
806    ) -> Result<ArrayImpl, ArrayError> {
807        Ok(ArrayImpl::Float64(array.into()))
808    }
809
810    fn from_date32_array(&self, array: &arrow_array::Date32Array) -> Result<ArrayImpl, ArrayError> {
811        Ok(ArrayImpl::Date(array.into()))
812    }
813
814    fn from_time64us_array(
815        &self,
816        array: &arrow_array::Time64MicrosecondArray,
817    ) -> Result<ArrayImpl, ArrayError> {
818        Ok(ArrayImpl::Time(array.into()))
819    }
820
821    fn from_timestampsecond_array(
822        &self,
823        array: &arrow_array::TimestampSecondArray,
824    ) -> Result<ArrayImpl, ArrayError> {
825        Ok(ArrayImpl::Timestamp(array.into()))
826    }
827    fn from_timestampsecond_some_array(
828        &self,
829        array: &arrow_array::TimestampSecondArray,
830    ) -> Result<ArrayImpl, ArrayError> {
831        Ok(ArrayImpl::Timestamptz(array.into()))
832    }
833
834    fn from_timestampms_array(
835        &self,
836        array: &arrow_array::TimestampMillisecondArray,
837    ) -> Result<ArrayImpl, ArrayError> {
838        Ok(ArrayImpl::Timestamp(array.into()))
839    }
840
841    fn from_timestampms_some_array(
842        &self,
843        array: &arrow_array::TimestampMillisecondArray,
844    ) -> Result<ArrayImpl, ArrayError> {
845        Ok(ArrayImpl::Timestamptz(array.into()))
846    }
847
848    fn from_timestampus_array(
849        &self,
850        array: &arrow_array::TimestampMicrosecondArray,
851    ) -> Result<ArrayImpl, ArrayError> {
852        Ok(ArrayImpl::Timestamp(array.into()))
853    }
854
855    fn from_timestampus_some_array(
856        &self,
857        array: &arrow_array::TimestampMicrosecondArray,
858    ) -> Result<ArrayImpl, ArrayError> {
859        Ok(ArrayImpl::Timestamptz(array.into()))
860    }
861
862    fn from_timestampns_array(
863        &self,
864        array: &arrow_array::TimestampNanosecondArray,
865    ) -> Result<ArrayImpl, ArrayError> {
866        Ok(ArrayImpl::Timestamp(array.into()))
867    }
868
869    fn from_timestampns_some_array(
870        &self,
871        array: &arrow_array::TimestampNanosecondArray,
872    ) -> Result<ArrayImpl, ArrayError> {
873        Ok(ArrayImpl::Timestamptz(array.into()))
874    }
875
876    fn from_interval_array(
877        &self,
878        array: &arrow_array::IntervalMonthDayNanoArray,
879    ) -> Result<ArrayImpl, ArrayError> {
880        Ok(ArrayImpl::Interval(array.into()))
881    }
882
883    fn from_utf8_array(&self, array: &arrow_array::StringArray) -> Result<ArrayImpl, ArrayError> {
884        Ok(ArrayImpl::Utf8(array.into()))
885    }
886
887    fn from_utf8_view_array(
888        &self,
889        array: &arrow_array::StringViewArray,
890    ) -> Result<ArrayImpl, ArrayError> {
891        Ok(ArrayImpl::Utf8(array.into()))
892    }
893
894    fn from_binary_array(&self, array: &arrow_array::BinaryArray) -> Result<ArrayImpl, ArrayError> {
895        Ok(ArrayImpl::Bytea(array.into()))
896    }
897
898    fn from_large_utf8_array(
899        &self,
900        array: &arrow_array::LargeStringArray,
901    ) -> Result<ArrayImpl, ArrayError> {
902        Ok(ArrayImpl::Utf8(array.into()))
903    }
904
905    fn from_large_binary_array(
906        &self,
907        array: &arrow_array::LargeBinaryArray,
908    ) -> Result<ArrayImpl, ArrayError> {
909        Ok(ArrayImpl::Bytea(array.into()))
910    }
911
912    fn from_list_array(&self, array: &arrow_array::ListArray) -> Result<ArrayImpl, ArrayError> {
913        use arrow_array::Array;
914        let arrow_schema::DataType::List(field) = array.data_type() else {
915            panic!("nested field types cannot be determined.");
916        };
917        Ok(ArrayImpl::List(ListArray {
918            value: Box::new(self.from_array(field, array.values())?),
919            bitmap: match array.nulls() {
920                Some(nulls) => nulls.iter().collect(),
921                None => Bitmap::ones(array.len()),
922            },
923            offsets: array.offsets().iter().map(|o| *o as u32).collect(),
924        }))
925    }
926
927    fn from_struct_array(&self, array: &arrow_array::StructArray) -> Result<ArrayImpl, ArrayError> {
928        use arrow_array::Array;
929        let arrow_schema::DataType::Struct(fields) = array.data_type() else {
930            panic!("nested field types cannot be determined.");
931        };
932        Ok(ArrayImpl::Struct(StructArray::new(
933            self.from_fields(fields)?,
934            array
935                .columns()
936                .iter()
937                .zip_eq_fast(fields)
938                .map(|(array, field)| self.from_array(field, array).map(Arc::new))
939                .try_collect()?,
940            (0..array.len()).map(|i| array.is_valid(i)).collect(),
941        )))
942    }
943
944    /// Returns `true` if all expected fields are present in `actual_fields`, and `actual_fields`
945    /// has more fields or has them in a different order.
946    ///
947    /// This is used to decide whether to use `from_struct_array_projected` (projection needed)
948    /// or fall back to the normal `from_struct_array` path (exact match).
949    fn struct_fields_dominated(
950        expected_fields: &arrow_schema::Fields,
951        actual_fields: &arrow_schema::Fields,
952    ) -> bool {
953        // Fast path: if lengths are equal and names match in order, no projection needed
954        if expected_fields.len() == actual_fields.len() {
955            let all_match = expected_fields
956                .iter()
957                .zip_eq_fast(actual_fields.iter())
958                .all(|(e, a)| e.name() == a.name());
959            if all_match {
960                return false; // exact match, use normal path
961            }
962        }
963        // Check that all expected fields exist in actual (by name)
964        let actual_names: std::collections::HashSet<&str> =
965            actual_fields.iter().map(|f| f.name().as_str()).collect();
966        expected_fields
967            .iter()
968            .all(|e| actual_names.contains(e.name().as_str()))
969    }
970
971    /// Converts Arrow `StructArray` to RisingWave `StructArray` according to the expected fields.
972    ///
973    /// This is mainly used for Parquet file source, where the upstream struct may contain extra
974    /// fields. The conversion aligns fields by name, ignores extra fields, and keeps the expected
975    /// field order.
976    fn from_struct_array_projected(
977        &self,
978        expected_fields: &arrow_schema::Fields,
979        array: &arrow_array::StructArray,
980    ) -> Result<ArrayImpl, ArrayError> {
981        use std::collections::HashMap;
982
983        use arrow_array::Array;
984
985        let arrow_schema::DataType::Struct(actual_fields) = array.data_type() else {
986            panic!("nested field types cannot be determined.");
987        };
988
989        let actual_name_to_index: HashMap<&str, usize> = actual_fields
990            .iter()
991            .enumerate()
992            .map(|(idx, f)| (f.name().as_str(), idx))
993            .collect();
994
995        let len = array.len();
996        let projected_columns = expected_fields
997            .iter()
998            .map(|expected_field| {
999                if let Some(&idx) = actual_name_to_index.get(expected_field.name().as_str()) {
1000                    let child = array.columns()[idx].clone();
1001                    self.from_array(expected_field, &child).map(Arc::new)
1002                } else {
1003                    // Field missing in Arrow struct. Fill SQL NULL with the expected RW type.
1004                    let rw_ty = self.from_field(expected_field)?;
1005                    let mut builder = ArrayBuilderImpl::with_type(len, rw_ty);
1006                    builder.append_n(len, Datum::None);
1007                    Ok(Arc::new(builder.finish()))
1008                }
1009            })
1010            .try_collect()?;
1011
1012        Ok(ArrayImpl::Struct(StructArray::new(
1013            self.from_fields(expected_fields)?,
1014            projected_columns,
1015            (0..len).map(|i| array.is_valid(i)).collect(),
1016        )))
1017    }
1018
1019    fn from_map_array(&self, array: &arrow_array::MapArray) -> Result<ArrayImpl, ArrayError> {
1020        use arrow_array::Array;
1021        let struct_array = self.from_struct_array(array.entries())?;
1022        let list_array = ListArray {
1023            value: Box::new(struct_array),
1024            bitmap: match array.nulls() {
1025                Some(nulls) => nulls.iter().collect(),
1026                None => Bitmap::ones(array.len()),
1027            },
1028            offsets: array.offsets().iter().map(|o| *o as u32).collect(),
1029        };
1030
1031        Ok(ArrayImpl::Map(MapArray { inner: list_array }))
1032    }
1033}
1034
1035impl From<&Bitmap> for arrow_buffer::NullBuffer {
1036    fn from(bitmap: &Bitmap) -> Self {
1037        bitmap.iter().collect()
1038    }
1039}
1040
1041/// Implement bi-directional `From` between concrete array types.
1042macro_rules! converts {
1043    ($ArrayType:ty, $ArrowType:ty) => {
1044        impl From<&$ArrayType> for $ArrowType {
1045            fn from(array: &$ArrayType) -> Self {
1046                array.iter().collect()
1047            }
1048        }
1049        impl From<&$ArrowType> for $ArrayType {
1050            fn from(array: &$ArrowType) -> Self {
1051                array.iter().collect()
1052            }
1053        }
1054        impl From<&[$ArrowType]> for $ArrayType {
1055            fn from(arrays: &[$ArrowType]) -> Self {
1056                arrays.iter().flat_map(|a| a.iter()).collect()
1057            }
1058        }
1059    };
1060    // convert values using FromIntoArrow
1061    ($ArrayType:ty, $ArrowType:ty, @map) => {
1062        impl From<&$ArrayType> for $ArrowType {
1063            fn from(array: &$ArrayType) -> Self {
1064                array.iter().map(|o| o.map(|v| v.into_arrow())).collect()
1065            }
1066        }
1067        impl From<&$ArrowType> for $ArrayType {
1068            fn from(array: &$ArrowType) -> Self {
1069                array
1070                    .iter()
1071                    .map(|o| {
1072                        o.map(|v| {
1073                            <<$ArrayType as Array>::RefItem<'_> as FromIntoArrow>::from_arrow(v)
1074                        })
1075                    })
1076                    .collect()
1077            }
1078        }
1079        impl From<&[$ArrowType]> for $ArrayType {
1080            fn from(arrays: &[$ArrowType]) -> Self {
1081                arrays
1082                    .iter()
1083                    .flat_map(|a| a.iter())
1084                    .map(|o| {
1085                        o.map(|v| {
1086                            <<$ArrayType as Array>::RefItem<'_> as FromIntoArrow>::from_arrow(v)
1087                        })
1088                    })
1089                    .collect()
1090            }
1091        }
1092    };
1093}
1094
1095/// Used to convert different types.
1096macro_rules! converts_with_type {
1097    ($ArrayType:ty, $ArrowType:ty, $FromType:ty, $ToType:ty) => {
1098        impl From<&$ArrayType> for $ArrowType {
1099            fn from(array: &$ArrayType) -> Self {
1100                let values: Vec<Option<$ToType>> =
1101                    array.iter().map(|x| x.map(|v| v as $ToType)).collect();
1102                <$ArrowType>::from_iter(values)
1103            }
1104        }
1105
1106        impl From<&$ArrowType> for $ArrayType {
1107            fn from(array: &$ArrowType) -> Self {
1108                let values: Vec<Option<$FromType>> =
1109                    array.iter().map(|x| x.map(|v| v as $FromType)).collect();
1110                <$ArrayType>::from_iter(values)
1111            }
1112        }
1113
1114        impl From<&[$ArrowType]> for $ArrayType {
1115            fn from(arrays: &[$ArrowType]) -> Self {
1116                let values: Vec<Option<$FromType>> = arrays
1117                    .iter()
1118                    .flat_map(|a| a.iter().map(|x| x.map(|v| v as $FromType)))
1119                    .collect();
1120                <$ArrayType>::from_iter(values)
1121            }
1122        }
1123    };
1124}
1125
1126macro_rules! converts_with_timeunit {
1127    ($ArrayType:ty, $ArrowType:ty, $time_unit:expr, @map) => {
1128
1129        impl From<&$ArrayType> for $ArrowType {
1130            fn from(array: &$ArrayType) -> Self {
1131                array.iter().map(|o| o.map(|v| v.into_arrow_with_unit($time_unit))).collect()
1132            }
1133        }
1134
1135        impl From<&$ArrowType> for $ArrayType {
1136            fn from(array: &$ArrowType) -> Self {
1137                array.iter().map(|o| {
1138                    o.map(|v| {
1139                        let timestamp = <<$ArrayType as Array>::RefItem<'_> as FromIntoArrowWithUnit>::from_arrow_with_unit(v, $time_unit);
1140                        timestamp
1141                    })
1142                }).collect()
1143            }
1144        }
1145
1146        impl From<&[$ArrowType]> for $ArrayType {
1147            fn from(arrays: &[$ArrowType]) -> Self {
1148                arrays
1149                    .iter()
1150                    .flat_map(|a| a.iter())
1151                    .map(|o| {
1152                        o.map(|v| {
1153                            <<$ArrayType as Array>::RefItem<'_> as FromIntoArrowWithUnit>::from_arrow_with_unit(v, $time_unit)
1154                        })
1155                    })
1156                    .collect()
1157            }
1158        }
1159
1160    };
1161}
1162
1163converts!(BoolArray, arrow_array::BooleanArray);
1164converts!(I16Array, arrow_array::Int16Array);
1165converts!(I32Array, arrow_array::Int32Array);
1166converts!(I64Array, arrow_array::Int64Array);
1167converts!(F32Array, arrow_array::Float32Array, @map);
1168converts!(F64Array, arrow_array::Float64Array, @map);
1169converts!(BytesArray, arrow_array::BinaryArray);
1170converts!(BytesArray, arrow_array::LargeBinaryArray);
1171converts!(Utf8Array, arrow_array::StringArray);
1172converts!(Utf8Array, arrow_array::LargeStringArray);
1173converts!(Utf8Array, arrow_array::StringViewArray);
1174converts!(DateArray, arrow_array::Date32Array, @map);
1175converts!(TimeArray, arrow_array::Time64MicrosecondArray, @map);
1176converts!(IntervalArray, arrow_array::IntervalMonthDayNanoArray, @map);
1177converts!(SerialArray, arrow_array::Int64Array, @map);
1178
1179converts_with_type!(I16Array, arrow_array::Int8Array, i16, i8);
1180converts_with_type!(I16Array, arrow_array::UInt8Array, i16, u8);
1181converts_with_type!(I32Array, arrow_array::UInt16Array, i32, u16);
1182converts_with_type!(I64Array, arrow_array::UInt32Array, i64, u32);
1183
1184converts_with_timeunit!(TimestampArray, arrow_array::TimestampSecondArray, TimeUnit::Second, @map);
1185converts_with_timeunit!(TimestampArray, arrow_array::TimestampMillisecondArray, TimeUnit::Millisecond, @map);
1186converts_with_timeunit!(TimestampArray, arrow_array::TimestampMicrosecondArray, TimeUnit::Microsecond, @map);
1187converts_with_timeunit!(TimestampArray, arrow_array::TimestampNanosecondArray, TimeUnit::Nanosecond, @map);
1188
1189converts_with_timeunit!(TimestamptzArray, arrow_array::TimestampSecondArray, TimeUnit::Second, @map);
1190converts_with_timeunit!(TimestamptzArray, arrow_array::TimestampMillisecondArray,TimeUnit::Millisecond, @map);
1191converts_with_timeunit!(TimestamptzArray, arrow_array::TimestampMicrosecondArray, TimeUnit::Microsecond, @map);
1192converts_with_timeunit!(TimestamptzArray, arrow_array::TimestampNanosecondArray, TimeUnit::Nanosecond, @map);
1193
1194/// Converts RisingWave value from and into Arrow value.
1195trait FromIntoArrow {
1196    /// The corresponding element type in the Arrow array.
1197    type ArrowType;
1198    fn from_arrow(value: Self::ArrowType) -> Self;
1199    fn into_arrow(self) -> Self::ArrowType;
1200}
1201
1202/// Converts RisingWave value from and into Arrow value.
1203/// Specifically used for converting timestamp types according to timeunit.
1204trait FromIntoArrowWithUnit {
1205    type ArrowType;
1206    /// The timestamp type used to distinguish different time units, only utilized when the Arrow type is a timestamp.
1207    type TimestampType;
1208    fn from_arrow_with_unit(value: Self::ArrowType, time_unit: Self::TimestampType) -> Self;
1209    fn into_arrow_with_unit(self, time_unit: Self::TimestampType) -> Self::ArrowType;
1210}
1211
1212impl FromIntoArrow for Serial {
1213    type ArrowType = i64;
1214
1215    fn from_arrow(value: Self::ArrowType) -> Self {
1216        value.into()
1217    }
1218
1219    fn into_arrow(self) -> Self::ArrowType {
1220        self.into()
1221    }
1222}
1223
1224impl FromIntoArrow for F32 {
1225    type ArrowType = f32;
1226
1227    fn from_arrow(value: Self::ArrowType) -> Self {
1228        value.into()
1229    }
1230
1231    fn into_arrow(self) -> Self::ArrowType {
1232        self.into()
1233    }
1234}
1235
1236impl FromIntoArrow for F64 {
1237    type ArrowType = f64;
1238
1239    fn from_arrow(value: Self::ArrowType) -> Self {
1240        value.into()
1241    }
1242
1243    fn into_arrow(self) -> Self::ArrowType {
1244        self.into()
1245    }
1246}
1247
1248impl FromIntoArrow for Date {
1249    type ArrowType = i32;
1250
1251    #[allow(deprecated)]
1252    fn from_arrow(value: Self::ArrowType) -> Self {
1253        Date(arrow_array::types::Date32Type::to_naive_date(value))
1254    }
1255
1256    fn into_arrow(self) -> Self::ArrowType {
1257        arrow_array::types::Date32Type::from_naive_date(self.0)
1258    }
1259}
1260
1261impl FromIntoArrow for Time {
1262    type ArrowType = i64;
1263
1264    fn from_arrow(value: Self::ArrowType) -> Self {
1265        Time(
1266            NaiveTime::from_num_seconds_from_midnight_opt(
1267                (value / 1_000_000) as _,
1268                (value % 1_000_000 * 1000) as _,
1269            )
1270            .unwrap(),
1271        )
1272    }
1273
1274    fn into_arrow(self) -> Self::ArrowType {
1275        self.0
1276            .signed_duration_since(NaiveTime::default())
1277            .num_microseconds()
1278            .unwrap()
1279    }
1280}
1281
1282impl FromIntoArrowWithUnit for Timestamp {
1283    type ArrowType = i64;
1284    type TimestampType = TimeUnit;
1285
1286    fn from_arrow_with_unit(value: Self::ArrowType, time_unit: Self::TimestampType) -> Self {
1287        match time_unit {
1288            TimeUnit::Second => {
1289                Timestamp(DateTime::from_timestamp(value as _, 0).unwrap().naive_utc())
1290            }
1291            TimeUnit::Millisecond => {
1292                Timestamp(DateTime::from_timestamp_millis(value).unwrap().naive_utc())
1293            }
1294            TimeUnit::Microsecond => {
1295                Timestamp(DateTime::from_timestamp_micros(value).unwrap().naive_utc())
1296            }
1297            TimeUnit::Nanosecond => Timestamp(DateTime::from_timestamp_nanos(value).naive_utc()),
1298        }
1299    }
1300
1301    fn into_arrow_with_unit(self, time_unit: Self::TimestampType) -> Self::ArrowType {
1302        match time_unit {
1303            TimeUnit::Second => self.0.and_utc().timestamp(),
1304            TimeUnit::Millisecond => self.0.and_utc().timestamp_millis(),
1305            TimeUnit::Microsecond => self.0.and_utc().timestamp_micros(),
1306            TimeUnit::Nanosecond => self.0.and_utc().timestamp_nanos_opt().unwrap(),
1307        }
1308    }
1309}
1310
1311impl FromIntoArrowWithUnit for Timestamptz {
1312    type ArrowType = i64;
1313    type TimestampType = TimeUnit;
1314
1315    fn from_arrow_with_unit(value: Self::ArrowType, time_unit: Self::TimestampType) -> Self {
1316        match time_unit {
1317            TimeUnit::Second => Timestamptz::from_secs(value).unwrap_or_default(),
1318            TimeUnit::Millisecond => Timestamptz::from_millis(value).unwrap_or_default(),
1319            TimeUnit::Microsecond => Timestamptz::from_micros(value),
1320            TimeUnit::Nanosecond => Timestamptz::from_nanos(value).unwrap_or_default(),
1321        }
1322    }
1323
1324    fn into_arrow_with_unit(self, time_unit: Self::TimestampType) -> Self::ArrowType {
1325        match time_unit {
1326            TimeUnit::Second => self.timestamp(),
1327            TimeUnit::Millisecond => self.timestamp_millis(),
1328            TimeUnit::Microsecond => self.timestamp_micros(),
1329            TimeUnit::Nanosecond => self.timestamp_nanos().unwrap(),
1330        }
1331    }
1332}
1333
1334impl FromIntoArrow for Interval {
1335    type ArrowType = ArrowIntervalType;
1336
1337    fn from_arrow(value: Self::ArrowType) -> Self {
1338        Interval::from_month_day_usec(value.months, value.days, value.nanoseconds / 1000)
1339    }
1340
1341    fn into_arrow(self) -> Self::ArrowType {
1342        ArrowIntervalType {
1343            months: self.months(),
1344            days: self.days(),
1345            // TODO: this may overflow and we need `try_into`
1346            nanoseconds: self.usecs() * 1000,
1347        }
1348    }
1349}
1350
1351impl From<&DecimalArray> for arrow_array::LargeBinaryArray {
1352    fn from(array: &DecimalArray) -> Self {
1353        let mut builder =
1354            arrow_array::builder::LargeBinaryBuilder::with_capacity(array.len(), array.len() * 8);
1355        for value in array.iter() {
1356            builder.append_option(value.map(|d| d.to_string()));
1357        }
1358        builder.finish()
1359    }
1360}
1361
1362impl From<&DecimalArray> for arrow_array::StringArray {
1363    fn from(array: &DecimalArray) -> Self {
1364        let mut builder =
1365            arrow_array::builder::StringBuilder::with_capacity(array.len(), array.len() * 8);
1366        for value in array.iter() {
1367            builder.append_option(value.map(|d| d.to_string()));
1368        }
1369        builder.finish()
1370    }
1371}
1372
1373// This arrow decimal type is used by iceberg source to read iceberg decimal into RW decimal.
1374impl TryFrom<&arrow_array::Decimal128Array> for DecimalArray {
1375    type Error = ArrayError;
1376
1377    fn try_from(array: &arrow_array::Decimal128Array) -> Result<Self, Self::Error> {
1378        if array.scale() < 0 {
1379            bail!("support negative scale for arrow decimal")
1380        }
1381
1382        // Calculate the max value based on the Arrow decimal's precision
1383        // When writing Inf to Arrow Decimal128(precision, scale), we use 10^precision - 1
1384        let precision = array.precision();
1385        let max_value = 10_i128.pow(precision as u32) - 1;
1386
1387        let from_arrow = |value| {
1388            const NAN: i128 = i128::MIN + 1;
1389            let res = match value {
1390                // Check for special values using Arrow Decimal's max value, not i128::MAX
1391                NAN => Decimal::NaN,
1392                v if v == max_value => Decimal::PositiveInf,
1393                v if v == -max_value => Decimal::NegativeInf,
1394                i128::MAX => Decimal::PositiveInf, // Fallback for old data
1395                i128::MIN => Decimal::NegativeInf, // Fallback for old data
1396                _ => Decimal::truncated_i128_and_scale(value, array.scale() as u32)
1397                    .ok_or_else(|| ArrayError::from_arrow("decimal overflow"))?,
1398            };
1399            Ok(res)
1400        };
1401        array
1402            .iter()
1403            .map(|o| o.map(from_arrow).transpose())
1404            .collect::<Result<Self, Self::Error>>()
1405    }
1406}
1407
1408// Since RisingWave does not support UInt type, convert UInt64Array to Decimal.
1409impl TryFrom<&arrow_array::UInt64Array> for DecimalArray {
1410    type Error = ArrayError;
1411
1412    fn try_from(array: &arrow_array::UInt64Array) -> Result<Self, Self::Error> {
1413        let from_arrow = |value| {
1414            // Convert the value to a Decimal with scale 0
1415            let res = Decimal::from(value);
1416            Ok(res)
1417        };
1418
1419        // Map over the array and convert each value
1420        array
1421            .iter()
1422            .map(|o| o.map(from_arrow).transpose())
1423            .collect::<Result<Self, Self::Error>>()
1424    }
1425}
1426
1427impl TryFrom<&arrow_array::Float16Array> for F32Array {
1428    type Error = ArrayError;
1429
1430    fn try_from(array: &arrow_array::Float16Array) -> Result<Self, Self::Error> {
1431        let from_arrow = |value| Ok(f32::from(value));
1432
1433        array
1434            .iter()
1435            .map(|o| o.map(from_arrow).transpose())
1436            .collect::<Result<Self, Self::Error>>()
1437    }
1438}
1439
1440impl TryFrom<&arrow_array::LargeBinaryArray> for DecimalArray {
1441    type Error = ArrayError;
1442
1443    fn try_from(array: &arrow_array::LargeBinaryArray) -> Result<Self, Self::Error> {
1444        array
1445            .iter()
1446            .map(|o| {
1447                o.map(|s| {
1448                    let s = std::str::from_utf8(s)
1449                        .map_err(|_| ArrayError::from_arrow(format!("invalid decimal: {s:?}")))?;
1450                    s.parse()
1451                        .map_err(|_| ArrayError::from_arrow(format!("invalid decimal: {s:?}")))
1452                })
1453                .transpose()
1454            })
1455            .try_collect()
1456    }
1457}
1458
1459impl TryFrom<&arrow_array::StringArray> for DecimalArray {
1460    type Error = ArrayError;
1461
1462    fn try_from(array: &arrow_array::StringArray) -> Result<Self, Self::Error> {
1463        array
1464            .iter()
1465            .map(|o| {
1466                o.map(|s| {
1467                    s.parse()
1468                        .map_err(|_| ArrayError::from_arrow(format!("invalid decimal: {s:?}")))
1469                })
1470                .transpose()
1471            })
1472            .try_collect()
1473    }
1474}
1475
1476impl From<&JsonbArray> for arrow_array::StringArray {
1477    fn from(array: &JsonbArray) -> Self {
1478        let mut builder =
1479            arrow_array::builder::StringBuilder::with_capacity(array.len(), array.len() * 16);
1480        for value in array.iter() {
1481            match value {
1482                Some(jsonb) => {
1483                    write!(&mut builder, "{}", jsonb).unwrap();
1484                    builder.append_value("");
1485                }
1486                None => builder.append_null(),
1487            }
1488        }
1489        builder.finish()
1490    }
1491}
1492
1493impl TryFrom<&arrow_array::StringArray> for JsonbArray {
1494    type Error = ArrayError;
1495
1496    fn try_from(array: &arrow_array::StringArray) -> Result<Self, Self::Error> {
1497        array
1498            .iter()
1499            .map(|o| {
1500                o.map(|s| {
1501                    s.parse()
1502                        .map_err(|_| ArrayError::from_arrow(format!("invalid json: {s}")))
1503                })
1504                .transpose()
1505            })
1506            .try_collect()
1507    }
1508}
1509
1510impl From<&IntervalArray> for arrow_array::StringArray {
1511    fn from(array: &IntervalArray) -> Self {
1512        let mut builder =
1513            arrow_array::builder::StringBuilder::with_capacity(array.len(), array.len() * 16);
1514        for value in array.iter() {
1515            match value {
1516                Some(interval) => {
1517                    write!(&mut builder, "{}", interval).unwrap();
1518                    builder.append_value("");
1519                }
1520                None => builder.append_null(),
1521            }
1522        }
1523        builder.finish()
1524    }
1525}
1526
1527impl From<&JsonbArray> for arrow_array::LargeStringArray {
1528    fn from(array: &JsonbArray) -> Self {
1529        let mut builder =
1530            arrow_array::builder::LargeStringBuilder::with_capacity(array.len(), array.len() * 16);
1531        for value in array.iter() {
1532            match value {
1533                Some(jsonb) => {
1534                    write!(&mut builder, "{}", jsonb).unwrap();
1535                    builder.append_value("");
1536                }
1537                None => builder.append_null(),
1538            }
1539        }
1540        builder.finish()
1541    }
1542}
1543
1544impl TryFrom<&arrow_array::LargeStringArray> for JsonbArray {
1545    type Error = ArrayError;
1546
1547    fn try_from(array: &arrow_array::LargeStringArray) -> Result<Self, Self::Error> {
1548        array
1549            .iter()
1550            .map(|o| {
1551                o.map(|s| {
1552                    s.parse()
1553                        .map_err(|_| ArrayError::from_arrow(format!("invalid json: {s}")))
1554                })
1555                .transpose()
1556            })
1557            .try_collect()
1558    }
1559}
1560
1561impl From<arrow_buffer::i256> for Int256 {
1562    fn from(value: arrow_buffer::i256) -> Self {
1563        let buffer = value.to_be_bytes();
1564        Int256::from_be_bytes(buffer)
1565    }
1566}
1567
1568impl<'a> From<Int256Ref<'a>> for arrow_buffer::i256 {
1569    fn from(val: Int256Ref<'a>) -> Self {
1570        let buffer = val.to_be_bytes();
1571        arrow_buffer::i256::from_be_bytes(buffer)
1572    }
1573}
1574
1575impl From<&Int256Array> for arrow_array::Decimal256Array {
1576    fn from(array: &Int256Array) -> Self {
1577        array
1578            .iter()
1579            .map(|o| o.map(arrow_buffer::i256::from))
1580            .collect()
1581    }
1582}
1583
1584impl From<&arrow_array::Decimal256Array> for Int256Array {
1585    fn from(array: &arrow_array::Decimal256Array) -> Self {
1586        let values = array.iter().map(|o| o.map(Int256::from)).collect_vec();
1587
1588        values
1589            .iter()
1590            .map(|i| i.as_ref().map(|v| v.as_scalar_ref()))
1591            .collect()
1592    }
1593}
1594
1595/// This function checks whether the schema of a Parquet file matches the user-defined schema in RisingWave.
1596/// It handles the following special cases:
1597/// - Arrow's `timestamp(_, None)` types (all four time units) match with RisingWave's `Timestamp` type.
1598/// - Arrow's `timestamp(_, Some)` matches with RisingWave's `Timestamptz` type.
1599/// - Since RisingWave does not have an `UInt` type:
1600///   - Arrow's `UInt8` matches with RisingWave's `Int16`.
1601///   - Arrow's `UInt16` matches with RisingWave's `Int32`.
1602///   - Arrow's `UInt32` matches with RisingWave's `Int64`.
1603///   - Arrow's `UInt64` matches with RisingWave's `Decimal`.
1604/// - Arrow's `Float16` matches with RisingWave's `Float32`.
1605///
1606/// Nested data type matching:
1607/// - Struct: Arrow's `Struct` type matches with RisingWave's `Struct` type recursively, requiring that all expected fields exist and match by name and type. Extra Arrow fields are allowed.
1608/// - List: Arrow's `List` type matches with RisingWave's `List` type recursively, requiring the same element type.
1609/// - Map: Arrow's `Map` type matches with RisingWave's `Map` type recursively, requiring the key and value types to match, and the inner struct must have exactly two fields named "key" and "value".
1610pub fn is_parquet_schema_match_source_schema(
1611    arrow_data_type: &arrow_schema::DataType,
1612    rw_data_type: &crate::types::DataType,
1613) -> bool {
1614    use arrow_schema::DataType as ArrowType;
1615
1616    use crate::types::{DataType as RwType, MapType, StructType};
1617
1618    match (arrow_data_type, rw_data_type) {
1619        // Primitive type matching and special cases
1620        (ArrowType::Boolean, RwType::Boolean)
1621        | (ArrowType::Int8 | ArrowType::Int16 | ArrowType::UInt8, RwType::Int16)
1622        | (ArrowType::Int32 | ArrowType::UInt16, RwType::Int32)
1623        | (ArrowType::Int64 | ArrowType::UInt32, RwType::Int64)
1624        | (ArrowType::UInt64 | ArrowType::Decimal128(_, _), RwType::Decimal)
1625        | (ArrowType::Decimal256(_, _), RwType::Int256)
1626        | (ArrowType::Float16 | ArrowType::Float32, RwType::Float32)
1627        | (ArrowType::Float64, RwType::Float64)
1628        | (ArrowType::Timestamp(_, None), RwType::Timestamp)
1629        | (ArrowType::Timestamp(_, Some(_)), RwType::Timestamptz)
1630        | (ArrowType::Date32, RwType::Date)
1631        | (ArrowType::Time32(_) | ArrowType::Time64(_), RwType::Time)
1632        | (ArrowType::Interval(arrow_schema::IntervalUnit::MonthDayNano), RwType::Interval)
1633        | (ArrowType::Utf8 | ArrowType::LargeUtf8, RwType::Varchar)
1634        | (ArrowType::Binary | ArrowType::LargeBinary, RwType::Bytea) => true,
1635
1636        // Struct type recursive matching
1637        // Arrow's Struct matches RisingWave's Struct if all expected field names exist and types
1638        // match recursively. Extra Arrow fields are allowed and field order is ignored.
1639        (ArrowType::Struct(arrow_fields), RwType::Struct(rw_struct)) => {
1640            if arrow_fields.len() < rw_struct.len() {
1641                return false;
1642            }
1643            for (rw_name, rw_ty) in rw_struct.iter() {
1644                let Some(arrow_field) = arrow_fields.iter().find(|f| f.name() == rw_name) else {
1645                    return false;
1646                };
1647                if !is_parquet_schema_match_source_schema(arrow_field.data_type(), rw_ty) {
1648                    return false;
1649                }
1650            }
1651            true
1652        }
1653        // List type recursive matching
1654        // Arrow's List matches RisingWave's List if the element type matches recursively
1655        (ArrowType::List(arrow_field), RwType::List(rw_list_ty)) => {
1656            is_parquet_schema_match_source_schema(arrow_field.data_type(), rw_list_ty.elem())
1657        }
1658        // Map type recursive matching
1659        // Arrow's Map matches RisingWave's Map if the key and value types match recursively,
1660        // and the inner struct has exactly two fields named "key" and "value"
1661        (ArrowType::Map(arrow_field, _), RwType::Map(rw_map_ty)) => {
1662            if let ArrowType::Struct(fields) = arrow_field.data_type() {
1663                if fields.len() != 2 {
1664                    return false;
1665                }
1666                let key_field = &fields[0];
1667                let value_field = &fields[1];
1668                if key_field.name() != "key" || value_field.name() != "value" {
1669                    return false;
1670                }
1671                let (rw_key_ty, rw_value_ty) = (rw_map_ty.key(), rw_map_ty.value());
1672                is_parquet_schema_match_source_schema(key_field.data_type(), rw_key_ty)
1673                    && is_parquet_schema_match_source_schema(value_field.data_type(), rw_value_ty)
1674            } else {
1675                false
1676            }
1677        }
1678        // Fallback: types do not match
1679        _ => false,
1680    }
1681}
1682#[cfg(test)]
1683mod tests {
1684
1685    use arrow_schema::{DataType as ArrowType, Field as ArrowField};
1686
1687    use super::*;
1688    use crate::types::{DataType as RwType, MapType, StructType};
1689
1690    #[test]
1691    fn test_struct_schema_match() {
1692        // Arrow: struct<f1: Double, f2: Utf8>
1693
1694        let arrow_struct = ArrowType::Struct(
1695            vec![
1696                ArrowField::new("f1", ArrowType::Float64, true),
1697                ArrowField::new("f2", ArrowType::Utf8, true),
1698            ]
1699            .into(),
1700        );
1701        // RW: struct<f1 Double, f2 Varchar>
1702        let rw_struct = RwType::Struct(StructType::new(vec![
1703            ("f1".to_owned(), RwType::Float64),
1704            ("f2".to_owned(), RwType::Varchar),
1705        ]));
1706        assert!(is_parquet_schema_match_source_schema(
1707            &arrow_struct,
1708            &rw_struct
1709        ));
1710
1711        // Arrow is a superset of RW struct fields.
1712        let arrow_struct_superset = ArrowType::Struct(
1713            vec![
1714                ArrowField::new("f1", ArrowType::Float64, true),
1715                ArrowField::new("f2", ArrowType::Utf8, true),
1716                ArrowField::new("f3", ArrowType::Int32, true),
1717            ]
1718            .into(),
1719        );
1720        assert!(is_parquet_schema_match_source_schema(
1721            &arrow_struct_superset,
1722            &rw_struct
1723        ));
1724
1725        // Field order is ignored for struct matching.
1726        let arrow_struct_reordered = ArrowType::Struct(
1727            vec![
1728                ArrowField::new("f2", ArrowType::Utf8, true),
1729                ArrowField::new("f1", ArrowType::Float64, true),
1730            ]
1731            .into(),
1732        );
1733        assert!(is_parquet_schema_match_source_schema(
1734            &arrow_struct_reordered,
1735            &rw_struct
1736        ));
1737
1738        // Field names do not match
1739        let arrow_struct2 = ArrowType::Struct(
1740            vec![
1741                ArrowField::new("f1", ArrowType::Float64, true),
1742                ArrowField::new("f3", ArrowType::Utf8, true),
1743            ]
1744            .into(),
1745        );
1746        assert!(!is_parquet_schema_match_source_schema(
1747            &arrow_struct2,
1748            &rw_struct
1749        ));
1750    }
1751
1752    #[test]
1753    fn test_struct_projection_from_arrow() {
1754        use std::sync::Arc;
1755
1756        use itertools::Itertools;
1757
1758        struct Dummy;
1759        impl FromArrow for Dummy {}
1760
1761        // Actual Arrow struct: struct<foo:int32, bar:utf8, baz:int32>
1762        let actual_fields: arrow_schema::Fields = vec![
1763            ArrowField::new("foo", ArrowType::Int32, true),
1764            ArrowField::new("bar", ArrowType::Utf8, true),
1765            ArrowField::new("baz", ArrowType::Int32, true),
1766        ]
1767        .into();
1768        let foo: arrow_array::ArrayRef =
1769            Arc::new(arrow_array::Int32Array::from(vec![Some(10), Some(20)]));
1770        let bar: arrow_array::ArrayRef =
1771            Arc::new(arrow_array::StringArray::from(vec![Some("a"), Some("b")]));
1772        let baz: arrow_array::ArrayRef =
1773            Arc::new(arrow_array::Int32Array::from(vec![Some(100), Some(200)]));
1774        let actual_struct = arrow_array::StructArray::new(actual_fields, vec![foo, bar, baz], None);
1775        let actual_struct_ref: arrow_array::ArrayRef = Arc::new(actual_struct);
1776
1777        // Expected struct in RW schema (via to_arrow_field): struct<foo:int32, bar:utf8>
1778        let expected_field = ArrowField::new(
1779            "s",
1780            ArrowType::Struct(
1781                vec![
1782                    ArrowField::new("foo", ArrowType::Int32, true),
1783                    ArrowField::new("bar", ArrowType::Utf8, true),
1784                ]
1785                .into(),
1786            ),
1787            true,
1788        );
1789
1790        let array_impl = Dummy
1791            .from_array(&expected_field, &actual_struct_ref)
1792            .unwrap();
1793
1794        let ArrayImpl::Struct(s) = array_impl else {
1795            panic!("expected RW StructArray");
1796        };
1797
1798        let DataType::Struct(st) = s.data_type() else {
1799            panic!("expected RW struct type");
1800        };
1801        assert_eq!(st.len(), 2);
1802        assert_eq!(st.iter().map(|(n, _)| n).collect_vec(), vec!["foo", "bar"]);
1803
1804        let v0 = s.value_at(0).unwrap().to_owned_scalar();
1805        let v1 = s.value_at(1).unwrap().to_owned_scalar();
1806        assert_eq!(
1807            v0,
1808            StructValue::new(vec![
1809                Some(ScalarImpl::Int32(10)),
1810                Some(ScalarImpl::Utf8("a".into()))
1811            ])
1812        );
1813        assert_eq!(
1814            v1,
1815            StructValue::new(vec![
1816                Some(ScalarImpl::Int32(20)),
1817                Some(ScalarImpl::Utf8("b".into()))
1818            ])
1819        );
1820    }
1821
1822    #[test]
1823    fn test_list_schema_match() {
1824        // Arrow: list<double>
1825        let arrow_list =
1826            ArrowType::List(Box::new(ArrowField::new("item", ArrowType::Float64, true)).into());
1827        // RW: list<double>
1828        let rw_list = RwType::Float64.list();
1829        assert!(is_parquet_schema_match_source_schema(&arrow_list, &rw_list));
1830
1831        let rw_list2 = RwType::Int32.list();
1832        assert!(!is_parquet_schema_match_source_schema(
1833            &arrow_list,
1834            &rw_list2
1835        ));
1836    }
1837
1838    #[test]
1839    fn test_map_schema_match() {
1840        // Arrow: map<utf8, int32>
1841        let arrow_map = ArrowType::Map(
1842            Arc::new(ArrowField::new(
1843                "entries",
1844                ArrowType::Struct(
1845                    vec![
1846                        ArrowField::new("key", ArrowType::Utf8, false),
1847                        ArrowField::new("value", ArrowType::Int32, true),
1848                    ]
1849                    .into(),
1850                ),
1851                false,
1852            )),
1853            false,
1854        );
1855        // RW: map<varchar, int32>
1856        let rw_map = RwType::Map(MapType::from_kv(RwType::Varchar, RwType::Int32));
1857        assert!(is_parquet_schema_match_source_schema(&arrow_map, &rw_map));
1858
1859        // Key type does not match
1860        let rw_map2 = RwType::Map(MapType::from_kv(RwType::Int32, RwType::Int32));
1861        assert!(!is_parquet_schema_match_source_schema(&arrow_map, &rw_map2));
1862
1863        // Value type does not match
1864        let rw_map3 = RwType::Map(MapType::from_kv(RwType::Varchar, RwType::Float64));
1865        assert!(!is_parquet_schema_match_source_schema(&arrow_map, &rw_map3));
1866
1867        // Arrow inner struct field name does not match
1868        let arrow_map2 = ArrowType::Map(
1869            Arc::new(ArrowField::new(
1870                "entries",
1871                ArrowType::Struct(
1872                    vec![
1873                        ArrowField::new("k", ArrowType::Utf8, false),
1874                        ArrowField::new("value", ArrowType::Int32, true),
1875                    ]
1876                    .into(),
1877                ),
1878                false,
1879            )),
1880            false,
1881        );
1882        assert!(!is_parquet_schema_match_source_schema(&arrow_map2, &rw_map));
1883    }
1884
1885    #[test]
1886    fn bool() {
1887        let array = BoolArray::from_iter([None, Some(false), Some(true)]);
1888        let arrow = arrow_array::BooleanArray::from(&array);
1889        assert_eq!(BoolArray::from(&arrow), array);
1890    }
1891
1892    #[test]
1893    fn i16() {
1894        let array = I16Array::from_iter([None, Some(-7), Some(25)]);
1895        let arrow = arrow_array::Int16Array::from(&array);
1896        assert_eq!(I16Array::from(&arrow), array);
1897    }
1898
1899    #[test]
1900    fn i32() {
1901        let array = I32Array::from_iter([None, Some(-7), Some(25)]);
1902        let arrow = arrow_array::Int32Array::from(&array);
1903        assert_eq!(I32Array::from(&arrow), array);
1904    }
1905
1906    #[test]
1907    fn i64() {
1908        let array = I64Array::from_iter([None, Some(-7), Some(25)]);
1909        let arrow = arrow_array::Int64Array::from(&array);
1910        assert_eq!(I64Array::from(&arrow), array);
1911    }
1912
1913    #[test]
1914    fn f32() {
1915        let array = F32Array::from_iter([None, Some(-7.0), Some(25.0)]);
1916        let arrow = arrow_array::Float32Array::from(&array);
1917        assert_eq!(F32Array::from(&arrow), array);
1918    }
1919
1920    #[test]
1921    fn f64() {
1922        let array = F64Array::from_iter([None, Some(-7.0), Some(25.0)]);
1923        let arrow = arrow_array::Float64Array::from(&array);
1924        assert_eq!(F64Array::from(&arrow), array);
1925    }
1926
1927    #[test]
1928    fn int8() {
1929        let array: PrimitiveArray<i16> = I16Array::from_iter([None, Some(-128), Some(127)]);
1930        let arr = arrow_array::Int8Array::from(vec![None, Some(-128), Some(127)]);
1931        let converted: PrimitiveArray<i16> = (&arr).into();
1932        assert_eq!(converted, array);
1933    }
1934
1935    #[test]
1936    fn uint8() {
1937        let array: PrimitiveArray<i16> = I16Array::from_iter([None, Some(7), Some(25)]);
1938        let arr = arrow_array::UInt8Array::from(vec![None, Some(7), Some(25)]);
1939        let converted: PrimitiveArray<i16> = (&arr).into();
1940        assert_eq!(converted, array);
1941    }
1942
1943    #[test]
1944    fn uint16() {
1945        let array: PrimitiveArray<i32> = I32Array::from_iter([None, Some(7), Some(65535)]);
1946        let arr = arrow_array::UInt16Array::from(vec![None, Some(7), Some(65535)]);
1947        let converted: PrimitiveArray<i32> = (&arr).into();
1948        assert_eq!(converted, array);
1949    }
1950
1951    #[test]
1952    fn uint32() {
1953        let array: PrimitiveArray<i64> = I64Array::from_iter([None, Some(7), Some(4294967295)]);
1954        let arr = arrow_array::UInt32Array::from(vec![None, Some(7), Some(4294967295)]);
1955        let converted: PrimitiveArray<i64> = (&arr).into();
1956        assert_eq!(converted, array);
1957    }
1958
1959    #[test]
1960    fn uint64() {
1961        let array: PrimitiveArray<Decimal> = DecimalArray::from_iter([
1962            None,
1963            Some(Decimal::Normalized("7".parse().unwrap())),
1964            Some(Decimal::Normalized("18446744073709551615".parse().unwrap())),
1965        ]);
1966        let arr = arrow_array::UInt64Array::from(vec![None, Some(7), Some(18446744073709551615)]);
1967        let converted: PrimitiveArray<Decimal> = (&arr).try_into().unwrap();
1968        assert_eq!(converted, array);
1969    }
1970
1971    #[test]
1972    fn date() {
1973        let array = DateArray::from_iter([
1974            None,
1975            Date::with_days_since_ce(12345).ok(),
1976            Date::with_days_since_ce(-12345).ok(),
1977        ]);
1978        let arrow = arrow_array::Date32Array::from(&array);
1979        assert_eq!(DateArray::from(&arrow), array);
1980    }
1981
1982    #[test]
1983    fn time() {
1984        let array = TimeArray::from_iter([None, Time::with_micro(24 * 3600 * 1_000_000 - 1).ok()]);
1985        let arrow = arrow_array::Time64MicrosecondArray::from(&array);
1986        assert_eq!(TimeArray::from(&arrow), array);
1987    }
1988
1989    #[test]
1990    fn timestamp() {
1991        let array =
1992            TimestampArray::from_iter([None, Timestamp::with_micros(123456789012345678).ok()]);
1993        let arrow = arrow_array::TimestampMicrosecondArray::from(&array);
1994        assert_eq!(TimestampArray::from(&arrow), array);
1995    }
1996
1997    #[test]
1998    fn interval() {
1999        let array = IntervalArray::from_iter([
2000            None,
2001            Some(Interval::from_month_day_usec(
2002                1_000_000,
2003                1_000,
2004                1_000_000_000,
2005            )),
2006            Some(Interval::from_month_day_usec(
2007                -1_000_000,
2008                -1_000,
2009                -1_000_000_000,
2010            )),
2011        ]);
2012        let arrow = arrow_array::IntervalMonthDayNanoArray::from(&array);
2013        assert_eq!(IntervalArray::from(&arrow), array);
2014    }
2015
2016    #[test]
2017    fn string() {
2018        let array = Utf8Array::from_iter([None, Some("array"), Some("arrow")]);
2019        let arrow = arrow_array::StringArray::from(&array);
2020        assert_eq!(Utf8Array::from(&arrow), array);
2021    }
2022
2023    #[test]
2024    fn binary() {
2025        let array = BytesArray::from_iter([None, Some("array".as_bytes())]);
2026        let arrow = arrow_array::BinaryArray::from(&array);
2027        assert_eq!(BytesArray::from(&arrow), array);
2028    }
2029
2030    #[test]
2031    fn decimal() {
2032        let array = DecimalArray::from_iter([
2033            None,
2034            Some(Decimal::NaN),
2035            Some(Decimal::PositiveInf),
2036            Some(Decimal::NegativeInf),
2037            Some(Decimal::Normalized("123.4".parse().unwrap())),
2038            Some(Decimal::Normalized("123.456".parse().unwrap())),
2039        ]);
2040        let arrow = arrow_array::LargeBinaryArray::from(&array);
2041        assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array);
2042
2043        let arrow = arrow_array::StringArray::from(&array);
2044        assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array);
2045    }
2046
2047    #[test]
2048    fn jsonb() {
2049        let array = JsonbArray::from_iter([
2050            None,
2051            Some("null".parse().unwrap()),
2052            Some("false".parse().unwrap()),
2053            Some("1".parse().unwrap()),
2054            Some("[1, 2, 3]".parse().unwrap()),
2055            Some(r#"{ "a": 1, "b": null }"#.parse().unwrap()),
2056        ]);
2057        let arrow = arrow_array::LargeStringArray::from(&array);
2058        assert_eq!(JsonbArray::try_from(&arrow).unwrap(), array);
2059
2060        let arrow = arrow_array::StringArray::from(&array);
2061        assert_eq!(JsonbArray::try_from(&arrow).unwrap(), array);
2062    }
2063
2064    #[test]
2065    fn int256() {
2066        let values = [
2067            None,
2068            Some(Int256::from(1)),
2069            Some(Int256::from(i64::MAX)),
2070            Some(Int256::from(i64::MAX) * Int256::from(i64::MAX)),
2071            Some(Int256::from(i64::MAX) * Int256::from(i64::MAX) * Int256::from(i64::MAX)),
2072            Some(
2073                Int256::from(i64::MAX)
2074                    * Int256::from(i64::MAX)
2075                    * Int256::from(i64::MAX)
2076                    * Int256::from(i64::MAX),
2077            ),
2078            Some(Int256::min_value()),
2079            Some(Int256::max_value()),
2080        ];
2081
2082        let array =
2083            Int256Array::from_iter(values.iter().map(|r| r.as_ref().map(|x| x.as_scalar_ref())));
2084        let arrow = arrow_array::Decimal256Array::from(&array);
2085        assert_eq!(Int256Array::from(&arrow), array);
2086    }
2087}