risingwave_dml/
dml_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::collections::hash_map::Entry;
18use std::sync::{Arc, Weak};
19
20use parking_lot::RwLock;
21use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::{TxnId, TxnIdGenerator};
23use risingwave_common::util::worker_util::WorkerNodeId;
24
25use crate::error::{DmlError, Result};
26use crate::{TableDmlHandle, TableDmlHandleRef};
27
28pub type DmlManagerRef = Arc<DmlManager>;
29
30#[derive(Debug)]
31pub struct TableReader {
32    version_id: TableVersionId,
33    pub handle: Weak<TableDmlHandle>,
34}
35
36/// [`DmlManager`] manages the communication between batch data manipulation and streaming
37/// processing.
38/// NOTE: `TableDmlHandle` is used here as an out-of-the-box solution. We should further optimize
39/// its implementation (e.g. directly expose a channel instead of offering a `write_chunk`
40/// interface).
41#[derive(Debug)]
42pub struct DmlManager {
43    table_readers: RwLock<HashMap<TableId, TableReader>>,
44    txn_id_generator: TxnIdGenerator,
45    dml_channel_initial_permits: usize,
46}
47
48impl DmlManager {
49    pub fn new(worker_node_id: WorkerNodeId, dml_channel_initial_permits: usize) -> Self {
50        Self {
51            table_readers: RwLock::new(HashMap::new()),
52            txn_id_generator: TxnIdGenerator::new(worker_node_id),
53            dml_channel_initial_permits,
54        }
55    }
56
57    pub fn for_test() -> Self {
58        const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
59        Self::new(WorkerNodeId::default(), TEST_DML_CHANNEL_INIT_PERMITS)
60    }
61
62    /// Register a new DML reader for a table. If the reader for this version of the table already
63    /// exists, returns a reference to the existing reader.
64    pub fn register_reader(
65        &self,
66        table_id: TableId,
67        table_version_id: TableVersionId,
68        column_descs: &[ColumnDesc],
69    ) -> Result<TableDmlHandleRef> {
70        let mut table_readers = self.table_readers.write();
71        // Clear invalid table readers.
72        table_readers.retain(|_, r| r.handle.strong_count() > 0);
73
74        macro_rules! new_handle {
75            ($entry:ident) => {{
76                let handle = Arc::new(TableDmlHandle::new(
77                    column_descs.to_vec(),
78                    self.dml_channel_initial_permits,
79                ));
80                $entry.insert(TableReader {
81                    version_id: table_version_id,
82                    handle: Arc::downgrade(&handle),
83                });
84                handle
85            }};
86        }
87
88        let handle = match table_readers.entry(table_id) {
89            // Create a new reader. This happens when the first `DmlExecutor` of this table is
90            // activated on this compute node.
91            Entry::Vacant(v) => new_handle!(v),
92
93            Entry::Occupied(mut o) => {
94                let TableReader { version_id, handle } = o.get();
95
96                match table_version_id.cmp(version_id) {
97                    // This should never happen as the schema change is guaranteed to happen after a
98                    // table is successfully created and all the readers are registered.
99                    Ordering::Less => unreachable!("table version `{table_version_id}` expired"),
100
101                    // Register with the correct version.
102                    Ordering::Equal => {
103                        if let Some(handle) = handle.upgrade() {
104                            // If there's already a reader, check the schema is the same and reuse it.
105                            // This happens when the following `DmlExecutor`s of this table is activated
106                            // on this compute node.
107                            assert_eq!(
108                                handle.column_descs(),
109                                column_descs,
110                                "dml handler registers with same version but different schema"
111                            );
112                            handle
113                        } else {
114                            // Currently when scaling the fragment, we may drop all actors first before
115                            // creating new actors, which will drop the old reader. In this case, recreate
116                            // a new reader.
117                            // TODO: this will interrupt ongoing DML requests even for scaling out. We
118                            // should try preserving the old actors, thus preserving the reader.
119                            new_handle!(o)
120                        }
121                    }
122
123                    // A new version of the table is activated, overwrite the old reader.
124                    Ordering::Greater => new_handle!(o),
125                }
126            }
127        };
128
129        Ok(handle)
130    }
131
132    pub fn table_dml_handle(
133        &self,
134        table_id: TableId,
135        table_version_id: TableVersionId,
136    ) -> Result<TableDmlHandleRef> {
137        let table_dml_handle = {
138            let table_readers = self.table_readers.read();
139
140            match table_readers.get(&table_id) {
141                Some(TableReader { version_id, handle }) => {
142                    match table_version_id.cmp(version_id) {
143                        // A new version of the table is activated, but the DML request is still on
144                        // the old version.
145                        Ordering::Less => {
146                            return Err(DmlError::SchemaChanged);
147                        }
148
149                        // Write the chunk of correct version to the table.
150                        Ordering::Equal => handle.upgrade(),
151
152                        // This should never happen as the notification of the new version is
153                        // guaranteed to happen after all new readers are activated.
154                        Ordering::Greater => {
155                            unreachable!("table version `{table_version_id} not registered")
156                        }
157                    }
158                }
159                None => None,
160            }
161        }
162        .ok_or(DmlError::NoReader)?;
163
164        Ok(table_dml_handle)
165    }
166
167    pub fn clear(&self) {
168        self.table_readers.write().clear()
169    }
170
171    pub fn gen_txn_id(&self) -> TxnId {
172        self.txn_id_generator.gen_txn_id()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use risingwave_common::array::StreamChunk;
179    use risingwave_common::catalog::INITIAL_TABLE_VERSION_ID;
180    use risingwave_common::test_prelude::StreamChunkTestExt;
181    use risingwave_common::types::DataType;
182
183    use super::*;
184
185    const TEST_TRANSACTION_ID: TxnId = 0;
186    const TEST_SESSION_ID: u32 = 0;
187
188    #[tokio::test]
189    async fn test_register_and_drop() {
190        let dml_manager = DmlManager::for_test();
191        let table_id = TableId::new(1);
192        let table_version_id = INITIAL_TABLE_VERSION_ID;
193        let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
194        let chunk = || StreamChunk::from_pretty("F\n+ 1");
195
196        let h1 = dml_manager
197            .register_reader(table_id, table_version_id, &column_descs)
198            .unwrap();
199        let h2 = dml_manager
200            .register_reader(table_id, table_version_id, &column_descs)
201            .unwrap();
202
203        // They should be the same handle.
204        assert!(Arc::ptr_eq(&h1, &h2));
205
206        // Start reading.
207        let r1 = h1.stream_reader();
208        let r2 = h2.stream_reader();
209
210        let table_dml_handle = dml_manager
211            .table_dml_handle(table_id, table_version_id)
212            .unwrap();
213        let mut write_handle = table_dml_handle
214            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
215            .unwrap();
216        write_handle.begin().unwrap();
217
218        // Should be able to write to the table.
219        write_handle.write_chunk(chunk()).await.unwrap();
220
221        // After dropping the corresponding reader, the write handle should be not allowed to write.
222        // This is to simulate the scale-in of DML executors.
223        drop(r1);
224
225        write_handle.write_chunk(chunk()).await.unwrap_err();
226
227        // Unless we create a new write handle.
228        let mut write_handle = table_dml_handle
229            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
230            .unwrap();
231        write_handle.begin().unwrap();
232        write_handle.write_chunk(chunk()).await.unwrap();
233
234        // After dropping the last reader, no more writes are allowed.
235        // This is to simulate the dropping of the table.
236        drop(r2);
237        write_handle.write_chunk(chunk()).await.unwrap_err();
238    }
239
240    #[tokio::test]
241    async fn test_versioned() {
242        let dml_manager = DmlManager::for_test();
243        let table_id = TableId::new(1);
244
245        let old_version_id = INITIAL_TABLE_VERSION_ID;
246        let old_column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
247        let old_chunk = || StreamChunk::from_pretty("F\n+ 1");
248
249        let new_version_id = old_version_id + 1;
250        let new_column_descs = vec![
251            ColumnDesc::unnamed(100.into(), DataType::Float64),
252            ColumnDesc::unnamed(101.into(), DataType::Float64),
253        ];
254        let new_chunk = || StreamChunk::from_pretty("F F\n+ 1 2");
255
256        // Start reading.
257        let old_h = dml_manager
258            .register_reader(table_id, old_version_id, &old_column_descs)
259            .unwrap();
260        let _old_r = old_h.stream_reader();
261
262        let table_dml_handle = dml_manager
263            .table_dml_handle(table_id, old_version_id)
264            .unwrap();
265        let mut write_handle = table_dml_handle
266            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
267            .unwrap();
268        write_handle.begin().unwrap();
269
270        // Should be able to write to the table.
271        write_handle.write_chunk(old_chunk()).await.unwrap();
272
273        // Start reading the new version.
274        let new_h = dml_manager
275            .register_reader(table_id, new_version_id, &new_column_descs)
276            .unwrap();
277        let _new_r = new_h.stream_reader();
278
279        // Still be able to write to the old write handle, if the channel is not closed.
280        write_handle.write_chunk(old_chunk()).await.unwrap();
281
282        // However, it is no longer possible to create a `table_dml_handle` with the old version;
283        dml_manager
284            .table_dml_handle(table_id, old_version_id)
285            .unwrap_err();
286
287        // Should be able to write to the new version.
288        let table_dml_handle = dml_manager
289            .table_dml_handle(table_id, new_version_id)
290            .unwrap();
291        let mut write_handle = table_dml_handle
292            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
293            .unwrap();
294        write_handle.begin().unwrap();
295        write_handle.write_chunk(new_chunk()).await.unwrap();
296    }
297
298    #[test]
299    #[should_panic]
300    fn test_bad_schema() {
301        let dml_manager = DmlManager::for_test();
302        let table_id = TableId::new(1);
303        let table_version_id = INITIAL_TABLE_VERSION_ID;
304
305        let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
306        let other_column_descs = vec![ColumnDesc::unnamed(101.into(), DataType::Float64)];
307
308        let _h = dml_manager
309            .register_reader(table_id, table_version_id, &column_descs)
310            .unwrap();
311
312        // Should panic as the schema is different.
313        let _h = dml_manager
314            .register_reader(table_id, table_version_id, &other_column_descs)
315            .unwrap();
316    }
317}