risingwave_batch_executors/executor/
insert.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use futures_async_stream::try_stream;
19use itertools::Itertools;
20use risingwave_common::array::{
21    ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, SerialArray, StreamChunk,
22};
23use risingwave_common::catalog::{Schema, TableId, TableVersionId};
24use risingwave_common::transaction::transaction_id::TxnId;
25use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
26use risingwave_dml::dml_manager::DmlManagerRef;
27use risingwave_expr::expr::{BoxedExpression, build_from_prost};
28use risingwave_pb::batch_plan::plan_node::NodeBody;
29use risingwave_pb::plan_common::IndexAndExpr;
30
31use crate::error::{BatchError, Result};
32use crate::executor::{
33    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
34};
35
36/// [`InsertExecutor`] implements table insertion with values from its child executor.
37pub struct InsertExecutor {
38    /// Target table id.
39    table_id: TableId,
40    table_version_id: TableVersionId,
41    dml_manager: DmlManagerRef,
42    child: BoxedExecutor,
43    #[expect(dead_code)]
44    chunk_size: usize,
45    schema: Schema,
46    identity: String,
47    column_indices: Vec<usize>,
48    sorted_default_columns: Vec<(usize, BoxedExpression)>,
49
50    row_id_index: Option<usize>,
51    returning: bool,
52    txn_id: TxnId,
53    session_id: u32,
54}
55
56impl InsertExecutor {
57    #[allow(clippy::too_many_arguments)]
58    pub fn new(
59        table_id: TableId,
60        table_version_id: TableVersionId,
61        dml_manager: DmlManagerRef,
62        child: BoxedExecutor,
63        chunk_size: usize,
64        identity: String,
65        column_indices: Vec<usize>,
66        sorted_default_columns: Vec<(usize, BoxedExpression)>,
67        row_id_index: Option<usize>,
68        returning: bool,
69        session_id: u32,
70    ) -> Self {
71        let table_schema = child.schema().clone();
72        let txn_id = dml_manager.gen_txn_id();
73        Self {
74            table_id,
75            table_version_id,
76            dml_manager,
77            child,
78            chunk_size,
79            schema: table_schema,
80            identity,
81            column_indices,
82            sorted_default_columns,
83            row_id_index,
84            returning,
85            txn_id,
86            session_id,
87        }
88    }
89}
90
91impl Executor for InsertExecutor {
92    fn schema(&self) -> &Schema {
93        &self.schema
94    }
95
96    fn identity(&self) -> &str {
97        &self.identity
98    }
99
100    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
101        self.do_execute()
102    }
103}
104
105impl InsertExecutor {
106    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
107    async fn do_execute(self: Box<Self>) {
108        let data_types = self.child.schema().data_types();
109        let mut builder = DataChunkBuilder::new(data_types, 1024);
110
111        let table_dml_handle = self
112            .dml_manager
113            .table_dml_handle(self.table_id, self.table_version_id)?;
114        let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?;
115
116        write_handle.begin()?;
117
118        // Transform the data chunk to a stream chunk, then write to the source.
119        // Return the returning chunk.
120        let write_txn_data = |chunk: DataChunk| async {
121            let cap = chunk.capacity();
122            let (mut columns, vis) = chunk.into_parts();
123
124            let dummy_chunk = DataChunk::new_dummy(cap);
125
126            let mut ordered_columns = self
127                .column_indices
128                .iter()
129                .enumerate()
130                .map(|(i, idx)| (*idx, columns[i].clone()))
131                .collect_vec();
132
133            ordered_columns.reserve(ordered_columns.len() + self.sorted_default_columns.len());
134
135            for (idx, expr) in &self.sorted_default_columns {
136                let column = expr.eval(&dummy_chunk).await?;
137                ordered_columns.push((*idx, column));
138            }
139
140            ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
141            columns = ordered_columns
142                .into_iter()
143                .map(|(_, column)| column)
144                .collect_vec();
145
146            // Construct the returning chunk, without the `row_id` column.
147            let returning_chunk = DataChunk::new(columns.clone(), vis.clone());
148
149            // If the user does not specify the primary key, then we need to add a column as the
150            // primary key.
151            if let Some(row_id_index) = self.row_id_index {
152                let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
153                columns.insert(row_id_index, Arc::new(row_id_col.into()))
154            }
155
156            let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
157
158            #[cfg(debug_assertions)]
159            table_dml_handle.check_chunk_schema(&stream_chunk);
160
161            write_handle.write_chunk(stream_chunk).await?;
162
163            Result::Ok(returning_chunk)
164        };
165
166        let mut rows_inserted = 0;
167
168        #[for_await]
169        for data_chunk in self.child.execute() {
170            let data_chunk = data_chunk?;
171            for chunk in builder.append_chunk(data_chunk) {
172                let chunk = write_txn_data(chunk).await?;
173                rows_inserted += chunk.cardinality();
174                if self.returning {
175                    yield chunk;
176                }
177            }
178        }
179
180        if let Some(chunk) = builder.consume_all() {
181            let chunk = write_txn_data(chunk).await?;
182            rows_inserted += chunk.cardinality();
183            if self.returning {
184                yield chunk;
185            }
186        }
187
188        write_handle.end().await?;
189
190        // create ret value
191        if !self.returning {
192            let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
193            array_builder.append(Some(rows_inserted as i64));
194
195            let array = array_builder.finish();
196            let ret_chunk = DataChunk::new(vec![Arc::new(array.into())], 1);
197
198            yield ret_chunk
199        }
200    }
201}
202
203impl BoxedExecutorBuilder for InsertExecutor {
204    async fn new_boxed_executor(
205        source: &ExecutorBuilder<'_>,
206        inputs: Vec<BoxedExecutor>,
207    ) -> Result<BoxedExecutor> {
208        let [child]: [_; 1] = inputs.try_into().unwrap();
209
210        let insert_node = try_match_expand!(
211            source.plan_node().get_node_body().unwrap(),
212            NodeBody::Insert
213        )?;
214
215        let table_id = TableId::new(insert_node.table_id);
216        let column_indices = insert_node
217            .column_indices
218            .iter()
219            .map(|&i| i as usize)
220            .collect();
221        let sorted_default_columns = if let Some(default_columns) = &insert_node.default_columns {
222            let mut default_columns = default_columns
223                .get_default_columns()
224                .iter()
225                .cloned()
226                .map(|IndexAndExpr { index: i, expr: e }| {
227                    Ok((
228                        i as usize,
229                        build_from_prost(&e.context("expression is None")?)
230                            .context("failed to build expression")?,
231                    ))
232                })
233                .collect::<Result<Vec<_>>>()?;
234            default_columns.sort_unstable_by_key(|(i, _)| *i);
235            default_columns
236        } else {
237            vec![]
238        };
239
240        Ok(Box::new(Self::new(
241            table_id,
242            insert_node.table_version_id,
243            source.context().dml_manager(),
244            child,
245            source.context().get_config().developer.chunk_size,
246            source.plan_node().get_identity().clone(),
247            column_indices,
248            sorted_default_columns,
249            insert_node.row_id_index.as_ref().map(|index| *index as _),
250            insert_node.returning,
251            insert_node.session_id,
252        )))
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use std::ops::Bound;
259
260    use assert_matches::assert_matches;
261    use foyer::CacheHint;
262    use futures::StreamExt;
263    use risingwave_common::array::{Array, ArrayImpl, I32Array, StructArray};
264    use risingwave_common::catalog::{
265        ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, schema_test_utils,
266    };
267    use risingwave_common::transaction::transaction_message::TxnMsg;
268    use risingwave_common::types::{DataType, StructType};
269    use risingwave_dml::dml_manager::DmlManager;
270    use risingwave_storage::hummock::CachePolicy;
271    use risingwave_storage::hummock::test_utils::*;
272    use risingwave_storage::memory::MemoryStateStore;
273    use risingwave_storage::store::ReadOptions;
274
275    use super::*;
276    use crate::executor::test_utils::MockExecutor;
277    use crate::*;
278
279    #[tokio::test]
280    async fn test_insert_executor() -> Result<()> {
281        let dml_manager = Arc::new(DmlManager::for_test());
282        let store = MemoryStateStore::new();
283
284        // Make struct field
285        let struct_field = Field::unnamed(
286            StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]).into(),
287        );
288
289        // Schema for mock executor.
290        let mut schema = schema_test_utils::ii();
291        schema.fields.push(struct_field.clone());
292        let mut mock_executor = MockExecutor::new(schema.clone());
293
294        // Schema of the table
295        let mut schema = schema_test_utils::ii();
296        schema.fields.push(struct_field);
297        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
298
299        let row_id_index = Some(3);
300
301        let col1 = Arc::new(I32Array::from_iter([1, 3, 5, 7, 9]).into());
302        let col2 = Arc::new(I32Array::from_iter([2, 4, 6, 8, 10]).into());
303        let array = StructArray::new(
304            StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
305            vec![
306                I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
307                I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
308                I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
309            ],
310            [true, false, false, false, false].into_iter().collect(),
311        );
312        let col3 = Arc::new(array.into());
313        let data_chunk: DataChunk = DataChunk::new(vec![col1, col2, col3], 5);
314        mock_executor.add(data_chunk.clone());
315
316        // Create the table.
317        let table_id = TableId::new(0);
318
319        // Create reader
320        let column_descs = schema
321            .fields
322            .iter()
323            .enumerate()
324            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
325            .collect_vec();
326        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
327        // due to the `Weak` reference in `DmlManager`.
328        let reader = dml_manager
329            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
330            .unwrap();
331        let mut reader = reader.stream_reader().into_stream();
332
333        // Insert
334        let insert_executor = Box::new(InsertExecutor::new(
335            table_id,
336            INITIAL_TABLE_VERSION_ID,
337            dml_manager,
338            Box::new(mock_executor),
339            1024,
340            "InsertExecutor".to_owned(),
341            vec![0, 1, 2], // Ignoring insertion order
342            vec![],
343            row_id_index,
344            false,
345            0,
346        ));
347        let handle = tokio::spawn(async move {
348            let mut stream = insert_executor.execute();
349            let result = stream.next().await.unwrap().unwrap();
350
351            assert_eq!(
352                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
353                vec![Some(5)] // inserted rows
354            );
355        });
356
357        // Read
358        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
359
360        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
361            assert_eq!(
362                chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
363                vec![Some(1), Some(3), Some(5), Some(7), Some(9)]
364            );
365
366            assert_eq!(
367                chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
368                vec![Some(2), Some(4), Some(6), Some(8), Some(10)]
369            );
370
371            let array: ArrayImpl = StructArray::new(
372                StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
373                vec![
374                    I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
375                    I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
376                    I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
377                ],
378                [true, false, false, false, false].into_iter().collect(),
379            )
380            .into();
381            assert_eq!(*chunk.columns()[2], array);
382        });
383
384        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
385        let epoch = u64::MAX;
386        let full_range = (Bound::Unbounded, Bound::Unbounded);
387        let store_content = store
388            .scan(
389                full_range,
390                epoch,
391                None,
392                ReadOptions {
393                    cache_policy: CachePolicy::Fill(CacheHint::Normal),
394                    ..Default::default()
395                },
396            )
397            .await?;
398        assert!(store_content.is_empty());
399
400        handle.await.unwrap();
401
402        Ok(())
403    }
404}