risingwave_batch/executor/
fast_insert.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::array::{DataChunk, Op, SerialArray, StreamChunk};
19use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
20use risingwave_common::transaction::transaction_id::TxnId;
21use risingwave_common::types::DataType;
22use risingwave_common::util::epoch::{Epoch, INVALID_EPOCH};
23use risingwave_dml::dml_manager::DmlManagerRef;
24use risingwave_pb::task_service::FastInsertRequest;
25
26use crate::error::Result;
27
28/// A fast insert executor spacially designed for non-pgwire inserts such as websockets and webhooks.
29pub struct FastInsertExecutor {
30    /// Target table id.
31    table_id: TableId,
32    table_version_id: TableVersionId,
33    dml_manager: DmlManagerRef,
34    column_indices: Vec<usize>,
35
36    row_id_index: Option<usize>,
37    txn_id: TxnId,
38    request_id: u32,
39}
40
41impl FastInsertExecutor {
42    pub fn build(
43        dml_manager: DmlManagerRef,
44        insert_req: FastInsertRequest,
45    ) -> Result<(FastInsertExecutor, DataChunk)> {
46        let table_id = TableId::new(insert_req.table_id);
47        let column_indices = insert_req
48            .column_indices
49            .iter()
50            .map(|&i| i as usize)
51            .collect();
52        let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]);
53        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
54        let data_chunk_pb = insert_req
55            .data_chunk
56            .expect("no data_chunk found in fast insert node");
57
58        Ok((
59            FastInsertExecutor::new(
60                table_id,
61                insert_req.table_version_id,
62                dml_manager,
63                column_indices,
64                insert_req.row_id_index.as_ref().map(|index| *index as _),
65                insert_req.request_id,
66            ),
67            DataChunk::from_protobuf(&data_chunk_pb)?,
68        ))
69    }
70
71    #[allow(clippy::too_many_arguments)]
72    pub fn new(
73        table_id: TableId,
74        table_version_id: TableVersionId,
75        dml_manager: DmlManagerRef,
76        column_indices: Vec<usize>,
77        row_id_index: Option<usize>,
78        request_id: u32,
79    ) -> Self {
80        let txn_id = dml_manager.gen_txn_id();
81        Self {
82            table_id,
83            table_version_id,
84            dml_manager,
85            column_indices,
86            row_id_index,
87            txn_id,
88            request_id,
89        }
90    }
91}
92
93impl FastInsertExecutor {
94    pub async fn do_execute(
95        self,
96        data_chunk_to_insert: DataChunk,
97        returning_epoch: bool,
98    ) -> Result<Epoch> {
99        let table_dml_handle = self
100            .dml_manager
101            .table_dml_handle(self.table_id, self.table_version_id)?;
102        // instead of session id, we use request id here to select a write handle.
103        let mut write_handle = table_dml_handle.write_handle(self.request_id, self.txn_id)?;
104
105        write_handle.begin()?;
106
107        // Transform the data chunk to a stream chunk, then write to the source.
108        // Return the returning chunk.
109        let write_txn_data = |chunk: DataChunk| async {
110            let cap = chunk.capacity();
111            let (mut columns, vis) = chunk.into_parts();
112
113            let mut ordered_columns = self
114                .column_indices
115                .iter()
116                .enumerate()
117                .map(|(i, idx)| (*idx, columns[i].clone()))
118                .collect_vec();
119
120            ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
121            columns = ordered_columns
122                .into_iter()
123                .map(|(_, column)| column)
124                .collect_vec();
125
126            // If the user does not specify the primary key, then we need to add a column as the
127            // primary key.
128            if let Some(row_id_index) = self.row_id_index {
129                let row_id_col = SerialArray::from_iter(std::iter::repeat_n(None, cap));
130                columns.insert(row_id_index, Arc::new(row_id_col.into()))
131            }
132
133            let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);
134
135            #[cfg(debug_assertions)]
136            table_dml_handle.check_chunk_schema(&stream_chunk);
137
138            write_handle.write_chunk(stream_chunk).await?;
139
140            Result::Ok(())
141        };
142        write_txn_data(data_chunk_to_insert).await?;
143        if returning_epoch {
144            write_handle.end_returning_epoch().await.map_err(Into::into)
145        } else {
146            write_handle.end().await?;
147            // the returned epoch is invalid and should not be used.
148            Ok(Epoch(INVALID_EPOCH))
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use std::collections::HashMap;
156    use std::ops::Bound;
157
158    use assert_matches::assert_matches;
159    use futures::StreamExt;
160    use risingwave_common::array::{Array, JsonbArrayBuilder};
161    use risingwave_common::catalog::{ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID};
162    use risingwave_common::transaction::transaction_message::TxnMsg;
163    use risingwave_common::types::JsonbVal;
164    use risingwave_dml::dml_manager::DmlManager;
165    use risingwave_storage::hummock::test_utils::StateStoreReadTestExt;
166    use risingwave_storage::memory::MemoryStateStore;
167    use risingwave_storage::store::ReadOptions;
168    use serde_json::json;
169
170    use super::*;
171    use crate::risingwave_common::array::ArrayBuilder;
172    use crate::risingwave_common::types::Scalar;
173    use crate::*;
174
175    #[tokio::test]
176    async fn test_fast_insert() -> Result<()> {
177        let epoch = Epoch::now();
178        let dml_manager = Arc::new(DmlManager::for_test());
179        let store = MemoryStateStore::new();
180        // Schema of the table
181        let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]);
182        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
183
184        let row_id_index = Some(1);
185
186        let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb);
187
188        let mut header_map = HashMap::new();
189        header_map.insert("data".to_owned(), "value1".to_owned());
190
191        let json_value = json!(header_map);
192        let jsonb_val = JsonbVal::from(json_value);
193        builder.append(Some(jsonb_val.as_scalar_ref()));
194
195        // Use builder to obtain a single (List) column DataChunk
196        let data_chunk = DataChunk::new(vec![builder.finish().into_ref()], 1);
197
198        // Create the table.
199        let table_id = TableId::new(0);
200
201        // Create reader
202        let column_descs = schema
203            .fields
204            .iter()
205            .enumerate()
206            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
207            .collect_vec();
208        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
209        // due to the `Weak` reference in `DmlManager`.
210        let reader = dml_manager
211            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
212            .unwrap();
213        let mut reader = reader.stream_reader().into_stream();
214
215        // Insert
216        let insert_executor = Box::new(FastInsertExecutor::new(
217            table_id,
218            INITIAL_TABLE_VERSION_ID,
219            dml_manager,
220            vec![0], // Ignoring insertion order
221            row_id_index,
222            0,
223        ));
224        let handle = tokio::spawn(async move {
225            let epoch_received = insert_executor.do_execute(data_chunk, true).await.unwrap();
226            assert_eq!(epoch, epoch_received);
227        });
228
229        // Read
230        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));
231
232        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
233            assert_eq!(chunk.columns().len(),2);
234            let array = chunk.columns()[0].as_jsonb().iter().collect::<Vec<_>>();
235            assert_eq!(JsonbVal::from(array[0].unwrap()), jsonb_val);
236        });
237
238        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(epoch_notifier)) => {
239            epoch_notifier.send(epoch).unwrap();
240        });
241        let epoch = u64::MAX;
242        let full_range = (Bound::Unbounded, Bound::Unbounded);
243        let store_content = store
244            .scan(full_range, epoch, None, ReadOptions::default())
245            .await?;
246        assert!(store_content.is_empty());
247
248        handle.await.unwrap();
249
250        Ok(())
251    }
252}