1use std::collections::BTreeMap;
16use std::future::Future;
17use std::mem;
18
19use either::Either;
20use futures::future::{Either as FutureEither, select};
21use futures::stream::FuturesOrdered;
22use futures::{StreamExt, TryStreamExt};
23use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
24use risingwave_common::transaction::transaction_id::TxnId;
25use risingwave_common::transaction::transaction_message::TxnMsg;
26use risingwave_common_rate_limit::{MonitoredRateLimiter, RateLimit, RateLimiter};
27use risingwave_dml::dml_manager::DmlManagerRef;
28use risingwave_expr::codegen::BoxStream;
29use risingwave_hummock_sdk::HummockReadEpoch;
30use risingwave_pb::common::ThrottleType;
31use risingwave_storage::StateStore;
32use risingwave_storage::store::TryWaitEpochOptions;
33use tokio::sync::oneshot;
34
35use crate::executor::prelude::*;
36use crate::executor::stream_reader::StreamReaderWithPause;
37
38pub struct DmlExecutor<S: StateStore> {
41 actor_ctx: ActorContextRef,
42
43 upstream: Executor,
44
45 dml_manager: DmlManagerRef,
47
48 table_id: TableId,
50
51 table_version_id: TableVersionId,
53
54 column_descs: Vec<ColumnDesc>,
56
57 chunk_size: usize,
58
59 rate_limiter: Arc<MonitoredRateLimiter>,
60
61 state_store: S,
62}
63
64const MAX_CHUNK_FOR_ATOMICITY: usize = 32;
72
73#[derive(Debug, Default)]
74struct TxnBuffer {
75 vec: Vec<StreamChunk>,
76 overflow: bool,
78}
79
80impl<S: StateStore> DmlExecutor<S> {
81 #[allow(clippy::too_many_arguments)]
82 pub fn new(
83 actor_ctx: ActorContextRef,
84 upstream: Executor,
85 dml_manager: DmlManagerRef,
86 table_id: TableId,
87 table_version_id: TableVersionId,
88 column_descs: Vec<ColumnDesc>,
89 chunk_size: usize,
90 rate_limit: RateLimit,
91 state_store: S,
92 ) -> Self {
93 let rate_limiter = Arc::new(RateLimiter::new(rate_limit).monitored(table_id));
94 Self {
95 actor_ctx,
96 upstream,
97 dml_manager,
98 table_id,
99 table_version_id,
100 column_descs,
101 chunk_size,
102 rate_limiter,
103 state_store,
104 }
105 }
106
107 #[try_stream(ok = Message, error = StreamExecutorError)]
108 async fn execute_inner(self: Box<Self>) {
109 let mut upstream = self.upstream.execute();
110
111 let actor_id = self.actor_ctx.id;
112
113 let barrier = expect_first_barrier(&mut upstream).await?;
115
116 let handle = self.dml_manager.register_reader(
124 self.table_id,
125 self.table_version_id,
126 &self.column_descs,
127 )?;
128 let reader = apply_dml_rate_limit(
129 handle.stream_reader().into_stream(),
130 self.rate_limiter.clone(),
131 )
132 .boxed()
133 .map_err(StreamExecutorError::from);
134
135 let mut stream = StreamReaderWithPause::<false, TxnMsg>::new(upstream, reader);
139
140 if barrier.is_pause_on_startup() {
142 stream.pause_stream();
143 }
144
145 yield Message::Barrier(barrier);
146
147 let mut active_txn_map: BTreeMap<TxnId, TxnBuffer> = Default::default();
149 let mut batch_group: Vec<StreamChunk> = vec![];
151
152 let mut builder = StreamChunkBuilder::new(
153 self.chunk_size,
154 self.column_descs
155 .iter()
156 .map(|c| c.data_type.clone())
157 .collect(),
158 );
159
160 let mut pending_persistence_notifiers: Vec<oneshot::Sender<()>> = Vec::new();
163
164 let mut persistence_futures = FuturesOrdered::new();
167
168 while let Some(input_msg) =
169 next_input_driving_persistence(&mut stream, &mut persistence_futures).await?
170 {
171 match input_msg {
172 Either::Left(msg) => {
173 if let Message::Barrier(barrier) = &msg {
175 if !pending_persistence_notifiers.is_empty() {
177 let notifiers = mem::take(&mut pending_persistence_notifiers);
178 let closing_epoch = barrier.epoch.prev;
179 let store = self.state_store.clone();
180 let table_id = self.table_id;
181 persistence_futures.push_back(async move {
182 store
183 .try_wait_epoch(
184 HummockReadEpoch::Committed(closing_epoch),
185 TryWaitEpochOptions { table_id },
186 )
187 .await
188 .map_err(StreamExecutorError::from)?;
189 for tx in notifiers {
190 let _ = tx.send(());
191 }
192 Ok(())
193 });
194 }
195
196 if let Some(mutation) = barrier.mutation.as_deref() {
199 match mutation {
200 Mutation::Pause => stream.pause_stream(),
201 Mutation::Resume => stream.resume_stream(),
202 Mutation::Throttle(fragment_to_apply) => {
203 if let Some(entry) =
204 fragment_to_apply.get(&self.actor_ctx.fragment_id)
205 && entry.throttle_type() == ThrottleType::Dml
206 {
207 let new_rate_limit = entry.rate_limit.into();
208 let old_rate_limit =
209 self.rate_limiter.update(new_rate_limit);
210
211 if old_rate_limit != new_rate_limit {
212 tracing::info!(
213 old_rate_limit = ?old_rate_limit,
214 new_rate_limit = ?new_rate_limit,
215 %actor_id,
216 "dml rate limit changed",
217 );
218 }
219 }
220 }
221 _ => {}
222 }
223 }
224
225 if !batch_group.is_empty() {
227 let vec = mem::take(&mut batch_group);
228 for chunk in vec {
229 for (op, row) in chunk.rows() {
230 if let Some(chunk) = builder.append_row(op, row) {
231 yield Message::Chunk(chunk);
232 }
233 }
234 }
235 if let Some(chunk) = builder.take() {
236 yield Message::Chunk(chunk);
237 }
238 }
239 }
240 yield msg;
241 }
242 Either::Right(txn_msg) => {
243 match txn_msg {
245 TxnMsg::Begin(txn_id) => {
246 active_txn_map
247 .try_insert(txn_id, TxnBuffer::default())
248 .unwrap_or_else(|_| {
249 panic!("Transaction id collision txn_id = {}.", txn_id)
250 });
251 }
252 TxnMsg::End(txn_id, persistence_notifier) => {
253 if let Some(tx) = persistence_notifier {
254 pending_persistence_notifiers.push(tx);
255 }
256 let mut txn_buffer = active_txn_map.remove(&txn_id)
257 .unwrap_or_else(|| panic!("Receive an unexpected transaction end message. Active transaction map doesn't contain this transaction txn_id = {}.", txn_id));
258
259 let txn_buffer_cardinality = txn_buffer
260 .vec
261 .iter()
262 .map(|c| c.cardinality())
263 .sum::<usize>();
264 let batch_group_cardinality =
265 batch_group.iter().map(|c| c.cardinality()).sum::<usize>();
266
267 if txn_buffer_cardinality >= self.chunk_size {
268 if !batch_group.is_empty() {
270 let vec = mem::take(&mut batch_group);
271 for chunk in vec {
272 for (op, row) in chunk.rows() {
273 if let Some(chunk) = builder.append_row(op, row) {
274 yield Message::Chunk(chunk);
275 }
276 }
277 }
278 if let Some(chunk) = builder.take() {
279 yield Message::Chunk(chunk);
280 }
281 }
282
283 for chunk in txn_buffer.vec {
285 yield Message::Chunk(chunk);
286 }
287 } else if txn_buffer_cardinality + batch_group_cardinality
288 <= self.chunk_size
289 {
290 batch_group.extend(txn_buffer.vec);
292 } else {
293 if !batch_group.is_empty() {
295 let vec = mem::take(&mut batch_group);
296 for chunk in vec {
297 for (op, row) in chunk.rows() {
298 if let Some(chunk) = builder.append_row(op, row) {
299 yield Message::Chunk(chunk);
300 }
301 }
302 }
303 if let Some(chunk) = builder.take() {
304 yield Message::Chunk(chunk);
305 }
306 }
307
308 mem::swap(&mut txn_buffer.vec, &mut batch_group);
310 }
311 }
312 TxnMsg::Rollback(txn_id) => {
313 let txn_buffer = active_txn_map.remove(&txn_id)
314 .unwrap_or_else(|| panic!("Receive an unexpected transaction rollback message. Active transaction map doesn't contain this transaction txn_id = {}.", txn_id));
315 if txn_buffer.overflow {
316 tracing::warn!(
317 "txn_id={} large transaction tries to rollback, but part of its data has already been sent to the downstream.",
318 txn_id
319 );
320 }
321 }
322 TxnMsg::Data(txn_id, chunk) => {
323 match active_txn_map.get_mut(&txn_id) {
324 Some(txn_buffer) => {
325 if txn_buffer.overflow {
328 yield Message::Chunk(chunk);
329 continue;
330 }
331 txn_buffer.vec.push(chunk);
332 if txn_buffer.vec.len() > MAX_CHUNK_FOR_ATOMICITY {
333 tracing::warn!(
335 "txn_id={} Too many chunks for atomicity. Sent them to the downstream anyway.",
336 txn_id
337 );
338 for chunk in txn_buffer.vec.drain(..) {
339 yield Message::Chunk(chunk);
340 }
341 txn_buffer.overflow = true;
342 }
343 }
344 None => panic!(
345 "Receive an unexpected transaction data message. Active transaction map doesn't contain this transaction txn_id = {}.",
346 txn_id
347 ),
348 };
349 }
350 }
351 }
352 }
353 }
354 }
355}
356
357impl<S: StateStore> Execute for DmlExecutor<S> {
358 fn execute(self: Box<Self>) -> BoxedMessageStream {
359 self.execute_inner().boxed()
360 }
361}
362
363async fn next_input_driving_persistence(
367 stream: &mut StreamReaderWithPause<false, TxnMsg>,
368 persistence_futures: &mut FuturesOrdered<impl Future<Output = StreamExecutorResult<()>>>,
369) -> StreamExecutorResult<Option<Either<Message, TxnMsg>>> {
370 loop {
371 if persistence_futures.is_empty() {
372 return stream.next().await.transpose();
373 }
374
375 match select(persistence_futures.next(), stream.next()).await {
376 FutureEither::Left((Some(Ok(())), _)) => continue,
377 FutureEither::Left((Some(Err(err)), _)) => return Err(err),
378 FutureEither::Left((None, _)) => {
379 unreachable!("persistence_futures is known to be non-empty")
380 }
381 FutureEither::Right((stream_item, _)) => return stream_item.transpose(),
382 }
383 }
384}
385
386type BoxTxnMessageStream = BoxStream<'static, risingwave_dml::error::Result<TxnMsg>>;
387#[try_stream(ok = TxnMsg, error = risingwave_dml::error::DmlError)]
388async fn apply_dml_rate_limit(
389 stream: BoxTxnMessageStream,
390 rate_limiter: Arc<MonitoredRateLimiter>,
391) {
392 #[for_await]
393 for txn_msg in stream {
394 match txn_msg? {
395 TxnMsg::Begin(txn_id) => {
396 yield TxnMsg::Begin(txn_id);
397 }
398 TxnMsg::End(txn_id, persistence_notifier) => {
399 yield TxnMsg::End(txn_id, persistence_notifier);
400 }
401 TxnMsg::Rollback(txn_id) => {
402 yield TxnMsg::Rollback(txn_id);
403 }
404 TxnMsg::Data(txn_id, chunk) => {
405 let chunk_size = chunk.capacity();
406 if chunk_size == 0 {
407 yield TxnMsg::Data(txn_id, chunk);
409 continue;
410 }
411 let rate_limit = loop {
412 match rate_limiter.rate_limit() {
413 RateLimit::Pause => rate_limiter.wait(0).await,
414 limit => break limit,
415 }
416 };
417
418 match rate_limit {
419 RateLimit::Pause => unreachable!(),
420 RateLimit::Disabled => {
421 yield TxnMsg::Data(txn_id, chunk);
422 continue;
423 }
424 RateLimit::Fixed(limit) => {
425 let max_permits = limit.get();
426 let required_permits = chunk.rate_limit_permits();
427 if required_permits <= max_permits {
428 rate_limiter.wait(required_permits).await;
429 yield TxnMsg::Data(txn_id, chunk);
430 } else {
431 for small_chunk in chunk.split(max_permits as _) {
433 rate_limiter.wait_chunk(&small_chunk).await;
434 yield TxnMsg::Data(txn_id, small_chunk);
435 }
436 }
437 }
438 }
439 }
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use std::sync::{Arc, Mutex};
447
448 use futures::FutureExt;
449 use risingwave_common::catalog::{ColumnId, Field, INITIAL_TABLE_VERSION_ID};
450 use risingwave_common::test_prelude::StreamChunkTestExt;
451 use risingwave_common::util::epoch::test_epoch;
452 use risingwave_dml::dml_manager::DmlManager;
453 use risingwave_hummock_sdk::key::TableKeyRange;
454 use risingwave_storage::error::StorageResult;
455 use risingwave_storage::memory::MemoryStateStore;
456 use risingwave_storage::panic_store::{PanicStateStore, PanicStateStoreIter};
457 use risingwave_storage::store::*;
458
459 use super::*;
460 use crate::executor::test_utils::MockSource;
461
462 const TEST_TRANSACTION_ID: TxnId = 0;
463 const TEST_SESSION_ID: u32 = 0;
464
465 type WaitEpochCallSender = oneshot::Sender<(HummockReadEpoch, TryWaitEpochOptions)>;
466
467 #[derive(Clone)]
468 struct MockWaitEpochStateStore {
469 wait_epoch_called_tx: Arc<Mutex<Option<WaitEpochCallSender>>>,
470 wait_epoch_release_rx: Arc<tokio::sync::Mutex<Option<oneshot::Receiver<()>>>>,
471 }
472
473 impl StateStoreReadLog for MockWaitEpochStateStore {
474 type ChangeLogIter = PanicStateStoreIter<StateStoreReadLogItem>;
475
476 async fn next_epoch(&self, _epoch: u64, _options: NextEpochOptions) -> StorageResult<u64> {
477 panic!("should not read changelog from MockWaitEpochStateStore")
478 }
479
480 async fn iter_log(
481 &self,
482 _epoch_range: (u64, u64),
483 _key_range: TableKeyRange,
484 _options: ReadLogOptions,
485 ) -> StorageResult<Self::ChangeLogIter> {
486 panic!("should not read changelog from MockWaitEpochStateStore")
487 }
488 }
489
490 impl StateStore for MockWaitEpochStateStore {
491 type Local = PanicStateStore;
492 type ReadSnapshot = PanicStateStore;
493 type VectorWriter = PanicStateStore;
494
495 async fn try_wait_epoch(
496 &self,
497 epoch: HummockReadEpoch,
498 options: TryWaitEpochOptions,
499 ) -> StorageResult<()> {
500 if let Some(tx) = self.wait_epoch_called_tx.lock().unwrap().take() {
501 assert!(tx.send((epoch, options)).is_ok());
502 }
503 let rx = self.wait_epoch_release_rx.lock().await.take().unwrap();
504 rx.await.unwrap();
505 Ok(())
506 }
507
508 async fn new_local(&self, _option: NewLocalOptions) -> Self::Local {
509 panic!("should not create local state from MockWaitEpochStateStore")
510 }
511
512 async fn new_read_snapshot(
513 &self,
514 _epoch: HummockReadEpoch,
515 _options: NewReadSnapshotOptions,
516 ) -> StorageResult<Self::ReadSnapshot> {
517 panic!("should not read snapshot from MockWaitEpochStateStore")
518 }
519
520 async fn new_vector_writer(&self, _options: NewVectorWriterOptions) -> Self::VectorWriter {
521 panic!("should not create vector writer from MockWaitEpochStateStore")
522 }
523 }
524
525 #[tokio::test]
526 async fn test_dml_executor() {
527 let table_id = TableId::default();
528 let schema = Schema::new(vec![
529 Field::unnamed(DataType::Int64),
530 Field::unnamed(DataType::Int64),
531 ]);
532 let column_descs = vec![
533 ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64),
534 ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64),
535 ];
536 let stream_key = vec![0];
537 let dml_manager = Arc::new(DmlManager::for_test());
538
539 let (mut tx, source) = MockSource::channel();
540 let source = source.into_executor(schema, stream_key);
541
542 let dml_executor = DmlExecutor::new(
543 ActorContext::for_test(0),
544 source,
545 dml_manager.clone(),
546 table_id,
547 INITIAL_TABLE_VERSION_ID,
548 column_descs,
549 1024,
550 RateLimit::Disabled,
551 MemoryStateStore::new(),
552 );
553 let mut dml_executor = dml_executor.boxed().execute();
554
555 let stream_chunk1 = StreamChunk::from_pretty(
556 " I I
557 + 1 1
558 + 2 2
559 + 3 6",
560 );
561 let stream_chunk2 = StreamChunk::from_pretty(
562 " I I
563 + 88 43",
564 );
565 let stream_chunk3 = StreamChunk::from_pretty(
566 " I I
567 + 199 40
568 + 978 72
569 + 134 41
570 + 398 98",
571 );
572 let batch_chunk = StreamChunk::from_pretty(
573 " I I
574 U+ 1 11
575 U+ 2 22",
576 );
577
578 tx.push_barrier(test_epoch(1), false);
580 let msg = dml_executor.next().await.unwrap().unwrap();
581 assert!(matches!(msg, Message::Barrier(_)));
582
583 tx.push_chunk(stream_chunk1);
585 tx.push_chunk(stream_chunk2);
586 tx.push_chunk(stream_chunk3);
587
588 let table_dml_handle = dml_manager
589 .table_dml_handle(table_id, INITIAL_TABLE_VERSION_ID)
590 .unwrap();
591 let mut write_handle = table_dml_handle
592 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
593 .unwrap();
594
595 write_handle.begin().unwrap();
597 write_handle.write_chunk(batch_chunk).await.unwrap();
598 tokio::spawn(async move {
601 write_handle.end().await.unwrap();
602 tx.push_barrier(test_epoch(2), false);
604 });
605
606 let msg = dml_executor.next().await.unwrap().unwrap();
608 assert_eq!(
609 msg.into_chunk().unwrap(),
610 StreamChunk::from_pretty(
611 " I I
612 + 1 1
613 + 2 2
614 + 3 6",
615 )
616 );
617
618 let msg = dml_executor.next().await.unwrap().unwrap();
625 assert_eq!(
626 msg.into_chunk().unwrap(),
627 StreamChunk::from_pretty(
628 " I I
629 + 88 43",
630 )
631 );
632
633 let msg = dml_executor.next().await.unwrap().unwrap();
637 assert_eq!(
638 msg.into_chunk().unwrap(),
639 StreamChunk::from_pretty(
640 " I I
641 + 199 40
642 + 978 72
643 + 134 41
644 + 398 98",
645 )
646 );
647
648 let msg = dml_executor.next().await.unwrap().unwrap();
650 assert_eq!(
651 msg.into_chunk().unwrap(),
652 StreamChunk::from_pretty(
653 " I I
654 U+ 1 11
655 U+ 2 22",
656 )
657 );
658
659 let msg = dml_executor.next().await.unwrap().unwrap();
660 assert!(matches!(msg, Message::Barrier(_)));
661 }
662
663 #[tokio::test]
664 async fn test_dml_executor_waits_for_barrier_prev_epoch_persistence() {
665 let table_id = TableId::new(233);
666 let schema = Schema::new(vec![Field::unnamed(DataType::Int64)]);
667 let column_descs = vec![ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64)];
668 let stream_key = vec![0];
669 let dml_manager = Arc::new(DmlManager::for_test());
670 let (wait_epoch_called_tx, wait_epoch_called_rx) = oneshot::channel();
671 let (wait_epoch_release_tx, wait_epoch_release_rx) = oneshot::channel();
672
673 let (mut tx, source) = MockSource::channel();
674 let source = source.into_executor(schema, stream_key);
675 let dml_executor = DmlExecutor::new(
676 ActorContext::for_test(0),
677 source,
678 dml_manager.clone(),
679 table_id,
680 INITIAL_TABLE_VERSION_ID,
681 column_descs,
682 1024,
683 RateLimit::Disabled,
684 MockWaitEpochStateStore {
685 wait_epoch_called_tx: Arc::new(Mutex::new(Some(wait_epoch_called_tx))),
686 wait_epoch_release_rx: Arc::new(tokio::sync::Mutex::new(Some(
687 wait_epoch_release_rx,
688 ))),
689 },
690 );
691 let mut dml_executor = dml_executor.boxed().execute();
692
693 tx.push_barrier_with_prev_epoch_for_test(test_epoch(10), test_epoch(9), false);
694 let msg = dml_executor.next().await.unwrap().unwrap();
695 assert!(matches!(msg, Message::Barrier(_)));
696
697 let table_dml_handle = dml_manager
698 .table_dml_handle(table_id, INITIAL_TABLE_VERSION_ID)
699 .unwrap();
700 let mut write_handle = table_dml_handle
701 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
702 .unwrap();
703 write_handle.begin().unwrap();
704 write_handle
705 .write_chunk(StreamChunk::from_pretty(
706 " I
707 + 7",
708 ))
709 .await
710 .unwrap();
711
712 let mut persistence_future = Box::pin(write_handle.end_wait_persistence().unwrap());
713
714 assert!(dml_executor.next().now_or_never().is_none());
717
718 let drain_handle = tokio::spawn(async move {
719 while let Some(msg) = dml_executor.next().await {
720 let _ = msg.unwrap();
721 }
722 });
723
724 tx.push_barrier_with_prev_epoch_for_test(test_epoch(11), test_epoch(10), false);
725
726 let (wait_epoch, options) = wait_epoch_called_rx.await.unwrap();
727 assert!(matches!(
728 wait_epoch,
729 HummockReadEpoch::Committed(epoch) if epoch == test_epoch(10)
730 ));
731 assert_eq!(options.table_id, table_id);
732 assert!(persistence_future.as_mut().now_or_never().is_none());
733
734 wait_epoch_release_tx.send(()).unwrap();
735 persistence_future.await.unwrap();
736
737 drain_handle.abort();
738 }
739}