risingwave_batch/executor/
fast_insert.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
// Copyright 2025 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::iter::repeat;
use std::sync::Arc;

use itertools::Itertools;
use risingwave_common::array::{DataChunk, Op, SerialArray, StreamChunk};
use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
use risingwave_common::transaction::transaction_id::TxnId;
use risingwave_common::types::DataType;
use risingwave_common::util::epoch::{Epoch, INVALID_EPOCH};
use risingwave_dml::dml_manager::DmlManagerRef;
use risingwave_pb::task_service::FastInsertRequest;

use crate::error::Result;

/// A fast insert executor spacially designed for non-pgwire inserts such as websockets and webhooks.
pub struct FastInsertExecutor {
    /// Target table id.
    table_id: TableId,
    table_version_id: TableVersionId,
    dml_manager: DmlManagerRef,
    column_indices: Vec<usize>,

    row_id_index: Option<usize>,
    txn_id: TxnId,
    request_id: u32,
}

impl FastInsertExecutor {
    pub fn build(
        dml_manager: DmlManagerRef,
        insert_req: FastInsertRequest,
    ) -> Result<(FastInsertExecutor, DataChunk)> {
        let table_id = TableId::new(insert_req.table_id);
        let column_indices = insert_req
            .column_indices
            .iter()
            .map(|&i| i as usize)
            .collect();
        let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]);
        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
        let data_chunk_pb = insert_req
            .data_chunk
            .expect("no data_chunk found in fast insert node");

        Ok((
            FastInsertExecutor::new(
                table_id,
                insert_req.table_version_id,
                dml_manager,
                column_indices,
                insert_req.row_id_index.as_ref().map(|index| *index as _),
                insert_req.request_id,
            ),
            DataChunk::from_protobuf(&data_chunk_pb)?,
        ))
    }

    #[allow(clippy::too_many_arguments)]
    pub fn new(
        table_id: TableId,
        table_version_id: TableVersionId,
        dml_manager: DmlManagerRef,
        column_indices: Vec<usize>,
        row_id_index: Option<usize>,
        request_id: u32,
    ) -> Self {
        let txn_id = dml_manager.gen_txn_id();
        Self {
            table_id,
            table_version_id,
            dml_manager,
            column_indices,
            row_id_index,
            txn_id,
            request_id,
        }
    }
}

impl FastInsertExecutor {
    pub async fn do_execute(
        self,
        data_chunk_to_insert: DataChunk,
        returning_epoch: bool,
    ) -> Result<Epoch> {
        let table_dml_handle = self
            .dml_manager
            .table_dml_handle(self.table_id, self.table_version_id)?;
        // instead of session id, we use request id here to select a write handle.
        let mut write_handle = table_dml_handle.write_handle(self.request_id, self.txn_id)?;

        write_handle.begin()?;

        // Transform the data chunk to a stream chunk, then write to the source.
        // Return the returning chunk.
        let write_txn_data = |chunk: DataChunk| async {
            let cap = chunk.capacity();
            let (mut columns, vis) = chunk.into_parts();

            let mut ordered_columns = self
                .column_indices
                .iter()
                .enumerate()
                .map(|(i, idx)| (*idx, columns[i].clone()))
                .collect_vec();

            ordered_columns.sort_unstable_by_key(|(idx, _)| *idx);
            columns = ordered_columns
                .into_iter()
                .map(|(_, column)| column)
                .collect_vec();

            // If the user does not specify the primary key, then we need to add a column as the
            // primary key.
            if let Some(row_id_index) = self.row_id_index {
                let row_id_col = SerialArray::from_iter(repeat(None).take(cap));
                columns.insert(row_id_index, Arc::new(row_id_col.into()))
            }

            let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis);

            #[cfg(debug_assertions)]
            table_dml_handle.check_chunk_schema(&stream_chunk);

            write_handle.write_chunk(stream_chunk).await?;

            Result::Ok(())
        };
        write_txn_data(data_chunk_to_insert).await?;
        if returning_epoch {
            write_handle.end_returning_epoch().await.map_err(Into::into)
        } else {
            write_handle.end().await?;
            // the returned epoch is invalid and should not be used.
            Ok(Epoch(INVALID_EPOCH))
        }
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;
    use std::ops::Bound;

    use assert_matches::assert_matches;
    use futures::StreamExt;
    use risingwave_common::array::{Array, JsonbArrayBuilder};
    use risingwave_common::catalog::{ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID};
    use risingwave_common::transaction::transaction_message::TxnMsg;
    use risingwave_common::types::JsonbVal;
    use risingwave_dml::dml_manager::DmlManager;
    use risingwave_storage::memory::MemoryStateStore;
    use risingwave_storage::store::{ReadOptions, StateStoreReadExt};
    use serde_json::json;

    use super::*;
    use crate::risingwave_common::array::ArrayBuilder;
    use crate::risingwave_common::types::Scalar;
    use crate::*;

    #[tokio::test]
    async fn test_fast_insert() -> Result<()> {
        let epoch = Epoch::now();
        let dml_manager = Arc::new(DmlManager::for_test());
        let store = MemoryStateStore::new();
        // Schema of the table
        let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]);
        schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column

        let row_id_index = Some(1);

        let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb);

        let mut header_map = HashMap::new();
        header_map.insert("data".to_owned(), "value1".to_owned());

        let json_value = json!(header_map);
        let jsonb_val = JsonbVal::from(json_value);
        builder.append(Some(jsonb_val.as_scalar_ref()));

        // Use builder to obtain a single (List) column DataChunk
        let data_chunk = DataChunk::new(vec![builder.finish().into_ref()], 1);

        // Create the table.
        let table_id = TableId::new(0);

        // Create reader
        let column_descs = schema
            .fields
            .iter()
            .enumerate()
            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
            .collect_vec();
        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
        // due to the `Weak` reference in `DmlManager`.
        let reader = dml_manager
            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
            .unwrap();
        let mut reader = reader.stream_reader().into_stream();

        // Insert
        let insert_executor = Box::new(FastInsertExecutor::new(
            table_id,
            INITIAL_TABLE_VERSION_ID,
            dml_manager,
            vec![0], // Ignoring insertion order
            row_id_index,
            0,
        ));
        let handle = tokio::spawn(async move {
            let epoch_received = insert_executor.do_execute(data_chunk, true).await.unwrap();
            assert_eq!(epoch, epoch_received);
        });

        // Read
        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_));

        assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => {
            assert_eq!(chunk.columns().len(),2);
            let array = chunk.columns()[0].as_jsonb().iter().collect::<Vec<_>>();
            assert_eq!(JsonbVal::from(array[0].unwrap()), jsonb_val);
        });

        assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(epoch_notifier)) => {
            epoch_notifier.send(epoch).unwrap();
        });
        let epoch = u64::MAX;
        let full_range = (Bound::Unbounded, Bound::Unbounded);
        let store_content = store
            .scan(full_range, epoch, None, ReadOptions::default())
            .await?;
        assert!(store_content.is_empty());

        handle.await.unwrap();

        Ok(())
    }
}