risingwave_dml/
table.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use futures_async_stream::try_stream;
18use parking_lot::RwLock;
19use risingwave_common::array::StreamChunk;
20use risingwave_common::catalog::ColumnDesc;
21use risingwave_common::transaction::transaction_id::TxnId;
22use risingwave_common::transaction::transaction_message::TxnMsg;
23use risingwave_common::util::epoch::Epoch;
24use tokio::sync::oneshot;
25
26use crate::error::{DmlError, Result};
27use crate::txn_channel::{Receiver, Sender, txn_channel};
28
29pub type TableDmlHandleRef = Arc<TableDmlHandle>;
30
31#[derive(Debug)]
32pub struct TableDmlHandleCore {
33    /// The senders of the changes channel.
34    ///
35    /// When a `StreamReader` is created, a channel will be created and the sender will be
36    /// saved here. The insert statement will take one channel randomly.
37    pub changes_txs: Vec<Sender>,
38}
39
40/// [`TableDmlHandle`] is a special internal source to handle table updates from user,
41/// including insert/delete/update statements via SQL interface.
42///
43/// Changed rows will be send to the associated "materialize" streaming task, then be written to the
44/// state store. Therefore, [`TableDmlHandle`] can be simply be treated as a channel without side
45/// effects.
46#[derive(Debug)]
47pub struct TableDmlHandle {
48    pub core: RwLock<TableDmlHandleCore>,
49
50    /// All columns in this table.
51    pub column_descs: Vec<ColumnDesc>,
52
53    /// The initial permits of the channel between each [`TableDmlHandle`] and the dml executors.
54    dml_channel_initial_permits: usize,
55}
56
57impl TableDmlHandle {
58    pub fn new(column_descs: Vec<ColumnDesc>, dml_channel_initial_permits: usize) -> Self {
59        let core = TableDmlHandleCore {
60            changes_txs: vec![],
61        };
62
63        Self {
64            core: RwLock::new(core),
65            column_descs,
66            dml_channel_initial_permits,
67        }
68    }
69
70    pub fn stream_reader(&self) -> TableStreamReader {
71        let mut core = self.core.write();
72        // The `txn_channel` is used to limit the maximum chunk permits to avoid the producer
73        // produces chunks too fast and cause an out of memory error.
74        let (tx, rx) = txn_channel(self.dml_channel_initial_permits);
75        core.changes_txs.push(tx);
76
77        TableStreamReader { rx }
78    }
79
80    pub fn write_handle(&self, session_id: u32, txn_id: TxnId) -> Result<WriteHandle> {
81        // The `changes_txs` should not be empty normally, since we ensured that the channels
82        // between the `TableDmlHandle` and the `SourceExecutor`s are ready before we making the
83        // table catalog visible to the users. However, when we're recovering, it's possible
84        // that the streaming executors are not ready when the frontend is able to schedule DML
85        // tasks to the compute nodes, so this'll be temporarily unavailable, so we throw an
86        // error instead of asserting here.
87        // TODO: may reject DML when streaming executors are not recovered.
88        loop {
89            let guard = self.core.read();
90            if guard.changes_txs.is_empty() {
91                return Err(DmlError::NoReader);
92            }
93            let len = guard.changes_txs.len();
94            // Use session id instead of txn_id to choose channel so that we can preserve transaction order in the same session.
95            // PS: only hold if there's no scaling on the table.
96            let sender = guard
97                .changes_txs
98                .get((session_id % len as u32) as usize)
99                .unwrap()
100                .clone();
101
102            drop(guard);
103
104            if sender.is_closed() {
105                // Remove all closed channels.
106                self.core
107                    .write()
108                    .changes_txs
109                    .retain(|sender| !sender.is_closed());
110            } else {
111                return Ok(WriteHandle::new(txn_id, sender));
112            }
113        }
114    }
115
116    /// Get the reference of all columns in this table.
117    pub fn column_descs(&self) -> &[ColumnDesc] {
118        self.column_descs.as_ref()
119    }
120
121    pub fn check_chunk_schema(&self, chunk: &StreamChunk) {
122        risingwave_common::util::schema_check::schema_check(
123            self.column_descs
124                .iter()
125                .filter_map(|c| (!c.is_generated()).then_some(&c.data_type)),
126            chunk.columns(),
127        )
128        .expect("table source write txn_msg schema check failed");
129    }
130}
131
132#[derive(Debug, PartialEq)]
133enum TxnState {
134    Init,
135    Begin,
136    Committed,
137    Rollback,
138}
139
140/// [`WriteHandle`] writes its data into a table in a transactional way.
141///
142/// First, it needs to call `begin()` and then write chunks by calling `write_chunk()`.
143///
144/// Finally call `end()` to commit the transaction or `rollback()` to rollback the transaction.
145///
146/// If the [`WriteHandle`] is dropped with a `Begin` transaction state, it will automatically
147/// rollback the transaction.
148pub struct WriteHandle {
149    txn_id: TxnId,
150    tx: Sender,
151    // Indicate whether `TxnMsg::End` or `TxnMsg::Rollback` have been sent to the write channel.
152    txn_state: TxnState,
153}
154
155impl Drop for WriteHandle {
156    fn drop(&mut self) {
157        if self.txn_state == TxnState::Begin {
158            let _ = self.rollback_inner();
159        }
160    }
161}
162
163impl WriteHandle {
164    pub fn new(txn_id: TxnId, tx: Sender) -> Self {
165        Self {
166            txn_id,
167            tx,
168            txn_state: TxnState::Init,
169        }
170    }
171
172    pub fn begin(&mut self) -> Result<()> {
173        assert_eq!(self.txn_state, TxnState::Init);
174        self.txn_state = TxnState::Begin;
175        // Ignore the notifier.
176        self.write_txn_control_msg(TxnMsg::Begin(self.txn_id))?;
177        Ok(())
178    }
179
180    pub async fn write_chunk(&self, chunk: StreamChunk) -> Result<()> {
181        assert_eq!(self.txn_state, TxnState::Begin);
182        // Ignore the notifier.
183        let _notifier = self
184            .write_txn_data_msg(TxnMsg::Data(self.txn_id, chunk))
185            .await?;
186        Ok(())
187    }
188
189    pub async fn end(mut self) -> Result<()> {
190        assert_eq!(self.txn_state, TxnState::Begin);
191        self.txn_state = TxnState::Committed;
192        // Await the notifier.
193        let notifier = self.write_txn_control_msg(TxnMsg::End(self.txn_id, None))?;
194        notifier.await.map_err(|_| DmlError::ReaderClosed)?;
195        Ok(())
196    }
197
198    pub async fn end_returning_epoch(mut self) -> Result<Epoch> {
199        assert_eq!(self.txn_state, TxnState::Begin);
200        self.txn_state = TxnState::Committed;
201        // Await the notifier.
202        let (epoch_notifier_tx, epoch_notifier_rx) = oneshot::channel();
203        let notifier = self.write_txn_control_msg_returning_epoch(TxnMsg::End(
204            self.txn_id,
205            Some(epoch_notifier_tx),
206        ))?;
207        notifier.await.map_err(|_| DmlError::ReaderClosed)?;
208        let epoch = epoch_notifier_rx
209            .await
210            .map_err(|_| DmlError::ReaderClosed)?;
211        Ok(epoch)
212    }
213
214    pub fn rollback(mut self) -> Result<oneshot::Receiver<usize>> {
215        self.rollback_inner()
216    }
217
218    fn rollback_inner(&mut self) -> Result<oneshot::Receiver<usize>> {
219        assert_eq!(self.txn_state, TxnState::Begin);
220        self.txn_state = TxnState::Rollback;
221        self.write_txn_control_msg(TxnMsg::Rollback(self.txn_id))
222    }
223
224    /// Asynchronously write txn messages into table. Changes written here will be simply passed to
225    /// the associated streaming task via channel, and then be materialized to storage there.
226    ///
227    /// Returns an oneshot channel which will be notified when the chunk is taken by some reader,
228    /// and the `usize` represents the cardinality of this chunk.
229    async fn write_txn_data_msg(&self, txn_msg: TxnMsg) -> Result<oneshot::Receiver<usize>> {
230        assert_eq!(self.txn_id, txn_msg.txn_id());
231        let (notifier_tx, notifier_rx) = oneshot::channel();
232        match self.tx.send(txn_msg, notifier_tx).await {
233            Ok(_) => Ok(notifier_rx),
234
235            // It's possible that the source executor is scaled in or migrated, so the channel
236            // is closed. To guarantee the transactional atomicity, bail out.
237            Err(_) => Err(DmlError::ReaderClosed),
238        }
239    }
240
241    /// Same as the `write_txn_data_msg`, but it is not an async function and send control message
242    /// without permit acquiring.
243    fn write_txn_control_msg(&self, txn_msg: TxnMsg) -> Result<oneshot::Receiver<usize>> {
244        assert_eq!(self.txn_id, txn_msg.txn_id());
245        let (notifier_tx, notifier_rx) = oneshot::channel();
246        match self.tx.send_immediate(txn_msg, notifier_tx) {
247            Ok(_) => Ok(notifier_rx),
248
249            // It's possible that the source executor is scaled in or migrated, so the channel
250            // is closed. To guarantee the transactional atomicity, bail out.
251            Err(_) => Err(DmlError::ReaderClosed),
252        }
253    }
254
255    fn write_txn_control_msg_returning_epoch(
256        &self,
257        txn_msg: TxnMsg,
258    ) -> Result<oneshot::Receiver<usize>> {
259        assert_eq!(self.txn_id, txn_msg.txn_id());
260        let (notifier_tx, notifier_rx) = oneshot::channel();
261        match self.tx.send_immediate(txn_msg, notifier_tx) {
262            Ok(_) => Ok(notifier_rx),
263
264            // It's possible that the source executor is scaled in or migrated, so the channel
265            // is closed. To guarantee the transactional atomicity, bail out.
266            Err(_) => Err(DmlError::ReaderClosed),
267        }
268    }
269}
270
271/// [`TableStreamReader`] reads changes from a certain table continuously.
272/// This struct should be only used for associated materialize task, thus the reader should be
273/// created only once. Further streaming task relying on this table source should follow the
274/// structure of "`MView` on `MView`".
275#[derive(Debug)]
276pub struct TableStreamReader {
277    /// The receiver of the changes channel.
278    rx: Receiver,
279}
280
281impl TableStreamReader {
282    #[try_stream(boxed, ok = StreamChunk, error = DmlError)]
283    pub async fn into_data_stream_for_test(mut self) {
284        while let Some((txn_msg, notifier)) = self.rx.recv().await {
285            // Notify about that we've taken the chunk.
286            match txn_msg {
287                TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => {
288                    _ = notifier.send(0);
289                }
290                TxnMsg::Data(_, chunk) => {
291                    _ = notifier.send(chunk.cardinality());
292                    yield chunk;
293                }
294            }
295        }
296    }
297
298    #[try_stream(boxed, ok = TxnMsg, error = DmlError)]
299    pub async fn into_stream(mut self) {
300        while let Some((txn_msg, notifier)) = self.rx.recv().await {
301            // Notify about that we've taken the chunk.
302            match &txn_msg {
303                TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => {
304                    _ = notifier.send(0);
305                    yield txn_msg;
306                }
307                TxnMsg::Data(_, chunk) => {
308                    _ = notifier.send(chunk.cardinality());
309                    yield txn_msg;
310                }
311            }
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use assert_matches::assert_matches;
319    use futures::StreamExt;
320    use itertools::Itertools;
321    use risingwave_common::array::{Array, I64Array, Op};
322    use risingwave_common::catalog::ColumnId;
323    use risingwave_common::types::DataType;
324
325    use super::*;
326
327    const TEST_TRANSACTION_ID: TxnId = 0;
328    const TEST_SESSION_ID: u32 = 0;
329
330    fn new_table_dml_handle() -> TableDmlHandle {
331        TableDmlHandle::new(
332            vec![ColumnDesc::unnamed(ColumnId::from(0), DataType::Int64)],
333            32768,
334        )
335    }
336
337    #[tokio::test]
338    async fn test_table_dml_handle() -> Result<()> {
339        let table_dml_handle = Arc::new(new_table_dml_handle());
340        let mut reader = table_dml_handle.stream_reader().into_stream();
341        let mut write_handle = table_dml_handle
342            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
343            .unwrap();
344        write_handle.begin().unwrap();
345
346        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
347
348        macro_rules! write_chunk {
349            ($i:expr) => {{
350                let chunk =
351                    StreamChunk::new(vec![Op::Insert], vec![I64Array::from_iter([$i]).into_ref()]);
352                write_handle.write_chunk(chunk).await.unwrap();
353            }};
354        }
355
356        write_chunk!(0);
357
358        macro_rules! check_next_chunk {
359            ($i: expr) => {
360                assert_matches!(reader.next().await.unwrap()?, txn_msg => {
361                    let chunk = txn_msg.as_stream_chunk().unwrap();
362                    assert_eq!(chunk.columns()[0].as_int64().iter().collect_vec(), vec![Some($i)]);
363                });
364            }
365        }
366
367        check_next_chunk!(0);
368
369        write_chunk!(1);
370        check_next_chunk!(1);
371
372        // Since the end will wait the notifier which is sent by the reader,
373        // we need to spawn a task here to avoid dead lock.
374        tokio::spawn(async move {
375            write_handle.end().await.unwrap();
376        });
377
378        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
379
380        Ok(())
381    }
382
383    #[tokio::test]
384    async fn test_write_handle_rollback_on_drop() -> Result<()> {
385        let table_dml_handle = Arc::new(new_table_dml_handle());
386        let mut reader = table_dml_handle.stream_reader().into_stream();
387        let mut write_handle = table_dml_handle
388            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
389            .unwrap();
390        write_handle.begin().unwrap();
391
392        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
393
394        let chunk = StreamChunk::new(vec![Op::Insert], vec![I64Array::from_iter([1]).into_ref()]);
395        write_handle.write_chunk(chunk).await.unwrap();
396
397        assert_matches!(reader.next().await.unwrap()?, txn_msg => {
398            let chunk = txn_msg.as_stream_chunk().unwrap();
399            assert_eq!(chunk.columns()[0].as_int64().iter().collect_vec(), vec![Some(1)]);
400        });
401
402        // Rollback on drop
403        drop(write_handle);
404        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Rollback(_));
405
406        Ok(())
407    }
408}