risingwave_stream/executor/source/batch_source/
batch_adbc_snowflake_fetch.rs1use std::collections::VecDeque;
16
17use arrow::RecordBatch;
18use either::Either;
19use futures::stream;
20use itertools::Itertools;
21use parking_lot::RwLock;
22use risingwave_common::array::Op;
23use risingwave_common::array::arrow::arrow_array_56 as arrow;
24use risingwave_common::id::TableId;
25use risingwave_common::types::{JsonbVal, Scalar, ScalarRef};
26use risingwave_connector::source::ConnectorProperties;
27use risingwave_connector::source::adbc_snowflake::{
28 AdbcSnowflakeArrowConvert, AdbcSnowflakeProperties,
29};
30use risingwave_connector::source::reader::desc::SourceDesc;
31use thiserror_ext::AsReport;
32
33use super::batch_adbc_snowflake_list::AdbcSnowflakeSplit;
34use crate::executor::prelude::*;
35use crate::executor::source::StreamSourceCore;
36use crate::executor::stream_reader::StreamReaderWithPause;
37use crate::task::LocalBarrierManager;
38
39pub struct BatchAdbcSnowflakeFetchExecutor<S: StateStore> {
40 actor_ctx: ActorContextRef,
41
42 stream_source_core: Option<StreamSourceCore<S>>,
44
45 upstream: Option<Executor>,
47
48 barrier_manager: LocalBarrierManager,
50
51 associated_table_id: TableId,
52}
53
54impl<S: StateStore> BatchAdbcSnowflakeFetchExecutor<S> {
55 pub fn new(
56 actor_ctx: ActorContextRef,
57 stream_source_core: StreamSourceCore<S>,
58 upstream: Executor,
59 barrier_manager: LocalBarrierManager,
60 associated_table_id: Option<TableId>,
61 ) -> Self {
62 assert!(associated_table_id.is_some());
63 Self {
64 actor_ctx,
65 stream_source_core: Some(stream_source_core),
66 upstream: Some(upstream),
67 barrier_manager,
68 associated_table_id: associated_table_id.unwrap(),
69 }
70 }
71}
72
73impl<S: StateStore> BatchAdbcSnowflakeFetchExecutor<S> {
74 #[try_stream(ok = Message, error = StreamExecutorError)]
75 async fn into_stream(mut self) {
76 let mut upstream = self.upstream.take().unwrap().execute();
77 let barrier = expect_first_barrier(&mut upstream).await?;
78 yield Message::Barrier(barrier);
79
80 let mut is_refreshing = false;
81 let mut is_list_finished = false;
82 let mut splits_on_fetch: usize = 0;
83 let is_load_finished = Arc::new(RwLock::new(false));
84 let mut split_queue = VecDeque::new();
85
86 let mut core = self.stream_source_core.take().unwrap();
87 let source_desc_builder = core.source_desc_builder.take().unwrap();
88 let source_desc = source_desc_builder
89 .build()
90 .map_err(StreamExecutorError::connector_error)?;
91
92 let column_names: Vec<String> = source_desc
96 .columns
97 .iter()
98 .filter(|c| c.is_visible() && c.additional_column.column_type.is_none())
99 .map(|c| c.name.clone())
100 .collect();
101 tracing::debug!("[adbc snowflake fetch] column_names: {:?}", column_names);
102
103 let mut stream =
104 StreamReaderWithPause::<true, StreamChunk>::new(upstream, stream::pending().boxed());
105
106 while let Some(msg) = stream.next().await {
107 match msg {
108 Err(e) => {
109 tracing::error!(error = %e.as_report(), "Fetch Error");
110 split_queue.clear();
111 *is_load_finished.write() = false;
112 return Err(e);
113 }
114 Ok(msg) => match msg {
115 Either::Left(msg) => match msg {
116 Message::Barrier(barrier) => {
117 let mut need_rebuild_reader = false;
118 if let Some(mutation) = barrier.mutation.as_deref() {
119 match mutation {
120 Mutation::Pause => stream.pause_stream(),
121 Mutation::Resume => stream.resume_stream(),
122 Mutation::RefreshStart {
123 associated_source_id,
124 ..
125 } if associated_source_id == &core.source_id => {
126 tracing::info!(
127 ?barrier.epoch,
128 actor_id = %self.actor_ctx.id,
129 source_id = %core.source_id,
130 table_id = %self.associated_table_id,
131 "RefreshStart:"
132 );
133
134 split_queue.clear();
136 splits_on_fetch = 0;
137 is_refreshing = true;
138 is_list_finished = false;
139 *is_load_finished.write() = false;
140
141 need_rebuild_reader = true;
142 }
143 Mutation::ListFinish {
144 associated_source_id,
145 } if associated_source_id == &core.source_id => {
146 tracing::info!(
147 ?barrier.epoch,
148 actor_id = %self.actor_ctx.id,
149 source_id = %core.source_id,
150 table_id = %self.associated_table_id,
151 "ListFinish:"
152 );
153 is_list_finished = true;
154 }
155 _ => {
156 }
158 }
159 }
160
161 if splits_on_fetch == 0
162 && split_queue.is_empty()
163 && is_list_finished
164 && is_refreshing
165 && barrier.is_checkpoint()
166 {
167 tracing::info!(
168 ?barrier.epoch,
169 actor_id = %self.actor_ctx.id,
170 source_id = %core.source_id,
171 table_id = %self.associated_table_id,
172 "Reporting load finished"
173 );
174 self.barrier_manager.report_source_load_finished(
175 barrier.epoch,
176 self.actor_ctx.id,
177 self.associated_table_id,
178 core.source_id,
179 );
180
181 is_list_finished = false;
183 is_refreshing = false;
184 }
185
186 yield Message::Barrier(barrier);
187
188 if need_rebuild_reader
189 || (splits_on_fetch == 0
190 && !split_queue.is_empty()
191 && is_refreshing)
192 {
193 Self::replace_with_new_reader(
194 &mut split_queue,
195 &mut stream,
196 &mut splits_on_fetch,
197 source_desc.clone(),
198 &column_names,
199 is_load_finished.clone(),
200 )?;
201 }
202 }
203 Message::Chunk(chunk) => {
204 let split_values: Vec<(String, JsonbVal)> = chunk
205 .data_chunk()
206 .rows()
207 .map(|row| {
208 let split_id = row.datum_at(0).unwrap().into_utf8();
209 let split = row.datum_at(1).unwrap().into_jsonb();
210 (split_id.to_owned(), split.to_owned_scalar())
211 })
212 .collect();
213 tracing::debug!("received split assignments: {:?}", split_values);
214 split_queue.extend(split_values);
215 }
216 Message::Watermark(_) => unreachable!(),
217 },
218 Either::Right(chunk) => {
219 if *is_load_finished.read() {
221 splits_on_fetch -= 1;
222 tracing::debug!(
223 "split read finished, remaining splits_on_fetch: {}",
224 splits_on_fetch
225 );
226 }
227 yield Message::Chunk(chunk);
228 }
229 },
230 }
231 }
232 }
233
234 fn replace_with_new_reader<const BIASED: bool>(
235 split_queue: &mut VecDeque<(String, JsonbVal)>,
236 stream: &mut StreamReaderWithPause<BIASED, StreamChunk>,
237 splits_on_fetch: &mut usize,
238 source_desc: SourceDesc,
239 column_names: &[String],
240 read_finished: Arc<RwLock<bool>>,
241 ) -> StreamExecutorResult<()> {
242 if let Some((split_id, split_json)) = split_queue.pop_front() {
246 tracing::debug!("building reader for split: {}", split_id);
247 *splits_on_fetch = 1;
248 *read_finished.write() = false;
249
250 let split = AdbcSnowflakeSplit::decode(split_json.as_scalar_ref())?;
251 let reader =
252 Self::build_split_reader(source_desc, split, column_names.to_vec(), read_finished);
253 stream.replace_data_stream(reader.boxed());
254 } else {
255 stream.replace_data_stream(stream::pending().boxed());
256 }
257
258 Ok(())
259 }
260
261 #[try_stream(ok = StreamChunk, error = StreamExecutorError)]
266 async fn build_split_reader(
267 source_desc: SourceDesc,
268 split: AdbcSnowflakeSplit,
269 column_names: Vec<String>,
270 read_finished: Arc<RwLock<bool>>,
271 ) {
272 let properties = source_desc.source.config.clone();
273 let properties = match properties {
274 ConnectorProperties::AdbcSnowflake(props) => Box::new(*props),
275 _ => unreachable!(),
276 };
277
278 let max_chunk_size = crate::config::chunk_size();
279 let chunks = Self::read_split(properties, split, &column_names)?;
280 for chunk in chunks {
281 if chunk.capacity() > max_chunk_size {
283 for small_chunk in chunk.split(max_chunk_size) {
284 yield small_chunk;
285 }
286 } else {
287 yield chunk;
288 }
289 }
290
291 *read_finished.write() = true;
292 }
293
294 fn read_split(
301 properties: Box<AdbcSnowflakeProperties>,
302 split: AdbcSnowflakeSplit,
303 column_names: &[String],
304 ) -> StreamExecutorResult<Vec<StreamChunk>> {
305 let select_list = column_names
306 .iter()
307 .map(|c| format!(r#""{}""#, c))
308 .collect::<Vec<_>>()
309 .join(", ");
310
311 if let Some(ref ts) = split.snapshot_timestamp {
313 let table_expr_with_snapshot = format!("{} AT(TIMESTAMP => '{}')", split.table_ref, ts);
314 let mut final_query = format!("SELECT {select_list} FROM {}", table_expr_with_snapshot);
315 if let Some(ref where_clause) = split.where_clause {
316 final_query = format!("{final_query} WHERE {where_clause}");
317 }
318
319 tracing::debug!(
320 split_id = %split.split_id,
321 query = %final_query,
322 "executing query for split with time travel"
323 );
324
325 match properties.execute_query(&final_query) {
326 Ok(batches) => {
327 return Self::convert_batches_to_chunks(&split.split_id, batches);
328 }
329 Err(e) => {
330 tracing::warn!(
332 split_id = %split.split_id,
333 error = %e.as_report(),
334 "Time travel query failed, falling back to current data"
335 );
336 }
337 }
338 }
339
340 let mut final_query = format!("SELECT {select_list} FROM {}", split.table_ref);
342 if let Some(ref where_clause) = split.where_clause {
343 final_query = format!("{final_query} WHERE {where_clause}");
344 }
345
346 tracing::debug!(
347 split_id = %split.split_id,
348 query = %final_query,
349 "executing query for split without time travel"
350 );
351
352 let batches = properties.execute_query(&final_query)?;
353 Self::convert_batches_to_chunks(&split.split_id, batches)
354 }
355
356 fn convert_batches_to_chunks(
358 split_id: &str,
359 batches: Vec<RecordBatch>,
360 ) -> StreamExecutorResult<Vec<StreamChunk>> {
361 let converter = AdbcSnowflakeArrowConvert;
362 let mut chunks = Vec::new();
363
364 for batch in batches {
365 let data_chunk = converter.chunk_from_record_batch(&batch)?;
369
370 let stream_chunk = StreamChunk::from_parts(
372 itertools::repeat_n(Op::Insert, data_chunk.capacity()).collect_vec(),
373 data_chunk,
374 );
375
376 chunks.push(stream_chunk);
377 }
378
379 tracing::debug!(
380 split_id = %split_id,
381 num_chunks = chunks.len(),
382 "finished reading split"
383 );
384
385 Ok(chunks)
386 }
387}
388
389impl<S: StateStore> Execute for BatchAdbcSnowflakeFetchExecutor<S> {
390 fn execute(self: Box<Self>) -> BoxedMessageStream {
391 self.into_stream().boxed()
392 }
393}
394
395impl<S: StateStore> Debug for BatchAdbcSnowflakeFetchExecutor<S> {
396 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
397 if let Some(core) = &self.stream_source_core {
398 f.debug_struct("BatchAdbcSnowflakeFetchExecutor")
399 .field("source_id", &core.source_id)
400 .field("column_ids", &core.column_ids)
401 .finish()
402 } else {
403 f.debug_struct("BatchAdbcSnowflakeFetchExecutor").finish()
404 }
405 }
406}