risingwave_dml/
dml_manager.rs1use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::collections::hash_map::Entry;
18use std::sync::{Arc, Weak};
19
20use parking_lot::RwLock;
21use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::{TxnId, TxnIdGenerator};
23use risingwave_common::util::worker_util::WorkerNodeId;
24
25use crate::error::{DmlError, Result};
26use crate::{TableDmlHandle, TableDmlHandleRef};
27
28pub type DmlManagerRef = Arc<DmlManager>;
29
30#[derive(Debug)]
31pub struct TableReader {
32 version_id: TableVersionId,
33 pub handle: Weak<TableDmlHandle>,
34}
35
36#[derive(Debug)]
42pub struct DmlManager {
43 table_readers: RwLock<HashMap<TableId, TableReader>>,
44 txn_id_generator: TxnIdGenerator,
45 dml_channel_initial_permits: usize,
46}
47
48impl DmlManager {
49 pub fn new(worker_node_id: WorkerNodeId, dml_channel_initial_permits: usize) -> Self {
50 Self {
51 table_readers: RwLock::new(HashMap::new()),
52 txn_id_generator: TxnIdGenerator::new(worker_node_id),
53 dml_channel_initial_permits,
54 }
55 }
56
57 pub fn for_test() -> Self {
58 const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
59 Self::new(WorkerNodeId::default(), TEST_DML_CHANNEL_INIT_PERMITS)
60 }
61
62 pub fn register_reader(
65 &self,
66 table_id: TableId,
67 table_version_id: TableVersionId,
68 column_descs: &[ColumnDesc],
69 ) -> Result<TableDmlHandleRef> {
70 let mut table_readers = self.table_readers.write();
71 table_readers.retain(|_, r| r.handle.strong_count() > 0);
73
74 macro_rules! new_handle {
75 ($entry:ident) => {{
76 let handle = Arc::new(TableDmlHandle::new(
77 column_descs.to_vec(),
78 self.dml_channel_initial_permits,
79 ));
80 $entry.insert(TableReader {
81 version_id: table_version_id,
82 handle: Arc::downgrade(&handle),
83 });
84 handle
85 }};
86 }
87
88 let handle = match table_readers.entry(table_id) {
89 Entry::Vacant(v) => new_handle!(v),
92
93 Entry::Occupied(mut o) => {
94 let TableReader { version_id, handle } = o.get();
95
96 match table_version_id.cmp(version_id) {
97 Ordering::Less => unreachable!("table version `{table_version_id}` expired"),
100
101 Ordering::Equal => handle
105 .upgrade()
106 .inspect(|handle| {
107 assert_eq!(
108 handle.column_descs(),
109 column_descs,
110 "dml handler registers with same version but different schema"
111 )
112 })
113 .expect("the first dml executor is gone"), Ordering::Greater => new_handle!(o),
117 }
118 }
119 };
120
121 Ok(handle)
122 }
123
124 pub fn table_dml_handle(
125 &self,
126 table_id: TableId,
127 table_version_id: TableVersionId,
128 ) -> Result<TableDmlHandleRef> {
129 let table_dml_handle = {
130 let table_readers = self.table_readers.read();
131
132 match table_readers.get(&table_id) {
133 Some(TableReader { version_id, handle }) => {
134 match table_version_id.cmp(version_id) {
135 Ordering::Less => {
138 return Err(DmlError::SchemaChanged);
139 }
140
141 Ordering::Equal => handle.upgrade(),
143
144 Ordering::Greater => {
147 unreachable!("table version `{table_version_id} not registered")
148 }
149 }
150 }
151 None => None,
152 }
153 }
154 .ok_or(DmlError::NoReader)?;
155
156 Ok(table_dml_handle)
157 }
158
159 pub fn clear(&self) {
160 self.table_readers.write().clear()
161 }
162
163 pub fn gen_txn_id(&self) -> TxnId {
164 self.txn_id_generator.gen_txn_id()
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use risingwave_common::array::StreamChunk;
171 use risingwave_common::catalog::INITIAL_TABLE_VERSION_ID;
172 use risingwave_common::test_prelude::StreamChunkTestExt;
173 use risingwave_common::types::DataType;
174
175 use super::*;
176
177 const TEST_TRANSACTION_ID: TxnId = 0;
178 const TEST_SESSION_ID: u32 = 0;
179
180 #[tokio::test]
181 async fn test_register_and_drop() {
182 let dml_manager = DmlManager::for_test();
183 let table_id = TableId::new(1);
184 let table_version_id = INITIAL_TABLE_VERSION_ID;
185 let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
186 let chunk = || StreamChunk::from_pretty("F\n+ 1");
187
188 let h1 = dml_manager
189 .register_reader(table_id, table_version_id, &column_descs)
190 .unwrap();
191 let h2 = dml_manager
192 .register_reader(table_id, table_version_id, &column_descs)
193 .unwrap();
194
195 assert!(Arc::ptr_eq(&h1, &h2));
197
198 let r1 = h1.stream_reader();
200 let r2 = h2.stream_reader();
201
202 let table_dml_handle = dml_manager
203 .table_dml_handle(table_id, table_version_id)
204 .unwrap();
205 let mut write_handle = table_dml_handle
206 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
207 .unwrap();
208 write_handle.begin().unwrap();
209
210 write_handle.write_chunk(chunk()).await.unwrap();
212
213 drop(r1);
216
217 write_handle.write_chunk(chunk()).await.unwrap_err();
218
219 let mut write_handle = table_dml_handle
221 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
222 .unwrap();
223 write_handle.begin().unwrap();
224 write_handle.write_chunk(chunk()).await.unwrap();
225
226 drop(r2);
229 write_handle.write_chunk(chunk()).await.unwrap_err();
230 }
231
232 #[tokio::test]
233 async fn test_versioned() {
234 let dml_manager = DmlManager::for_test();
235 let table_id = TableId::new(1);
236
237 let old_version_id = INITIAL_TABLE_VERSION_ID;
238 let old_column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
239 let old_chunk = || StreamChunk::from_pretty("F\n+ 1");
240
241 let new_version_id = old_version_id + 1;
242 let new_column_descs = vec![
243 ColumnDesc::unnamed(100.into(), DataType::Float64),
244 ColumnDesc::unnamed(101.into(), DataType::Float64),
245 ];
246 let new_chunk = || StreamChunk::from_pretty("F F\n+ 1 2");
247
248 let old_h = dml_manager
250 .register_reader(table_id, old_version_id, &old_column_descs)
251 .unwrap();
252 let _old_r = old_h.stream_reader();
253
254 let table_dml_handle = dml_manager
255 .table_dml_handle(table_id, old_version_id)
256 .unwrap();
257 let mut write_handle = table_dml_handle
258 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
259 .unwrap();
260 write_handle.begin().unwrap();
261
262 write_handle.write_chunk(old_chunk()).await.unwrap();
264
265 let new_h = dml_manager
267 .register_reader(table_id, new_version_id, &new_column_descs)
268 .unwrap();
269 let _new_r = new_h.stream_reader();
270
271 write_handle.write_chunk(old_chunk()).await.unwrap();
273
274 dml_manager
276 .table_dml_handle(table_id, old_version_id)
277 .unwrap_err();
278
279 let table_dml_handle = dml_manager
281 .table_dml_handle(table_id, new_version_id)
282 .unwrap();
283 let mut write_handle = table_dml_handle
284 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
285 .unwrap();
286 write_handle.begin().unwrap();
287 write_handle.write_chunk(new_chunk()).await.unwrap();
288 }
289
290 #[test]
291 #[should_panic]
292 fn test_bad_schema() {
293 let dml_manager = DmlManager::for_test();
294 let table_id = TableId::new(1);
295 let table_version_id = INITIAL_TABLE_VERSION_ID;
296
297 let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
298 let other_column_descs = vec![ColumnDesc::unnamed(101.into(), DataType::Float64)];
299
300 let _h = dml_manager
301 .register_reader(table_id, table_version_id, &column_descs)
302 .unwrap();
303
304 let _h = dml_manager
306 .register_reader(table_id, table_version_id, &other_column_descs)
307 .unwrap();
308 }
309}