1use std::sync::Arc;
16
17use futures_async_stream::try_stream;
18use parking_lot::RwLock;
19use risingwave_common::array::StreamChunk;
20use risingwave_common::catalog::ColumnDesc;
21use risingwave_common::transaction::transaction_id::TxnId;
22use risingwave_common::transaction::transaction_message::TxnMsg;
23use risingwave_common::util::epoch::Epoch;
24use tokio::sync::oneshot;
25
26use crate::error::{DmlError, Result};
27use crate::txn_channel::{Receiver, Sender, txn_channel};
28
29pub type TableDmlHandleRef = Arc<TableDmlHandle>;
30
31#[derive(Debug)]
32pub struct TableDmlHandleCore {
33 pub changes_txs: Vec<Sender>,
38}
39
40#[derive(Debug)]
47pub struct TableDmlHandle {
48 pub core: RwLock<TableDmlHandleCore>,
49
50 pub column_descs: Vec<ColumnDesc>,
52
53 dml_channel_initial_permits: usize,
55}
56
57impl TableDmlHandle {
58 pub fn new(column_descs: Vec<ColumnDesc>, dml_channel_initial_permits: usize) -> Self {
59 let core = TableDmlHandleCore {
60 changes_txs: vec![],
61 };
62
63 Self {
64 core: RwLock::new(core),
65 column_descs,
66 dml_channel_initial_permits,
67 }
68 }
69
70 pub fn stream_reader(&self) -> TableStreamReader {
71 let mut core = self.core.write();
72 let (tx, rx) = txn_channel(self.dml_channel_initial_permits);
75 core.changes_txs.push(tx);
76
77 TableStreamReader { rx }
78 }
79
80 pub fn write_handle(&self, session_id: u32, txn_id: TxnId) -> Result<WriteHandle> {
81 loop {
89 let guard = self.core.read();
90 if guard.changes_txs.is_empty() {
91 return Err(DmlError::NoReader);
92 }
93 let len = guard.changes_txs.len();
94 let sender = guard
97 .changes_txs
98 .get((session_id % len as u32) as usize)
99 .unwrap()
100 .clone();
101
102 drop(guard);
103
104 if sender.is_closed() {
105 self.core
107 .write()
108 .changes_txs
109 .retain(|sender| !sender.is_closed());
110 } else {
111 return Ok(WriteHandle::new(txn_id, sender));
112 }
113 }
114 }
115
116 pub fn column_descs(&self) -> &[ColumnDesc] {
118 self.column_descs.as_ref()
119 }
120
121 pub fn check_chunk_schema(&self, chunk: &StreamChunk) {
122 risingwave_common::util::schema_check::schema_check(
123 self.column_descs
124 .iter()
125 .filter_map(|c| (!c.is_generated()).then_some(&c.data_type)),
126 chunk.columns(),
127 )
128 .expect("table source write txn_msg schema check failed");
129 }
130}
131
132#[derive(Debug, PartialEq)]
133enum TxnState {
134 Init,
135 Begin,
136 Committed,
137 Rollback,
138}
139
140pub struct WriteHandle {
149 txn_id: TxnId,
150 tx: Sender,
151 txn_state: TxnState,
153}
154
155impl Drop for WriteHandle {
156 fn drop(&mut self) {
157 if self.txn_state == TxnState::Begin {
158 let _ = self.rollback_inner();
159 }
160 }
161}
162
163impl WriteHandle {
164 pub fn new(txn_id: TxnId, tx: Sender) -> Self {
165 Self {
166 txn_id,
167 tx,
168 txn_state: TxnState::Init,
169 }
170 }
171
172 pub fn begin(&mut self) -> Result<()> {
173 assert_eq!(self.txn_state, TxnState::Init);
174 self.txn_state = TxnState::Begin;
175 self.write_txn_control_msg(TxnMsg::Begin(self.txn_id))?;
177 Ok(())
178 }
179
180 pub async fn write_chunk(&self, chunk: StreamChunk) -> Result<()> {
181 assert_eq!(self.txn_state, TxnState::Begin);
182 let _notifier = self
184 .write_txn_data_msg(TxnMsg::Data(self.txn_id, chunk))
185 .await?;
186 Ok(())
187 }
188
189 pub async fn end(mut self) -> Result<()> {
190 assert_eq!(self.txn_state, TxnState::Begin);
191 self.txn_state = TxnState::Committed;
192 let notifier = self.write_txn_control_msg(TxnMsg::End(self.txn_id, None))?;
194 notifier.await.map_err(|_| DmlError::ReaderClosed)?;
195 Ok(())
196 }
197
198 pub async fn end_returning_epoch(mut self) -> Result<Epoch> {
199 assert_eq!(self.txn_state, TxnState::Begin);
200 self.txn_state = TxnState::Committed;
201 let (epoch_notifier_tx, epoch_notifier_rx) = oneshot::channel();
203 let notifier = self.write_txn_control_msg_returning_epoch(TxnMsg::End(
204 self.txn_id,
205 Some(epoch_notifier_tx),
206 ))?;
207 notifier.await.map_err(|_| DmlError::ReaderClosed)?;
208 let epoch = epoch_notifier_rx
209 .await
210 .map_err(|_| DmlError::ReaderClosed)?;
211 Ok(epoch)
212 }
213
214 pub fn rollback(mut self) -> Result<oneshot::Receiver<usize>> {
215 self.rollback_inner()
216 }
217
218 fn rollback_inner(&mut self) -> Result<oneshot::Receiver<usize>> {
219 assert_eq!(self.txn_state, TxnState::Begin);
220 self.txn_state = TxnState::Rollback;
221 self.write_txn_control_msg(TxnMsg::Rollback(self.txn_id))
222 }
223
224 async fn write_txn_data_msg(&self, txn_msg: TxnMsg) -> Result<oneshot::Receiver<usize>> {
230 assert_eq!(self.txn_id, txn_msg.txn_id());
231 let (notifier_tx, notifier_rx) = oneshot::channel();
232 match self.tx.send(txn_msg, notifier_tx).await {
233 Ok(_) => Ok(notifier_rx),
234
235 Err(_) => Err(DmlError::ReaderClosed),
238 }
239 }
240
241 fn write_txn_control_msg(&self, txn_msg: TxnMsg) -> Result<oneshot::Receiver<usize>> {
244 assert_eq!(self.txn_id, txn_msg.txn_id());
245 let (notifier_tx, notifier_rx) = oneshot::channel();
246 match self.tx.send_immediate(txn_msg, notifier_tx) {
247 Ok(_) => Ok(notifier_rx),
248
249 Err(_) => Err(DmlError::ReaderClosed),
252 }
253 }
254
255 fn write_txn_control_msg_returning_epoch(
256 &self,
257 txn_msg: TxnMsg,
258 ) -> Result<oneshot::Receiver<usize>> {
259 assert_eq!(self.txn_id, txn_msg.txn_id());
260 let (notifier_tx, notifier_rx) = oneshot::channel();
261 match self.tx.send_immediate(txn_msg, notifier_tx) {
262 Ok(_) => Ok(notifier_rx),
263
264 Err(_) => Err(DmlError::ReaderClosed),
267 }
268 }
269}
270
271#[derive(Debug)]
276pub struct TableStreamReader {
277 rx: Receiver,
279}
280
281impl TableStreamReader {
282 #[try_stream(boxed, ok = StreamChunk, error = DmlError)]
283 pub async fn into_data_stream_for_test(mut self) {
284 while let Some((txn_msg, notifier)) = self.rx.recv().await {
285 match txn_msg {
287 TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => {
288 _ = notifier.send(0);
289 }
290 TxnMsg::Data(_, chunk) => {
291 _ = notifier.send(chunk.cardinality());
292 yield chunk;
293 }
294 }
295 }
296 }
297
298 #[try_stream(boxed, ok = TxnMsg, error = DmlError)]
299 pub async fn into_stream(mut self) {
300 while let Some((txn_msg, notifier)) = self.rx.recv().await {
301 match &txn_msg {
303 TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => {
304 _ = notifier.send(0);
305 yield txn_msg;
306 }
307 TxnMsg::Data(_, chunk) => {
308 _ = notifier.send(chunk.cardinality());
309 yield txn_msg;
310 }
311 }
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use assert_matches::assert_matches;
319 use futures::StreamExt;
320 use itertools::Itertools;
321 use risingwave_common::array::{Array, I64Array, Op};
322 use risingwave_common::catalog::ColumnId;
323 use risingwave_common::types::DataType;
324
325 use super::*;
326
327 const TEST_TRANSACTION_ID: TxnId = 0;
328 const TEST_SESSION_ID: u32 = 0;
329
330 fn new_table_dml_handle() -> TableDmlHandle {
331 TableDmlHandle::new(
332 vec![ColumnDesc::unnamed(ColumnId::from(0), DataType::Int64)],
333 32768,
334 )
335 }
336
337 #[tokio::test]
338 async fn test_table_dml_handle() -> Result<()> {
339 let table_dml_handle = Arc::new(new_table_dml_handle());
340 let mut reader = table_dml_handle.stream_reader().into_stream();
341 let mut write_handle = table_dml_handle
342 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
343 .unwrap();
344 write_handle.begin().unwrap();
345
346 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
347
348 macro_rules! write_chunk {
349 ($i:expr) => {{
350 let chunk =
351 StreamChunk::new(vec![Op::Insert], vec![I64Array::from_iter([$i]).into_ref()]);
352 write_handle.write_chunk(chunk).await.unwrap();
353 }};
354 }
355
356 write_chunk!(0);
357
358 macro_rules! check_next_chunk {
359 ($i: expr) => {
360 assert_matches!(reader.next().await.unwrap()?, txn_msg => {
361 let chunk = txn_msg.as_stream_chunk().unwrap();
362 assert_eq!(chunk.columns()[0].as_int64().iter().collect_vec(), vec![Some($i)]);
363 });
364 }
365 }
366
367 check_next_chunk!(0);
368
369 write_chunk!(1);
370 check_next_chunk!(1);
371
372 tokio::spawn(async move {
375 write_handle.end().await.unwrap();
376 });
377
378 assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
379
380 Ok(())
381 }
382
383 #[tokio::test]
384 async fn test_write_handle_rollback_on_drop() -> Result<()> {
385 let table_dml_handle = Arc::new(new_table_dml_handle());
386 let mut reader = table_dml_handle.stream_reader().into_stream();
387 let mut write_handle = table_dml_handle
388 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
389 .unwrap();
390 write_handle.begin().unwrap();
391
392 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
393
394 let chunk = StreamChunk::new(vec![Op::Insert], vec![I64Array::from_iter([1]).into_ref()]);
395 write_handle.write_chunk(chunk).await.unwrap();
396
397 assert_matches!(reader.next().await.unwrap()?, txn_msg => {
398 let chunk = txn_msg.as_stream_chunk().unwrap();
399 assert_eq!(chunk.columns()[0].as_int64().iter().collect_vec(), vec![Some(1)]);
400 });
401
402 drop(write_handle);
404 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Rollback(_));
405
406 Ok(())
407 }
408}