risingwave_dml/
dml_manager.rsuse std::cmp::Ordering;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, Weak};
use parking_lot::RwLock;
use risingwave_common::catalog::{ColumnDesc, TableId, TableVersionId};
use risingwave_common::transaction::transaction_id::{TxnId, TxnIdGenerator};
use risingwave_common::util::worker_util::WorkerNodeId;
use crate::error::{DmlError, Result};
use crate::{TableDmlHandle, TableDmlHandleRef};
pub type DmlManagerRef = Arc<DmlManager>;
#[derive(Debug)]
pub struct TableReader {
version_id: TableVersionId,
pub handle: Weak<TableDmlHandle>,
}
#[derive(Debug)]
pub struct DmlManager {
pub table_readers: RwLock<HashMap<TableId, TableReader>>,
txn_id_generator: TxnIdGenerator,
dml_channel_initial_permits: usize,
}
impl DmlManager {
pub fn new(worker_node_id: WorkerNodeId, dml_channel_initial_permits: usize) -> Self {
Self {
table_readers: RwLock::new(HashMap::new()),
txn_id_generator: TxnIdGenerator::new(worker_node_id),
dml_channel_initial_permits,
}
}
pub fn for_test() -> Self {
const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
Self::new(WorkerNodeId::default(), TEST_DML_CHANNEL_INIT_PERMITS)
}
pub fn register_reader(
&self,
table_id: TableId,
table_version_id: TableVersionId,
column_descs: &[ColumnDesc],
) -> Result<TableDmlHandleRef> {
let mut table_readers = self.table_readers.write();
table_readers.retain(|_, r| r.handle.strong_count() > 0);
macro_rules! new_handle {
($entry:ident) => {{
let handle = Arc::new(TableDmlHandle::new(
column_descs.to_vec(),
self.dml_channel_initial_permits,
));
$entry.insert(TableReader {
version_id: table_version_id,
handle: Arc::downgrade(&handle),
});
handle
}};
}
let handle = match table_readers.entry(table_id) {
Entry::Vacant(v) => new_handle!(v),
Entry::Occupied(mut o) => {
let TableReader { version_id, handle } = o.get();
match table_version_id.cmp(version_id) {
Ordering::Less => unreachable!("table version `{table_version_id}` expired"),
Ordering::Equal => handle
.upgrade()
.inspect(|handle| {
assert_eq!(
handle.column_descs(),
column_descs,
"dml handler registers with same version but different schema"
)
})
.expect("the first dml executor is gone"), Ordering::Greater => new_handle!(o),
}
}
};
Ok(handle)
}
pub fn table_dml_handle(
&self,
table_id: TableId,
table_version_id: TableVersionId,
) -> Result<TableDmlHandleRef> {
let table_dml_handle = {
let table_readers = self.table_readers.read();
match table_readers.get(&table_id) {
Some(TableReader { version_id, handle }) => {
match table_version_id.cmp(version_id) {
Ordering::Less => {
return Err(DmlError::SchemaChanged);
}
Ordering::Equal => handle.upgrade(),
Ordering::Greater => {
unreachable!("table version `{table_version_id} not registered")
}
}
}
None => None,
}
}
.ok_or(DmlError::NoReader)?;
Ok(table_dml_handle)
}
pub fn clear(&self) {
self.table_readers.write().clear()
}
pub fn gen_txn_id(&self) -> TxnId {
self.txn_id_generator.gen_txn_id()
}
}
#[cfg(test)]
mod tests {
use risingwave_common::array::StreamChunk;
use risingwave_common::catalog::INITIAL_TABLE_VERSION_ID;
use risingwave_common::test_prelude::StreamChunkTestExt;
use risingwave_common::types::DataType;
use super::*;
const TEST_TRANSACTION_ID: TxnId = 0;
const TEST_SESSION_ID: u32 = 0;
#[tokio::test]
async fn test_register_and_drop() {
let dml_manager = DmlManager::for_test();
let table_id = TableId::new(1);
let table_version_id = INITIAL_TABLE_VERSION_ID;
let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
let chunk = || StreamChunk::from_pretty("F\n+ 1");
let h1 = dml_manager
.register_reader(table_id, table_version_id, &column_descs)
.unwrap();
let h2 = dml_manager
.register_reader(table_id, table_version_id, &column_descs)
.unwrap();
assert!(Arc::ptr_eq(&h1, &h2));
let r1 = h1.stream_reader();
let r2 = h2.stream_reader();
let table_dml_handle = dml_manager
.table_dml_handle(table_id, table_version_id)
.unwrap();
let mut write_handle = table_dml_handle
.write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
.unwrap();
write_handle.begin().unwrap();
write_handle.write_chunk(chunk()).await.unwrap();
drop(r1);
write_handle.write_chunk(chunk()).await.unwrap_err();
let mut write_handle = table_dml_handle
.write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
.unwrap();
write_handle.begin().unwrap();
write_handle.write_chunk(chunk()).await.unwrap();
drop(r2);
write_handle.write_chunk(chunk()).await.unwrap_err();
}
#[tokio::test]
async fn test_versioned() {
let dml_manager = DmlManager::for_test();
let table_id = TableId::new(1);
let old_version_id = INITIAL_TABLE_VERSION_ID;
let old_column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
let old_chunk = || StreamChunk::from_pretty("F\n+ 1");
let new_version_id = old_version_id + 1;
let new_column_descs = vec![
ColumnDesc::unnamed(100.into(), DataType::Float64),
ColumnDesc::unnamed(101.into(), DataType::Float64),
];
let new_chunk = || StreamChunk::from_pretty("F F\n+ 1 2");
let old_h = dml_manager
.register_reader(table_id, old_version_id, &old_column_descs)
.unwrap();
let _old_r = old_h.stream_reader();
let table_dml_handle = dml_manager
.table_dml_handle(table_id, old_version_id)
.unwrap();
let mut write_handle = table_dml_handle
.write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
.unwrap();
write_handle.begin().unwrap();
write_handle.write_chunk(old_chunk()).await.unwrap();
let new_h = dml_manager
.register_reader(table_id, new_version_id, &new_column_descs)
.unwrap();
let _new_r = new_h.stream_reader();
write_handle.write_chunk(old_chunk()).await.unwrap();
dml_manager
.table_dml_handle(table_id, old_version_id)
.unwrap_err();
let table_dml_handle = dml_manager
.table_dml_handle(table_id, new_version_id)
.unwrap();
let mut write_handle = table_dml_handle
.write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID)
.unwrap();
write_handle.begin().unwrap();
write_handle.write_chunk(new_chunk()).await.unwrap();
}
#[test]
#[should_panic]
fn test_bad_schema() {
let dml_manager = DmlManager::for_test();
let table_id = TableId::new(1);
let table_version_id = INITIAL_TABLE_VERSION_ID;
let column_descs = vec![ColumnDesc::unnamed(100.into(), DataType::Float64)];
let other_column_descs = vec![ColumnDesc::unnamed(101.into(), DataType::Float64)];
let _h = dml_manager
.register_reader(table_id, table_version_id, &column_descs)
.unwrap();
let _h = dml_manager
.register_reader(table_id, table_version_id, &other_column_descs)
.unwrap();
}
}