risingwave_batch_executors/executor/
update.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use futures_async_stream::try_stream;
16use itertools::Itertools;
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::array::{
19    Array, ArrayBuilder, DataChunk, PrimitiveArrayBuilder, StreamChunk, StreamChunkBuilder,
20};
21use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::TxnId;
23use risingwave_common::types::DataType;
24use risingwave_common::util::iter_util::ZipEqDebug;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_expr::expr::{BoxedExpression, build_from_prost};
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34/// [`UpdateExecutor`] implements table update with values from its child executor and given
35/// expressions.
36// Note: multiple `UPDATE`s in a single epoch, or concurrent `UPDATE`s may lead to conflicting
37// records. This is validated and filtered on the first `Materialize`.
38pub struct UpdateExecutor {
39    /// Target table id.
40    table_id: TableId,
41    table_version_id: TableVersionId,
42    dml_manager: DmlManagerRef,
43    child: BoxedExecutor,
44    old_exprs: Vec<BoxedExpression>,
45    new_exprs: Vec<BoxedExpression>,
46    chunk_size: usize,
47    schema: Schema,
48    identity: String,
49    returning: bool,
50    txn_id: TxnId,
51    session_id: u32,
52    upsert: bool,
53    wait_for_persistence: bool,
54}
55
56impl UpdateExecutor {
57    #[expect(clippy::too_many_arguments)]
58    pub fn new(
59        table_id: TableId,
60        table_version_id: TableVersionId,
61        dml_manager: DmlManagerRef,
62        child: BoxedExecutor,
63        old_exprs: Vec<BoxedExpression>,
64        new_exprs: Vec<BoxedExpression>,
65        chunk_size: usize,
66        identity: String,
67        returning: bool,
68        session_id: u32,
69        upsert: bool,
70        wait_for_persistence: bool,
71    ) -> Self {
72        let chunk_size = chunk_size.next_multiple_of(2);
73        let table_schema = child.schema().clone();
74        let txn_id = dml_manager.gen_txn_id();
75
76        Self {
77            table_id,
78            table_version_id,
79            dml_manager,
80            child,
81            old_exprs,
82            new_exprs,
83            chunk_size,
84            schema: if returning {
85                table_schema
86            } else {
87                Schema {
88                    fields: vec![Field::unnamed(DataType::Int64)],
89                }
90            },
91            identity,
92            returning,
93            txn_id,
94            session_id,
95            upsert,
96            wait_for_persistence,
97        }
98    }
99}
100
101impl Executor for UpdateExecutor {
102    fn schema(&self) -> &Schema {
103        &self.schema
104    }
105
106    fn identity(&self) -> &str {
107        &self.identity
108    }
109
110    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
111        self.do_execute()
112    }
113}
114
115impl UpdateExecutor {
116    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
117    async fn do_execute(self: Box<Self>) {
118        let table_dml_handle = self
119            .dml_manager
120            .table_dml_handle(self.table_id, self.table_version_id)?;
121
122        let data_types = table_dml_handle
123            .column_descs()
124            .iter()
125            .map(|c| c.data_type.clone())
126            .collect_vec();
127
128        assert_eq!(
129            data_types,
130            self.new_exprs.iter().map(|e| e.return_type()).collect_vec(),
131            "bad update schema"
132        );
133        assert_eq!(
134            data_types,
135            self.old_exprs.iter().map(|e| e.return_type()).collect_vec(),
136            "bad update schema"
137        );
138
139        let mut builder = StreamChunkBuilder::new(self.chunk_size, data_types);
140
141        let mut write_handle: risingwave_dml::WriteHandle =
142            table_dml_handle.write_handle(self.session_id, self.txn_id)?;
143        write_handle.begin()?;
144
145        // Write to the source to the handle.
146        let write_txn_data = |chunk: StreamChunk| async {
147            if cfg!(debug_assertions) {
148                table_dml_handle.check_chunk_schema(&chunk);
149            }
150            write_handle.write_chunk(chunk).await
151        };
152
153        let mut rows_updated = 0;
154
155        #[for_await]
156        for input in self.child.execute() {
157            let input = input?;
158
159            let old_data_chunk = {
160                let mut columns = Vec::with_capacity(self.old_exprs.len());
161                for expr in &self.old_exprs {
162                    let column = expr.eval(&input).await?;
163                    columns.push(column);
164                }
165
166                DataChunk::new(columns, input.visibility().clone())
167            };
168
169            let updated_data_chunk = {
170                let mut columns = Vec::with_capacity(self.new_exprs.len());
171                for expr in &self.new_exprs {
172                    let column = expr.eval(&input).await?;
173                    columns.push(column);
174                }
175
176                DataChunk::new(columns, input.visibility().clone())
177            };
178
179            if self.returning {
180                yield updated_data_chunk.clone();
181            }
182
183            for (row_delete, row_insert) in
184                (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows())
185            {
186                rows_updated += 1;
187                // If row_delete == row_insert, we don't need to do an actual update.
188                if row_delete == row_insert {
189                    continue;
190                }
191                let chunk = if self.upsert {
192                    // In upsert mode, we only write the new row.
193                    builder.append_record(Record::Insert {
194                        new_row: row_insert,
195                    })
196                } else {
197                    // Note: we've banned updating the primary key when binding `UPDATE` statement.
198                    // So we can safely use `Update` op.
199                    builder.append_record(Record::Update {
200                        old_row: row_delete,
201                        new_row: row_insert,
202                    })
203                };
204
205                if let Some(chunk) = chunk {
206                    write_txn_data(chunk).await?;
207                }
208            }
209        }
210
211        if let Some(chunk) = builder.take() {
212            write_txn_data(chunk).await?;
213        }
214        if self.wait_for_persistence {
215            write_handle.end_wait_persistence()?.await?;
216        } else {
217            write_handle.end().await?;
218        }
219
220        // Create ret value
221        if !self.returning {
222            let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
223            array_builder.append(Some(rows_updated as i64));
224
225            let array = array_builder.finish();
226            let ret_chunk = DataChunk::new(vec![array.into_ref()], 1);
227
228            yield ret_chunk
229        }
230    }
231}
232
233impl BoxedExecutorBuilder for UpdateExecutor {
234    async fn new_boxed_executor(
235        source: &ExecutorBuilder<'_>,
236        inputs: Vec<BoxedExecutor>,
237    ) -> Result<BoxedExecutor> {
238        let [child]: [_; 1] = inputs.try_into().unwrap();
239
240        let update_node = try_match_expand!(
241            source.plan_node().get_node_body().unwrap(),
242            NodeBody::Update
243        )?;
244
245        let table_id = update_node.table_id;
246
247        let old_exprs: Vec<_> = update_node
248            .get_old_exprs()
249            .iter()
250            .map(build_from_prost)
251            .try_collect()?;
252
253        let new_exprs: Vec<_> = update_node
254            .get_new_exprs()
255            .iter()
256            .map(build_from_prost)
257            .try_collect()?;
258
259        Ok(Box::new(Self::new(
260            table_id,
261            update_node.table_version_id,
262            source.context().dml_manager(),
263            child,
264            old_exprs,
265            new_exprs,
266            source.context().get_config().developer.chunk_size,
267            source.plan_node().get_identity().clone(),
268            update_node.returning,
269            update_node.session_id,
270            update_node.upsert,
271            update_node.wait_for_persistence,
272        )))
273    }
274}
275
276#[cfg(test)]
277#[cfg(any())]
278mod tests {
279    use std::sync::Arc;
280
281    use futures::StreamExt;
282    use risingwave_common::catalog::{
283        ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID, schema_test_utils,
284    };
285    use risingwave_common::test_prelude::DataChunkTestExt;
286    use risingwave_dml::dml_manager::DmlManager;
287    use risingwave_expr::expr::InputRefExpression;
288
289    use super::*;
290    use crate::executor::test_utils::MockExecutor;
291    use crate::*;
292
293    #[tokio::test]
294    async fn test_update_executor() -> Result<()> {
295        let dml_manager = Arc::new(DmlManager::for_test());
296
297        // Schema for mock executor.
298        let schema = schema_test_utils::ii();
299        let mut mock_executor = MockExecutor::new(schema.clone());
300
301        // Schema of the table
302        let schema = schema_test_utils::ii();
303
304        mock_executor.add(DataChunk::from_pretty(
305            "i  i
306             1  2
307             3  4
308             5  6
309             7  8
310             9 10",
311        ));
312
313        // Update expressions, will swap two columns.
314        let exprs = vec![
315            Box::new(InputRefExpression::new(DataType::Int32, 1)) as BoxedExpression,
316            Box::new(InputRefExpression::new(DataType::Int32, 0)),
317        ];
318
319        // Create the table.
320        let table_id = TableId::new(0);
321
322        // Create reader
323        let column_descs = schema
324            .fields
325            .iter()
326            .enumerate()
327            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
328            .collect_vec();
329        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
330        // due to the `Weak` reference in `DmlManager`.
331        let reader = dml_manager
332            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
333            .unwrap();
334        let mut reader = reader.stream_reader().into_stream();
335
336        // Update
337        let update_executor = Box::new(UpdateExecutor::new(
338            table_id,
339            INITIAL_TABLE_VERSION_ID,
340            dml_manager,
341            Box::new(mock_executor),
342            exprs,
343            5,
344            "UpdateExecutor".to_string(),
345            false,
346            vec![0, 1],
347            0,
348        ));
349
350        let handle = tokio::spawn(async move {
351            let fields = &update_executor.schema().fields;
352            assert_eq!(fields[0].data_type, DataType::Int64);
353
354            let mut stream = update_executor.execute();
355            let result = stream.next().await.unwrap().unwrap();
356
357            assert_eq!(
358                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
359                vec![Some(5)] // updated rows
360            );
361        });
362
363        reader.next().await.unwrap()?.into_begin().unwrap();
364
365        // Read
366        // As we set the chunk size to 5, we'll get 2 chunks. Note that the update records for one
367        // row cannot be cut into two chunks, so the first chunk will actually have 6 rows.
368        for updated_rows in [1..=3, 4..=5] {
369            let txn_msg = reader.next().await.unwrap()?;
370            let chunk = txn_msg.as_stream_chunk().unwrap();
371            assert_eq!(
372                chunk.ops().chunks(2).collect_vec(),
373                vec![&[Op::UpdateDelete, Op::UpdateInsert]; updated_rows.clone().count()]
374            );
375
376            assert_eq!(
377                chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
378                updated_rows
379                    .clone()
380                    .flat_map(|i| [i * 2 - 1, i * 2]) // -1, +2, -3, +4, ...
381                    .map(Some)
382                    .collect_vec()
383            );
384
385            assert_eq!(
386                chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
387                updated_rows
388                    .clone()
389                    .flat_map(|i| [i * 2, i * 2 - 1]) // -2, +1, -4, +3, ...
390                    .map(Some)
391                    .collect_vec()
392            );
393        }
394
395        reader.next().await.unwrap()?.into_end().unwrap();
396
397        handle.await.unwrap();
398
399        Ok(())
400    }
401}