1use std::sync::Arc;
16
17use anyhow::Context;
18use futures_async_stream::try_stream;
19use itertools::Itertools;
20use risingwave_common::array::{
21 ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, SerialArray, StreamChunk,
22};
23use risingwave_common::catalog::{Schema, TableId, TableVersionId};
24use risingwave_common::transaction::transaction_id::TxnId;
25use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
26use risingwave_dml::dml_manager::DmlManagerRef;
27use risingwave_expr::expr::{BoxedExpression, build_from_prost};
28use risingwave_pb::batch_plan::plan_node::NodeBody;
29use risingwave_pb::plan_common::IndexAndExpr;
30
31use crate::error::{BatchError, Result};
32use crate::executor::{
33 BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
34};
35
36pub struct InsertExecutor {
38 table_id: TableId,
40 table_version_id: TableVersionId,
41 dml_manager: DmlManagerRef,
42 child: BoxedExecutor,
43 #[expect(dead_code)]
44 chunk_size: usize,
45 schema: Schema,
46 identity: String,
47 column_indices: Vec<usize>,
48 sorted_default_columns: Vec<(usize, BoxedExpression)>,
49
50 row_id_index: Option<usize>,
51 returning: bool,
52 txn_id: TxnId,
53 session_id: u32,
54}
55
56impl InsertExecutor {
57 #[allow(clippy::too_many_arguments)]
58 pub fn new(
59 table_id: TableId,
60 table_version_id: TableVersionId,
61 dml_manager: DmlManagerRef,
62 child: BoxedExecutor,
63 chunk_size: usize,
64 identity: String,
65 column_indices: Vec<usize>,
66 sorted_default_columns: Vec<(usize, BoxedExpression)>,
67 row_id_index: Option<usize>,
68 returning: bool,
69 session_id: u32,
70 ) -> Self {
71 let table_schema = child.schema().clone();
72 let txn_id = dml_manager.gen_txn_id();
73 Self {
74 table_id,
75 table_version_id,
76 dml_manager,
77 child,
78 chunk_size,
79 schema: table_schema,
80 identity,
81 column_indices,
82 sorted_default_columns,
83 row_id_index,
84 returning,
85 txn_id,
86 session_id,
87 }
88 }
89}
90
91impl Executor for InsertExecutor {
92 fn schema(&self) -> &Schema {
93 &self.schema
94 }
95
96 fn identity(&self) -> &str {
97 &self.identity
98 }
99
100 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
101 self.do_execute()
102 }
103}
104
105impl InsertExecutor {
106 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
107 async fn do_execute(self: Box<Self>) {
108 let data_types = self.child.schema().data_types();
109 let mut builder = DataChunkBuilder::new(data_types, 1024);
110
111 let table_dml_handle = self
112 .dml_manager
113 .table_dml_handle(self.table_id, self.table_version_id)?;
114 let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?;
115
116 write_handle.begin()?;
117
118 let write_txn_data = |chunk: DataChunk| async {
121 let cap = chunk.capacity();
122 let (mut columns, vis) = chunk.into_parts();
123
124 let dummy_chunk = DataChunk::new_dummy(cap);
125
126 let mut ordered_columns = self
127 .column_indices
128 .iter()
129 .enumerate()
130 .map(|(i, idx)| (*idx, columns[i].clone()))
131 .collect_vec();
132
133 ordered_columns.reserve(ordered_columns.len() + self.sorted_default_columns.len());
134
135 for (idx, expr) in &self.sorted_default_columns {
136 let column = expr.eval(&dummy_chunk).await?;
137 ordered_columns.push((*idx, column));
138 }
139
140 ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
141 columns = ordered_columns
142 .into_iter()
143 .map(|(_, column)| column)
144 .collect_vec();
145
146 let returning_chunk = DataChunk::new(columns.clone(), vis.clone());
148
149 if let Some(row_id_index) = self.row_id_index {
152 let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
153 columns.insert(row_id_index, Arc::new(row_id_col.into()))
154 }
155
156 let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
157
158 #[cfg(debug_assertions)]
159 table_dml_handle.check_chunk_schema(&stream_chunk);
160
161 write_handle.write_chunk(stream_chunk).await?;
162
163 Result::Ok(returning_chunk)
164 };
165
166 let mut rows_inserted = 0;
167
168 #[for_await]
169 for data_chunk in self.child.execute() {
170 let data_chunk = data_chunk?;
171 for chunk in builder.append_chunk(data_chunk) {
172 let chunk = write_txn_data(chunk).await?;
173 rows_inserted += chunk.cardinality();
174 if self.returning {
175 yield chunk;
176 }
177 }
178 }
179
180 if let Some(chunk) = builder.consume_all() {
181 let chunk = write_txn_data(chunk).await?;
182 rows_inserted += chunk.cardinality();
183 if self.returning {
184 yield chunk;
185 }
186 }
187
188 write_handle.end().await?;
189
190 if !self.returning {
192 let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
193 array_builder.append(Some(rows_inserted as i64));
194
195 let array = array_builder.finish();
196 let ret_chunk = DataChunk::new(vec![Arc::new(array.into())], 1);
197
198 yield ret_chunk
199 }
200 }
201}
202
203impl BoxedExecutorBuilder for InsertExecutor {
204 async fn new_boxed_executor(
205 source: &ExecutorBuilder<'_>,
206 inputs: Vec<BoxedExecutor>,
207 ) -> Result<BoxedExecutor> {
208 let [child]: [_; 1] = inputs.try_into().unwrap();
209
210 let insert_node = try_match_expand!(
211 source.plan_node().get_node_body().unwrap(),
212 NodeBody::Insert
213 )?;
214
215 let table_id = TableId::new(insert_node.table_id);
216 let column_indices = insert_node
217 .column_indices
218 .iter()
219 .map(|&i| i as usize)
220 .collect();
221 let sorted_default_columns = if let Some(default_columns) = &insert_node.default_columns {
222 let mut default_columns = default_columns
223 .get_default_columns()
224 .iter()
225 .cloned()
226 .map(|IndexAndExpr { index: i, expr: e }| {
227 Ok((
228 i as usize,
229 build_from_prost(&e.context("expression is None")?)
230 .context("failed to build expression")?,
231 ))
232 })
233 .collect::<Result<Vec<_>>>()?;
234 default_columns.sort_unstable_by_key(|(i, _)| *i);
235 default_columns
236 } else {
237 vec![]
238 };
239
240 Ok(Box::new(Self::new(
241 table_id,
242 insert_node.table_version_id,
243 source.context().dml_manager(),
244 child,
245 source.context().get_config().developer.chunk_size,
246 source.plan_node().get_identity().clone(),
247 column_indices,
248 sorted_default_columns,
249 insert_node.row_id_index.as_ref().map(|index| *index as _),
250 insert_node.returning,
251 insert_node.session_id,
252 )))
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use std::ops::Bound;
259
260 use assert_matches::assert_matches;
261 use foyer::CacheHint;
262 use futures::StreamExt;
263 use risingwave_common::array::{Array, ArrayImpl, I32Array, StructArray};
264 use risingwave_common::catalog::{
265 ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, schema_test_utils,
266 };
267 use risingwave_common::transaction::transaction_message::TxnMsg;
268 use risingwave_common::types::{DataType, StructType};
269 use risingwave_dml::dml_manager::DmlManager;
270 use risingwave_storage::hummock::CachePolicy;
271 use risingwave_storage::hummock::test_utils::*;
272 use risingwave_storage::memory::MemoryStateStore;
273 use risingwave_storage::store::ReadOptions;
274
275 use super::*;
276 use crate::executor::test_utils::MockExecutor;
277 use crate::*;
278
279 #[tokio::test]
280 async fn test_insert_executor() -> Result<()> {
281 let dml_manager = Arc::new(DmlManager::for_test());
282 let store = MemoryStateStore::new();
283
284 let struct_field = Field::unnamed(
286 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]).into(),
287 );
288
289 let mut schema = schema_test_utils::ii();
291 schema.fields.push(struct_field.clone());
292 let mut mock_executor = MockExecutor::new(schema.clone());
293
294 let mut schema = schema_test_utils::ii();
296 schema.fields.push(struct_field);
297 schema.fields.push(Field::unnamed(DataType::Serial)); let row_id_index = Some(3);
300
301 let col1 = Arc::new(I32Array::from_iter([1, 3, 5, 7, 9]).into());
302 let col2 = Arc::new(I32Array::from_iter([2, 4, 6, 8, 10]).into());
303 let array = StructArray::new(
304 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
305 vec![
306 I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
307 I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
308 I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
309 ],
310 [true, false, false, false, false].into_iter().collect(),
311 );
312 let col3 = Arc::new(array.into());
313 let data_chunk: DataChunk = DataChunk::new(vec![col1, col2, col3], 5);
314 mock_executor.add(data_chunk.clone());
315
316 let table_id = TableId::new(0);
318
319 let column_descs = schema
321 .fields
322 .iter()
323 .enumerate()
324 .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
325 .collect_vec();
326 let reader = dml_manager
329 .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
330 .unwrap();
331 let mut reader = reader.stream_reader().into_stream();
332
333 let insert_executor = Box::new(InsertExecutor::new(
335 table_id,
336 INITIAL_TABLE_VERSION_ID,
337 dml_manager,
338 Box::new(mock_executor),
339 1024,
340 "InsertExecutor".to_owned(),
341 vec![0, 1, 2], vec![],
343 row_id_index,
344 false,
345 0,
346 ));
347 let handle = tokio::spawn(async move {
348 let mut stream = insert_executor.execute();
349 let result = stream.next().await.unwrap().unwrap();
350
351 assert_eq!(
352 result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
353 vec![Some(5)] );
355 });
356
357 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
359
360 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
361 assert_eq!(
362 chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
363 vec![Some(1), Some(3), Some(5), Some(7), Some(9)]
364 );
365
366 assert_eq!(
367 chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
368 vec![Some(2), Some(4), Some(6), Some(8), Some(10)]
369 );
370
371 let array: ArrayImpl = StructArray::new(
372 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
373 vec![
374 I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
375 I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
376 I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
377 ],
378 [true, false, false, false, false].into_iter().collect(),
379 )
380 .into();
381 assert_eq!(*chunk.columns()[2], array);
382 });
383
384 assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
385 let epoch = u64::MAX;
386 let full_range = (Bound::Unbounded, Bound::Unbounded);
387 let store_content = store
388 .scan(
389 full_range,
390 epoch,
391 None,
392 ReadOptions {
393 cache_policy: CachePolicy::Fill(CacheHint::Normal),
394 ..Default::default()
395 },
396 )
397 .await?;
398 assert!(store_content.is_empty());
399
400 handle.await.unwrap();
401
402 Ok(())
403 }
404}