risingwave_stream/executor/
dml.rs1use std::collections::BTreeMap;
16use std::mem;
17
18use either::Either;
19use futures::TryStreamExt;
20use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
21use risingwave_common::transaction::transaction_id::TxnId;
22use risingwave_common::transaction::transaction_message::TxnMsg;
23use risingwave_common_rate_limit::{MonitoredRateLimiter, RateLimit, RateLimiter};
24use risingwave_dml::dml_manager::DmlManagerRef;
25use risingwave_expr::codegen::BoxStream;
26
27use crate::executor::prelude::*;
28use crate::executor::stream_reader::StreamReaderWithPause;
29
30pub struct DmlExecutor {
33 actor_ctx: ActorContextRef,
34
35 upstream: Executor,
36
37 dml_manager: DmlManagerRef,
39
40 table_id: TableId,
42
43 table_version_id: TableVersionId,
45
46 column_descs: Vec<ColumnDesc>,
48
49 chunk_size: usize,
50
51 rate_limiter: Arc<MonitoredRateLimiter>,
52}
53
54const MAX_CHUNK_FOR_ATOMICITY: usize = 32;
62
63#[derive(Debug, Default)]
64struct TxnBuffer {
65 vec: Vec<StreamChunk>,
66 overflow: bool,
68}
69
70impl DmlExecutor {
71 #[allow(clippy::too_many_arguments)]
72 pub fn new(
73 actor_ctx: ActorContextRef,
74 upstream: Executor,
75 dml_manager: DmlManagerRef,
76 table_id: TableId,
77 table_version_id: TableVersionId,
78 column_descs: Vec<ColumnDesc>,
79 chunk_size: usize,
80 rate_limit: RateLimit,
81 ) -> Self {
82 let rate_limiter = Arc::new(RateLimiter::new(rate_limit).monitored(table_id));
83 Self {
84 actor_ctx,
85 upstream,
86 dml_manager,
87 table_id,
88 table_version_id,
89 column_descs,
90 chunk_size,
91 rate_limiter,
92 }
93 }
94
95 #[try_stream(ok = Message, error = StreamExecutorError)]
96 async fn execute_inner(self: Box<Self>) {
97 let mut upstream = self.upstream.execute();
98
99 let actor_id = self.actor_ctx.id;
100
101 let barrier = expect_first_barrier(&mut upstream).await?;
103
104 let handle = self.dml_manager.register_reader(
112 self.table_id,
113 self.table_version_id,
114 &self.column_descs,
115 )?;
116 let reader = apply_dml_rate_limit(
117 handle.stream_reader().into_stream(),
118 self.rate_limiter.clone(),
119 )
120 .boxed()
121 .map_err(StreamExecutorError::from);
122
123 let mut stream = StreamReaderWithPause::<false, TxnMsg>::new(upstream, reader);
127
128 if barrier.is_pause_on_startup() {
130 stream.pause_stream();
131 }
132
133 let mut epoch = barrier.get_curr_epoch();
134
135 yield Message::Barrier(barrier);
136
137 let mut active_txn_map: BTreeMap<TxnId, TxnBuffer> = Default::default();
139 let mut batch_group: Vec<StreamChunk> = vec![];
141
142 let mut builder = StreamChunkBuilder::new(
143 self.chunk_size,
144 self.column_descs
145 .iter()
146 .map(|c| c.data_type.clone())
147 .collect(),
148 );
149
150 while let Some(input_msg) = stream.next().await {
151 match input_msg? {
152 Either::Left(msg) => {
153 if let Message::Barrier(barrier) = &msg {
155 epoch = barrier.get_curr_epoch();
156 if let Some(mutation) = barrier.mutation.as_deref() {
159 match mutation {
160 Mutation::Pause => stream.pause_stream(),
161 Mutation::Resume => stream.resume_stream(),
162 Mutation::Throttle(actor_to_apply) => {
163 if let Some(new_rate_limit) =
164 actor_to_apply.get(&self.actor_ctx.id)
165 {
166 let new_rate_limit = (*new_rate_limit).into();
167 let old_rate_limit =
168 self.rate_limiter.update(new_rate_limit);
169
170 if old_rate_limit != new_rate_limit {
171 tracing::info!(
172 old_rate_limit = ?old_rate_limit,
173 new_rate_limit = ?new_rate_limit,
174 actor_id,
175 "dml rate limit changed",
176 );
177 }
178 }
179 }
180 _ => {}
181 }
182 }
183
184 if !batch_group.is_empty() {
186 let vec = mem::take(&mut batch_group);
187 for chunk in vec {
188 for (op, row) in chunk.rows() {
189 if let Some(chunk) = builder.append_row(op, row) {
190 yield Message::Chunk(chunk);
191 }
192 }
193 }
194 if let Some(chunk) = builder.take() {
195 yield Message::Chunk(chunk);
196 }
197 }
198 }
199 yield msg;
200 }
201 Either::Right(txn_msg) => {
202 match txn_msg {
204 TxnMsg::Begin(txn_id) => {
205 active_txn_map
206 .try_insert(txn_id, TxnBuffer::default())
207 .unwrap_or_else(|_| {
208 panic!("Transaction id collision txn_id = {}.", txn_id)
209 });
210 }
211 TxnMsg::End(txn_id, epoch_notifier) => {
212 if let Some(sender) = epoch_notifier {
213 let _ = sender.send(epoch);
214 }
215 let mut txn_buffer = active_txn_map.remove(&txn_id)
216 .unwrap_or_else(|| panic!("Receive an unexpected transaction end message. Active transaction map doesn't contain this transaction txn_id = {}.", txn_id));
217
218 let txn_buffer_cardinality = txn_buffer
219 .vec
220 .iter()
221 .map(|c| c.cardinality())
222 .sum::<usize>();
223 let batch_group_cardinality =
224 batch_group.iter().map(|c| c.cardinality()).sum::<usize>();
225
226 if txn_buffer_cardinality >= self.chunk_size {
227 if !batch_group.is_empty() {
229 let vec = mem::take(&mut batch_group);
230 for chunk in vec {
231 for (op, row) in chunk.rows() {
232 if let Some(chunk) = builder.append_row(op, row) {
233 yield Message::Chunk(chunk);
234 }
235 }
236 }
237 if let Some(chunk) = builder.take() {
238 yield Message::Chunk(chunk);
239 }
240 }
241
242 for chunk in txn_buffer.vec {
244 yield Message::Chunk(chunk);
245 }
246 } else if txn_buffer_cardinality + batch_group_cardinality
247 <= self.chunk_size
248 {
249 batch_group.extend(txn_buffer.vec);
251 } else {
252 if !batch_group.is_empty() {
254 let vec = mem::take(&mut batch_group);
255 for chunk in vec {
256 for (op, row) in chunk.rows() {
257 if let Some(chunk) = builder.append_row(op, row) {
258 yield Message::Chunk(chunk);
259 }
260 }
261 }
262 if let Some(chunk) = builder.take() {
263 yield Message::Chunk(chunk);
264 }
265 }
266
267 mem::swap(&mut txn_buffer.vec, &mut batch_group);
269 }
270 }
271 TxnMsg::Rollback(txn_id) => {
272 let txn_buffer = active_txn_map.remove(&txn_id)
273 .unwrap_or_else(|| panic!("Receive an unexpected transaction rollback message. Active transaction map doesn't contain this transaction txn_id = {}.", txn_id));
274 if txn_buffer.overflow {
275 tracing::warn!(
276 "txn_id={} large transaction tries to rollback, but part of its data has already been sent to the downstream.",
277 txn_id
278 );
279 }
280 }
281 TxnMsg::Data(txn_id, chunk) => {
282 match active_txn_map.get_mut(&txn_id) {
283 Some(txn_buffer) => {
284 if txn_buffer.overflow {
287 yield Message::Chunk(chunk);
288 continue;
289 }
290 txn_buffer.vec.push(chunk);
291 if txn_buffer.vec.len() > MAX_CHUNK_FOR_ATOMICITY {
292 tracing::warn!(
294 "txn_id={} Too many chunks for atomicity. Sent them to the downstream anyway.",
295 txn_id
296 );
297 for chunk in txn_buffer.vec.drain(..) {
298 yield Message::Chunk(chunk);
299 }
300 txn_buffer.overflow = true;
301 }
302 }
303 None => panic!(
304 "Receive an unexpected transaction data message. Active transaction map doesn't contain this transaction txn_id = {}.",
305 txn_id
306 ),
307 };
308 }
309 }
310 }
311 }
312 }
313 }
314}
315
316impl Execute for DmlExecutor {
317 fn execute(self: Box<Self>) -> BoxedMessageStream {
318 self.execute_inner().boxed()
319 }
320}
321
322type BoxTxnMessageStream = BoxStream<'static, risingwave_dml::error::Result<TxnMsg>>;
323#[try_stream(ok = TxnMsg, error = risingwave_dml::error::DmlError)]
324async fn apply_dml_rate_limit(
325 stream: BoxTxnMessageStream,
326 rate_limiter: Arc<MonitoredRateLimiter>,
327) {
328 #[for_await]
329 for txn_msg in stream {
330 match txn_msg? {
331 TxnMsg::Begin(txn_id) => {
332 yield TxnMsg::Begin(txn_id);
333 }
334 TxnMsg::End(txn_id, epoch_notifier) => {
335 yield TxnMsg::End(txn_id, epoch_notifier);
336 }
337 TxnMsg::Rollback(txn_id) => {
338 yield TxnMsg::Rollback(txn_id);
339 }
340 TxnMsg::Data(txn_id, chunk) => {
341 let chunk_size = chunk.capacity();
342 if chunk_size == 0 {
343 yield TxnMsg::Data(txn_id, chunk);
345 continue;
346 }
347 let rate_limit = loop {
348 match rate_limiter.rate_limit() {
349 RateLimit::Pause => rate_limiter.wait(0).await,
350 limit => break limit,
351 }
352 };
353
354 match rate_limit {
355 RateLimit::Pause => unreachable!(),
356 RateLimit::Disabled => {
357 yield TxnMsg::Data(txn_id, chunk);
358 continue;
359 }
360 RateLimit::Fixed(limit) => {
361 let max_permits = limit.get();
362 let required_permits = chunk.compute_rate_limit_chunk_permits();
363 if required_permits <= max_permits {
364 rate_limiter.wait(required_permits).await;
365 yield TxnMsg::Data(txn_id, chunk);
366 } else {
367 for small_chunk in chunk.split(max_permits as _) {
369 let required_permits =
370 small_chunk.compute_rate_limit_chunk_permits();
371 rate_limiter.wait(required_permits).await;
372 yield TxnMsg::Data(txn_id, small_chunk);
373 }
374 }
375 }
376 }
377 }
378 }
379 }
380}
381
382#[cfg(test)]
383mod tests {
384
385 use risingwave_common::catalog::{ColumnId, Field, INITIAL_TABLE_VERSION_ID};
386 use risingwave_common::test_prelude::StreamChunkTestExt;
387 use risingwave_common::util::epoch::test_epoch;
388 use risingwave_dml::dml_manager::DmlManager;
389
390 use super::*;
391 use crate::executor::test_utils::MockSource;
392
393 const TEST_TRANSACTION_ID: TxnId = 0;
394 const TEST_SESSION_ID: u32 = 0;
395
396 #[tokio::test]
397 async fn test_dml_executor() {
398 let table_id = TableId::default();
399 let schema = Schema::new(vec![
400 Field::unnamed(DataType::Int64),
401 Field::unnamed(DataType::Int64),
402 ]);
403 let column_descs = vec![
404 ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64),
405 ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64),
406 ];
407 let pk_indices = vec![0];
408 let dml_manager = Arc::new(DmlManager::for_test());
409
410 let (mut tx, source) = MockSource::channel();
411 let source = source.into_executor(schema, pk_indices);
412
413 let dml_executor = DmlExecutor::new(
414 ActorContext::for_test(0),
415 source,
416 dml_manager.clone(),
417 table_id,
418 INITIAL_TABLE_VERSION_ID,
419 column_descs,
420 1024,
421 RateLimit::Disabled,
422 );
423 let mut dml_executor = dml_executor.boxed().execute();
424
425 let stream_chunk1 = StreamChunk::from_pretty(
426 " I I
427 + 1 1
428 + 2 2
429 + 3 6",
430 );
431 let stream_chunk2 = StreamChunk::from_pretty(
432 " I I
433 + 88 43",
434 );
435 let stream_chunk3 = StreamChunk::from_pretty(
436 " I I
437 + 199 40
438 + 978 72
439 + 134 41
440 + 398 98",
441 );
442 let batch_chunk = StreamChunk::from_pretty(
443 " I I
444 U+ 1 11
445 U+ 2 22",
446 );
447
448 tx.push_barrier(test_epoch(1), false);
450 let msg = dml_executor.next().await.unwrap().unwrap();
451 assert!(matches!(msg, Message::Barrier(_)));
452
453 tx.push_chunk(stream_chunk1);
455 tx.push_chunk(stream_chunk2);
456 tx.push_chunk(stream_chunk3);
457
458 let table_dml_handle = dml_manager
459 .table_dml_handle(table_id, INITIAL_TABLE_VERSION_ID)
460 .unwrap();
461 let mut write_handle = table_dml_handle
462 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
463 .unwrap();
464
465 write_handle.begin().unwrap();
467 write_handle.write_chunk(batch_chunk).await.unwrap();
468 tokio::spawn(async move {
471 write_handle.end().await.unwrap();
472 tx.push_barrier(test_epoch(2), false);
474 });
475
476 let msg = dml_executor.next().await.unwrap().unwrap();
478 assert_eq!(
479 msg.into_chunk().unwrap(),
480 StreamChunk::from_pretty(
481 " I I
482 + 1 1
483 + 2 2
484 + 3 6",
485 )
486 );
487
488 let msg = dml_executor.next().await.unwrap().unwrap();
495 assert_eq!(
496 msg.into_chunk().unwrap(),
497 StreamChunk::from_pretty(
498 " I I
499 + 88 43",
500 )
501 );
502
503 let msg = dml_executor.next().await.unwrap().unwrap();
507 assert_eq!(
508 msg.into_chunk().unwrap(),
509 StreamChunk::from_pretty(
510 " I I
511 + 199 40
512 + 978 72
513 + 134 41
514 + 398 98",
515 )
516 );
517
518 let msg = dml_executor.next().await.unwrap().unwrap();
520 assert_eq!(
521 msg.into_chunk().unwrap(),
522 StreamChunk::from_pretty(
523 " I I
524 U+ 1 11
525 U+ 2 22",
526 )
527 );
528
529 let msg = dml_executor.next().await.unwrap().unwrap();
530 assert!(matches!(msg, Message::Barrier(_)));
531 }
532}