risingwave_batch_executors/executor/
insert.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use futures_async_stream::try_stream;
19use itertools::Itertools;
20use risingwave_common::array::{
21    ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, SerialArray, StreamChunk,
22};
23use risingwave_common::catalog::{Schema, TableId, TableVersionId};
24use risingwave_common::transaction::transaction_id::TxnId;
25use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
26use risingwave_dml::dml_manager::DmlManagerRef;
27use risingwave_expr::expr::{BoxedExpression, build_from_prost};
28use risingwave_pb::batch_plan::plan_node::NodeBody;
29use risingwave_pb::plan_common::IndexAndExpr;
30
31use crate::error::{BatchError, Result};
32use crate::executor::{
33    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
34};
35
36/// [`InsertExecutor`] implements table insertion with values from its child executor.
37pub struct InsertExecutor {
38    /// Target table id.
39    table_id: TableId,
40    table_version_id: TableVersionId,
41    dml_manager: DmlManagerRef,
42    child: BoxedExecutor,
43    chunk_size: usize,
44    schema: Schema,
45    identity: String,
46    column_indices: Vec<usize>,
47    sorted_default_columns: Vec<(usize, BoxedExpression)>,
48
49    row_id_index: Option<usize>,
50    returning: bool,
51    txn_id: TxnId,
52    session_id: u32,
53}
54
55impl InsertExecutor {
56    #[allow(clippy::too_many_arguments)]
57    pub fn new(
58        table_id: TableId,
59        table_version_id: TableVersionId,
60        dml_manager: DmlManagerRef,
61        child: BoxedExecutor,
62        chunk_size: usize,
63        identity: String,
64        column_indices: Vec<usize>,
65        sorted_default_columns: Vec<(usize, BoxedExpression)>,
66        row_id_index: Option<usize>,
67        returning: bool,
68        session_id: u32,
69    ) -> Self {
70        let table_schema = child.schema().clone();
71        let txn_id = dml_manager.gen_txn_id();
72        Self {
73            table_id,
74            table_version_id,
75            dml_manager,
76            child,
77            chunk_size,
78            schema: table_schema,
79            identity,
80            column_indices,
81            sorted_default_columns,
82            row_id_index,
83            returning,
84            txn_id,
85            session_id,
86        }
87    }
88}
89
90impl Executor for InsertExecutor {
91    fn schema(&self) -> &Schema {
92        &self.schema
93    }
94
95    fn identity(&self) -> &str {
96        &self.identity
97    }
98
99    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
100        self.do_execute()
101    }
102}
103
104impl InsertExecutor {
105    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
106    async fn do_execute(self: Box<Self>) {
107        let data_types = self.child.schema().data_types();
108        let mut builder = DataChunkBuilder::new(data_types, self.chunk_size);
109
110        let table_dml_handle = self
111            .dml_manager
112            .table_dml_handle(self.table_id, self.table_version_id)?;
113        let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?;
114
115        write_handle.begin()?;
116
117        // Transform the data chunk to a stream chunk, then write to the source.
118        // Return the returning chunk.
119        let write_txn_data = |chunk: DataChunk| async {
120            let cap = chunk.capacity();
121            let (mut columns, vis) = chunk.into_parts();
122
123            let dummy_chunk = DataChunk::new_dummy(cap);
124
125            let mut ordered_columns = self
126                .column_indices
127                .iter()
128                .enumerate()
129                .map(|(i, idx)| (*idx, columns[i].clone()))
130                .collect_vec();
131
132            ordered_columns.reserve(ordered_columns.len() + self.sorted_default_columns.len());
133
134            for (idx, expr) in &self.sorted_default_columns {
135                let column = expr.eval(&dummy_chunk).await?;
136                ordered_columns.push((*idx, column));
137            }
138
139            ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
140            columns = ordered_columns
141                .into_iter()
142                .map(|(_, column)| column)
143                .collect_vec();
144
145            // Construct the returning chunk, without the `row_id` column.
146            let returning_chunk = DataChunk::new(columns.clone(), vis.clone());
147
148            // If the user does not specify the primary key, then we need to add a column as the
149            // primary key.
150            if let Some(row_id_index) = self.row_id_index {
151                let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
152                columns.insert(row_id_index, Arc::new(row_id_col.into()))
153            }
154
155            let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
156
157            #[cfg(debug_assertions)]
158            table_dml_handle.check_chunk_schema(&stream_chunk);
159
160            write_handle.write_chunk(stream_chunk).await?;
161
162            Result::Ok(returning_chunk)
163        };
164
165        let mut rows_inserted = 0;
166
167        #[for_await]
168        for data_chunk in self.child.execute() {
169            let data_chunk = data_chunk?;
170            for chunk in builder.append_chunk(data_chunk) {
171                let chunk = write_txn_data(chunk).await?;
172                rows_inserted += chunk.cardinality();
173                if self.returning {
174                    yield chunk;
175                }
176            }
177        }
178
179        if let Some(chunk) = builder.consume_all() {
180            let chunk = write_txn_data(chunk).await?;
181            rows_inserted += chunk.cardinality();
182            if self.returning {
183                yield chunk;
184            }
185        }
186
187        write_handle.end().await?;
188
189        // create ret value
190        if !self.returning {
191            let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
192            array_builder.append(Some(rows_inserted as i64));
193
194            let array = array_builder.finish();
195            let ret_chunk = DataChunk::new(vec![Arc::new(array.into())], 1);
196
197            yield ret_chunk
198        }
199    }
200}
201
202impl BoxedExecutorBuilder for InsertExecutor {
203    async fn new_boxed_executor(
204        source: &ExecutorBuilder<'_>,
205        inputs: Vec<BoxedExecutor>,
206    ) -> Result<BoxedExecutor> {
207        let [child]: [_; 1] = inputs.try_into().unwrap();
208
209        let insert_node = try_match_expand!(
210            source.plan_node().get_node_body().unwrap(),
211            NodeBody::Insert
212        )?;
213
214        let table_id = TableId::new(insert_node.table_id);
215        let column_indices = insert_node
216            .column_indices
217            .iter()
218            .map(|&i| i as usize)
219            .collect();
220        let sorted_default_columns = if let Some(default_columns) = &insert_node.default_columns {
221            let mut default_columns = default_columns
222                .get_default_columns()
223                .iter()
224                .cloned()
225                .map(|IndexAndExpr { index: i, expr: e }| {
226                    Ok((
227                        i as usize,
228                        build_from_prost(&e.context("expression is None")?)
229                            .context("failed to build expression")?,
230                    ))
231                })
232                .collect::<Result<Vec<_>>>()?;
233            default_columns.sort_unstable_by_key(|(i, _)| *i);
234            default_columns
235        } else {
236            vec![]
237        };
238
239        Ok(Box::new(Self::new(
240            table_id,
241            insert_node.table_version_id,
242            source.context().dml_manager(),
243            child,
244            source.context().get_config().developer.chunk_size,
245            source.plan_node().get_identity().clone(),
246            column_indices,
247            sorted_default_columns,
248            insert_node.row_id_index.as_ref().map(|index| *index as _),
249            insert_node.returning,
250            insert_node.session_id,
251        )))
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use std::ops::Bound;
258
259    use assert_matches::assert_matches;
260    use foyer::Hint;
261    use futures::StreamExt;
262    use risingwave_common::array::{Array, ArrayImpl, I32Array, StructArray};
263    use risingwave_common::catalog::{
264        ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, schema_test_utils,
265    };
266    use risingwave_common::transaction::transaction_message::TxnMsg;
267    use risingwave_common::types::{DataType, StructType};
268    use risingwave_dml::dml_manager::DmlManager;
269    use risingwave_storage::hummock::CachePolicy;
270    use risingwave_storage::hummock::test_utils::*;
271    use risingwave_storage::memory::MemoryStateStore;
272
273    use super::*;
274    use crate::executor::test_utils::MockExecutor;
275    use crate::*;
276
277    #[tokio::test]
278    async fn test_insert_executor() -> Result<()> {
279        let dml_manager = Arc::new(DmlManager::for_test());
280        let store = MemoryStateStore::new();
281
282        // Make struct field
283        let struct_field = Field::unnamed(
284            StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]).into(),
285        );
286
287        // Schema for mock executor.
288        let mut schema = schema_test_utils::ii();
289        schema.fields.push(struct_field.clone());
290        let mut mock_executor = MockExecutor::new(schema.clone());
291
292        // Schema of the table
293        let mut schema = schema_test_utils::ii();
294        schema.fields.push(struct_field);
295        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
296
297        let row_id_index = Some(3);
298
299        let col1 = Arc::new(I32Array::from_iter([1, 3, 5, 7, 9]).into());
300        let col2 = Arc::new(I32Array::from_iter([2, 4, 6, 8, 10]).into());
301        let array = StructArray::new(
302            StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
303            vec![
304                I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
305                I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
306                I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
307            ],
308            [true, false, false, false, false].into_iter().collect(),
309        );
310        let col3 = Arc::new(array.into());
311        let data_chunk: DataChunk = DataChunk::new(vec![col1, col2, col3], 5);
312        mock_executor.add(data_chunk.clone());
313
314        // Create the table.
315        let table_id = TableId::new(0);
316
317        // Create reader
318        let column_descs = schema
319            .fields
320            .iter()
321            .enumerate()
322            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
323            .collect_vec();
324        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
325        // due to the `Weak` reference in `DmlManager`.
326        let reader = dml_manager
327            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
328            .unwrap();
329        let mut reader = reader.stream_reader().into_stream();
330
331        // Insert
332        let insert_executor = Box::new(InsertExecutor::new(
333            table_id,
334            INITIAL_TABLE_VERSION_ID,
335            dml_manager,
336            Box::new(mock_executor),
337            1024,
338            "InsertExecutor".to_owned(),
339            vec![0, 1, 2], // Ignoring insertion order
340            vec![],
341            row_id_index,
342            false,
343            0,
344        ));
345        let handle = tokio::spawn(async move {
346            let mut stream = insert_executor.execute();
347            let result = stream.next().await.unwrap().unwrap();
348
349            assert_eq!(
350                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
351                vec![Some(5)] // inserted rows
352            );
353        });
354
355        // Read
356        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
357
358        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
359            assert_eq!(
360                chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
361                vec![Some(1), Some(3), Some(5), Some(7), Some(9)]
362            );
363
364            assert_eq!(
365                chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
366                vec![Some(2), Some(4), Some(6), Some(8), Some(10)]
367            );
368
369            let array: ArrayImpl = StructArray::new(
370                StructType::unnamed(vec![DataType::Int32, DataType::Int32, DataType::Int32]),
371                vec![
372                    I32Array::from_iter([Some(1), None, None, None, None]).into_ref(),
373                    I32Array::from_iter([Some(2), None, None, None, None]).into_ref(),
374                    I32Array::from_iter([Some(3), None, None, None, None]).into_ref(),
375                ],
376                [true, false, false, false, false].into_iter().collect(),
377            )
378            .into();
379            assert_eq!(*chunk.columns()[2], array);
380        });
381
382        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..));
383        let epoch = u64::MAX;
384        let full_range = (Bound::Unbounded, Bound::Unbounded);
385        let store_content = store
386            .scan(
387                full_range,
388                epoch,
389                None,
390                ReadOptions {
391                    cache_policy: CachePolicy::Fill(Hint::Normal),
392                    ..Default::default()
393                },
394            )
395            .await?;
396        assert!(store_content.is_empty());
397
398        handle.await.unwrap();
399
400        Ok(())
401    }
402}