risingwave_dml/
dml_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::collections::hash_map::Entry;
18use std::sync::{Arc, Weak};
19
20use parking_lot::RwLock;
21use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::{TxnId, TxnIdGenerator};
23use risingwave_common::util::worker_util::WorkerNodeId;
24
25use crate::error::{DmlError, Result};
26use crate::{TableDmlHandle, TableDmlHandleRef};
27
28pub type DmlManagerRef = Arc<DmlManager>;
29
30#[derive(Debug)]
31pub struct TableReader {
32    version_id: TableVersionId,
33    pub handle: Weak<TableDmlHandle>,
34}
35
36/// [`DmlManager`] manages the communication between batch data manipulation and streaming
37/// processing.
38/// NOTE: `TableDmlHandle` is used here as an out-of-the-box solution. We should further optimize
39/// its implementation (e.g. directly expose a channel instead of offering a `write_chunk`
40/// interface).
41#[derive(Debug)]
42pub struct DmlManager {
43    table_readers: RwLock<HashMap<TableId, TableReader>>,
44    txn_id_generator: TxnIdGenerator,
45    dml_channel_initial_permits: usize,
46}
47
48impl DmlManager {
49    pub fn new(worker_node_id: WorkerNodeId, dml_channel_initial_permits: usize) -> Self {
50        Self {
51            table_readers: RwLock::new(HashMap::new()),
52            txn_id_generator: TxnIdGenerator::new(worker_node_id),
53            dml_channel_initial_permits,
54        }
55    }
56
57    pub fn for_test() -> Self {
58        const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
59        Self::new(WorkerNodeId::default(), TEST_DML_CHANNEL_INIT_PERMITS)
60    }
61
62    /// Register a new DML reader for a table. If the reader for this version of the table already
63    /// exists, returns a reference to the existing reader.
64    pub fn register_reader(
65        &self,
66        table_id: TableId,
67        table_version_id: TableVersionId,
68        column_descs: &[ColumnDesc],
69    ) -> Result<TableDmlHandleRef> {
70        let mut table_readers = self.table_readers.write();
71        // Clear invalid table readers.
72        table_readers.retain(|_, r| r.handle.strong_count() > 0);
73
74        macro_rules! new_handle {
75            ($entry:ident) => {{
76                let handle = Arc::new(TableDmlHandle::new(
77                    column_descs.to_vec(),
78                    self.dml_channel_initial_permits,
79                ));
80                $entry.insert(TableReader {
81                    version_id: table_version_id,
82                    handle: Arc::downgrade(&handle),
83                });
84                handle
85            }};
86        }
87
88        let handle = match table_readers.entry(table_id) {
89            // Create a new reader. This happens when the first `DmlExecutor` of this table is
90            // activated on this compute node.
91            Entry::Vacant(v) => new_handle!(v),
92
93            Entry::Occupied(mut o) => {
94                let TableReader { version_id, handle } = o.get();
95
96                match table_version_id.cmp(version_id) {
97                    // This should never happen as the schema change is guaranteed to happen after a
98                    // table is successfully created and all the readers are registered.
99                    Ordering::Less => unreachable!("table version `{table_version_id}` expired"),
100
101                    // Register with the correct version. This happens when the following
102                    // `DmlExecutor`s of this table is activated on this compute
103                    // node.
104                    Ordering::Equal => handle
105                        .upgrade()
106                        .inspect(|handle| {
107                            assert_eq!(
108                                handle.column_descs(),
109                                column_descs,
110                                "dml handler registers with same version but different schema"
111                            )
112                        })
113                        .expect("the first dml executor is gone"), // this should never happen
114
115                    // A new version of the table is activated, overwrite the old reader.
116                    Ordering::Greater => new_handle!(o),
117                }
118            }
119        };
120
121        Ok(handle)
122    }
123
124    pub fn table_dml_handle(
125        &self,
126        table_id: TableId,
127        table_version_id: TableVersionId,
128    ) -> Result<TableDmlHandleRef> {
129        let table_dml_handle = {
130            let table_readers = self.table_readers.read();
131
132            match table_readers.get(&table_id) {
133                Some(TableReader { version_id, handle }) => {
134                    match table_version_id.cmp(version_id) {
135                        // A new version of the table is activated, but the DML request is still on
136                        // the old version.
137                        Ordering::Less => {
138                            return Err(DmlError::SchemaChanged);
139                        }
140
141                        // Write the chunk of correct version to the table.
142                        Ordering::Equal => handle.upgrade(),
143
144                        // This should never happen as the notification of the new version is
145                        // guaranteed to happen after all new readers are activated.
146                        Ordering::Greater => {
147                            unreachable!("table version `{table_version_id} not registered")
148                        }
149                    }
150                }
151                None => None,
152            }
153        }
154        .ok_or(DmlError::NoReader)?;
155
156        Ok(table_dml_handle)
157    }
158
159    pub fn clear(&self) {
160        self.table_readers.write().clear()
161    }
162
163    pub fn gen_txn_id(&self) -> TxnId {
164        self.txn_id_generator.gen_txn_id()
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use risingwave_common::array::StreamChunk;
171    use risingwave_common::catalog::INITIAL_TABLE_VERSION_ID;
172    use risingwave_common::test_prelude::StreamChunkTestExt;
173    use risingwave_common::types::DataType;
174
175    use super::*;
176
177    const TEST_TRANSACTION_ID: TxnId = 0;
178    const TEST_SESSION_ID: u32 = 0;
179
180    #[tokio::test]
181    async fn test_register_and_drop() {
182        let dml_manager = DmlManager::for_test();
183        let table_id = TableId::new(1);
184        let table_version_id = INITIAL_TABLE_VERSION_ID;
185        let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
186        let chunk = || StreamChunk::from_pretty("F\n+ 1");
187
188        let h1 = dml_manager
189            .register_reader(table_id, table_version_id, &column_descs)
190            .unwrap();
191        let h2 = dml_manager
192            .register_reader(table_id, table_version_id, &column_descs)
193            .unwrap();
194
195        // They should be the same handle.
196        assert!(Arc::ptr_eq(&h1, &h2));
197
198        // Start reading.
199        let r1 = h1.stream_reader();
200        let r2 = h2.stream_reader();
201
202        let table_dml_handle = dml_manager
203            .table_dml_handle(table_id, table_version_id)
204            .unwrap();
205        let mut write_handle = table_dml_handle
206            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
207            .unwrap();
208        write_handle.begin().unwrap();
209
210        // Should be able to write to the table.
211        write_handle.write_chunk(chunk()).await.unwrap();
212
213        // After dropping the corresponding reader, the write handle should be not allowed to write.
214        // This is to simulate the scale-in of DML executors.
215        drop(r1);
216
217        write_handle.write_chunk(chunk()).await.unwrap_err();
218
219        // Unless we create a new write handle.
220        let mut write_handle = table_dml_handle
221            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
222            .unwrap();
223        write_handle.begin().unwrap();
224        write_handle.write_chunk(chunk()).await.unwrap();
225
226        // After dropping the last reader, no more writes are allowed.
227        // This is to simulate the dropping of the table.
228        drop(r2);
229        write_handle.write_chunk(chunk()).await.unwrap_err();
230    }
231
232    #[tokio::test]
233    async fn test_versioned() {
234        let dml_manager = DmlManager::for_test();
235        let table_id = TableId::new(1);
236
237        let old_version_id = INITIAL_TABLE_VERSION_ID;
238        let old_column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
239        let old_chunk = || StreamChunk::from_pretty("F\n+ 1");
240
241        let new_version_id = old_version_id + 1;
242        let new_column_descs = vec![
243            ColumnDesc::unnamed(100.into(), DataType::Float64),
244            ColumnDesc::unnamed(101.into(), DataType::Float64),
245        ];
246        let new_chunk = || StreamChunk::from_pretty("F F\n+ 1 2");
247
248        // Start reading.
249        let old_h = dml_manager
250            .register_reader(table_id, old_version_id, &old_column_descs)
251            .unwrap();
252        let _old_r = old_h.stream_reader();
253
254        let table_dml_handle = dml_manager
255            .table_dml_handle(table_id, old_version_id)
256            .unwrap();
257        let mut write_handle = table_dml_handle
258            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
259            .unwrap();
260        write_handle.begin().unwrap();
261
262        // Should be able to write to the table.
263        write_handle.write_chunk(old_chunk()).await.unwrap();
264
265        // Start reading the new version.
266        let new_h = dml_manager
267            .register_reader(table_id, new_version_id, &new_column_descs)
268            .unwrap();
269        let _new_r = new_h.stream_reader();
270
271        // Still be able to write to the old write handle, if the channel is not closed.
272        write_handle.write_chunk(old_chunk()).await.unwrap();
273
274        // However, it is no longer possible to create a `table_dml_handle` with the old version;
275        dml_manager
276            .table_dml_handle(table_id, old_version_id)
277            .unwrap_err();
278
279        // Should be able to write to the new version.
280        let table_dml_handle = dml_manager
281            .table_dml_handle(table_id, new_version_id)
282            .unwrap();
283        let mut write_handle = table_dml_handle
284            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
285            .unwrap();
286        write_handle.begin().unwrap();
287        write_handle.write_chunk(new_chunk()).await.unwrap();
288    }
289
290    #[test]
291    #[should_panic]
292    fn test_bad_schema() {
293        let dml_manager = DmlManager::for_test();
294        let table_id = TableId::new(1);
295        let table_version_id = INITIAL_TABLE_VERSION_ID;
296
297        let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
298        let other_column_descs = vec![ColumnDesc::unnamed(101.into(), DataType::Float64)];
299
300        let _h = dml_manager
301            .register_reader(table_id, table_version_id, &column_descs)
302            .unwrap();
303
304        // Should panic as the schema is different.
305        let _h = dml_manager
306            .register_reader(table_id, table_version_id, &other_column_descs)
307            .unwrap();
308    }
309}