risingwave_stream/executor/
dml.rs1use std::collections::BTreeMap;
16use std::mem;
17
18use either::Either;
19use futures::TryStreamExt;
20use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
21use risingwave_common::transaction::transaction_id::TxnId;
22use risingwave_common::transaction::transaction_message::TxnMsg;
23use risingwave_common_rate_limit::{MonitoredRateLimiter, RateLimit, RateLimiter};
24use risingwave_dml::dml_manager::DmlManagerRef;
25use risingwave_expr::codegen::BoxStream;
26use risingwave_pb::common::ThrottleType;
27
28use crate::executor::prelude::*;
29use crate::executor::stream_reader::StreamReaderWithPause;
30
31pub struct DmlExecutor {
34 actor_ctx: ActorContextRef,
35
36 upstream: Executor,
37
38 dml_manager: DmlManagerRef,
40
41 table_id: TableId,
43
44 table_version_id: TableVersionId,
46
47 column_descs: Vec<ColumnDesc>,
49
50 chunk_size: usize,
51
52 rate_limiter: Arc<MonitoredRateLimiter>,
53}
54
55const MAX_CHUNK_FOR_ATOMICITY: usize = 32;
63
64#[derive(Debug, Default)]
65struct TxnBuffer {
66 vec: Vec<StreamChunk>,
67 overflow: bool,
69}
70
71impl DmlExecutor {
72 #[allow(clippy::too_many_arguments)]
73 pub fn new(
74 actor_ctx: ActorContextRef,
75 upstream: Executor,
76 dml_manager: DmlManagerRef,
77 table_id: TableId,
78 table_version_id: TableVersionId,
79 column_descs: Vec<ColumnDesc>,
80 chunk_size: usize,
81 rate_limit: RateLimit,
82 ) -> Self {
83 let rate_limiter = Arc::new(RateLimiter::new(rate_limit).monitored(table_id));
84 Self {
85 actor_ctx,
86 upstream,
87 dml_manager,
88 table_id,
89 table_version_id,
90 column_descs,
91 chunk_size,
92 rate_limiter,
93 }
94 }
95
96 #[try_stream(ok = Message, error = StreamExecutorError)]
97 async fn execute_inner(self: Box<Self>) {
98 let mut upstream = self.upstream.execute();
99
100 let actor_id = self.actor_ctx.id;
101
102 let barrier = expect_first_barrier(&mut upstream).await?;
104
105 let handle = self.dml_manager.register_reader(
113 self.table_id,
114 self.table_version_id,
115 &self.column_descs,
116 )?;
117 let reader = apply_dml_rate_limit(
118 handle.stream_reader().into_stream(),
119 self.rate_limiter.clone(),
120 )
121 .boxed()
122 .map_err(StreamExecutorError::from);
123
124 let mut stream = StreamReaderWithPause::<false, TxnMsg>::new(upstream, reader);
128
129 if barrier.is_pause_on_startup() {
131 stream.pause_stream();
132 }
133
134 let mut epoch = barrier.get_curr_epoch();
135
136 yield Message::Barrier(barrier);
137
138 let mut active_txn_map: BTreeMap<TxnId, TxnBuffer> = Default::default();
140 let mut batch_group: Vec<StreamChunk> = vec![];
142
143 let mut builder = StreamChunkBuilder::new(
144 self.chunk_size,
145 self.column_descs
146 .iter()
147 .map(|c| c.data_type.clone())
148 .collect(),
149 );
150
151 while let Some(input_msg) = stream.next().await {
152 match input_msg? {
153 Either::Left(msg) => {
154 if let Message::Barrier(barrier) = &msg {
156 epoch = barrier.get_curr_epoch();
157 if let Some(mutation) = barrier.mutation.as_deref() {
160 match mutation {
161 Mutation::Pause => stream.pause_stream(),
162 Mutation::Resume => stream.resume_stream(),
163 Mutation::Throttle(fragment_to_apply) => {
164 if let Some(entry) =
165 fragment_to_apply.get(&self.actor_ctx.fragment_id)
166 && entry.throttle_type() == ThrottleType::Dml
167 {
168 let new_rate_limit = entry.rate_limit.into();
169 let old_rate_limit =
170 self.rate_limiter.update(new_rate_limit);
171
172 if old_rate_limit != new_rate_limit {
173 tracing::info!(
174 old_rate_limit = ?old_rate_limit,
175 new_rate_limit = ?new_rate_limit,
176 %actor_id,
177 "dml rate limit changed",
178 );
179 }
180 }
181 }
182 _ => {}
183 }
184 }
185
186 if !batch_group.is_empty() {
188 let vec = mem::take(&mut batch_group);
189 for chunk in vec {
190 for (op, row) in chunk.rows() {
191 if let Some(chunk) = builder.append_row(op, row) {
192 yield Message::Chunk(chunk);
193 }
194 }
195 }
196 if let Some(chunk) = builder.take() {
197 yield Message::Chunk(chunk);
198 }
199 }
200 }
201 yield msg;
202 }
203 Either::Right(txn_msg) => {
204 match txn_msg {
206 TxnMsg::Begin(txn_id) => {
207 active_txn_map
208 .try_insert(txn_id, TxnBuffer::default())
209 .unwrap_or_else(|_| {
210 panic!("Transaction id collision txn_id = {}.", txn_id)
211 });
212 }
213 TxnMsg::End(txn_id, epoch_notifier) => {
214 if let Some(sender) = epoch_notifier {
215 let _ = sender.send(epoch);
216 }
217 let mut txn_buffer = active_txn_map.remove(&txn_id)
218 .unwrap_or_else(|| panic!("Receive an unexpected transaction end message. Active transaction map doesn't contain this transaction txn_id = {}.", txn_id));
219
220 let txn_buffer_cardinality = txn_buffer
221 .vec
222 .iter()
223 .map(|c| c.cardinality())
224 .sum::<usize>();
225 let batch_group_cardinality =
226 batch_group.iter().map(|c| c.cardinality()).sum::<usize>();
227
228 if txn_buffer_cardinality >= self.chunk_size {
229 if !batch_group.is_empty() {
231 let vec = mem::take(&mut batch_group);
232 for chunk in vec {
233 for (op, row) in chunk.rows() {
234 if let Some(chunk) = builder.append_row(op, row) {
235 yield Message::Chunk(chunk);
236 }
237 }
238 }
239 if let Some(chunk) = builder.take() {
240 yield Message::Chunk(chunk);
241 }
242 }
243
244 for chunk in txn_buffer.vec {
246 yield Message::Chunk(chunk);
247 }
248 } else if txn_buffer_cardinality + batch_group_cardinality
249 <= self.chunk_size
250 {
251 batch_group.extend(txn_buffer.vec);
253 } else {
254 if !batch_group.is_empty() {
256 let vec = mem::take(&mut batch_group);
257 for chunk in vec {
258 for (op, row) in chunk.rows() {
259 if let Some(chunk) = builder.append_row(op, row) {
260 yield Message::Chunk(chunk);
261 }
262 }
263 }
264 if let Some(chunk) = builder.take() {
265 yield Message::Chunk(chunk);
266 }
267 }
268
269 mem::swap(&mut txn_buffer.vec, &mut batch_group);
271 }
272 }
273 TxnMsg::Rollback(txn_id) => {
274 let txn_buffer = active_txn_map.remove(&txn_id)
275 .unwrap_or_else(|| panic!("Receive an unexpected transaction rollback message. Active transaction map doesn't contain this transaction txn_id = {}.", txn_id));
276 if txn_buffer.overflow {
277 tracing::warn!(
278 "txn_id={} large transaction tries to rollback, but part of its data has already been sent to the downstream.",
279 txn_id
280 );
281 }
282 }
283 TxnMsg::Data(txn_id, chunk) => {
284 match active_txn_map.get_mut(&txn_id) {
285 Some(txn_buffer) => {
286 if txn_buffer.overflow {
289 yield Message::Chunk(chunk);
290 continue;
291 }
292 txn_buffer.vec.push(chunk);
293 if txn_buffer.vec.len() > MAX_CHUNK_FOR_ATOMICITY {
294 tracing::warn!(
296 "txn_id={} Too many chunks for atomicity. Sent them to the downstream anyway.",
297 txn_id
298 );
299 for chunk in txn_buffer.vec.drain(..) {
300 yield Message::Chunk(chunk);
301 }
302 txn_buffer.overflow = true;
303 }
304 }
305 None => panic!(
306 "Receive an unexpected transaction data message. Active transaction map doesn't contain this transaction txn_id = {}.",
307 txn_id
308 ),
309 };
310 }
311 }
312 }
313 }
314 }
315 }
316}
317
318impl Execute for DmlExecutor {
319 fn execute(self: Box<Self>) -> BoxedMessageStream {
320 self.execute_inner().boxed()
321 }
322}
323
324type BoxTxnMessageStream = BoxStream<'static, risingwave_dml::error::Result<TxnMsg>>;
325#[try_stream(ok = TxnMsg, error = risingwave_dml::error::DmlError)]
326async fn apply_dml_rate_limit(
327 stream: BoxTxnMessageStream,
328 rate_limiter: Arc<MonitoredRateLimiter>,
329) {
330 #[for_await]
331 for txn_msg in stream {
332 match txn_msg? {
333 TxnMsg::Begin(txn_id) => {
334 yield TxnMsg::Begin(txn_id);
335 }
336 TxnMsg::End(txn_id, epoch_notifier) => {
337 yield TxnMsg::End(txn_id, epoch_notifier);
338 }
339 TxnMsg::Rollback(txn_id) => {
340 yield TxnMsg::Rollback(txn_id);
341 }
342 TxnMsg::Data(txn_id, chunk) => {
343 let chunk_size = chunk.capacity();
344 if chunk_size == 0 {
345 yield TxnMsg::Data(txn_id, chunk);
347 continue;
348 }
349 let rate_limit = loop {
350 match rate_limiter.rate_limit() {
351 RateLimit::Pause => rate_limiter.wait(0).await,
352 limit => break limit,
353 }
354 };
355
356 match rate_limit {
357 RateLimit::Pause => unreachable!(),
358 RateLimit::Disabled => {
359 yield TxnMsg::Data(txn_id, chunk);
360 continue;
361 }
362 RateLimit::Fixed(limit) => {
363 let max_permits = limit.get();
364 let required_permits = chunk.rate_limit_permits();
365 if required_permits <= max_permits {
366 rate_limiter.wait(required_permits).await;
367 yield TxnMsg::Data(txn_id, chunk);
368 } else {
369 for small_chunk in chunk.split(max_permits as _) {
371 rate_limiter.wait_chunk(&small_chunk).await;
372 yield TxnMsg::Data(txn_id, small_chunk);
373 }
374 }
375 }
376 }
377 }
378 }
379 }
380}
381
382#[cfg(test)]
383mod tests {
384
385 use risingwave_common::catalog::{ColumnId, Field, INITIAL_TABLE_VERSION_ID};
386 use risingwave_common::test_prelude::StreamChunkTestExt;
387 use risingwave_common::util::epoch::test_epoch;
388 use risingwave_dml::dml_manager::DmlManager;
389
390 use super::*;
391 use crate::executor::test_utils::MockSource;
392
393 const TEST_TRANSACTION_ID: TxnId = 0;
394 const TEST_SESSION_ID: u32 = 0;
395
396 #[tokio::test]
397 async fn test_dml_executor() {
398 let table_id = TableId::default();
399 let schema = Schema::new(vec![
400 Field::unnamed(DataType::Int64),
401 Field::unnamed(DataType::Int64),
402 ]);
403 let column_descs = vec![
404 ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64),
405 ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64),
406 ];
407 let stream_key = vec![0];
408 let dml_manager = Arc::new(DmlManager::for_test());
409
410 let (mut tx, source) = MockSource::channel();
411 let source = source.into_executor(schema, stream_key);
412
413 let dml_executor = DmlExecutor::new(
414 ActorContext::for_test(0),
415 source,
416 dml_manager.clone(),
417 table_id,
418 INITIAL_TABLE_VERSION_ID,
419 column_descs,
420 1024,
421 RateLimit::Disabled,
422 );
423 let mut dml_executor = dml_executor.boxed().execute();
424
425 let stream_chunk1 = StreamChunk::from_pretty(
426 " I I
427 + 1 1
428 + 2 2
429 + 3 6",
430 );
431 let stream_chunk2 = StreamChunk::from_pretty(
432 " I I
433 + 88 43",
434 );
435 let stream_chunk3 = StreamChunk::from_pretty(
436 " I I
437 + 199 40
438 + 978 72
439 + 134 41
440 + 398 98",
441 );
442 let batch_chunk = StreamChunk::from_pretty(
443 " I I
444 U+ 1 11
445 U+ 2 22",
446 );
447
448 tx.push_barrier(test_epoch(1), false);
450 let msg = dml_executor.next().await.unwrap().unwrap();
451 assert!(matches!(msg, Message::Barrier(_)));
452
453 tx.push_chunk(stream_chunk1);
455 tx.push_chunk(stream_chunk2);
456 tx.push_chunk(stream_chunk3);
457
458 let table_dml_handle = dml_manager
459 .table_dml_handle(table_id, INITIAL_TABLE_VERSION_ID)
460 .unwrap();
461 let mut write_handle = table_dml_handle
462 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
463 .unwrap();
464
465 write_handle.begin().unwrap();
467 write_handle.write_chunk(batch_chunk).await.unwrap();
468 tokio::spawn(async move {
471 write_handle.end().await.unwrap();
472 tx.push_barrier(test_epoch(2), false);
474 });
475
476 let msg = dml_executor.next().await.unwrap().unwrap();
478 assert_eq!(
479 msg.into_chunk().unwrap(),
480 StreamChunk::from_pretty(
481 " I I
482 + 1 1
483 + 2 2
484 + 3 6",
485 )
486 );
487
488 let msg = dml_executor.next().await.unwrap().unwrap();
495 assert_eq!(
496 msg.into_chunk().unwrap(),
497 StreamChunk::from_pretty(
498 " I I
499 + 88 43",
500 )
501 );
502
503 let msg = dml_executor.next().await.unwrap().unwrap();
507 assert_eq!(
508 msg.into_chunk().unwrap(),
509 StreamChunk::from_pretty(
510 " I I
511 + 199 40
512 + 978 72
513 + 134 41
514 + 398 98",
515 )
516 );
517
518 let msg = dml_executor.next().await.unwrap().unwrap();
520 assert_eq!(
521 msg.into_chunk().unwrap(),
522 StreamChunk::from_pretty(
523 " I I
524 U+ 1 11
525 U+ 2 22",
526 )
527 );
528
529 let msg = dml_executor.next().await.unwrap().unwrap();
530 assert!(matches!(msg, Message::Barrier(_)));
531 }
532}