risingwave_connector/source/iceberg/
mod.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod parquet_file_handler;
16
17pub mod metrics;
18use std::collections::{BinaryHeap, HashMap, HashSet};
19use std::sync::Arc;
20
21use anyhow::anyhow;
22use async_trait::async_trait;
23use futures::StreamExt;
24use futures_async_stream::{for_await, try_stream};
25use iceberg::Catalog;
26use iceberg::expr::{BoundPredicate, Predicate as IcebergPredicate};
27use iceberg::scan::FileScanTask;
28use iceberg::spec::FormatVersion;
29use iceberg::table::Table;
30pub use parquet_file_handler::*;
31use phf::{Set, phf_set};
32use risingwave_common::array::arrow::IcebergArrowConvert;
33use risingwave_common::array::{ArrayImpl, DataChunk, I64Array, Utf8Array};
34use risingwave_common::bail;
35use risingwave_common::types::JsonbVal;
36use risingwave_common_estimate_size::EstimateSize;
37use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
38use serde::{Deserialize, Serialize};
39
40pub use self::metrics::{GLOBAL_ICEBERG_SCAN_METRICS, IcebergScanMetrics};
41use crate::connector_common::{IcebergCommon, IcebergTableIdentifier};
42use crate::enforce_secret::{EnforceSecret, EnforceSecretError};
43use crate::error::{ConnectorError, ConnectorResult};
44use crate::parser::ParserConfig;
45use crate::source::{
46    BoxSourceChunkStream, Column, SourceContextRef, SourceEnumeratorContextRef, SourceProperties,
47    SplitEnumerator, SplitId, SplitMetaData, SplitReader, UnknownFields,
48};
49pub const ICEBERG_CONNECTOR: &str = "iceberg";
50
51#[derive(Clone, Debug, Deserialize, with_options::WithOptions)]
52pub struct IcebergProperties {
53    #[serde(flatten)]
54    pub common: IcebergCommon,
55
56    #[serde(flatten)]
57    pub table: IcebergTableIdentifier,
58
59    // For jdbc catalog
60    #[serde(rename = "catalog.jdbc.user")]
61    pub jdbc_user: Option<String>,
62    #[serde(rename = "catalog.jdbc.password")]
63    pub jdbc_password: Option<String>,
64
65    #[serde(flatten)]
66    pub unknown_fields: HashMap<String, String>,
67}
68
69impl EnforceSecret for IcebergProperties {
70    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
71        "catalog.jdbc.password",
72    };
73
74    fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
75        for prop in prop_iter {
76            IcebergCommon::enforce_one(prop)?;
77            if Self::ENFORCE_SECRET_PROPERTIES.contains(prop) {
78                return Err(EnforceSecretError {
79                    key: prop.to_owned(),
80                }
81                .into());
82            }
83        }
84        Ok(())
85    }
86}
87
88impl IcebergProperties {
89    pub async fn create_catalog(&self) -> ConnectorResult<Arc<dyn Catalog>> {
90        let mut java_catalog_props = HashMap::new();
91        if let Some(jdbc_user) = self.jdbc_user.clone() {
92            java_catalog_props.insert("jdbc.user".to_owned(), jdbc_user);
93        }
94        if let Some(jdbc_password) = self.jdbc_password.clone() {
95            java_catalog_props.insert("jdbc.password".to_owned(), jdbc_password);
96        }
97        // TODO: support path_style_access and java_catalog_props for iceberg source
98        self.common.create_catalog(&java_catalog_props).await
99    }
100
101    pub async fn load_table(&self) -> ConnectorResult<Table> {
102        let mut java_catalog_props = HashMap::new();
103        if let Some(jdbc_user) = self.jdbc_user.clone() {
104            java_catalog_props.insert("jdbc.user".to_owned(), jdbc_user);
105        }
106        if let Some(jdbc_password) = self.jdbc_password.clone() {
107            java_catalog_props.insert("jdbc.password".to_owned(), jdbc_password);
108        }
109        // TODO: support java_catalog_props for iceberg source
110        self.common
111            .load_table(&self.table, &java_catalog_props)
112            .await
113    }
114}
115
116impl SourceProperties for IcebergProperties {
117    type Split = IcebergSplit;
118    type SplitEnumerator = IcebergSplitEnumerator;
119    type SplitReader = IcebergFileReader;
120
121    const SOURCE_NAME: &'static str = ICEBERG_CONNECTOR;
122}
123
124impl UnknownFields for IcebergProperties {
125    fn unknown_fields(&self) -> HashMap<String, String> {
126        self.unknown_fields.clone()
127    }
128}
129
130#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
131pub struct IcebergFileScanTaskJsonStr(String);
132
133impl IcebergFileScanTaskJsonStr {
134    pub fn deserialize(&self) -> FileScanTask {
135        serde_json::from_str(&self.0).unwrap()
136    }
137
138    pub fn serialize(task: &FileScanTask) -> Self {
139        Self(serde_json::to_string(task).unwrap())
140    }
141}
142
143#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
144pub enum IcebergFileScanTask {
145    Data(Vec<FileScanTask>),
146    EqualityDelete(Vec<FileScanTask>),
147    PositionDelete(Vec<FileScanTask>),
148}
149
150impl IcebergFileScanTask {
151    pub fn tasks(&self) -> &[FileScanTask] {
152        match self {
153            IcebergFileScanTask::Data(file_scan_tasks)
154            | IcebergFileScanTask::EqualityDelete(file_scan_tasks)
155            | IcebergFileScanTask::PositionDelete(file_scan_tasks) => file_scan_tasks,
156        }
157    }
158
159    pub fn is_empty(&self) -> bool {
160        self.tasks().is_empty()
161    }
162
163    pub fn files(&self) -> Vec<String> {
164        self.tasks()
165            .iter()
166            .map(|task| task.data_file_path.clone())
167            .collect()
168    }
169
170    pub fn predicate(&self) -> Option<&BoundPredicate> {
171        let first_task = self.tasks().first()?;
172        first_task.predicate.as_ref()
173    }
174}
175
176#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
177pub struct IcebergSplit {
178    pub split_id: i64,
179    pub task: IcebergFileScanTask,
180}
181
182impl IcebergSplit {
183    pub fn empty(iceberg_scan_type: IcebergScanType) -> Self {
184        let task = match iceberg_scan_type {
185            IcebergScanType::DataScan => IcebergFileScanTask::Data(vec![]),
186            IcebergScanType::EqualityDeleteScan => IcebergFileScanTask::EqualityDelete(vec![]),
187            IcebergScanType::PositionDeleteScan => IcebergFileScanTask::PositionDelete(vec![]),
188            _ => unimplemented!(),
189        };
190        Self { split_id: 0, task }
191    }
192}
193
194impl SplitMetaData for IcebergSplit {
195    fn id(&self) -> SplitId {
196        self.split_id.to_string().into()
197    }
198
199    fn restore_from_json(value: JsonbVal) -> ConnectorResult<Self> {
200        serde_json::from_value(value.take()).map_err(|e| anyhow!(e).into())
201    }
202
203    fn encode_to_json(&self) -> JsonbVal {
204        serde_json::to_value(self.clone()).unwrap().into()
205    }
206
207    fn update_offset(&mut self, _last_seen_offset: String) -> ConnectorResult<()> {
208        unimplemented!()
209    }
210}
211
212#[derive(Debug, Clone)]
213pub struct IcebergSplitEnumerator {
214    config: IcebergProperties,
215}
216
217#[derive(Debug, Clone)]
218pub struct IcebergDeleteParameters {
219    pub equality_delete_columns: Vec<String>,
220    pub has_position_delete: bool,
221    pub snapshot_id: Option<i64>,
222}
223
224#[async_trait]
225impl SplitEnumerator for IcebergSplitEnumerator {
226    type Properties = IcebergProperties;
227    type Split = IcebergSplit;
228
229    async fn new(
230        properties: Self::Properties,
231        context: SourceEnumeratorContextRef,
232    ) -> ConnectorResult<Self> {
233        Ok(Self::new_inner(properties, context))
234    }
235
236    async fn list_splits(&mut self) -> ConnectorResult<Vec<Self::Split>> {
237        // Like file source, iceberg streaming source has a List Executor and a Fetch Executor,
238        // instead of relying on SplitEnumerator on meta.
239        // TODO: add some validation logic here.
240        Ok(vec![])
241    }
242}
243impl IcebergSplitEnumerator {
244    pub fn new_inner(properties: IcebergProperties, _context: SourceEnumeratorContextRef) -> Self {
245        Self { config: properties }
246    }
247}
248
249#[derive(Debug, Clone, PartialEq, Eq, Hash)]
250pub enum IcebergTimeTravelInfo {
251    Version(i64),
252    TimestampMs(i64),
253}
254
255#[derive(Debug, Clone)]
256pub struct IcebergListResult {
257    pub data_files: Vec<FileScanTask>,
258    pub equality_delete_files: Vec<FileScanTask>,
259    pub position_delete_files: Vec<FileScanTask>,
260    pub equality_delete_columns: Vec<String>,
261    pub format_version: FormatVersion,
262    pub schema: std::sync::Arc<iceberg::spec::Schema>,
263}
264
265impl IcebergSplitEnumerator {
266    pub fn get_snapshot_id(
267        table: &Table,
268        time_travel_info: Option<IcebergTimeTravelInfo>,
269    ) -> ConnectorResult<Option<i64>> {
270        let current_snapshot = table.metadata().current_snapshot();
271        let Some(current_snapshot) = current_snapshot else {
272            return Ok(None);
273        };
274
275        let snapshot_id = match time_travel_info {
276            Some(IcebergTimeTravelInfo::Version(version)) => {
277                let Some(snapshot) = table.metadata().snapshot_by_id(version) else {
278                    bail!("Cannot find the snapshot id in the iceberg table.");
279                };
280                snapshot.snapshot_id()
281            }
282            Some(IcebergTimeTravelInfo::TimestampMs(timestamp)) => {
283                let snapshot = table
284                    .metadata()
285                    .snapshots()
286                    .filter(|snapshot| snapshot.timestamp_ms() <= timestamp)
287                    .max_by_key(|snapshot| snapshot.timestamp_ms());
288                match snapshot {
289                    Some(snapshot) => snapshot.snapshot_id(),
290                    None => {
291                        // convert unix time to human-readable time
292                        let time = chrono::DateTime::from_timestamp_millis(timestamp);
293                        if let Some(time) = time {
294                            tracing::warn!("Cannot find a snapshot older than {}", time);
295                        } else {
296                            tracing::warn!("Cannot find a snapshot");
297                        }
298                        return Ok(None);
299                    }
300                }
301            }
302            None => current_snapshot.snapshot_id(),
303        };
304        Ok(Some(snapshot_id))
305    }
306
307    pub async fn list_scan_tasks(
308        &self,
309        time_travel_info: Option<IcebergTimeTravelInfo>,
310        predicate: IcebergPredicate,
311    ) -> ConnectorResult<Option<IcebergListResult>> {
312        let table = self.config.load_table().await?;
313        let snapshot_id = Self::get_snapshot_id(&table, time_travel_info)?;
314
315        let Some(snapshot_id) = snapshot_id else {
316            return Ok(None);
317        };
318        let res = self
319            .list_scan_tasks_inner(&table, snapshot_id, predicate)
320            .await?;
321        Ok(Some(res))
322    }
323
324    async fn list_scan_tasks_inner(
325        &self,
326        table: &Table,
327        snapshot_id: i64,
328        predicate: IcebergPredicate,
329    ) -> ConnectorResult<IcebergListResult> {
330        let format_version = table.metadata().format_version();
331        let table_schema = table.metadata().current_schema();
332        tracing::debug!("iceberg_table_schema: {:?}", table_schema);
333
334        let mut position_delete_files = vec![];
335        let mut position_delete_files_set = HashSet::new();
336        let mut data_files = vec![];
337        let mut equality_delete_files = vec![];
338        let mut equality_delete_files_set = HashSet::new();
339        let mut equality_delete_ids = None;
340        let mut scan_builder = table.scan().snapshot_id(snapshot_id).select_all();
341        if predicate != IcebergPredicate::AlwaysTrue {
342            scan_builder = scan_builder.with_filter(predicate.clone());
343        }
344        let scan = scan_builder.build()?;
345        let file_scan_stream = scan.plan_files().await?;
346
347        #[for_await]
348        for task in file_scan_stream {
349            let task: FileScanTask = task?;
350
351            // Collect delete files for separate scan types, but keep task.deletes intact
352            for delete_file in &task.deletes {
353                let delete_file = delete_file.as_ref().clone();
354                match delete_file.data_file_content {
355                    iceberg::spec::DataContentType::Data => {
356                        bail!("Data file should not in task deletes");
357                    }
358                    iceberg::spec::DataContentType::EqualityDeletes => {
359                        if equality_delete_files_set.insert(delete_file.data_file_path.clone()) {
360                            if equality_delete_ids.is_none() {
361                                equality_delete_ids = delete_file.equality_ids.clone();
362                            } else if equality_delete_ids != delete_file.equality_ids {
363                                bail!(
364                                    "The schema of iceberg equality delete file must be consistent"
365                                );
366                            }
367                            equality_delete_files.push(delete_file);
368                        }
369                    }
370                    iceberg::spec::DataContentType::PositionDeletes => {
371                        if position_delete_files_set.insert(delete_file.data_file_path.clone()) {
372                            position_delete_files.push(delete_file);
373                        }
374                    }
375                }
376            }
377
378            match task.data_file_content {
379                iceberg::spec::DataContentType::Data => {
380                    // Keep the original task with its deletes field intact
381                    data_files.push(task);
382                }
383                iceberg::spec::DataContentType::EqualityDeletes => {
384                    bail!("Equality delete files should not be in the data files");
385                }
386                iceberg::spec::DataContentType::PositionDeletes => {
387                    bail!("Position delete files should not be in the data files");
388                }
389            }
390        }
391        let schema = table_schema.clone();
392        let equality_delete_columns = equality_delete_ids
393            .unwrap_or_default()
394            .into_iter()
395            .map(|id| match schema.name_by_field_id(id) {
396                Some(name) => Ok::<std::string::String, ConnectorError>(name.to_owned()),
397                None => bail!("Delete field id {} not found in schema", id),
398            })
399            .collect::<ConnectorResult<Vec<_>>>()?;
400
401        Ok(IcebergListResult {
402            data_files,
403            equality_delete_files,
404            position_delete_files,
405            equality_delete_columns,
406            format_version,
407            schema,
408        })
409    }
410
411    /// Uniformly distribute scan tasks to compute nodes.
412    /// It's deterministic so that it can best utilize the data locality.
413    ///
414    /// # Arguments
415    /// * `file_scan_tasks`: The file scan tasks to be split.
416    /// * `split_num`: The number of splits to be created.
417    ///
418    /// This algorithm is based on a min-heap. It will push all groups into the heap, and then pop the smallest group and add the file scan task to it.
419    /// Ensure that the total length of each group is as balanced as possible.
420    /// The time complexity is O(n log k), where n is the number of file scan tasks and k is the number of splits.
421    /// The space complexity is O(k), where k is the number of splits.
422    /// The algorithm is stable, so the order of the file scan tasks will be preserved.
423    pub fn split_n_vecs(
424        file_scan_tasks: Vec<FileScanTask>,
425        split_num: usize,
426    ) -> Vec<Vec<FileScanTask>> {
427        use std::cmp::{Ordering, Reverse};
428
429        #[derive(Default)]
430        struct FileScanTaskGroup {
431            idx: usize,
432            tasks: Vec<FileScanTask>,
433            total_length: u64,
434        }
435
436        impl Ord for FileScanTaskGroup {
437            fn cmp(&self, other: &Self) -> Ordering {
438                // when total_length is the same, we will sort by index
439                if self.total_length == other.total_length {
440                    self.idx.cmp(&other.idx)
441                } else {
442                    self.total_length.cmp(&other.total_length)
443                }
444            }
445        }
446
447        impl PartialOrd for FileScanTaskGroup {
448            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
449                Some(self.cmp(other))
450            }
451        }
452
453        impl Eq for FileScanTaskGroup {}
454
455        impl PartialEq for FileScanTaskGroup {
456            fn eq(&self, other: &Self) -> bool {
457                self.total_length == other.total_length
458            }
459        }
460
461        let mut heap = BinaryHeap::new();
462        // push all groups into heap
463        for idx in 0..split_num {
464            heap.push(Reverse(FileScanTaskGroup {
465                idx,
466                tasks: vec![],
467                total_length: 0,
468            }));
469        }
470
471        for file_task in file_scan_tasks {
472            let mut group = heap.peek_mut().unwrap();
473            group.0.total_length += file_task.length;
474            group.0.tasks.push(file_task);
475        }
476
477        // convert heap into vec and extract tasks
478        heap.into_vec()
479            .into_iter()
480            .map(|reverse_group| reverse_group.0.tasks)
481            .collect()
482    }
483}
484
485pub struct IcebergScanOpts {
486    pub chunk_size: usize,
487    pub need_seq_num: bool,
488    pub need_file_path_and_pos: bool,
489    pub handle_delete_files: bool,
490}
491
492/// Scan a data file. Delete files are handled by the iceberg-rust `reader.read` implementation.
493#[try_stream(ok = DataChunk, error = ConnectorError)]
494pub async fn scan_task_to_chunk_with_deletes(
495    table: Table,
496    mut data_file_scan_task: FileScanTask,
497    IcebergScanOpts {
498        chunk_size,
499        need_seq_num,
500        need_file_path_and_pos,
501        handle_delete_files,
502    }: IcebergScanOpts,
503    metrics: Option<Arc<IcebergScanMetrics>>,
504) {
505    let table_name = table.identifier().name().to_owned();
506
507    let num_delete_files = data_file_scan_task.deletes.len();
508    let expected_record_count = data_file_scan_task.record_count;
509    let file_start = std::time::Instant::now();
510
511    let mut read_bytes = scopeguard::guard(0u64, |read_bytes| {
512        if let Some(metrics) = metrics.clone() {
513            metrics
514                .iceberg_read_bytes
515                .with_guarded_label_values(&[&table_name])
516                .inc_by(read_bytes as _);
517        }
518    });
519
520    let data_file_path = data_file_scan_task.data_file_path.clone();
521    let data_sequence_number = data_file_scan_task.sequence_number;
522
523    tracing::debug!(
524        "scan_task_to_chunk_with_deletes: data_file={}, handle_delete_files={}, total_delete_files={}",
525        data_file_path,
526        handle_delete_files,
527        data_file_scan_task.deletes.len()
528    );
529
530    if !handle_delete_files {
531        // Keep the delete files from being applied when the caller opts out.
532        data_file_scan_task.deletes.clear();
533    }
534
535    // Read the data file; delete application is delegated to the reader.
536    let reader = table
537        .reader_builder()
538        .with_batch_size(chunk_size)
539        .with_row_group_filtering_enabled(true)
540        .build();
541    let file_scan_stream = tokio_stream::once(Ok(data_file_scan_task));
542
543    let record_batch_stream: iceberg::scan::ArrowRecordBatchStream =
544        reader.read(Box::pin(file_scan_stream))?;
545    let mut record_batch_stream = record_batch_stream.enumerate();
546
547    let mut total_rows_read: u64 = 0;
548
549    // Process each record batch. Delete application is handled by the SDK.
550    while let Some((batch_index, record_batch)) = record_batch_stream.next().await {
551        let record_batch = record_batch?;
552        let batch_start_pos = (batch_index * chunk_size) as i64;
553
554        let mut chunk = IcebergArrowConvert.chunk_from_record_batch(&record_batch)?;
555        let row_count = chunk.capacity();
556        total_rows_read += row_count as u64;
557
558        // Add metadata columns if requested
559        if need_seq_num {
560            let (mut columns, visibility) = chunk.into_parts();
561            columns.push(Arc::new(ArrayImpl::Int64(I64Array::from_iter(
562                std::iter::repeat_n(data_sequence_number, row_count),
563            ))));
564            chunk = DataChunk::from_parts(columns.into(), visibility);
565        }
566
567        if need_file_path_and_pos {
568            let (mut columns, visibility) = chunk.into_parts();
569            columns.push(Arc::new(ArrayImpl::Utf8(Utf8Array::from_iter(
570                std::iter::repeat_n(data_file_path.as_str(), row_count),
571            ))));
572
573            // Generate position values for each row in the batch
574            let positions: Vec<i64> =
575                (batch_start_pos..(batch_start_pos + row_count as i64)).collect();
576            columns.push(Arc::new(ArrayImpl::Int64(I64Array::from_iter(positions))));
577
578            chunk = DataChunk::from_parts(columns.into(), visibility);
579        }
580
581        *read_bytes += chunk.estimated_heap_size() as u64;
582        yield chunk;
583    }
584
585    // Record per-file metrics after reading all batches.
586    if let Some(ref metrics) = metrics {
587        let label_values = [table_name.as_str()];
588
589        // File read duration.
590        metrics
591            .iceberg_source_file_read_duration_seconds
592            .with_guarded_label_values(&label_values)
593            .observe(file_start.elapsed().as_secs_f64());
594
595        // Rows read.
596        if total_rows_read > 0 {
597            metrics
598                .iceberg_source_rows_read_total
599                .with_guarded_label_values(&label_values)
600                .inc_by(total_rows_read);
601        }
602
603        // File read count.
604        metrics
605            .iceberg_source_files_read_total
606            .with_guarded_label_values(&[table_name.as_str(), "data"])
607            .inc();
608
609        // APPROXIMATE: Estimate delete rows applied. The delta between expected_record_count
610        // and actual rows read may also include predicate pushdown / row-group pruning effects,
611        // so this metric can overcount. It is still useful as an approximate signal for
612        // detecting whether delete files cause significant row filtering.
613        if handle_delete_files
614            && num_delete_files > 0
615            && let Some(expected) = expected_record_count
616        {
617            let deleted = expected.saturating_sub(total_rows_read);
618            if deleted > 0 {
619                metrics
620                    .iceberg_source_delete_rows_applied_total
621                    .with_guarded_label_values(&[table_name.as_str(), "sdk_applied_approx"])
622                    .inc_by(deleted);
623            }
624        }
625    }
626}
627
628#[derive(Debug)]
629pub struct IcebergFileReader {}
630
631#[async_trait]
632impl SplitReader for IcebergFileReader {
633    type Properties = IcebergProperties;
634    type Split = IcebergSplit;
635
636    async fn new(
637        _props: IcebergProperties,
638        _splits: Vec<IcebergSplit>,
639        _parser_config: ParserConfig,
640        _source_ctx: SourceContextRef,
641        _columns: Option<Vec<Column>>,
642    ) -> ConnectorResult<Self> {
643        unimplemented!()
644    }
645
646    fn into_stream(self) -> BoxSourceChunkStream {
647        unimplemented!()
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use std::sync::Arc;
654
655    use iceberg::scan::FileScanTask;
656    use iceberg::spec::{DataContentType, Schema};
657
658    use super::*;
659
660    fn create_file_scan_task(length: u64, id: u64) -> FileScanTask {
661        FileScanTask {
662            length,
663            start: 0,
664            record_count: Some(0),
665            data_file_path: format!("test_{}.parquet", id),
666            data_file_content: DataContentType::Data,
667            data_file_format: iceberg::spec::DataFileFormat::Parquet,
668            schema: Arc::new(Schema::builder().build().unwrap()),
669            project_field_ids: vec![],
670            predicate: None,
671            deletes: vec![],
672            sequence_number: 0,
673            equality_ids: None,
674            file_size_in_bytes: 0,
675            partition: None,
676            partition_spec: None,
677            name_mapping: None,
678            case_sensitive: true,
679        }
680    }
681
682    #[test]
683    fn test_split_n_vecs_basic() {
684        let file_scan_tasks = (1..=12)
685            .map(|i| create_file_scan_task(i + 100, i))
686            .collect::<Vec<_>>(); // Ensure the correct function is called
687
688        let groups = IcebergSplitEnumerator::split_n_vecs(file_scan_tasks, 3);
689
690        assert_eq!(groups.len(), 3);
691
692        let group_lengths: Vec<u64> = groups
693            .iter()
694            .map(|group| group.iter().map(|task| task.length).sum())
695            .collect();
696
697        let max_length = *group_lengths.iter().max().unwrap();
698        let min_length = *group_lengths.iter().min().unwrap();
699        assert!(max_length - min_length <= 10, "Groups should be balanced");
700
701        let total_tasks: usize = groups.iter().map(|group| group.len()).sum();
702        assert_eq!(total_tasks, 12);
703    }
704
705    #[test]
706    fn test_split_n_vecs_empty() {
707        let file_scan_tasks = Vec::new();
708        let groups = IcebergSplitEnumerator::split_n_vecs(file_scan_tasks, 3);
709        assert_eq!(groups.len(), 3);
710        assert!(groups.iter().all(|group| group.is_empty()));
711    }
712
713    #[test]
714    fn test_split_n_vecs_single_task() {
715        let file_scan_tasks = vec![create_file_scan_task(100, 1)];
716        let groups = IcebergSplitEnumerator::split_n_vecs(file_scan_tasks, 3);
717        assert_eq!(groups.len(), 3);
718        assert_eq!(groups.iter().filter(|group| !group.is_empty()).count(), 1);
719    }
720
721    #[test]
722    fn test_split_n_vecs_uneven_distribution() {
723        let file_scan_tasks = vec![
724            create_file_scan_task(1000, 1),
725            create_file_scan_task(100, 2),
726            create_file_scan_task(100, 3),
727            create_file_scan_task(100, 4),
728            create_file_scan_task(100, 5),
729        ];
730
731        let groups = IcebergSplitEnumerator::split_n_vecs(file_scan_tasks, 2);
732        assert_eq!(groups.len(), 2);
733
734        let group_with_large_task = groups
735            .iter()
736            .find(|group| group.iter().any(|task| task.length == 1000))
737            .unwrap();
738        assert_eq!(group_with_large_task.len(), 1);
739    }
740
741    #[test]
742    fn test_split_n_vecs_same_files_distribution() {
743        let file_scan_tasks = vec![
744            create_file_scan_task(100, 1),
745            create_file_scan_task(100, 2),
746            create_file_scan_task(100, 3),
747            create_file_scan_task(100, 4),
748            create_file_scan_task(100, 5),
749            create_file_scan_task(100, 6),
750            create_file_scan_task(100, 7),
751            create_file_scan_task(100, 8),
752        ];
753
754        let groups = IcebergSplitEnumerator::split_n_vecs(file_scan_tasks.clone(), 4)
755            .iter()
756            .map(|g| {
757                g.iter()
758                    .map(|task| task.data_file_path.clone())
759                    .collect::<Vec<_>>()
760            })
761            .collect::<Vec<_>>();
762
763        for _ in 0..10000 {
764            let groups_2 = IcebergSplitEnumerator::split_n_vecs(file_scan_tasks.clone(), 4)
765                .iter()
766                .map(|g| {
767                    g.iter()
768                        .map(|task| task.data_file_path.clone())
769                        .collect::<Vec<_>>()
770                })
771                .collect::<Vec<_>>();
772
773            assert_eq!(groups, groups_2);
774        }
775    }
776}