risingwave_dml/
dml_manager.rs1use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::collections::hash_map::Entry;
18use std::sync::{Arc, Weak};
19
20use parking_lot::RwLock;
21use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::{TxnId, TxnIdGenerator};
23use risingwave_common::util::worker_util::WorkerNodeId;
24
25use crate::error::{DmlError, Result};
26use crate::{TableDmlHandle, TableDmlHandleRef};
27
28pub type DmlManagerRef = Arc<DmlManager>;
29
30#[derive(Debug)]
31pub struct TableReader {
32 version_id: TableVersionId,
33 pub handle: Weak<TableDmlHandle>,
34}
35
36#[derive(Debug)]
42pub struct DmlManager {
43 table_readers: RwLock<HashMap<TableId, TableReader>>,
44 txn_id_generator: TxnIdGenerator,
45 dml_channel_initial_permits: usize,
46}
47
48impl DmlManager {
49 pub fn new(worker_node_id: WorkerNodeId, dml_channel_initial_permits: usize) -> Self {
50 Self {
51 table_readers: RwLock::new(HashMap::new()),
52 txn_id_generator: TxnIdGenerator::new(worker_node_id),
53 dml_channel_initial_permits,
54 }
55 }
56
57 pub fn for_test() -> Self {
58 const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
59 Self::new(WorkerNodeId::default(), TEST_DML_CHANNEL_INIT_PERMITS)
60 }
61
62 pub fn register_reader(
65 &self,
66 table_id: TableId,
67 table_version_id: TableVersionId,
68 column_descs: &[ColumnDesc],
69 ) -> Result<TableDmlHandleRef> {
70 let mut table_readers = self.table_readers.write();
71 table_readers.retain(|_, r| r.handle.strong_count() > 0);
73
74 macro_rules! new_handle {
75 ($entry:ident) => {{
76 let handle = Arc::new(TableDmlHandle::new(
77 column_descs.to_vec(),
78 self.dml_channel_initial_permits,
79 ));
80 $entry.insert(TableReader {
81 version_id: table_version_id,
82 handle: Arc::downgrade(&handle),
83 });
84 handle
85 }};
86 }
87
88 let handle = match table_readers.entry(table_id) {
89 Entry::Vacant(v) => new_handle!(v),
92
93 Entry::Occupied(mut o) => {
94 let TableReader { version_id, handle } = o.get();
95
96 match table_version_id.cmp(version_id) {
97 Ordering::Less => unreachable!("table version `{table_version_id}` expired"),
100
101 Ordering::Equal => {
103 if let Some(handle) = handle.upgrade() {
104 assert_eq!(
108 handle.column_descs(),
109 column_descs,
110 "dml handler registers with same version but different schema"
111 );
112 handle
113 } else {
114 new_handle!(o)
120 }
121 }
122
123 Ordering::Greater => new_handle!(o),
125 }
126 }
127 };
128
129 Ok(handle)
130 }
131
132 pub fn table_dml_handle(
133 &self,
134 table_id: TableId,
135 table_version_id: TableVersionId,
136 ) -> Result<TableDmlHandleRef> {
137 let table_dml_handle = {
138 let table_readers = self.table_readers.read();
139
140 match table_readers.get(&table_id) {
141 Some(TableReader { version_id, handle }) => {
142 match table_version_id.cmp(version_id) {
143 Ordering::Less => {
146 return Err(DmlError::SchemaChanged);
147 }
148
149 Ordering::Equal => handle.upgrade(),
151
152 Ordering::Greater => {
155 unreachable!("table version `{table_version_id} not registered")
156 }
157 }
158 }
159 None => None,
160 }
161 }
162 .ok_or(DmlError::NoReader)?;
163
164 Ok(table_dml_handle)
165 }
166
167 pub fn clear(&self) {
168 self.table_readers.write().clear()
169 }
170
171 pub fn gen_txn_id(&self) -> TxnId {
172 self.txn_id_generator.gen_txn_id()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use risingwave_common::array::StreamChunk;
179 use risingwave_common::catalog::INITIAL_TABLE_VERSION_ID;
180 use risingwave_common::test_prelude::StreamChunkTestExt;
181 use risingwave_common::types::DataType;
182
183 use super::*;
184
185 const TEST_TRANSACTION_ID: TxnId = 0;
186 const TEST_SESSION_ID: u32 = 0;
187
188 #[tokio::test]
189 async fn test_register_and_drop() {
190 let dml_manager = DmlManager::for_test();
191 let table_id = TableId::new(1);
192 let table_version_id = INITIAL_TABLE_VERSION_ID;
193 let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
194 let chunk = || StreamChunk::from_pretty("F\n+ 1");
195
196 let h1 = dml_manager
197 .register_reader(table_id, table_version_id, &column_descs)
198 .unwrap();
199 let h2 = dml_manager
200 .register_reader(table_id, table_version_id, &column_descs)
201 .unwrap();
202
203 assert!(Arc::ptr_eq(&h1, &h2));
205
206 let r1 = h1.stream_reader();
208 let r2 = h2.stream_reader();
209
210 let table_dml_handle = dml_manager
211 .table_dml_handle(table_id, table_version_id)
212 .unwrap();
213 let mut write_handle = table_dml_handle
214 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
215 .unwrap();
216 write_handle.begin().unwrap();
217
218 write_handle.write_chunk(chunk()).await.unwrap();
220
221 drop(r1);
224
225 write_handle.write_chunk(chunk()).await.unwrap_err();
226
227 let mut write_handle = table_dml_handle
229 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
230 .unwrap();
231 write_handle.begin().unwrap();
232 write_handle.write_chunk(chunk()).await.unwrap();
233
234 drop(r2);
237 write_handle.write_chunk(chunk()).await.unwrap_err();
238 }
239
240 #[tokio::test]
241 async fn test_versioned() {
242 let dml_manager = DmlManager::for_test();
243 let table_id = TableId::new(1);
244
245 let old_version_id = INITIAL_TABLE_VERSION_ID;
246 let old_column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
247 let old_chunk = || StreamChunk::from_pretty("F\n+ 1");
248
249 let new_version_id = old_version_id + 1;
250 let new_column_descs = vec![
251 ColumnDesc::unnamed(100.into(), DataType::Float64),
252 ColumnDesc::unnamed(101.into(), DataType::Float64),
253 ];
254 let new_chunk = || StreamChunk::from_pretty("F F\n+ 1 2");
255
256 let old_h = dml_manager
258 .register_reader(table_id, old_version_id, &old_column_descs)
259 .unwrap();
260 let _old_r = old_h.stream_reader();
261
262 let table_dml_handle = dml_manager
263 .table_dml_handle(table_id, old_version_id)
264 .unwrap();
265 let mut write_handle = table_dml_handle
266 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
267 .unwrap();
268 write_handle.begin().unwrap();
269
270 write_handle.write_chunk(old_chunk()).await.unwrap();
272
273 let new_h = dml_manager
275 .register_reader(table_id, new_version_id, &new_column_descs)
276 .unwrap();
277 let _new_r = new_h.stream_reader();
278
279 write_handle.write_chunk(old_chunk()).await.unwrap();
281
282 dml_manager
284 .table_dml_handle(table_id, old_version_id)
285 .unwrap_err();
286
287 let table_dml_handle = dml_manager
289 .table_dml_handle(table_id, new_version_id)
290 .unwrap();
291 let mut write_handle = table_dml_handle
292 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
293 .unwrap();
294 write_handle.begin().unwrap();
295 write_handle.write_chunk(new_chunk()).await.unwrap();
296 }
297
298 #[test]
299 #[should_panic]
300 fn test_bad_schema() {
301 let dml_manager = DmlManager::for_test();
302 let table_id = TableId::new(1);
303 let table_version_id = INITIAL_TABLE_VERSION_ID;
304
305 let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
306 let other_column_descs = vec![ColumnDesc::unnamed(101.into(), DataType::Float64)];
307
308 let _h = dml_manager
309 .register_reader(table_id, table_version_id, &column_descs)
310 .unwrap();
311
312 let _h = dml_manager
314 .register_reader(table_id, table_version_id, &other_column_descs)
315 .unwrap();
316 }
317}