1use std::sync::Arc;
16
17use anyhow::Context;
18use futures_async_stream::try_stream;
19use itertools::Itertools;
20use risingwave_common::array::{
21 ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, SerialArray, StreamChunk,
22};
23use risingwave_common::catalog::{Schema, TableId, TableVersionId};
24use risingwave_common::transaction::transaction_id::TxnId;
25use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
26use risingwave_dml::dml_manager::DmlManagerRef;
27use risingwave_expr::expr::{BoxedExpression, build_from_prost};
28use risingwave_pb::batch_plan::plan_node::NodeBody;
29use risingwave_pb::plan_common::IndexAndExpr;
30
31use crate::error::{BatchError, Result};
32use crate::executor::{
33 BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
34};
35
36pub struct InsertExecutor {
38 table_id: TableId,
40 table_version_id: TableVersionId,
41 dml_manager: DmlManagerRef,
42 child: BoxedExecutor,
43 chunk_size: usize,
44 schema: Schema,
45 identity: String,
46 column_indices: Vec<usize>,
47 sorted_default_columns: Vec<(usize, BoxedExpression)>,
48
49 row_id_index: Option<usize>,
50 returning: bool,
51 txn_id: TxnId,
52 session_id: u32,
53 wait_for_persistence: bool,
54}
55
56impl InsertExecutor {
57 #[expect(clippy::too_many_arguments)]
58 pub fn new(
59 table_id: TableId,
60 table_version_id: TableVersionId,
61 dml_manager: DmlManagerRef,
62 child: BoxedExecutor,
63 chunk_size: usize,
64 identity: String,
65 column_indices: Vec<usize>,
66 sorted_default_columns: Vec<(usize, BoxedExpression)>,
67 row_id_index: Option<usize>,
68 returning: bool,
69 session_id: u32,
70 wait_for_persistence: bool,
71 ) -> Self {
72 let table_schema = child.schema().clone();
73 let txn_id = dml_manager.gen_txn_id();
74 Self {
75 table_id,
76 table_version_id,
77 dml_manager,
78 child,
79 chunk_size,
80 schema: table_schema,
81 identity,
82 column_indices,
83 sorted_default_columns,
84 row_id_index,
85 returning,
86 txn_id,
87 session_id,
88 wait_for_persistence,
89 }
90 }
91}
92
93impl Executor for InsertExecutor {
94 fn schema(&self) -> &Schema {
95 &self.schema
96 }
97
98 fn identity(&self) -> &str {
99 &self.identity
100 }
101
102 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
103 self.do_execute()
104 }
105}
106
107impl InsertExecutor {
108 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
109 async fn do_execute(self: Box<Self>) {
110 let data_types = self.child.schema().data_types();
111 let mut builder = DataChunkBuilder::new(data_types, self.chunk_size);
112
113 let table_dml_handle = self
114 .dml_manager
115 .table_dml_handle(self.table_id, self.table_version_id)?;
116 let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?;
117
118 write_handle.begin()?;
119
120 let write_txn_data = |chunk: DataChunk| async {
123 let cap = chunk.capacity();
124 let (mut columns, vis) = chunk.into_parts();
125
126 let dummy_chunk = DataChunk::new_dummy(cap);
127
128 let mut ordered_columns = self
129 .column_indices
130 .iter()
131 .enumerate()
132 .map(|(i, idx)| (*idx, columns[i].clone()))
133 .collect_vec();
134
135 ordered_columns.reserve(ordered_columns.len() + self.sorted_default_columns.len());
136
137 for (idx, expr) in &self.sorted_default_columns {
138 let column = expr.eval(&dummy_chunk).await?;
139 ordered_columns.push((*idx, column));
140 }
141
142 ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
143 columns = ordered_columns
144 .into_iter()
145 .map(|(_, column)| column)
146 .collect_vec();
147
148 let returning_chunk = DataChunk::new(columns.clone(), vis.clone());
150
151 if let Some(row_id_index) = self.row_id_index {
154 let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
155 columns.insert(row_id_index, Arc::new(row_id_col.into()))
156 }
157
158 let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
159
160 #[cfg(debug_assertions)]
161 table_dml_handle.check_chunk_schema(&stream_chunk);
162
163 write_handle.write_chunk(stream_chunk).await?;
164
165 Result::Ok(returning_chunk)
166 };
167
168 let mut rows_inserted = 0;
169
170 #[for_await]
171 for data_chunk in self.child.execute() {
172 let data_chunk = data_chunk?;
173 for chunk in builder.append_chunk(data_chunk) {
174 let chunk = write_txn_data(chunk).await?;
175 rows_inserted += chunk.cardinality();
176 if self.returning {
177 yield chunk;
178 }
179 }
180 }
181
182 if let Some(chunk) = builder.consume_all() {
183 let chunk = write_txn_data(chunk).await?;
184 rows_inserted += chunk.cardinality();
185 if self.returning {
186 yield chunk;
187 }
188 }
189
190 if self.wait_for_persistence {
191 write_handle.end_wait_persistence()?.await?;
192 } else {
193 write_handle.end().await?;
194 }
195
196 if !self.returning {
198 let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
199 array_builder.append(Some(rows_inserted as i64));
200
201 let array = array_builder.finish();
202 let ret_chunk = DataChunk::new(vec![Arc::new(array.into())], 1);
203
204 yield ret_chunk
205 }
206 }
207}
208
209impl BoxedExecutorBuilder for InsertExecutor {
210 async fn new_boxed_executor(
211 source: &ExecutorBuilder<'_>,
212 inputs: Vec<BoxedExecutor>,
213 ) -> Result<BoxedExecutor> {
214 let [child]: [_; 1] = inputs.try_into().unwrap();
215
216 let insert_node = try_match_expand!(
217 source.plan_node().get_node_body().unwrap(),
218 NodeBody::Insert
219 )?;
220
221 let table_id = insert_node.table_id;
222 let column_indices = insert_node
223 .column_indices
224 .iter()
225 .map(|&i| i as usize)
226 .collect();
227 let sorted_default_columns = if let Some(default_columns) = &insert_node.default_columns {
228 let mut default_columns = default_columns
229 .get_default_columns()
230 .iter()
231 .cloned()
232 .map(|IndexAndExpr { index: i, expr: e }| {
233 Ok((
234 i as usize,
235 build_from_prost(&e.context("expression is None")?)
236 .context("failed to build expression")?,
237 ))
238 })
239 .collect::<Result<Vec<_>>>()?;
240 default_columns.sort_unstable_by_key(|(i, _)| *i);
241 default_columns
242 } else {
243 vec![]
244 };
245
246 Ok(Box::new(Self::new(
247 table_id,
248 insert_node.table_version_id,
249 source.context().dml_manager(),
250 child,
251 source.context().get_config().developer.chunk_size,
252 source.plan_node().get_identity().clone(),
253 column_indices,
254 sorted_default_columns,
255 insert_node.row_id_index.as_ref().map(|index| *index as _),
256 insert_node.returning,
257 insert_node.session_id,
258 insert_node.wait_for_persistence,
259 )))
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use std::ops::Bound;
266
267 use assert_matches::assert_matches;
268 use foyer::Hint;
269 use futures::StreamExt;
270 use risingwave_common::array::{Array, ArrayImpl, I32Array, StructArray};
271 use risingwave_common::catalog::{
272 ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, schema_test_utils,
273 };
274 use risingwave_common::test_prelude::DataChunkTestExt;
275 use risingwave_common::transaction::transaction_message::TxnMsg;
276 use risingwave_common::types::{DataType, StructType};
277 use risingwave_dml::dml_manager::DmlManager;
278 use risingwave_storage::hummock::CachePolicy;
279 use risingwave_storage::hummock::test_utils::*;
280 use risingwave_storage::memory::MemoryStateStore;
281
282 use super::*;
283 use crate::executor::test_utils::MockExecutor;
284 use crate::*;
285
286 #[tokio::test]
287 async fn test_insert_executor() -> Result<()> {
288 let dml_manager = Arc::new(DmlManager::for_test());
289 let store = MemoryStateStore::new();
290
291 let struct_field = Field::unnamed(
293 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]).into(),
294 );
295
296 let mut schema = schema_test_utils::ii();
298 schema.fields.push(struct_field.clone());
299 let mut mock_executor = MockExecutor::new(schema.clone());
300
301 let mut schema = schema_test_utils::ii();
303 schema.fields.push(struct_field);
304 schema.fields.push(Field::unnamed(DataType::Serial)); let row_id_index = Some(3);
307
308 let col1 = Arc::new(I32Array::from_iter([1, 3, 5, 7, 9]).into());
309 let col2 = Arc::new(I32Array::from_iter([2, 4, 6, 8, 10]).into());
310 let array = StructArray::new(
311 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
312 vec![
313 I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
314 I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
315 I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
316 ],
317 [true, false, false, false, false].into_iter().collect(),
318 );
319 let col3 = Arc::new(array.into());
320 let data_chunk: DataChunk = DataChunk::new(vec![col1, col2, col3], 5);
321 mock_executor.add(data_chunk.clone());
322
323 let table_id = TableId::new(0);
325
326 let column_descs = schema
328 .fields
329 .iter()
330 .enumerate()
331 .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
332 .collect_vec();
333 let reader = dml_manager
336 .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
337 .unwrap();
338 let mut reader = reader.stream_reader().into_stream();
339
340 let insert_executor = Box::new(InsertExecutor::new(
342 table_id,
343 INITIAL_TABLE_VERSION_ID,
344 dml_manager,
345 Box::new(mock_executor),
346 1024,
347 "InsertExecutor".to_owned(),
348 vec![0, 1, 2], vec![],
350 row_id_index,
351 false,
352 0,
353 false,
354 ));
355 let handle = tokio::spawn(async move {
356 let mut stream = insert_executor.execute();
357 let result = stream.next().await.unwrap().unwrap();
358
359 assert_eq!(
360 result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
361 vec![Some(5)] );
363 });
364
365 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
367
368 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
369 assert_eq!(
370 chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
371 vec![Some(1), Some(3), Some(5), Some(7), Some(9)]
372 );
373
374 assert_eq!(
375 chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
376 vec![Some(2), Some(4), Some(6), Some(8), Some(10)]
377 );
378
379 let array: ArrayImpl = StructArray::new(
380 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
381 vec![
382 I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
383 I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
384 I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
385 ],
386 [true, false, false, false, false].into_iter().collect(),
387 )
388 .into();
389 assert_eq!(*chunk.columns()[2], array);
390 });
391
392 assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, None));
393 let epoch = u64::MAX;
394 let full_range = (Bound::Unbounded, Bound::Unbounded);
395 let store_content = store
396 .scan(
397 full_range,
398 epoch,
399 None,
400 ReadOptions {
401 cache_policy: CachePolicy::Fill(Hint::Normal),
402 ..Default::default()
403 },
404 )
405 .await?;
406 assert!(store_content.is_empty());
407
408 handle.await.unwrap();
409
410 Ok(())
411 }
412
413 #[tokio::test]
414 async fn test_insert_executor_wait_for_persistence() -> Result<()> {
415 let dml_manager = Arc::new(DmlManager::for_test());
416
417 let schema = schema_test_utils::ii();
418 let mut mock_executor = MockExecutor::new(schema.clone());
419 mock_executor.add(DataChunk::from_pretty(
420 "i i
421 1 2",
422 ));
423
424 let table_id = TableId::new(0);
425 let column_descs = schema
426 .fields
427 .iter()
428 .enumerate()
429 .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
430 .collect_vec();
431 let reader = dml_manager
432 .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
433 .unwrap();
434 let mut reader = reader.stream_reader().into_stream();
435
436 let insert_executor = Box::new(InsertExecutor::new(
437 table_id,
438 INITIAL_TABLE_VERSION_ID,
439 dml_manager,
440 Box::new(mock_executor),
441 1024,
442 "InsertExecutor".to_owned(),
443 vec![0, 1],
444 vec![],
445 None,
446 false,
447 0,
448 true,
449 ));
450 let handle = tokio::spawn(async move {
451 let mut stream = insert_executor.execute();
452 let result = stream.next().await.unwrap().unwrap();
453 assert_eq!(
454 result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
455 vec![Some(1)]
456 );
457 });
458
459 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
460 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, _));
461 assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(persistence_notifier)) => {
462 assert!(!handle.is_finished());
463 persistence_notifier.send(()).unwrap();
464 });
465
466 handle.await.unwrap();
467
468 Ok(())
469 }
470}