1use std::sync::Arc;
16
17use anyhow::Context;
18use futures_async_stream::try_stream;
19use itertools::Itertools;
20use risingwave_common::array::{
21 ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, SerialArray, StreamChunk,
22};
23use risingwave_common::catalog::{Schema, TableId, TableVersionId};
24use risingwave_common::transaction::transaction_id::TxnId;
25use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
26use risingwave_dml::dml_manager::DmlManagerRef;
27use risingwave_expr::expr::{BoxedExpression, build_from_prost};
28use risingwave_pb::batch_plan::plan_node::NodeBody;
29use risingwave_pb::plan_common::IndexAndExpr;
30
31use crate::error::{BatchError, Result};
32use crate::executor::{
33 BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
34};
35
36pub struct InsertExecutor {
38 table_id: TableId,
40 table_version_id: TableVersionId,
41 dml_manager: DmlManagerRef,
42 child: BoxedExecutor,
43 chunk_size: usize,
44 schema: Schema,
45 identity: String,
46 column_indices: Vec<usize>,
47 sorted_default_columns: Vec<(usize, BoxedExpression)>,
48
49 row_id_index: Option<usize>,
50 returning: bool,
51 txn_id: TxnId,
52 session_id: u32,
53}
54
55impl InsertExecutor {
56 #[allow(clippy::too_many_arguments)]
57 pub fn new(
58 table_id: TableId,
59 table_version_id: TableVersionId,
60 dml_manager: DmlManagerRef,
61 child: BoxedExecutor,
62 chunk_size: usize,
63 identity: String,
64 column_indices: Vec<usize>,
65 sorted_default_columns: Vec<(usize, BoxedExpression)>,
66 row_id_index: Option<usize>,
67 returning: bool,
68 session_id: u32,
69 ) -> Self {
70 let table_schema = child.schema().clone();
71 let txn_id = dml_manager.gen_txn_id();
72 Self {
73 table_id,
74 table_version_id,
75 dml_manager,
76 child,
77 chunk_size,
78 schema: table_schema,
79 identity,
80 column_indices,
81 sorted_default_columns,
82 row_id_index,
83 returning,
84 txn_id,
85 session_id,
86 }
87 }
88}
89
90impl Executor for InsertExecutor {
91 fn schema(&self) -> &Schema {
92 &self.schema
93 }
94
95 fn identity(&self) -> &str {
96 &self.identity
97 }
98
99 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
100 self.do_execute()
101 }
102}
103
104impl InsertExecutor {
105 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
106 async fn do_execute(self: Box<Self>) {
107 let data_types = self.child.schema().data_types();
108 let mut builder = DataChunkBuilder::new(data_types, self.chunk_size);
109
110 let table_dml_handle = self
111 .dml_manager
112 .table_dml_handle(self.table_id, self.table_version_id)?;
113 let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?;
114
115 write_handle.begin()?;
116
117 let write_txn_data = |chunk: DataChunk| async {
120 let cap = chunk.capacity();
121 let (mut columns, vis) = chunk.into_parts();
122
123 let dummy_chunk = DataChunk::new_dummy(cap);
124
125 let mut ordered_columns = self
126 .column_indices
127 .iter()
128 .enumerate()
129 .map(|(i, idx)| (*idx, columns[i].clone()))
130 .collect_vec();
131
132 ordered_columns.reserve(ordered_columns.len() + self.sorted_default_columns.len());
133
134 for (idx, expr) in &self.sorted_default_columns {
135 let column = expr.eval(&dummy_chunk).await?;
136 ordered_columns.push((*idx, column));
137 }
138
139 ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
140 columns = ordered_columns
141 .into_iter()
142 .map(|(_, column)| column)
143 .collect_vec();
144
145 let returning_chunk = DataChunk::new(columns.clone(), vis.clone());
147
148 if let Some(row_id_index) = self.row_id_index {
151 let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
152 columns.insert(row_id_index, Arc::new(row_id_col.into()))
153 }
154
155 let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
156
157 #[cfg(debug_assertions)]
158 table_dml_handle.check_chunk_schema(&stream_chunk);
159
160 write_handle.write_chunk(stream_chunk).await?;
161
162 Result::Ok(returning_chunk)
163 };
164
165 let mut rows_inserted = 0;
166
167 #[for_await]
168 for data_chunk in self.child.execute() {
169 let data_chunk = data_chunk?;
170 for chunk in builder.append_chunk(data_chunk) {
171 let chunk = write_txn_data(chunk).await?;
172 rows_inserted += chunk.cardinality();
173 if self.returning {
174 yield chunk;
175 }
176 }
177 }
178
179 if let Some(chunk) = builder.consume_all() {
180 let chunk = write_txn_data(chunk).await?;
181 rows_inserted += chunk.cardinality();
182 if self.returning {
183 yield chunk;
184 }
185 }
186
187 write_handle.end().await?;
188
189 if !self.returning {
191 let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
192 array_builder.append(Some(rows_inserted as i64));
193
194 let array = array_builder.finish();
195 let ret_chunk = DataChunk::new(vec![Arc::new(array.into())], 1);
196
197 yield ret_chunk
198 }
199 }
200}
201
202impl BoxedExecutorBuilder for InsertExecutor {
203 async fn new_boxed_executor(
204 source: &ExecutorBuilder<'_>,
205 inputs: Vec<BoxedExecutor>,
206 ) -> Result<BoxedExecutor> {
207 let [child]: [_; 1] = inputs.try_into().unwrap();
208
209 let insert_node = try_match_expand!(
210 source.plan_node().get_node_body().unwrap(),
211 NodeBody::Insert
212 )?;
213
214 let table_id = TableId::new(insert_node.table_id);
215 let column_indices = insert_node
216 .column_indices
217 .iter()
218 .map(|&i| i as usize)
219 .collect();
220 let sorted_default_columns = if let Some(default_columns) = &insert_node.default_columns {
221 let mut default_columns = default_columns
222 .get_default_columns()
223 .iter()
224 .cloned()
225 .map(|IndexAndExpr { index: i, expr: e }| {
226 Ok((
227 i as usize,
228 build_from_prost(&e.context("expression is None")?)
229 .context("failed to build expression")?,
230 ))
231 })
232 .collect::<Result<Vec<_>>>()?;
233 default_columns.sort_unstable_by_key(|(i, _)| *i);
234 default_columns
235 } else {
236 vec![]
237 };
238
239 Ok(Box::new(Self::new(
240 table_id,
241 insert_node.table_version_id,
242 source.context().dml_manager(),
243 child,
244 source.context().get_config().developer.chunk_size,
245 source.plan_node().get_identity().clone(),
246 column_indices,
247 sorted_default_columns,
248 insert_node.row_id_index.as_ref().map(|index| *index as _),
249 insert_node.returning,
250 insert_node.session_id,
251 )))
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use std::ops::Bound;
258
259 use assert_matches::assert_matches;
260 use foyer::Hint;
261 use futures::StreamExt;
262 use risingwave_common::array::{Array, ArrayImpl, I32Array, StructArray};
263 use risingwave_common::catalog::{
264 ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, schema_test_utils,
265 };
266 use risingwave_common::transaction::transaction_message::TxnMsg;
267 use risingwave_common::types::{DataType, StructType};
268 use risingwave_dml::dml_manager::DmlManager;
269 use risingwave_storage::hummock::CachePolicy;
270 use risingwave_storage::hummock::test_utils::*;
271 use risingwave_storage::memory::MemoryStateStore;
272
273 use super::*;
274 use crate::executor::test_utils::MockExecutor;
275 use crate::*;
276
277 #[tokio::test]
278 async fn test_insert_executor() -> Result<()> {
279 let dml_manager = Arc::new(DmlManager::for_test());
280 let store = MemoryStateStore::new();
281
282 let struct_field = Field::unnamed(
284 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]).into(),
285 );
286
287 let mut schema = schema_test_utils::ii();
289 schema.fields.push(struct_field.clone());
290 let mut mock_executor = MockExecutor::new(schema.clone());
291
292 let mut schema = schema_test_utils::ii();
294 schema.fields.push(struct_field);
295 schema.fields.push(Field::unnamed(DataType::Serial)); let row_id_index = Some(3);
298
299 let col1 = Arc::new(I32Array::from_iter([1, 3, 5, 7, 9]).into());
300 let col2 = Arc::new(I32Array::from_iter([2, 4, 6, 8, 10]).into());
301 let array = StructArray::new(
302 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
303 vec![
304 I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
305 I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
306 I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
307 ],
308 [true, false, false, false, false].into_iter().collect(),
309 );
310 let col3 = Arc::new(array.into());
311 let data_chunk: DataChunk = DataChunk::new(vec![col1, col2, col3], 5);
312 mock_executor.add(data_chunk.clone());
313
314 let table_id = TableId::new(0);
316
317 let column_descs = schema
319 .fields
320 .iter()
321 .enumerate()
322 .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
323 .collect_vec();
324 let reader = dml_manager
327 .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
328 .unwrap();
329 let mut reader = reader.stream_reader().into_stream();
330
331 let insert_executor = Box::new(InsertExecutor::new(
333 table_id,
334 INITIAL_TABLE_VERSION_ID,
335 dml_manager,
336 Box::new(mock_executor),
337 1024,
338 "InsertExecutor".to_owned(),
339 vec![0, 1, 2], vec![],
341 row_id_index,
342 false,
343 0,
344 ));
345 let handle = tokio::spawn(async move {
346 let mut stream = insert_executor.execute();
347 let result = stream.next().await.unwrap().unwrap();
348
349 assert_eq!(
350 result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
351 vec![Some(5)] );
353 });
354
355 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
357
358 assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
359 assert_eq!(
360 chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
361 vec![Some(1), Some(3), Some(5), Some(7), Some(9)]
362 );
363
364 assert_eq!(
365 chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
366 vec![Some(2), Some(4), Some(6), Some(8), Some(10)]
367 );
368
369 let array: ArrayImpl = StructArray::new(
370 StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
371 vec![
372 I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
373 I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
374 I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
375 ],
376 [true, false, false, false, false].into_iter().collect(),
377 )
378 .into();
379 assert_eq!(*chunk.columns()[2], array);
380 });
381
382 assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
383 let epoch = u64::MAX;
384 let full_range = (Bound::Unbounded, Bound::Unbounded);
385 let store_content = store
386 .scan(
387 full_range,
388 epoch,
389 None,
390 ReadOptions {
391 cache_policy: CachePolicy::Fill(Hint::Normal),
392 ..Default::default()
393 },
394 )
395 .await?;
396 assert!(store_content.is_empty());
397
398 handle.await.unwrap();
399
400 Ok(())
401 }
402}