risingwave_batch/executor/
fast_insert.rs1use std::sync::Arc;
16
17use risingwave_common::array::{DataChunk, Op, SerialArray, StreamChunk};
18use risingwave_common::catalog::{TableId, TableVersionId};
19use risingwave_common::transaction::transaction_id::TxnId;
20use risingwave_dml::dml_manager::DmlManagerRef;
21use risingwave_pb::task_service::FastInsertRequest;
22
23use crate::error::{BatchError, Result};
24
25pub struct FastInsertExecutor {
27 table_id: TableId,
29 table_version_id: TableVersionId,
30 dml_manager: DmlManagerRef,
31
32 row_id_index: Option<usize>,
33 txn_id: TxnId,
34 request_id: u32,
35}
36
37impl FastInsertExecutor {
38 pub fn build(
39 dml_manager: DmlManagerRef,
40 insert_req: FastInsertRequest,
41 ) -> Result<(FastInsertExecutor, DataChunk)> {
42 let table_id = insert_req.table_id;
43 let data_chunk_pb = insert_req
44 .data_chunk
45 .expect("no data_chunk found in fast insert node");
46 let data_chunk = DataChunk::from_protobuf(&data_chunk_pb)?;
47
48 Ok((
49 FastInsertExecutor::new(
50 table_id,
51 insert_req.table_version_id,
52 dml_manager,
53 insert_req.row_id_index.as_ref().map(|index| *index as _),
54 insert_req.request_id,
55 ),
56 data_chunk,
57 ))
58 }
59
60 #[allow(clippy::too_many_arguments)]
61 pub fn new(
62 table_id: TableId,
63 table_version_id: TableVersionId,
64 dml_manager: DmlManagerRef,
65 row_id_index: Option<usize>,
66 request_id: u32,
67 ) -> Self {
68 let txn_id = dml_manager.gen_txn_id();
69 Self {
70 table_id,
71 table_version_id,
72 dml_manager,
73 row_id_index,
74 txn_id,
75 request_id,
76 }
77 }
78}
79
80impl FastInsertExecutor {
81 pub async fn do_execute(
82 self,
83 data_chunk_to_insert: DataChunk,
84 wait_for_persistence: bool,
85 ) -> Result<()> {
86 let table_dml_handle = self
87 .dml_manager
88 .table_dml_handle(self.table_id, self.table_version_id)?;
89 let mut write_handle = table_dml_handle.write_handle(self.request_id, self.txn_id)?;
91
92 write_handle.begin()?;
93
94 let write_txn_data = |chunk: DataChunk| async {
97 let cap = chunk.capacity();
98 let (mut columns, vis) = chunk.into_parts();
99
100 if let Some(row_id_index) = self.row_id_index {
103 let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
104 columns.insert(row_id_index, Arc::new(row_id_col.into()))
105 }
106
107 let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
108
109 #[cfg(debug_assertions)]
110 table_dml_handle.check_chunk_schema(&stream_chunk);
111
112 write_handle.write_chunk(stream_chunk).await?;
113
114 Result::Ok(())
115 };
116 write_txn_data(data_chunk_to_insert).await?;
117 if wait_for_persistence {
118 write_handle
119 .end_wait_persistence()
120 .map_err(BatchError::from)?
121 .await
122 .map_err(BatchError::from)
123 } else {
124 write_handle.end().await.map_err(Into::into)
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use std::collections::HashMap;
132
133 use assert_matches::assert_matches;
134 use futures::StreamExt;
135 use itertools::Itertools;
136 use risingwave_common::array::{Array, JsonbArrayBuilder};
137 use risingwave_common::catalog::{
138 ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, Schema,
139 };
140 use risingwave_common::transaction::transaction_message::TxnMsg;
141 use risingwave_common::types::{DataType, JsonbVal};
142 use risingwave_dml::dml_manager::DmlManager;
143 use serde_json::json;
144
145 use super::*;
146 use crate::risingwave_common::array::ArrayBuilder;
147 use crate::risingwave_common::types::Scalar;
148 use crate::*;
149
150 #[tokio::test]
151 async fn test_fast_insert() -> Result<()> {
152 let dml_manager = Arc::new(DmlManager::for_test());
153 let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]);
155 schema.fields.push(Field::unnamed(DataType::Serial)); let row_id_index = Some(1);
158
159 let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb);
160
161 let mut header_map = HashMap::new();
162 header_map.insert("data".to_owned(), "value1".to_owned());
163
164 let json_value = json!(header_map);
165 let jsonb_val = JsonbVal::from(json_value);
166 builder.append(Some(jsonb_val.as_scalar_ref()));
167
168 let data_chunk = DataChunk::new(vec![builder.finish().into_ref()], 1);
170
171 let table_id = TableId::new(0);
173
174 let column_descs = schema
176 .fields
177 .iter()
178 .enumerate()
179 .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
180 .collect_vec();
181 let reader = dml_manager
184 .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
185 .unwrap();
186 let mut reader = reader.stream_reader().into_stream();
187
188 let insert_executor = Box::new(FastInsertExecutor::new(
190 table_id,
191 INITIAL_TABLE_VERSION_ID,
192 dml_manager,
193 row_id_index,
194 0,
195 ));
196 let handle = tokio::spawn(async move {
197 insert_executor.do_execute(data_chunk, true).await.unwrap();
198 });
199
200 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
202
203 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
204 assert_eq!(chunk.columns().len(), 2);
205 let array = chunk.columns()[0].as_jsonb().iter().collect::<Vec<_>>();
206 assert_eq!(JsonbVal::from(array[0].unwrap()), jsonb_val);
207 });
208
209 assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(persistence_notifier)) => {
211 persistence_notifier.send(()).unwrap();
212 });
213
214 handle.await.unwrap();
215
216 Ok(())
217 }
218}