risingwave_batch_executors/executor/
insert.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use futures_async_stream::try_stream;
19use itertools::Itertools;
20use risingwave_common::array::{
21    ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, SerialArray, StreamChunk,
22};
23use risingwave_common::catalog::{Schema, TableId, TableVersionId};
24use risingwave_common::transaction::transaction_id::TxnId;
25use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
26use risingwave_dml::dml_manager::DmlManagerRef;
27use risingwave_expr::expr::{BoxedExpression, build_from_prost};
28use risingwave_pb::batch_plan::plan_node::NodeBody;
29use risingwave_pb::plan_common::IndexAndExpr;
30
31use crate::error::{BatchError, Result};
32use crate::executor::{
33    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
34};
35
36/// [`InsertExecutor`] implements table insertion with values from its child executor.
37pub struct InsertExecutor {
38    /// Target table id.
39    table_id: TableId,
40    table_version_id: TableVersionId,
41    dml_manager: DmlManagerRef,
42    child: BoxedExecutor,
43    #[expect(dead_code)]
44    chunk_size: usize,
45    schema: Schema,
46    identity: String,
47    column_indices: Vec<usize>,
48    sorted_default_columns: Vec<(usize, BoxedExpression)>,
49
50    row_id_index: Option<usize>,
51    returning: bool,
52    txn_id: TxnId,
53    session_id: u32,
54}
55
56impl InsertExecutor {
57    #[allow(clippy::too_many_arguments)]
58    pub fn new(
59        table_id: TableId,
60        table_version_id: TableVersionId,
61        dml_manager: DmlManagerRef,
62        child: BoxedExecutor,
63        chunk_size: usize,
64        identity: String,
65        column_indices: Vec<usize>,
66        sorted_default_columns: Vec<(usize, BoxedExpression)>,
67        row_id_index: Option<usize>,
68        returning: bool,
69        session_id: u32,
70    ) -> Self {
71        let table_schema = child.schema().clone();
72        let txn_id = dml_manager.gen_txn_id();
73        Self {
74            table_id,
75            table_version_id,
76            dml_manager,
77            child,
78            chunk_size,
79            schema: table_schema,
80            identity,
81            column_indices,
82            sorted_default_columns,
83            row_id_index,
84            returning,
85            txn_id,
86            session_id,
87        }
88    }
89}
90
91impl Executor for InsertExecutor {
92    fn schema(&self) -> &Schema {
93        &self.schema
94    }
95
96    fn identity(&self) -> &str {
97        &self.identity
98    }
99
100    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
101        self.do_execute()
102    }
103}
104
105impl InsertExecutor {
106    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
107    async fn do_execute(self: Box<Self>) {
108        let data_types = self.child.schema().data_types();
109        let mut builder = DataChunkBuilder::new(data_types, 1024);
110
111        let table_dml_handle = self
112            .dml_manager
113            .table_dml_handle(self.table_id, self.table_version_id)?;
114        let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?;
115
116        write_handle.begin()?;
117
118        // Transform the data chunk to a stream chunk, then write to the source.
119        // Return the returning chunk.
120        let write_txn_data = |chunk: DataChunk| async {
121            let cap = chunk.capacity();
122            let (mut columns, vis) = chunk.into_parts();
123
124            let dummy_chunk = DataChunk::new_dummy(cap);
125
126            let mut ordered_columns = self
127                .column_indices
128                .iter()
129                .enumerate()
130                .map(|(i, idx)| (*idx, columns[i].clone()))
131                .collect_vec();
132
133            ordered_columns.reserve(ordered_columns.len() + self.sorted_default_columns.len());
134
135            for (idx, expr) in &self.sorted_default_columns {
136                let column = expr.eval(&dummy_chunk).await?;
137                ordered_columns.push((*idx, column));
138            }
139
140            ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
141            columns = ordered_columns
142                .into_iter()
143                .map(|(_, column)| column)
144                .collect_vec();
145
146            // Construct the returning chunk, without the `row_id` column.
147            let returning_chunk = DataChunk::new(columns.clone(), vis.clone());
148
149            // If the user does not specify the primary key, then we need to add a column as the
150            // primary key.
151            if let Some(row_id_index) = self.row_id_index {
152                let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
153                columns.insert(row_id_index, Arc::new(row_id_col.into()))
154            }
155
156            let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
157
158            #[cfg(debug_assertions)]
159            table_dml_handle.check_chunk_schema(&stream_chunk);
160
161            write_handle.write_chunk(stream_chunk).await?;
162
163            Result::Ok(returning_chunk)
164        };
165
166        let mut rows_inserted = 0;
167
168        #[for_await]
169        for data_chunk in self.child.execute() {
170            let data_chunk = data_chunk?;
171            for chunk in builder.append_chunk(data_chunk) {
172                let chunk = write_txn_data(chunk).await?;
173                rows_inserted += chunk.cardinality();
174                if self.returning {
175                    yield chunk;
176                }
177            }
178        }
179
180        if let Some(chunk) = builder.consume_all() {
181            let chunk = write_txn_data(chunk).await?;
182            rows_inserted += chunk.cardinality();
183            if self.returning {
184                yield chunk;
185            }
186        }
187
188        write_handle.end().await?;
189
190        // create ret value
191        if !self.returning {
192            let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
193            array_builder.append(Some(rows_inserted as i64));
194
195            let array = array_builder.finish();
196            let ret_chunk = DataChunk::new(vec![Arc::new(array.into())], 1);
197
198            yield ret_chunk
199        }
200    }
201}
202
203impl BoxedExecutorBuilder for InsertExecutor {
204    async fn new_boxed_executor(
205        source: &ExecutorBuilder<'_>,
206        inputs: Vec<BoxedExecutor>,
207    ) -> Result<BoxedExecutor> {
208        let [child]: [_; 1] = inputs.try_into().unwrap();
209
210        let insert_node = try_match_expand!(
211            source.plan_node().get_node_body().unwrap(),
212            NodeBody::Insert
213        )?;
214
215        let table_id = TableId::new(insert_node.table_id);
216        let column_indices = insert_node
217            .column_indices
218            .iter()
219            .map(|&i| i as usize)
220            .collect();
221        let sorted_default_columns = if let Some(default_columns) = &insert_node.default_columns {
222            let mut default_columns = default_columns
223                .get_default_columns()
224                .iter()
225                .cloned()
226                .map(|IndexAndExpr { index: i, expr: e }| {
227                    Ok((
228                        i as usize,
229                        build_from_prost(&e.context("expression is None")?)
230                            .context("failed to build expression")?,
231                    ))
232                })
233                .collect::<Result<Vec<_>>>()?;
234            default_columns.sort_unstable_by_key(|(i, _)| *i);
235            default_columns
236        } else {
237            vec![]
238        };
239
240        Ok(Box::new(Self::new(
241            table_id,
242            insert_node.table_version_id,
243            source.context().dml_manager(),
244            child,
245            source.context().get_config().developer.chunk_size,
246            source.plan_node().get_identity().clone(),
247            column_indices,
248            sorted_default_columns,
249            insert_node.row_id_index.as_ref().map(|index| *index as _),
250            insert_node.returning,
251            insert_node.session_id,
252        )))
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use std::ops::Bound;
259
260    use assert_matches::assert_matches;
261    use foyer::Hint;
262    use futures::StreamExt;
263    use risingwave_common::array::{Array, ArrayImpl, I32Array, StructArray};
264    use risingwave_common::catalog::{
265        ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, schema_test_utils,
266    };
267    use risingwave_common::transaction::transaction_message::TxnMsg;
268    use risingwave_common::types::{DataType, StructType};
269    use risingwave_dml::dml_manager::DmlManager;
270    use risingwave_storage::hummock::CachePolicy;
271    use risingwave_storage::hummock::test_utils::*;
272    use risingwave_storage::memory::MemoryStateStore;
273
274    use super::*;
275    use crate::executor::test_utils::MockExecutor;
276    use crate::*;
277
278    #[tokio::test]
279    async fn test_insert_executor() -> Result<()> {
280        let dml_manager = Arc::new(DmlManager::for_test());
281        let store = MemoryStateStore::new();
282
283        // Make struct field
284        let struct_field = Field::unnamed(
285            StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]).into(),
286        );
287
288        // Schema for mock executor.
289        let mut schema = schema_test_utils::ii();
290        schema.fields.push(struct_field.clone());
291        let mut mock_executor = MockExecutor::new(schema.clone());
292
293        // Schema of the table
294        let mut schema = schema_test_utils::ii();
295        schema.fields.push(struct_field);
296        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
297
298        let row_id_index = Some(3);
299
300        let col1 = Arc::new(I32Array::from_iter([1, 3, 5, 7, 9]).into());
301        let col2 = Arc::new(I32Array::from_iter([2, 4, 6, 8, 10]).into());
302        let array = StructArray::new(
303            StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
304            vec![
305                I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
306                I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
307                I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
308            ],
309            [true, false, false, false, false].into_iter().collect(),
310        );
311        let col3 = Arc::new(array.into());
312        let data_chunk: DataChunk = DataChunk::new(vec![col1, col2, col3], 5);
313        mock_executor.add(data_chunk.clone());
314
315        // Create the table.
316        let table_id = TableId::new(0);
317
318        // Create reader
319        let column_descs = schema
320            .fields
321            .iter()
322            .enumerate()
323            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
324            .collect_vec();
325        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
326        // due to the `Weak` reference in `DmlManager`.
327        let reader = dml_manager
328            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
329            .unwrap();
330        let mut reader = reader.stream_reader().into_stream();
331
332        // Insert
333        let insert_executor = Box::new(InsertExecutor::new(
334            table_id,
335            INITIAL_TABLE_VERSION_ID,
336            dml_manager,
337            Box::new(mock_executor),
338            1024,
339            "InsertExecutor".to_owned(),
340            vec![0, 1, 2], // Ignoring insertion order
341            vec![],
342            row_id_index,
343            false,
344            0,
345        ));
346        let handle = tokio::spawn(async move {
347            let mut stream = insert_executor.execute();
348            let result = stream.next().await.unwrap().unwrap();
349
350            assert_eq!(
351                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
352                vec![Some(5)] // inserted rows
353            );
354        });
355
356        // Read
357        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
358
359        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
360            assert_eq!(
361                chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
362                vec![Some(1), Some(3), Some(5), Some(7), Some(9)]
363            );
364
365            assert_eq!(
366                chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
367                vec![Some(2), Some(4), Some(6), Some(8), Some(10)]
368            );
369
370            let array: ArrayImpl = StructArray::new(
371                StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
372                vec![
373                    I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
374                    I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
375                    I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
376                ],
377                [true, false, false, false, false].into_iter().collect(),
378            )
379            .into();
380            assert_eq!(*chunk.columns()[2], array);
381        });
382
383        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
384        let epoch = u64::MAX;
385        let full_range = (Bound::Unbounded, Bound::Unbounded);
386        let store_content = store
387            .scan(
388                full_range,
389                epoch,
390                None,
391                ReadOptions {
392                    cache_policy: CachePolicy::Fill(Hint::Normal),
393                    ..Default::default()
394                },
395            )
396            .await?;
397        assert!(store_content.is_empty());
398
399        handle.await.unwrap();
400
401        Ok(())
402    }
403}