1use std::future::Future;
16use std::sync::Arc;
17
18use futures_async_stream::try_stream;
19use parking_lot::RwLock;
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::ColumnDesc;
22use risingwave_common::transaction::transaction_id::TxnId;
23use risingwave_common::transaction::transaction_message::TxnMsg;
24use tokio::sync::oneshot;
25
26use crate::error::{DmlError, Result};
27use crate::txn_channel::{Receiver, Sender, txn_channel};
28
29pub type TableDmlHandleRef = Arc<TableDmlHandle>;
30
31#[derive(Debug)]
32pub struct TableDmlHandleCore {
33 pub changes_txs: Vec<Sender>,
38}
39
40#[derive(Debug)]
47pub struct TableDmlHandle {
48 pub core: RwLock<TableDmlHandleCore>,
49
50 pub column_descs: Vec<ColumnDesc>,
52
53 dml_channel_initial_permits: usize,
55}
56
57impl TableDmlHandle {
58 pub fn new(column_descs: Vec<ColumnDesc>, dml_channel_initial_permits: usize) -> Self {
59 let core = TableDmlHandleCore {
60 changes_txs: vec![],
61 };
62
63 Self {
64 core: RwLock::new(core),
65 column_descs,
66 dml_channel_initial_permits,
67 }
68 }
69
70 pub fn stream_reader(&self) -> TableStreamReader {
71 let mut core = self.core.write();
72 let (tx, rx) = txn_channel(self.dml_channel_initial_permits);
75 core.changes_txs.push(tx);
76
77 TableStreamReader { rx }
78 }
79
80 pub fn write_handle(&self, session_id: u32, txn_id: TxnId) -> Result<WriteHandle> {
81 loop {
89 let guard = self.core.read();
90 if guard.changes_txs.is_empty() {
91 return Err(DmlError::NoReader);
92 }
93 let len = guard.changes_txs.len();
94 let sender = guard
97 .changes_txs
98 .get((session_id % len as u32) as usize)
99 .unwrap()
100 .clone();
101
102 drop(guard);
103
104 if sender.is_closed() {
105 self.core
107 .write()
108 .changes_txs
109 .retain(|sender| !sender.is_closed());
110 } else {
111 return Ok(WriteHandle::new(txn_id, sender));
112 }
113 }
114 }
115
116 pub fn column_descs(&self) -> &[ColumnDesc] {
118 self.column_descs.as_ref()
119 }
120
121 pub fn check_chunk_schema(&self, chunk: &StreamChunk) {
122 risingwave_common::util::schema_check::schema_check(
123 self.column_descs
124 .iter()
125 .filter_map(|c| (!c.is_generated()).then_some(&c.data_type)),
126 chunk.columns(),
127 )
128 .expect("table source write txn_msg schema check failed");
129 }
130}
131
132#[derive(Debug, PartialEq)]
133enum TxnState {
134 Init,
135 Begin,
136 Committed,
137 Rollback,
138}
139
140pub struct WriteHandle {
149 txn_id: TxnId,
150 tx: Sender,
151 txn_state: TxnState,
153}
154
155impl Drop for WriteHandle {
156 fn drop(&mut self) {
157 if self.txn_state == TxnState::Begin {
158 let _ = self.rollback_inner();
159 }
160 }
161}
162
163impl WriteHandle {
164 pub fn new(txn_id: TxnId, tx: Sender) -> Self {
165 Self {
166 txn_id,
167 tx,
168 txn_state: TxnState::Init,
169 }
170 }
171
172 pub fn begin(&mut self) -> Result<()> {
173 assert_eq!(self.txn_state, TxnState::Init);
174 self.txn_state = TxnState::Begin;
175 self.write_txn_control_msg(TxnMsg::Begin(self.txn_id))?;
177 Ok(())
178 }
179
180 pub async fn write_chunk(&self, chunk: StreamChunk) -> Result<()> {
181 assert_eq!(self.txn_state, TxnState::Begin);
182 let _notifier = self
184 .write_txn_data_msg(TxnMsg::Data(self.txn_id, chunk))
185 .await?;
186 Ok(())
187 }
188
189 pub async fn end(mut self) -> Result<()> {
190 assert_eq!(self.txn_state, TxnState::Begin);
191 self.txn_state = TxnState::Committed;
192 let notifier = self.write_txn_control_msg(TxnMsg::End(self.txn_id, None))?;
194 notifier.await.map_err(|_| DmlError::ReaderClosed)?;
195 Ok(())
196 }
197
198 pub fn end_wait_persistence(mut self) -> Result<impl Future<Output = Result<()>> + 'static> {
201 assert_eq!(self.txn_state, TxnState::Begin);
202 self.txn_state = TxnState::Committed;
203 let (persistence_tx, persistence_rx) = oneshot::channel();
204 let notifier =
205 self.write_txn_control_msg(TxnMsg::End(self.txn_id, Some(persistence_tx)))?;
206 Ok(async move {
207 notifier.await.map_err(|_| DmlError::ReaderClosed)?;
208 persistence_rx.await.map_err(|_| DmlError::ReaderClosed)?;
209 Ok(())
210 })
211 }
212
213 pub fn rollback(mut self) -> Result<oneshot::Receiver<usize>> {
214 self.rollback_inner()
215 }
216
217 fn rollback_inner(&mut self) -> Result<oneshot::Receiver<usize>> {
218 assert_eq!(self.txn_state, TxnState::Begin);
219 self.txn_state = TxnState::Rollback;
220 self.write_txn_control_msg(TxnMsg::Rollback(self.txn_id))
221 }
222
223 async fn write_txn_data_msg(&self, txn_msg: TxnMsg) -> Result<oneshot::Receiver<usize>> {
229 assert_eq!(self.txn_id, txn_msg.txn_id());
230 let (notifier_tx, notifier_rx) = oneshot::channel();
231 match self.tx.send(txn_msg, notifier_tx).await {
232 Ok(_) => Ok(notifier_rx),
233
234 Err(_) => Err(DmlError::ReaderClosed),
237 }
238 }
239
240 fn write_txn_control_msg(&self, txn_msg: TxnMsg) -> Result<oneshot::Receiver<usize>> {
243 assert_eq!(self.txn_id, txn_msg.txn_id());
244 let (notifier_tx, notifier_rx) = oneshot::channel();
245 match self.tx.send_immediate(txn_msg, notifier_tx) {
246 Ok(_) => Ok(notifier_rx),
247
248 Err(_) => Err(DmlError::ReaderClosed),
251 }
252 }
253}
254
255#[derive(Debug)]
260pub struct TableStreamReader {
261 rx: Receiver,
263}
264
265impl TableStreamReader {
266 #[try_stream(boxed, ok = StreamChunk, error = DmlError)]
267 pub async fn into_data_stream_for_test(mut self) {
268 while let Some((txn_msg, notifier)) = self.rx.recv().await {
269 match txn_msg {
271 TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => {
272 _ = notifier.send(0);
273 }
274 TxnMsg::Data(_, chunk) => {
275 _ = notifier.send(chunk.cardinality());
276 yield chunk;
277 }
278 }
279 }
280 }
281
282 #[try_stream(boxed, ok = TxnMsg, error = DmlError)]
283 pub async fn into_stream(mut self) {
284 while let Some((txn_msg, notifier)) = self.rx.recv().await {
285 match &txn_msg {
287 TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => {
288 _ = notifier.send(0);
289 yield txn_msg;
290 }
291 TxnMsg::Data(_, chunk) => {
292 _ = notifier.send(chunk.cardinality());
293 yield txn_msg;
294 }
295 }
296 }
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use assert_matches::assert_matches;
303 use futures::StreamExt;
304 use itertools::Itertools;
305 use risingwave_common::array::{Array, I64Array, Op};
306 use risingwave_common::catalog::ColumnId;
307 use risingwave_common::types::DataType;
308
309 use super::*;
310
311 const TEST_TRANSACTION_ID: TxnId = 0;
312 const TEST_SESSION_ID: u32 = 0;
313
314 fn new_table_dml_handle() -> TableDmlHandle {
315 TableDmlHandle::new(
316 vec![ColumnDesc::unnamed(ColumnId::from(0), DataType::Int64)],
317 32768,
318 )
319 }
320
321 #[tokio::test]
322 async fn test_table_dml_handle() -> Result<()> {
323 let table_dml_handle = Arc::new(new_table_dml_handle());
324 let mut reader = table_dml_handle.stream_reader().into_stream();
325 let mut write_handle = table_dml_handle
326 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
327 .unwrap();
328 write_handle.begin().unwrap();
329
330 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
331
332 macro_rules! write_chunk {
333 ($i:expr) => {{
334 let chunk =
335 StreamChunk::new(vec![Op::Insert], vec![I64Array::from_iter([$i]).into_ref()]);
336 write_handle.write_chunk(chunk).await.unwrap();
337 }};
338 }
339
340 write_chunk!(0);
341
342 macro_rules! check_next_chunk {
343 ($i: expr) => {
344 assert_matches!(reader.next().await.unwrap()?, txn_msg => {
345 let chunk = txn_msg.as_stream_chunk().unwrap();
346 assert_eq!(chunk.columns()[0].as_int64().iter().collect_vec(), vec![Some($i)]);
347 });
348 }
349 }
350
351 check_next_chunk!(0);
352
353 write_chunk!(1);
354 check_next_chunk!(1);
355
356 tokio::spawn(async move {
359 write_handle.end().await.unwrap();
360 });
361
362 assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
363
364 Ok(())
365 }
366
367 #[tokio::test]
368 async fn test_write_handle_rollback_on_drop() -> Result<()> {
369 let table_dml_handle = Arc::new(new_table_dml_handle());
370 let mut reader = table_dml_handle.stream_reader().into_stream();
371 let mut write_handle = table_dml_handle
372 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
373 .unwrap();
374 write_handle.begin().unwrap();
375
376 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
377
378 let chunk = StreamChunk::new(vec![Op::Insert], vec![I64Array::from_iter([1]).into_ref()]);
379 write_handle.write_chunk(chunk).await.unwrap();
380
381 assert_matches!(reader.next().await.unwrap()?, txn_msg => {
382 let chunk = txn_msg.as_stream_chunk().unwrap();
383 assert_eq!(chunk.columns()[0].as_int64().iter().collect_vec(), vec![Some(1)]);
384 });
385
386 drop(write_handle);
388 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Rollback(_));
389
390 Ok(())
391 }
392
393 #[tokio::test]
394 async fn test_end_wait_persistence() -> Result<()> {
395 let table_dml_handle = Arc::new(new_table_dml_handle());
396 let mut reader = table_dml_handle.stream_reader().into_stream();
397 let mut write_handle = table_dml_handle
398 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
399 .unwrap();
400 write_handle.begin().unwrap();
401
402 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
403
404 let handle = tokio::spawn(async move {
405 write_handle.end_wait_persistence().unwrap().await.unwrap();
406 });
407
408 assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(persistence_notifier)) => {
409 persistence_notifier.send(()).unwrap();
410 });
411
412 handle.await.unwrap();
413
414 Ok(())
415 }
416}