Skip to main content

risingwave_batch_executors/executor/
insert.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use futures_async_stream::try_stream;
19use itertools::Itertools;
20use risingwave_common::array::{
21    ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, SerialArray, StreamChunk,
22};
23use risingwave_common::catalog::{Schema, TableId, TableVersionId};
24use risingwave_common::transaction::transaction_id::TxnId;
25use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
26use risingwave_dml::dml_manager::DmlManagerRef;
27use risingwave_expr::expr::{BoxedExpression, build_from_prost};
28use risingwave_pb::batch_plan::plan_node::NodeBody;
29use risingwave_pb::plan_common::IndexAndExpr;
30
31use crate::error::{BatchError, Result};
32use crate::executor::{
33    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
34};
35
36/// [`InsertExecutor`] implements table insertion with values from its child executor.
37pub struct InsertExecutor {
38    /// Target table id.
39    table_id: TableId,
40    table_version_id: TableVersionId,
41    dml_manager: DmlManagerRef,
42    child: BoxedExecutor,
43    chunk_size: usize,
44    schema: Schema,
45    identity: String,
46    column_indices: Vec<usize>,
47    sorted_default_columns: Vec<(usize, BoxedExpression)>,
48
49    row_id_index: Option<usize>,
50    returning: bool,
51    txn_id: TxnId,
52    session_id: u32,
53    wait_for_persistence: bool,
54}
55
56impl InsertExecutor {
57    #[expect(clippy::too_many_arguments)]
58    pub fn new(
59        table_id: TableId,
60        table_version_id: TableVersionId,
61        dml_manager: DmlManagerRef,
62        child: BoxedExecutor,
63        chunk_size: usize,
64        identity: String,
65        column_indices: Vec<usize>,
66        sorted_default_columns: Vec<(usize, BoxedExpression)>,
67        row_id_index: Option<usize>,
68        returning: bool,
69        session_id: u32,
70        wait_for_persistence: bool,
71    ) -> Self {
72        let table_schema = child.schema().clone();
73        let txn_id = dml_manager.gen_txn_id();
74        Self {
75            table_id,
76            table_version_id,
77            dml_manager,
78            child,
79            chunk_size,
80            schema: table_schema,
81            identity,
82            column_indices,
83            sorted_default_columns,
84            row_id_index,
85            returning,
86            txn_id,
87            session_id,
88            wait_for_persistence,
89        }
90    }
91}
92
93impl Executor for InsertExecutor {
94    fn schema(&self) -> &Schema {
95        &self.schema
96    }
97
98    fn identity(&self) -> &str {
99        &self.identity
100    }
101
102    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
103        self.do_execute()
104    }
105}
106
107impl InsertExecutor {
108    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
109    async fn do_execute(self: Box<Self>) {
110        let data_types = self.child.schema().data_types();
111        let mut builder = DataChunkBuilder::new(data_types, self.chunk_size);
112
113        let table_dml_handle = self
114            .dml_manager
115            .table_dml_handle(self.table_id, self.table_version_id)?;
116        let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?;
117
118        write_handle.begin()?;
119
120        // Transform the data chunk to a stream chunk, then write to the source.
121        // Return the returning chunk.
122        let write_txn_data = |chunk: DataChunk| async {
123            let cap = chunk.capacity();
124            let (mut columns, vis) = chunk.into_parts();
125
126            let dummy_chunk = DataChunk::new_dummy(cap);
127
128            let mut ordered_columns = self
129                .column_indices
130                .iter()
131                .enumerate()
132                .map(|(i, idx)| (*idx, columns[i].clone()))
133                .collect_vec();
134
135            ordered_columns.reserve(ordered_columns.len() + self.sorted_default_columns.len());
136
137            for (idx, expr) in &self.sorted_default_columns {
138                let column = expr.eval(&dummy_chunk).await?;
139                ordered_columns.push((*idx, column));
140            }
141
142            ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
143            columns = ordered_columns
144                .into_iter()
145                .map(|(_, column)| column)
146                .collect_vec();
147
148            // Construct the returning chunk, without the `row_id` column.
149            let returning_chunk = DataChunk::new(columns.clone(), vis.clone());
150
151            // If the user does not specify the primary key, then we need to add a column as the
152            // primary key.
153            if let Some(row_id_index) = self.row_id_index {
154                let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
155                columns.insert(row_id_index, Arc::new(row_id_col.into()))
156            }
157
158            let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
159
160            #[cfg(debug_assertions)]
161            table_dml_handle.check_chunk_schema(&stream_chunk);
162
163            write_handle.write_chunk(stream_chunk).await?;
164
165            Result::Ok(returning_chunk)
166        };
167
168        let mut rows_inserted = 0;
169
170        #[for_await]
171        for data_chunk in self.child.execute() {
172            let data_chunk = data_chunk?;
173            for chunk in builder.append_chunk(data_chunk) {
174                let chunk = write_txn_data(chunk).await?;
175                rows_inserted += chunk.cardinality();
176                if self.returning {
177                    yield chunk;
178                }
179            }
180        }
181
182        if let Some(chunk) = builder.consume_all() {
183            let chunk = write_txn_data(chunk).await?;
184            rows_inserted += chunk.cardinality();
185            if self.returning {
186                yield chunk;
187            }
188        }
189
190        if self.wait_for_persistence {
191            write_handle.end_wait_persistence()?.await?;
192        } else {
193            write_handle.end().await?;
194        }
195
196        // create ret value
197        if !self.returning {
198            let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
199            array_builder.append(Some(rows_inserted as i64));
200
201            let array = array_builder.finish();
202            let ret_chunk = DataChunk::new(vec![Arc::new(array.into())], 1);
203
204            yield ret_chunk
205        }
206    }
207}
208
209impl BoxedExecutorBuilder for InsertExecutor {
210    async fn new_boxed_executor(
211        source: &ExecutorBuilder<'_>,
212        inputs: Vec<BoxedExecutor>,
213    ) -> Result<BoxedExecutor> {
214        let [child]: [_; 1] = inputs.try_into().unwrap();
215
216        let insert_node = try_match_expand!(
217            source.plan_node().get_node_body().unwrap(),
218            NodeBody::Insert
219        )?;
220
221        let table_id = insert_node.table_id;
222        let column_indices = insert_node
223            .column_indices
224            .iter()
225            .map(|&i| i as usize)
226            .collect();
227        let sorted_default_columns = if let Some(default_columns) = &insert_node.default_columns {
228            let mut default_columns = default_columns
229                .get_default_columns()
230                .iter()
231                .cloned()
232                .map(|IndexAndExpr { index: i, expr: e }| {
233                    Ok((
234                        i as usize,
235                        build_from_prost(&e.context("expression is None")?)
236                            .context("failed to build expression")?,
237                    ))
238                })
239                .collect::<Result<Vec<_>>>()?;
240            default_columns.sort_unstable_by_key(|(i, _)| *i);
241            default_columns
242        } else {
243            vec![]
244        };
245
246        Ok(Box::new(Self::new(
247            table_id,
248            insert_node.table_version_id,
249            source.context().dml_manager(),
250            child,
251            source.context().get_config().developer.chunk_size,
252            source.plan_node().get_identity().clone(),
253            column_indices,
254            sorted_default_columns,
255            insert_node.row_id_index.as_ref().map(|index| *index as _),
256            insert_node.returning,
257            insert_node.session_id,
258            insert_node.wait_for_persistence,
259        )))
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use std::ops::Bound;
266
267    use assert_matches::assert_matches;
268    use foyer::Hint;
269    use futures::StreamExt;
270    use risingwave_common::array::{Array, ArrayImpl, I32Array, StructArray};
271    use risingwave_common::catalog::{
272        ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, schema_test_utils,
273    };
274    use risingwave_common::test_prelude::DataChunkTestExt;
275    use risingwave_common::transaction::transaction_message::TxnMsg;
276    use risingwave_common::types::{DataType, StructType};
277    use risingwave_dml::dml_manager::DmlManager;
278    use risingwave_storage::hummock::CachePolicy;
279    use risingwave_storage::hummock::test_utils::*;
280    use risingwave_storage::memory::MemoryStateStore;
281
282    use super::*;
283    use crate::executor::test_utils::MockExecutor;
284    use crate::*;
285
286    #[tokio::test]
287    async fn test_insert_executor() -> Result<()> {
288        let dml_manager = Arc::new(DmlManager::for_test());
289        let store = MemoryStateStore::new();
290
291        // Make struct field
292        let struct_field = Field::unnamed(
293            StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]).into(),
294        );
295
296        // Schema for mock executor.
297        let mut schema = schema_test_utils::ii();
298        schema.fields.push(struct_field.clone());
299        let mut mock_executor = MockExecutor::new(schema.clone());
300
301        // Schema of the table
302        let mut schema = schema_test_utils::ii();
303        schema.fields.push(struct_field);
304        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
305
306        let row_id_index = Some(3);
307
308        let col1 = Arc::new(I32Array::from_iter([1, 3, 5, 7, 9]).into());
309        let col2 = Arc::new(I32Array::from_iter([2, 4, 6, 8, 10]).into());
310        let array = StructArray::new(
311            StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
312            vec![
313                I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
314                I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
315                I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
316            ],
317            [true, false, false, false, false].into_iter().collect(),
318        );
319        let col3 = Arc::new(array.into());
320        let data_chunk: DataChunk = DataChunk::new(vec![col1, col2, col3], 5);
321        mock_executor.add(data_chunk.clone());
322
323        // Create the table.
324        let table_id = TableId::new(0);
325
326        // Create reader
327        let column_descs = schema
328            .fields
329            .iter()
330            .enumerate()
331            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
332            .collect_vec();
333        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
334        // due to the `Weak` reference in `DmlManager`.
335        let reader = dml_manager
336            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
337            .unwrap();
338        let mut reader = reader.stream_reader().into_stream();
339
340        // Insert
341        let insert_executor = Box::new(InsertExecutor::new(
342            table_id,
343            INITIAL_TABLE_VERSION_ID,
344            dml_manager,
345            Box::new(mock_executor),
346            1024,
347            "InsertExecutor".to_owned(),
348            vec![0, 1, 2], // Ignoring insertion order
349            vec![],
350            row_id_index,
351            false,
352            0,
353            false,
354        ));
355        let handle = tokio::spawn(async move {
356            let mut stream = insert_executor.execute();
357            let result = stream.next().await.unwrap().unwrap();
358
359            assert_eq!(
360                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
361                vec![Some(5)] // inserted rows
362            );
363        });
364
365        // Read
366        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
367
368        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
369            assert_eq!(
370                chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
371                vec![Some(1), Some(3), Some(5), Some(7), Some(9)]
372            );
373
374            assert_eq!(
375                chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
376                vec![Some(2), Some(4), Some(6), Some(8), Some(10)]
377            );
378
379            let array: ArrayImpl = StructArray::new(
380                StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
381                vec![
382                    I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
383                    I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
384                    I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
385                ],
386                [true, false, false, false, false].into_iter().collect(),
387            )
388            .into();
389            assert_eq!(*chunk.columns()[2], array);
390        });
391
392        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, None));
393        let epoch = u64::MAX;
394        let full_range = (Bound::Unbounded, Bound::Unbounded);
395        let store_content = store
396            .scan(
397                full_range,
398                epoch,
399                None,
400                ReadOptions {
401                    cache_policy: CachePolicy::Fill(Hint::Normal),
402                    ..Default::default()
403                },
404            )
405            .await?;
406        assert!(store_content.is_empty());
407
408        handle.await.unwrap();
409
410        Ok(())
411    }
412
413    #[tokio::test]
414    async fn test_insert_executor_wait_for_persistence() -> Result<()> {
415        let dml_manager = Arc::new(DmlManager::for_test());
416
417        let schema = schema_test_utils::ii();
418        let mut mock_executor = MockExecutor::new(schema.clone());
419        mock_executor.add(DataChunk::from_pretty(
420            "i i
421             1 2",
422        ));
423
424        let table_id = TableId::new(0);
425        let column_descs = schema
426            .fields
427            .iter()
428            .enumerate()
429            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
430            .collect_vec();
431        let reader = dml_manager
432            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
433            .unwrap();
434        let mut reader = reader.stream_reader().into_stream();
435
436        let insert_executor = Box::new(InsertExecutor::new(
437            table_id,
438            INITIAL_TABLE_VERSION_ID,
439            dml_manager,
440            Box::new(mock_executor),
441            1024,
442            "InsertExecutor".to_owned(),
443            vec![0, 1],
444            vec![],
445            None,
446            false,
447            0,
448            true,
449        ));
450        let handle = tokio::spawn(async move {
451            let mut stream = insert_executor.execute();
452            let result = stream.next().await.unwrap().unwrap();
453            assert_eq!(
454                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
455                vec![Some(1)]
456            );
457        });
458
459        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
460        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, _));
461        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(persistence_notifier)) => {
462            assert!(!handle.is_finished());
463            persistence_notifier.send(()).unwrap();
464        });
465
466        handle.await.unwrap();
467
468        Ok(())
469    }
470}