risingwave_stream/executor/source/batch_source/
batch_adbc_snowflake_fetch.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16
17use arrow::RecordBatch;
18use either::Either;
19use futures::stream;
20use itertools::Itertools;
21use parking_lot::RwLock;
22use risingwave_common::array::Op;
23use risingwave_common::array::arrow::arrow_array_56 as arrow;
24use risingwave_common::id::TableId;
25use risingwave_common::types::{JsonbVal, Scalar, ScalarRef};
26use risingwave_connector::source::ConnectorProperties;
27use risingwave_connector::source::adbc_snowflake::{
28    AdbcSnowflakeArrowConvert, AdbcSnowflakeProperties,
29};
30use risingwave_connector::source::reader::desc::SourceDesc;
31use thiserror_ext::AsReport;
32
33use super::batch_adbc_snowflake_list::AdbcSnowflakeSplit;
34use crate::executor::prelude::*;
35use crate::executor::source::StreamSourceCore;
36use crate::executor::stream_reader::StreamReaderWithPause;
37use crate::task::LocalBarrierManager;
38
39pub struct BatchAdbcSnowflakeFetchExecutor<S: StateStore> {
40    actor_ctx: ActorContextRef,
41
42    /// Core component for managing external streaming source state
43    stream_source_core: Option<StreamSourceCore<S>>,
44
45    /// Upstream list executor that provides the splits to read.
46    upstream: Option<Executor>,
47
48    // barrier manager for reporting load finished
49    barrier_manager: LocalBarrierManager,
50
51    associated_table_id: TableId,
52}
53
54impl<S: StateStore> BatchAdbcSnowflakeFetchExecutor<S> {
55    pub fn new(
56        actor_ctx: ActorContextRef,
57        stream_source_core: StreamSourceCore<S>,
58        upstream: Executor,
59        barrier_manager: LocalBarrierManager,
60        associated_table_id: Option<TableId>,
61    ) -> Self {
62        assert!(associated_table_id.is_some());
63        Self {
64            actor_ctx,
65            stream_source_core: Some(stream_source_core),
66            upstream: Some(upstream),
67            barrier_manager,
68            associated_table_id: associated_table_id.unwrap(),
69        }
70    }
71}
72
73impl<S: StateStore> BatchAdbcSnowflakeFetchExecutor<S> {
74    #[try_stream(ok = Message, error = StreamExecutorError)]
75    async fn into_stream(mut self) {
76        let mut upstream = self.upstream.take().unwrap().execute();
77        let barrier = expect_first_barrier(&mut upstream).await?;
78        yield Message::Barrier(barrier);
79
80        let mut is_refreshing = false;
81        let mut is_list_finished = false;
82        let mut splits_on_fetch: usize = 0;
83        let is_load_finished = Arc::new(RwLock::new(false));
84        let mut split_queue = VecDeque::new();
85
86        let mut core = self.stream_source_core.take().unwrap();
87        let source_desc_builder = core.source_desc_builder.take().unwrap();
88        let source_desc = source_desc_builder
89            .build()
90            .map_err(StreamExecutorError::connector_error)?;
91
92        // Get column names from the source schema, filtering out additional columns and hidden columns.
93        // These columns are derived from the fetch executor's schema and represent the actual
94        // Snowflake table columns that should be queried.
95        let column_names: Vec<String> = source_desc
96            .columns
97            .iter()
98            .filter(|c| c.is_visible() && c.additional_column.column_type.is_none())
99            .map(|c| c.name.clone())
100            .collect();
101        tracing::debug!("[adbc snowflake fetch] column_names: {:?}", column_names);
102
103        let mut stream =
104            StreamReaderWithPause::<true, StreamChunk>::new(upstream, stream::pending().boxed());
105
106        while let Some(msg) = stream.next().await {
107            match msg {
108                Err(e) => {
109                    tracing::error!(error = %e.as_report(), "Fetch Error");
110                    split_queue.clear();
111                    *is_load_finished.write() = false;
112                    return Err(e);
113                }
114                Ok(msg) => match msg {
115                    Either::Left(msg) => match msg {
116                        Message::Barrier(barrier) => {
117                            let mut need_rebuild_reader = false;
118                            if let Some(mutation) = barrier.mutation.as_deref() {
119                                match mutation {
120                                    Mutation::Pause => stream.pause_stream(),
121                                    Mutation::Resume => stream.resume_stream(),
122                                    Mutation::RefreshStart {
123                                        associated_source_id,
124                                        ..
125                                    } if associated_source_id == &core.source_id => {
126                                        tracing::info!(
127                                            ?barrier.epoch,
128                                            actor_id = %self.actor_ctx.id,
129                                            source_id = %core.source_id,
130                                            table_id = %self.associated_table_id,
131                                            "RefreshStart:"
132                                        );
133
134                                        // reset states and abort current workload
135                                        split_queue.clear();
136                                        splits_on_fetch = 0;
137                                        is_refreshing = true;
138                                        is_list_finished = false;
139                                        *is_load_finished.write() = false;
140
141                                        need_rebuild_reader = true;
142                                    }
143                                    Mutation::ListFinish {
144                                        associated_source_id,
145                                    } if associated_source_id == &core.source_id => {
146                                        tracing::info!(
147                                            ?barrier.epoch,
148                                            actor_id = %self.actor_ctx.id,
149                                            source_id = %core.source_id,
150                                            table_id = %self.associated_table_id,
151                                            "ListFinish:"
152                                        );
153                                        is_list_finished = true;
154                                    }
155                                    _ => {
156                                        // ignore other mutations
157                                    }
158                                }
159                            }
160
161                            if splits_on_fetch == 0
162                                && split_queue.is_empty()
163                                && is_list_finished
164                                && is_refreshing
165                                && barrier.is_checkpoint()
166                            {
167                                tracing::info!(
168                                    ?barrier.epoch,
169                                    actor_id = %self.actor_ctx.id,
170                                    source_id = %core.source_id,
171                                    table_id = %self.associated_table_id,
172                                    "Reporting load finished"
173                                );
174                                self.barrier_manager.report_source_load_finished(
175                                    barrier.epoch,
176                                    self.actor_ctx.id,
177                                    self.associated_table_id,
178                                    core.source_id,
179                                );
180
181                                // reset flags
182                                is_list_finished = false;
183                                is_refreshing = false;
184                            }
185
186                            yield Message::Barrier(barrier);
187
188                            if need_rebuild_reader
189                                || (splits_on_fetch == 0
190                                    && !split_queue.is_empty()
191                                    && is_refreshing)
192                            {
193                                Self::replace_with_new_reader(
194                                    &mut split_queue,
195                                    &mut stream,
196                                    &mut splits_on_fetch,
197                                    source_desc.clone(),
198                                    &column_names,
199                                    is_load_finished.clone(),
200                                )?;
201                            }
202                        }
203                        Message::Chunk(chunk) => {
204                            let split_values: Vec<(String, JsonbVal)> = chunk
205                                .data_chunk()
206                                .rows()
207                                .map(|row| {
208                                    let split_id = row.datum_at(0).unwrap().into_utf8();
209                                    let split = row.datum_at(1).unwrap().into_jsonb();
210                                    (split_id.to_owned(), split.to_owned_scalar())
211                                })
212                                .collect();
213                            tracing::debug!("received split assignments: {:?}", split_values);
214                            split_queue.extend(split_values);
215                        }
216                        Message::Watermark(_) => unreachable!(),
217                    },
218                    Either::Right(chunk) => {
219                        // Check if the reader is finished after yielding
220                        if *is_load_finished.read() {
221                            splits_on_fetch -= 1;
222                            tracing::debug!(
223                                "split read finished, remaining splits_on_fetch: {}",
224                                splits_on_fetch
225                            );
226                        }
227                        yield Message::Chunk(chunk);
228                    }
229                },
230            }
231        }
232    }
233
234    fn replace_with_new_reader<const BIASED: bool>(
235        split_queue: &mut VecDeque<(String, JsonbVal)>,
236        stream: &mut StreamReaderWithPause<BIASED, StreamChunk>,
237        splits_on_fetch: &mut usize,
238        source_desc: SourceDesc,
239        column_names: &[String],
240        read_finished: Arc<RwLock<bool>>,
241    ) -> StreamExecutorResult<()> {
242        // For ADBC Snowflake, we process one split at a time to manage connection resources
243        // In the future, this could be extended to batch multiple splits
244
245        if let Some((split_id, split_json)) = split_queue.pop_front() {
246            tracing::debug!("building reader for split: {}", split_id);
247            *splits_on_fetch = 1;
248            *read_finished.write() = false;
249
250            let split = AdbcSnowflakeSplit::decode(split_json.as_scalar_ref())?;
251            let reader =
252                Self::build_split_reader(source_desc, split, column_names.to_vec(), read_finished);
253            stream.replace_data_stream(reader.boxed());
254        } else {
255            stream.replace_data_stream(stream::pending().boxed());
256        }
257
258        Ok(())
259    }
260
261    /// Build and execute a data reader for a single split.
262    /// The reader yields chunks from the split until all data is exhausted,
263    /// then sets the `read_finished` flag.
264    /// Chunks are split to respect the configured `chunk_size` for rate limiting.
265    #[try_stream(ok = StreamChunk, error = StreamExecutorError)]
266    async fn build_split_reader(
267        source_desc: SourceDesc,
268        split: AdbcSnowflakeSplit,
269        column_names: Vec<String>,
270        read_finished: Arc<RwLock<bool>>,
271    ) {
272        let properties = source_desc.source.config.clone();
273        let properties = match properties {
274            ConnectorProperties::AdbcSnowflake(props) => Box::new(*props),
275            _ => unreachable!(),
276        };
277
278        let max_chunk_size = crate::config::chunk_size();
279        let chunks = Self::read_split(properties, split, &column_names)?;
280        for chunk in chunks {
281            // Split large chunks to respect the configured chunk_size for rate limiting
282            if chunk.capacity() > max_chunk_size {
283                for small_chunk in chunk.split(max_chunk_size) {
284                    yield small_chunk;
285                }
286            } else {
287                yield chunk;
288            }
289        }
290
291        *read_finished.write() = true;
292    }
293
294    /// Read data from a single split by executing the Snowflake query.
295    /// Returns all chunks for the split. The query is built from the source schema,
296    /// `table_ref`, optional snapshot timestamp (AT clause), and optional WHERE clause.
297    /// Column names are derived from the fetch executor's schema, filtering out additional
298    /// columns and hidden columns.
299    /// If time travel query fails, falls back to querying without snapshot.
300    fn read_split(
301        properties: Box<AdbcSnowflakeProperties>,
302        split: AdbcSnowflakeSplit,
303        column_names: &[String],
304    ) -> StreamExecutorResult<Vec<StreamChunk>> {
305        let select_list = column_names
306            .iter()
307            .map(|c| format!(r#""{}""#, c))
308            .collect::<Vec<_>>()
309            .join(", ");
310
311        // Try with snapshot first if available
312        if let Some(ref ts) = split.snapshot_timestamp {
313            let table_expr_with_snapshot = format!("{} AT(TIMESTAMP => '{}')", split.table_ref, ts);
314            let mut final_query = format!("SELECT {select_list} FROM {}", table_expr_with_snapshot);
315            if let Some(ref where_clause) = split.where_clause {
316                final_query = format!("{final_query} WHERE {where_clause}");
317            }
318
319            tracing::debug!(
320                split_id = %split.split_id,
321                query = %final_query,
322                "executing query for split with time travel"
323            );
324
325            match properties.execute_query(&final_query) {
326                Ok(batches) => {
327                    return Self::convert_batches_to_chunks(&split.split_id, batches);
328                }
329                Err(e) => {
330                    // Time travel may have failed, log and fall back to current data
331                    tracing::warn!(
332                        split_id = %split.split_id,
333                        error = %e.as_report(),
334                        "Time travel query failed, falling back to current data"
335                    );
336                }
337            }
338        }
339
340        // Fall back to querying without snapshot
341        let mut final_query = format!("SELECT {select_list} FROM {}", split.table_ref);
342        if let Some(ref where_clause) = split.where_clause {
343            final_query = format!("{final_query} WHERE {where_clause}");
344        }
345
346        tracing::debug!(
347            split_id = %split.split_id,
348            query = %final_query,
349            "executing query for split without time travel"
350        );
351
352        let batches = properties.execute_query(&final_query)?;
353        Self::convert_batches_to_chunks(&split.split_id, batches)
354    }
355
356    /// Convert Arrow `RecordBatch`es to `StreamChunk`s
357    fn convert_batches_to_chunks(
358        split_id: &str,
359        batches: Vec<RecordBatch>,
360    ) -> StreamExecutorResult<Vec<StreamChunk>> {
361        let converter = AdbcSnowflakeArrowConvert;
362        let mut chunks = Vec::new();
363
364        for batch in batches {
365            // Convert Arrow RecordBatch to RisingWave DataChunk
366            // The column order in the RecordBatch matches the Snowflake query result,
367            // which is consistent with the schema inferred by get_arrow_schema() in the connector.
368            let data_chunk = converter.chunk_from_record_batch(&batch)?;
369
370            // Convert DataChunk to StreamChunk (all inserts)
371            let stream_chunk = StreamChunk::from_parts(
372                itertools::repeat_n(Op::Insert, data_chunk.capacity()).collect_vec(),
373                data_chunk,
374            );
375
376            chunks.push(stream_chunk);
377        }
378
379        tracing::debug!(
380            split_id = %split_id,
381            num_chunks = chunks.len(),
382            "finished reading split"
383        );
384
385        Ok(chunks)
386    }
387}
388
389impl<S: StateStore> Execute for BatchAdbcSnowflakeFetchExecutor<S> {
390    fn execute(self: Box<Self>) -> BoxedMessageStream {
391        self.into_stream().boxed()
392    }
393}
394
395impl<S: StateStore> Debug for BatchAdbcSnowflakeFetchExecutor<S> {
396    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
397        if let Some(core) = &self.stream_source_core {
398            f.debug_struct("BatchAdbcSnowflakeFetchExecutor")
399                .field("source_id", &core.source_id)
400                .field("column_ids", &core.column_ids)
401                .finish()
402        } else {
403            f.debug_struct("BatchAdbcSnowflakeFetchExecutor").finish()
404        }
405    }
406}