risingwave_dml/
table.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::sync::Arc;
17
18use futures_async_stream::try_stream;
19use parking_lot::RwLock;
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::ColumnDesc;
22use risingwave_common::transaction::transaction_id::TxnId;
23use risingwave_common::transaction::transaction_message::TxnMsg;
24use tokio::sync::oneshot;
25
26use crate::error::{DmlError, Result};
27use crate::txn_channel::{Receiver, Sender, txn_channel};
28
29pub type TableDmlHandleRef = Arc<TableDmlHandle>;
30
31#[derive(Debug)]
32pub struct TableDmlHandleCore {
33    /// The senders of the changes channel.
34    ///
35    /// When a `StreamReader` is created, a channel will be created and the sender will be
36    /// saved here. The insert statement will take one channel randomly.
37    pub changes_txs: Vec<Sender>,
38}
39
40/// [`TableDmlHandle`] is a special internal source to handle table updates from user,
41/// including insert/delete/update statements via SQL interface.
42///
43/// Changed rows will be send to the associated "materialize" streaming task, then be written to the
44/// state store. Therefore, [`TableDmlHandle`] can be simply be treated as a channel without side
45/// effects.
46#[derive(Debug)]
47pub struct TableDmlHandle {
48    pub core: RwLock<TableDmlHandleCore>,
49
50    /// All columns in this table.
51    pub column_descs: Vec<ColumnDesc>,
52
53    /// The initial permits of the channel between each [`TableDmlHandle`] and the dml executors.
54    dml_channel_initial_permits: usize,
55}
56
57impl TableDmlHandle {
58    pub fn new(column_descs: Vec<ColumnDesc>, dml_channel_initial_permits: usize) -> Self {
59        let core = TableDmlHandleCore {
60            changes_txs: vec![],
61        };
62
63        Self {
64            core: RwLock::new(core),
65            column_descs,
66            dml_channel_initial_permits,
67        }
68    }
69
70    pub fn stream_reader(&self) -> TableStreamReader {
71        let mut core = self.core.write();
72        // The `txn_channel` is used to limit the maximum chunk permits to avoid the producer
73        // produces chunks too fast and cause an out of memory error.
74        let (tx, rx) = txn_channel(self.dml_channel_initial_permits);
75        core.changes_txs.push(tx);
76
77        TableStreamReader { rx }
78    }
79
80    pub fn write_handle(&self, session_id: u32, txn_id: TxnId) -> Result<WriteHandle> {
81        // The `changes_txs` should not be empty normally, since we ensured that the channels
82        // between the `TableDmlHandle` and the `SourceExecutor`s are ready before we making the
83        // table catalog visible to the users. However, when we're recovering, it's possible
84        // that the streaming executors are not ready when the frontend is able to schedule DML
85        // tasks to the compute nodes, so this'll be temporarily unavailable, so we throw an
86        // error instead of asserting here.
87        // TODO: may reject DML when streaming executors are not recovered.
88        loop {
89            let guard = self.core.read();
90            if guard.changes_txs.is_empty() {
91                return Err(DmlError::NoReader);
92            }
93            let len = guard.changes_txs.len();
94            // Use session id instead of txn_id to choose channel so that we can preserve transaction order in the same session.
95            // PS: only hold if there's no scaling on the table.
96            let sender = guard
97                .changes_txs
98                .get((session_id % len as u32) as usize)
99                .unwrap()
100                .clone();
101
102            drop(guard);
103
104            if sender.is_closed() {
105                // Remove all closed channels.
106                self.core
107                    .write()
108                    .changes_txs
109                    .retain(|sender| !sender.is_closed());
110            } else {
111                return Ok(WriteHandle::new(txn_id, sender));
112            }
113        }
114    }
115
116    /// Get the reference of all columns in this table.
117    pub fn column_descs(&self) -> &[ColumnDesc] {
118        self.column_descs.as_ref()
119    }
120
121    pub fn check_chunk_schema(&self, chunk: &StreamChunk) {
122        risingwave_common::util::schema_check::schema_check(
123            self.column_descs
124                .iter()
125                .filter_map(|c| (!c.is_generated()).then_some(&c.data_type)),
126            chunk.columns(),
127        )
128        .expect("table source write txn_msg schema check failed");
129    }
130}
131
132#[derive(Debug, PartialEq)]
133enum TxnState {
134    Init,
135    Begin,
136    Committed,
137    Rollback,
138}
139
140/// [`WriteHandle`] writes its data into a table in a transactional way.
141///
142/// First, it needs to call `begin()` and then write chunks by calling `write_chunk()`.
143///
144/// Finally call `end()` to commit the transaction or `rollback()` to rollback the transaction.
145///
146/// If the [`WriteHandle`] is dropped with a `Begin` transaction state, it will automatically
147/// rollback the transaction.
148pub struct WriteHandle {
149    txn_id: TxnId,
150    tx: Sender,
151    // Indicate whether `TxnMsg::End` or `TxnMsg::Rollback` have been sent to the write channel.
152    txn_state: TxnState,
153}
154
155impl Drop for WriteHandle {
156    fn drop(&mut self) {
157        if self.txn_state == TxnState::Begin {
158            let _ = self.rollback_inner();
159        }
160    }
161}
162
163impl WriteHandle {
164    pub fn new(txn_id: TxnId, tx: Sender) -> Self {
165        Self {
166            txn_id,
167            tx,
168            txn_state: TxnState::Init,
169        }
170    }
171
172    pub fn begin(&mut self) -> Result<()> {
173        assert_eq!(self.txn_state, TxnState::Init);
174        self.txn_state = TxnState::Begin;
175        // Ignore the notifier.
176        self.write_txn_control_msg(TxnMsg::Begin(self.txn_id))?;
177        Ok(())
178    }
179
180    pub async fn write_chunk(&self, chunk: StreamChunk) -> Result<()> {
181        assert_eq!(self.txn_state, TxnState::Begin);
182        // Ignore the notifier.
183        let _notifier = self
184            .write_txn_data_msg(TxnMsg::Data(self.txn_id, chunk))
185            .await?;
186        Ok(())
187    }
188
189    pub async fn end(mut self) -> Result<()> {
190        assert_eq!(self.txn_state, TxnState::Begin);
191        self.txn_state = TxnState::Committed;
192        // Await the notifier.
193        let notifier = self.write_txn_control_msg(TxnMsg::End(self.txn_id, None))?;
194        notifier.await.map_err(|_| DmlError::ReaderClosed)?;
195        Ok(())
196    }
197
198    /// Like `end`, but waits until the data has been durably persisted before returning.
199    /// The `DmlExecutor` fires the persistence signal after `try_wait_epoch` succeeds.
200    pub fn end_wait_persistence(mut self) -> Result<impl Future<Output = Result<()>> + 'static> {
201        assert_eq!(self.txn_state, TxnState::Begin);
202        self.txn_state = TxnState::Committed;
203        let (persistence_tx, persistence_rx) = oneshot::channel();
204        let notifier =
205            self.write_txn_control_msg(TxnMsg::End(self.txn_id, Some(persistence_tx)))?;
206        Ok(async move {
207            notifier.await.map_err(|_| DmlError::ReaderClosed)?;
208            persistence_rx.await.map_err(|_| DmlError::ReaderClosed)?;
209            Ok(())
210        })
211    }
212
213    pub fn rollback(mut self) -> Result<oneshot::Receiver<usize>> {
214        self.rollback_inner()
215    }
216
217    fn rollback_inner(&mut self) -> Result<oneshot::Receiver<usize>> {
218        assert_eq!(self.txn_state, TxnState::Begin);
219        self.txn_state = TxnState::Rollback;
220        self.write_txn_control_msg(TxnMsg::Rollback(self.txn_id))
221    }
222
223    /// Asynchronously write txn messages into table. Changes written here will be simply passed to
224    /// the associated streaming task via channel, and then be materialized to storage there.
225    ///
226    /// Returns an oneshot channel which will be notified when the chunk is taken by some reader,
227    /// and the `usize` represents the cardinality of this chunk.
228    async fn write_txn_data_msg(&self, txn_msg: TxnMsg) -> Result<oneshot::Receiver<usize>> {
229        assert_eq!(self.txn_id, txn_msg.txn_id());
230        let (notifier_tx, notifier_rx) = oneshot::channel();
231        match self.tx.send(txn_msg, notifier_tx).await {
232            Ok(_) => Ok(notifier_rx),
233
234            // It's possible that the source executor is scaled in or migrated, so the channel
235            // is closed. To guarantee the transactional atomicity, bail out.
236            Err(_) => Err(DmlError::ReaderClosed),
237        }
238    }
239
240    /// Same as the `write_txn_data_msg`, but it is not an async function and send control message
241    /// without permit acquiring.
242    fn write_txn_control_msg(&self, txn_msg: TxnMsg) -> Result<oneshot::Receiver<usize>> {
243        assert_eq!(self.txn_id, txn_msg.txn_id());
244        let (notifier_tx, notifier_rx) = oneshot::channel();
245        match self.tx.send_immediate(txn_msg, notifier_tx) {
246            Ok(_) => Ok(notifier_rx),
247
248            // It's possible that the source executor is scaled in or migrated, so the channel
249            // is closed. To guarantee the transactional atomicity, bail out.
250            Err(_) => Err(DmlError::ReaderClosed),
251        }
252    }
253}
254
255/// [`TableStreamReader`] reads changes from a certain table continuously.
256/// This struct should be only used for associated materialize task, thus the reader should be
257/// created only once. Further streaming task relying on this table source should follow the
258/// structure of "`MView` on `MView`".
259#[derive(Debug)]
260pub struct TableStreamReader {
261    /// The receiver of the changes channel.
262    rx: Receiver,
263}
264
265impl TableStreamReader {
266    #[try_stream(boxed, ok = StreamChunk, error = DmlError)]
267    pub async fn into_data_stream_for_test(mut self) {
268        while let Some((txn_msg, notifier)) = self.rx.recv().await {
269            // Notify about that we've taken the chunk.
270            match txn_msg {
271                TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => {
272                    _ = notifier.send(0);
273                }
274                TxnMsg::Data(_, chunk) => {
275                    _ = notifier.send(chunk.cardinality());
276                    yield chunk;
277                }
278            }
279        }
280    }
281
282    #[try_stream(boxed, ok = TxnMsg, error = DmlError)]
283    pub async fn into_stream(mut self) {
284        while let Some((txn_msg, notifier)) = self.rx.recv().await {
285            // Notify about that we've taken the chunk.
286            match &txn_msg {
287                TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => {
288                    _ = notifier.send(0);
289                    yield txn_msg;
290                }
291                TxnMsg::Data(_, chunk) => {
292                    _ = notifier.send(chunk.cardinality());
293                    yield txn_msg;
294                }
295            }
296        }
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use assert_matches::assert_matches;
303    use futures::StreamExt;
304    use itertools::Itertools;
305    use risingwave_common::array::{Array, I64Array, Op};
306    use risingwave_common::catalog::ColumnId;
307    use risingwave_common::types::DataType;
308
309    use super::*;
310
311    const TEST_TRANSACTION_ID: TxnId = 0;
312    const TEST_SESSION_ID: u32 = 0;
313
314    fn new_table_dml_handle() -> TableDmlHandle {
315        TableDmlHandle::new(
316            vec![ColumnDesc::unnamed(ColumnId::from(0), DataType::Int64)],
317            32768,
318        )
319    }
320
321    #[tokio::test]
322    async fn test_table_dml_handle() -> Result<()> {
323        let table_dml_handle = Arc::new(new_table_dml_handle());
324        let mut reader = table_dml_handle.stream_reader().into_stream();
325        let mut write_handle = table_dml_handle
326            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
327            .unwrap();
328        write_handle.begin().unwrap();
329
330        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
331
332        macro_rules! write_chunk {
333            ($i:expr) => {{
334                let chunk =
335                    StreamChunk::new(vec![Op::Insert], vec![I64Array::from_iter([$i]).into_ref()]);
336                write_handle.write_chunk(chunk).await.unwrap();
337            }};
338        }
339
340        write_chunk!(0);
341
342        macro_rules! check_next_chunk {
343            ($i: expr) => {
344                assert_matches!(reader.next().await.unwrap()?, txn_msg => {
345                    let chunk = txn_msg.as_stream_chunk().unwrap();
346                    assert_eq!(chunk.columns()[0].as_int64().iter().collect_vec(), vec![Some($i)]);
347                });
348            }
349        }
350
351        check_next_chunk!(0);
352
353        write_chunk!(1);
354        check_next_chunk!(1);
355
356        // Since the end will wait the notifier which is sent by the reader,
357        // we need to spawn a task here to avoid dead lock.
358        tokio::spawn(async move {
359            write_handle.end().await.unwrap();
360        });
361
362        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
363
364        Ok(())
365    }
366
367    #[tokio::test]
368    async fn test_write_handle_rollback_on_drop() -> Result<()> {
369        let table_dml_handle = Arc::new(new_table_dml_handle());
370        let mut reader = table_dml_handle.stream_reader().into_stream();
371        let mut write_handle = table_dml_handle
372            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
373            .unwrap();
374        write_handle.begin().unwrap();
375
376        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
377
378        let chunk = StreamChunk::new(vec![Op::Insert], vec![I64Array::from_iter([1]).into_ref()]);
379        write_handle.write_chunk(chunk).await.unwrap();
380
381        assert_matches!(reader.next().await.unwrap()?, txn_msg => {
382            let chunk = txn_msg.as_stream_chunk().unwrap();
383            assert_eq!(chunk.columns()[0].as_int64().iter().collect_vec(), vec![Some(1)]);
384        });
385
386        // Rollback on drop
387        drop(write_handle);
388        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Rollback(_));
389
390        Ok(())
391    }
392
393    #[tokio::test]
394    async fn test_end_wait_persistence() -> Result<()> {
395        let table_dml_handle = Arc::new(new_table_dml_handle());
396        let mut reader = table_dml_handle.stream_reader().into_stream();
397        let mut write_handle = table_dml_handle
398            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
399            .unwrap();
400        write_handle.begin().unwrap();
401
402        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
403
404        let handle = tokio::spawn(async move {
405            write_handle.end_wait_persistence().unwrap().await.unwrap();
406        });
407
408        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(persistence_notifier)) => {
409            persistence_notifier.send(()).unwrap();
410        });
411
412        handle.await.unwrap();
413
414        Ok(())
415    }
416}