risingwave_batch/executor/
fast_insert.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::array::{DataChunk, Op, SerialArray, StreamChunk};
18use risingwave_common::catalog::{TableId, TableVersionId};
19use risingwave_common::transaction::transaction_id::TxnId;
20use risingwave_dml::dml_manager::DmlManagerRef;
21use risingwave_pb::task_service::FastInsertRequest;
22
23use crate::error::{BatchError, Result};
24
25/// A fast insert executor spacially designed for non-pgwire inserts such as websockets and webhooks.
26pub struct FastInsertExecutor {
27    /// Target table id.
28    table_id: TableId,
29    table_version_id: TableVersionId,
30    dml_manager: DmlManagerRef,
31
32    row_id_index: Option<usize>,
33    txn_id: TxnId,
34    request_id: u32,
35}
36
37impl FastInsertExecutor {
38    pub fn build(
39        dml_manager: DmlManagerRef,
40        insert_req: FastInsertRequest,
41    ) -> Result<(FastInsertExecutor, DataChunk)> {
42        let table_id = insert_req.table_id;
43        let data_chunk_pb = insert_req
44            .data_chunk
45            .expect("no data_chunk found in fast insert node");
46        let data_chunk = DataChunk::from_protobuf(&data_chunk_pb)?;
47
48        Ok((
49            FastInsertExecutor::new(
50                table_id,
51                insert_req.table_version_id,
52                dml_manager,
53                insert_req.row_id_index.as_ref().map(|index| *index as _),
54                insert_req.request_id,
55            ),
56            data_chunk,
57        ))
58    }
59
60    #[allow(clippy::too_many_arguments)]
61    pub fn new(
62        table_id: TableId,
63        table_version_id: TableVersionId,
64        dml_manager: DmlManagerRef,
65        row_id_index: Option<usize>,
66        request_id: u32,
67    ) -> Self {
68        let txn_id = dml_manager.gen_txn_id();
69        Self {
70            table_id,
71            table_version_id,
72            dml_manager,
73            row_id_index,
74            txn_id,
75            request_id,
76        }
77    }
78}
79
80impl FastInsertExecutor {
81    pub async fn do_execute(
82        self,
83        data_chunk_to_insert: DataChunk,
84        wait_for_persistence: bool,
85    ) -> Result<()> {
86        let table_dml_handle = self
87            .dml_manager
88            .table_dml_handle(self.table_id, self.table_version_id)?;
89        // instead of session id, we use request id here to select a write handle.
90        let mut write_handle = table_dml_handle.write_handle(self.request_id, self.txn_id)?;
91
92        write_handle.begin()?;
93
94        // Transform the data chunk to a stream chunk, then write to the source.
95        // Return the returning chunk.
96        let write_txn_data = |chunk: DataChunk| async {
97            let cap = chunk.capacity();
98            let (mut columns, vis) = chunk.into_parts();
99
100            // If the user does not specify the primary key, then we need to add a column as the
101            // primary key.
102            if let Some(row_id_index) = self.row_id_index {
103                let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
104                columns.insert(row_id_index, Arc::new(row_id_col.into()))
105            }
106
107            let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
108
109            #[cfg(debug_assertions)]
110            table_dml_handle.check_chunk_schema(&stream_chunk);
111
112            write_handle.write_chunk(stream_chunk).await?;
113
114            Result::Ok(())
115        };
116        write_txn_data(data_chunk_to_insert).await?;
117        if wait_for_persistence {
118            write_handle
119                .end_wait_persistence()
120                .map_err(BatchError::from)?
121                .await
122                .map_err(BatchError::from)
123        } else {
124            write_handle.end().await.map_err(Into::into)
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use std::collections::HashMap;
132
133    use assert_matches::assert_matches;
134    use futures::StreamExt;
135    use itertools::Itertools;
136    use risingwave_common::array::{Array, JsonbArrayBuilder};
137    use risingwave_common::catalog::{
138        ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, Schema,
139    };
140    use risingwave_common::transaction::transaction_message::TxnMsg;
141    use risingwave_common::types::{DataType, JsonbVal};
142    use risingwave_dml::dml_manager::DmlManager;
143    use serde_json::json;
144
145    use super::*;
146    use crate::risingwave_common::array::ArrayBuilder;
147    use crate::risingwave_common::types::Scalar;
148    use crate::*;
149
150    #[tokio::test]
151    async fn test_fast_insert() -> Result<()> {
152        let dml_manager = Arc::new(DmlManager::for_test());
153        // Schema of the table
154        let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]);
155        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
156
157        let row_id_index = Some(1);
158
159        let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb);
160
161        let mut header_map = HashMap::new();
162        header_map.insert("data".to_owned(), "value1".to_owned());
163
164        let json_value = json!(header_map);
165        let jsonb_val = JsonbVal::from(json_value);
166        builder.append(Some(jsonb_val.as_scalar_ref()));
167
168        // Use builder to obtain a single (List) column DataChunk
169        let data_chunk = DataChunk::new(vec![builder.finish().into_ref()], 1);
170
171        // Create the table.
172        let table_id = TableId::new(0);
173
174        // Create reader
175        let column_descs = schema
176            .fields
177            .iter()
178            .enumerate()
179            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
180            .collect_vec();
181        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
182        // due to the `Weak` reference in `DmlManager`.
183        let reader = dml_manager
184            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
185            .unwrap();
186        let mut reader = reader.stream_reader().into_stream();
187
188        // Insert
189        let insert_executor = Box::new(FastInsertExecutor::new(
190            table_id,
191            INITIAL_TABLE_VERSION_ID,
192            dml_manager,
193            row_id_index,
194            0,
195        ));
196        let handle = tokio::spawn(async move {
197            insert_executor.do_execute(data_chunk, true).await.unwrap();
198        });
199
200        // Read
201        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
202
203        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
204            assert_eq!(chunk.columns().len(), 2);
205            let array = chunk.columns()[0].as_jsonb().iter().collect::<Vec<_>>();
206            assert_eq!(JsonbVal::from(array[0].unwrap()), jsonb_val);
207        });
208
209        // Simulate the DmlExecutor: fire the persistence notifier after try_wait_epoch succeeds.
210        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(persistence_notifier)) => {
211            persistence_notifier.send(()).unwrap();
212        });
213
214        handle.await.unwrap();
215
216        Ok(())
217    }
218}