risingwave_batch/executor/
fast_insert.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::array::{DataChunk, Op, SerialArray, StreamChunk};
18use risingwave_common::catalog::{TableId, TableVersionId};
19use risingwave_common::transaction::transaction_id::TxnId;
20use risingwave_dml::dml_manager::DmlManagerRef;
21use risingwave_pb::task_service::FastInsertRequest;
22
23use crate::error::{BatchError, Result};
24
25pub(crate) fn inject_optional_row_id_column(
26    chunk: StreamChunk,
27    row_id_index: Option<usize>,
28) -> StreamChunk {
29    let Some(row_id_index) = row_id_index else {
30        return chunk;
31    };
32
33    let cap = chunk.data_chunk().capacity();
34    let (ops, mut columns, vis) = chunk.into_inner();
35    let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
36    columns.insert(row_id_index, Arc::new(row_id_col.into()));
37    StreamChunk::with_visibility(ops, columns, vis)
38}
39
40/// A fast insert executor spacially designed for non-pgwire inserts such as websockets and webhooks.
41pub struct FastInsertExecutor {
42    /// Target table id.
43    table_id: TableId,
44    table_version_id: TableVersionId,
45    dml_manager: DmlManagerRef,
46
47    row_id_index: Option<usize>,
48    txn_id: TxnId,
49    request_id: u32,
50}
51
52impl FastInsertExecutor {
53    pub fn build(
54        dml_manager: DmlManagerRef,
55        insert_req: FastInsertRequest,
56    ) -> Result<(FastInsertExecutor, DataChunk)> {
57        let table_id = insert_req.table_id;
58        let data_chunk_pb = insert_req
59            .data_chunk
60            .expect("no data_chunk found in fast insert node");
61        let data_chunk = DataChunk::from_protobuf(&data_chunk_pb)?;
62
63        Ok((
64            FastInsertExecutor::new(
65                table_id,
66                insert_req.table_version_id,
67                dml_manager,
68                insert_req.row_id_index.as_ref().map(|index| *index as _),
69                insert_req.request_id,
70            ),
71            data_chunk,
72        ))
73    }
74
75    pub fn new(
76        table_id: TableId,
77        table_version_id: TableVersionId,
78        dml_manager: DmlManagerRef,
79        row_id_index: Option<usize>,
80        request_id: u32,
81    ) -> Self {
82        let txn_id = dml_manager.gen_txn_id();
83        Self {
84            table_id,
85            table_version_id,
86            dml_manager,
87            row_id_index,
88            txn_id,
89            request_id,
90        }
91    }
92}
93
94impl FastInsertExecutor {
95    pub async fn do_execute(
96        self,
97        data_chunk_to_insert: DataChunk,
98        wait_for_persistence: bool,
99    ) -> Result<()> {
100        let table_dml_handle = self
101            .dml_manager
102            .table_dml_handle(self.table_id, self.table_version_id)?;
103        // instead of session id, we use request id here to select a write handle.
104        let mut write_handle = table_dml_handle.write_handle(self.request_id, self.txn_id)?;
105
106        write_handle.begin()?;
107
108        // Transform the data chunk to a stream chunk, then write to the source.
109        // Return the returning chunk.
110        let write_txn_data = |chunk: DataChunk| async {
111            let stream_chunk = inject_optional_row_id_column(
112                StreamChunk::from_parts(vec![Op::Insert; chunk.capacity()], chunk),
113                self.row_id_index,
114            );
115
116            #[cfg(debug_assertions)]
117            table_dml_handle.check_chunk_schema(&stream_chunk);
118
119            write_handle.write_chunk(stream_chunk).await?;
120
121            Result::Ok(())
122        };
123        write_txn_data(data_chunk_to_insert).await?;
124        if wait_for_persistence {
125            write_handle
126                .end_wait_persistence()
127                .map_err(BatchError::from)?
128                .await
129                .map_err(BatchError::from)
130        } else {
131            write_handle.end().await.map_err(Into::into)
132        }
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use std::collections::HashMap;
139
140    use assert_matches::assert_matches;
141    use futures::StreamExt;
142    use itertools::Itertools;
143    use risingwave_common::array::{Array, JsonbArrayBuilder};
144    use risingwave_common::catalog::{
145        ColumnDesc, ColumnId, Field, INITIAL_TABLE_VERSION_ID, Schema,
146    };
147    use risingwave_common::transaction::transaction_message::TxnMsg;
148    use risingwave_common::types::{DataType, JsonbVal};
149    use risingwave_dml::dml_manager::DmlManager;
150    use serde_json::json;
151
152    use super::*;
153    use crate::risingwave_common::array::ArrayBuilder;
154    use crate::risingwave_common::types::Scalar;
155    use crate::*;
156
157    #[tokio::test]
158    async fn test_fast_insert() -> Result<()> {
159        let dml_manager = Arc::new(DmlManager::for_test());
160        // Schema of the table
161        let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]);
162        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
163
164        let row_id_index = Some(1);
165
166        let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb);
167
168        let mut header_map = HashMap::new();
169        header_map.insert("data".to_owned(), "value1".to_owned());
170
171        let json_value = json!(header_map);
172        let jsonb_val = JsonbVal::from(json_value);
173        builder.append(Some(jsonb_val.as_scalar_ref()));
174
175        // Use builder to obtain a single (List) column DataChunk
176        let data_chunk = DataChunk::new(vec![builder.finish().into_ref()], 1);
177
178        // Create the table.
179        let table_id = TableId::new(0);
180
181        // Create reader
182        let column_descs = schema
183            .fields
184            .iter()
185            .enumerate()
186            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
187            .collect_vec();
188        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
189        // due to the `Weak` reference in `DmlManager`.
190        let reader = dml_manager
191            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
192            .unwrap();
193        let mut reader = reader.stream_reader().into_stream();
194
195        // Insert
196        let insert_executor = Box::new(FastInsertExecutor::new(
197            table_id,
198            INITIAL_TABLE_VERSION_ID,
199            dml_manager,
200            row_id_index,
201            0,
202        ));
203        let handle = tokio::spawn(async move {
204            insert_executor.do_execute(data_chunk, true).await.unwrap();
205        });
206
207        // Read
208        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
209
210        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
211            assert_eq!(chunk.columns().len(), 2);
212            let array = chunk.columns()[0].as_jsonb().iter().collect::<Vec<_>>();
213            assert_eq!(JsonbVal::from(array[0].unwrap()), jsonb_val);
214        });
215
216        // Simulate the DmlExecutor: fire the persistence notifier after try_wait_epoch succeeds.
217        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(persistence_notifier)) => {
218            persistence_notifier.send(()).unwrap();
219        });
220
221        handle.await.unwrap();
222
223        Ok(())
224    }
225}