risingwave_dml/
dml_manager.rs1use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::collections::hash_map::Entry;
18use std::sync::{Arc, Weak};
19
20use parking_lot::RwLock;
21use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::{TxnId, TxnIdGenerator};
23use risingwave_common::util::worker_util::WorkerNodeId;
24
25use crate::error::{DmlError, Result};
26use crate::{TableDmlHandle, TableDmlHandleRef};
27
28pub type DmlManagerRef = Arc<DmlManager>;
29
30#[derive(Debug)]
31pub struct TableReader {
32    version_id: TableVersionId,
33    pub handle: Weak<TableDmlHandle>,
34}
35
36#[derive(Debug)]
42pub struct DmlManager {
43    table_readers: RwLock<HashMap<TableId, TableReader>>,
44    txn_id_generator: TxnIdGenerator,
45    dml_channel_initial_permits: usize,
46}
47
48impl DmlManager {
49    pub fn new(worker_node_id: WorkerNodeId, dml_channel_initial_permits: usize) -> Self {
50        Self {
51            table_readers: RwLock::new(HashMap::new()),
52            txn_id_generator: TxnIdGenerator::new(worker_node_id),
53            dml_channel_initial_permits,
54        }
55    }
56
57    pub fn for_test() -> Self {
58        const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
59        Self::new(WorkerNodeId::default(), TEST_DML_CHANNEL_INIT_PERMITS)
60    }
61
62    pub fn register_reader(
65        &self,
66        table_id: TableId,
67        table_version_id: TableVersionId,
68        column_descs: &[ColumnDesc],
69    ) -> Result<TableDmlHandleRef> {
70        let mut table_readers = self.table_readers.write();
71        table_readers.retain(|_, r| r.handle.strong_count() > 0);
73
74        macro_rules! new_handle {
75            ($entry:ident) => {{
76                let handle = Arc::new(TableDmlHandle::new(
77                    column_descs.to_vec(),
78                    self.dml_channel_initial_permits,
79                ));
80                $entry.insert(TableReader {
81                    version_id: table_version_id,
82                    handle: Arc::downgrade(&handle),
83                });
84                handle
85            }};
86        }
87
88        let handle = match table_readers.entry(table_id) {
89            Entry::Vacant(v) => new_handle!(v),
92
93            Entry::Occupied(mut o) => {
94                let TableReader { version_id, handle } = o.get();
95
96                match table_version_id.cmp(version_id) {
97                    Ordering::Less => unreachable!("table version `{table_version_id}` expired"),
100
101                    Ordering::Equal => handle
105                        .upgrade()
106                        .inspect(|handle| {
107                            assert_eq!(
108                                handle.column_descs(),
109                                column_descs,
110                                "dml handler registers with same version but different schema"
111                            )
112                        })
113                        .expect("the first dml executor is gone"), Ordering::Greater => new_handle!(o),
117                }
118            }
119        };
120
121        Ok(handle)
122    }
123
124    pub fn table_dml_handle(
125        &self,
126        table_id: TableId,
127        table_version_id: TableVersionId,
128    ) -> Result<TableDmlHandleRef> {
129        let table_dml_handle = {
130            let table_readers = self.table_readers.read();
131
132            match table_readers.get(&table_id) {
133                Some(TableReader { version_id, handle }) => {
134                    match table_version_id.cmp(version_id) {
135                        Ordering::Less => {
138                            return Err(DmlError::SchemaChanged);
139                        }
140
141                        Ordering::Equal => handle.upgrade(),
143
144                        Ordering::Greater => {
147                            unreachable!("table version `{table_version_id} not registered")
148                        }
149                    }
150                }
151                None => None,
152            }
153        }
154        .ok_or(DmlError::NoReader)?;
155
156        Ok(table_dml_handle)
157    }
158
159    pub fn clear(&self) {
160        self.table_readers.write().clear()
161    }
162
163    pub fn gen_txn_id(&self) -> TxnId {
164        self.txn_id_generator.gen_txn_id()
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use risingwave_common::array::StreamChunk;
171    use risingwave_common::catalog::INITIAL_TABLE_VERSION_ID;
172    use risingwave_common::test_prelude::StreamChunkTestExt;
173    use risingwave_common::types::DataType;
174
175    use super::*;
176
177    const TEST_TRANSACTION_ID: TxnId = 0;
178    const TEST_SESSION_ID: u32 = 0;
179
180    #[tokio::test]
181    async fn test_register_and_drop() {
182        let dml_manager = DmlManager::for_test();
183        let table_id = TableId::new(1);
184        let table_version_id = INITIAL_TABLE_VERSION_ID;
185        let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
186        let chunk = || StreamChunk::from_pretty("F\n+ 1");
187
188        let h1 = dml_manager
189            .register_reader(table_id, table_version_id, &column_descs)
190            .unwrap();
191        let h2 = dml_manager
192            .register_reader(table_id, table_version_id, &column_descs)
193            .unwrap();
194
195        assert!(Arc::ptr_eq(&h1, &h2));
197
198        let r1 = h1.stream_reader();
200        let r2 = h2.stream_reader();
201
202        let table_dml_handle = dml_manager
203            .table_dml_handle(table_id, table_version_id)
204            .unwrap();
205        let mut write_handle = table_dml_handle
206            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
207            .unwrap();
208        write_handle.begin().unwrap();
209
210        write_handle.write_chunk(chunk()).await.unwrap();
212
213        drop(r1);
216
217        write_handle.write_chunk(chunk()).await.unwrap_err();
218
219        let mut write_handle = table_dml_handle
221            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
222            .unwrap();
223        write_handle.begin().unwrap();
224        write_handle.write_chunk(chunk()).await.unwrap();
225
226        drop(r2);
229        write_handle.write_chunk(chunk()).await.unwrap_err();
230    }
231
232    #[tokio::test]
233    async fn test_versioned() {
234        let dml_manager = DmlManager::for_test();
235        let table_id = TableId::new(1);
236
237        let old_version_id = INITIAL_TABLE_VERSION_ID;
238        let old_column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
239        let old_chunk = || StreamChunk::from_pretty("F\n+ 1");
240
241        let new_version_id = old_version_id + 1;
242        let new_column_descs = vec![
243            ColumnDesc::unnamed(100.into(), DataType::Float64),
244            ColumnDesc::unnamed(101.into(), DataType::Float64),
245        ];
246        let new_chunk = || StreamChunk::from_pretty("F F\n+ 1 2");
247
248        let old_h = dml_manager
250            .register_reader(table_id, old_version_id, &old_column_descs)
251            .unwrap();
252        let _old_r = old_h.stream_reader();
253
254        let table_dml_handle = dml_manager
255            .table_dml_handle(table_id, old_version_id)
256            .unwrap();
257        let mut write_handle = table_dml_handle
258            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
259            .unwrap();
260        write_handle.begin().unwrap();
261
262        write_handle.write_chunk(old_chunk()).await.unwrap();
264
265        let new_h = dml_manager
267            .register_reader(table_id, new_version_id, &new_column_descs)
268            .unwrap();
269        let _new_r = new_h.stream_reader();
270
271        write_handle.write_chunk(old_chunk()).await.unwrap();
273
274        dml_manager
276            .table_dml_handle(table_id, old_version_id)
277            .unwrap_err();
278
279        let table_dml_handle = dml_manager
281            .table_dml_handle(table_id, new_version_id)
282            .unwrap();
283        let mut write_handle = table_dml_handle
284            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
285            .unwrap();
286        write_handle.begin().unwrap();
287        write_handle.write_chunk(new_chunk()).await.unwrap();
288    }
289
290    #[test]
291    #[should_panic]
292    fn test_bad_schema() {
293        let dml_manager = DmlManager::for_test();
294        let table_id = TableId::new(1);
295        let table_version_id = INITIAL_TABLE_VERSION_ID;
296
297        let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
298        let other_column_descs = vec![ColumnDesc::unnamed(101.into(), DataType::Float64)];
299
300        let _h = dml_manager
301            .register_reader(table_id, table_version_id, &column_descs)
302            .unwrap();
303
304        let _h = dml_manager
306            .register_reader(table_id, table_version_id, &other_column_descs)
307            .unwrap();
308    }
309}