risingwave_batch_executors/executor/
update.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use futures_async_stream::try_stream;
16use itertools::Itertools;
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::array::{
19    Array, ArrayBuilder, DataChunk, PrimitiveArrayBuilder, StreamChunk, StreamChunkBuilder,
20};
21use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::TxnId;
23use risingwave_common::types::DataType;
24use risingwave_common::util::iter_util::ZipEqDebug;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_expr::expr::{BoxedExpression, build_from_prost};
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34/// [`UpdateExecutor`] implements table update with values from its child executor and given
35/// expressions.
36// Note: multiple `UPDATE`s in a single epoch, or concurrent `UPDATE`s may lead to conflicting
37// records. This is validated and filtered on the first `Materialize`.
38pub struct UpdateExecutor {
39    /// Target table id.
40    table_id: TableId,
41    table_version_id: TableVersionId,
42    dml_manager: DmlManagerRef,
43    child: BoxedExecutor,
44    old_exprs: Vec<BoxedExpression>,
45    new_exprs: Vec<BoxedExpression>,
46    chunk_size: usize,
47    schema: Schema,
48    identity: String,
49    returning: bool,
50    txn_id: TxnId,
51    session_id: u32,
52    upsert: bool,
53}
54
55impl UpdateExecutor {
56    #[allow(clippy::too_many_arguments)]
57    pub fn new(
58        table_id: TableId,
59        table_version_id: TableVersionId,
60        dml_manager: DmlManagerRef,
61        child: BoxedExecutor,
62        old_exprs: Vec<BoxedExpression>,
63        new_exprs: Vec<BoxedExpression>,
64        chunk_size: usize,
65        identity: String,
66        returning: bool,
67        session_id: u32,
68        upsert: bool,
69    ) -> Self {
70        let chunk_size = chunk_size.next_multiple_of(2);
71        let table_schema = child.schema().clone();
72        let txn_id = dml_manager.gen_txn_id();
73
74        Self {
75            table_id,
76            table_version_id,
77            dml_manager,
78            child,
79            old_exprs,
80            new_exprs,
81            chunk_size,
82            schema: if returning {
83                table_schema
84            } else {
85                Schema {
86                    fields: vec![Field::unnamed(DataType::Int64)],
87                }
88            },
89            identity,
90            returning,
91            txn_id,
92            session_id,
93            upsert,
94        }
95    }
96}
97
98impl Executor for UpdateExecutor {
99    fn schema(&self) -> &Schema {
100        &self.schema
101    }
102
103    fn identity(&self) -> &str {
104        &self.identity
105    }
106
107    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
108        self.do_execute()
109    }
110}
111
112impl UpdateExecutor {
113    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
114    async fn do_execute(self: Box<Self>) {
115        let table_dml_handle = self
116            .dml_manager
117            .table_dml_handle(self.table_id, self.table_version_id)?;
118
119        let data_types = table_dml_handle
120            .column_descs()
121            .iter()
122            .map(|c| c.data_type.clone())
123            .collect_vec();
124
125        assert_eq!(
126            data_types,
127            self.new_exprs.iter().map(|e| e.return_type()).collect_vec(),
128            "bad update schema"
129        );
130        assert_eq!(
131            data_types,
132            self.old_exprs.iter().map(|e| e.return_type()).collect_vec(),
133            "bad update schema"
134        );
135
136        let mut builder = StreamChunkBuilder::new(self.chunk_size, data_types);
137
138        let mut write_handle: risingwave_dml::WriteHandle =
139            table_dml_handle.write_handle(self.session_id, self.txn_id)?;
140        write_handle.begin()?;
141
142        // Write to the source to the handle.
143        let write_txn_data = |chunk: StreamChunk| async {
144            if cfg!(debug_assertions) {
145                table_dml_handle.check_chunk_schema(&chunk);
146            }
147            write_handle.write_chunk(chunk).await
148        };
149
150        let mut rows_updated = 0;
151
152        #[for_await]
153        for input in self.child.execute() {
154            let input = input?;
155
156            let old_data_chunk = {
157                let mut columns = Vec::with_capacity(self.old_exprs.len());
158                for expr in &self.old_exprs {
159                    let column = expr.eval(&input).await?;
160                    columns.push(column);
161                }
162
163                DataChunk::new(columns, input.visibility().clone())
164            };
165
166            let updated_data_chunk = {
167                let mut columns = Vec::with_capacity(self.new_exprs.len());
168                for expr in &self.new_exprs {
169                    let column = expr.eval(&input).await?;
170                    columns.push(column);
171                }
172
173                DataChunk::new(columns, input.visibility().clone())
174            };
175
176            if self.returning {
177                yield updated_data_chunk.clone();
178            }
179
180            for (row_delete, row_insert) in
181                (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows())
182            {
183                rows_updated += 1;
184                // If row_delete == row_insert, we don't need to do an actual update.
185                if row_delete == row_insert {
186                    continue;
187                }
188                let chunk = if self.upsert {
189                    // In upsert mode, we only write the new row.
190                    builder.append_record(Record::Insert {
191                        new_row: row_insert,
192                    })
193                } else {
194                    // Note: we've banned updating the primary key when binding `UPDATE` statement.
195                    // So we can safely use `Update` op.
196                    builder.append_record(Record::Update {
197                        old_row: row_delete,
198                        new_row: row_insert,
199                    })
200                };
201
202                if let Some(chunk) = chunk {
203                    write_txn_data(chunk).await?;
204                }
205            }
206        }
207
208        if let Some(chunk) = builder.take() {
209            write_txn_data(chunk).await?;
210        }
211        write_handle.end().await?;
212
213        // Create ret value
214        if !self.returning {
215            let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
216            array_builder.append(Some(rows_updated as i64));
217
218            let array = array_builder.finish();
219            let ret_chunk = DataChunk::new(vec![array.into_ref()], 1);
220
221            yield ret_chunk
222        }
223    }
224}
225
226impl BoxedExecutorBuilder for UpdateExecutor {
227    async fn new_boxed_executor(
228        source: &ExecutorBuilder<'_>,
229        inputs: Vec<BoxedExecutor>,
230    ) -> Result<BoxedExecutor> {
231        let [child]: [_; 1] = inputs.try_into().unwrap();
232
233        let update_node = try_match_expand!(
234            source.plan_node().get_node_body().unwrap(),
235            NodeBody::Update
236        )?;
237
238        let table_id = update_node.table_id;
239
240        let old_exprs: Vec<_> = update_node
241            .get_old_exprs()
242            .iter()
243            .map(build_from_prost)
244            .try_collect()?;
245
246        let new_exprs: Vec<_> = update_node
247            .get_new_exprs()
248            .iter()
249            .map(build_from_prost)
250            .try_collect()?;
251
252        Ok(Box::new(Self::new(
253            table_id,
254            update_node.table_version_id,
255            source.context().dml_manager(),
256            child,
257            old_exprs,
258            new_exprs,
259            source.context().get_config().developer.chunk_size,
260            source.plan_node().get_identity().clone(),
261            update_node.returning,
262            update_node.session_id,
263            update_node.upsert,
264        )))
265    }
266}
267
268#[cfg(test)]
269#[cfg(any())]
270mod tests {
271    use std::sync::Arc;
272
273    use futures::StreamExt;
274    use risingwave_common::catalog::{
275        ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID, schema_test_utils,
276    };
277    use risingwave_common::test_prelude::DataChunkTestExt;
278    use risingwave_dml::dml_manager::DmlManager;
279    use risingwave_expr::expr::InputRefExpression;
280
281    use super::*;
282    use crate::executor::test_utils::MockExecutor;
283    use crate::*;
284
285    #[tokio::test]
286    async fn test_update_executor() -> Result<()> {
287        let dml_manager = Arc::new(DmlManager::for_test());
288
289        // Schema for mock executor.
290        let schema = schema_test_utils::ii();
291        let mut mock_executor = MockExecutor::new(schema.clone());
292
293        // Schema of the table
294        let schema = schema_test_utils::ii();
295
296        mock_executor.add(DataChunk::from_pretty(
297            "i  i
298             1  2
299             3  4
300             5  6
301             7  8
302             9 10",
303        ));
304
305        // Update expressions, will swap two columns.
306        let exprs = vec![
307            Box::new(InputRefExpression::new(DataType::Int32, 1)) as BoxedExpression,
308            Box::new(InputRefExpression::new(DataType::Int32, 0)),
309        ];
310
311        // Create the table.
312        let table_id = TableId::new(0);
313
314        // Create reader
315        let column_descs = schema
316            .fields
317            .iter()
318            .enumerate()
319            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
320            .collect_vec();
321        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
322        // due to the `Weak` reference in `DmlManager`.
323        let reader = dml_manager
324            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
325            .unwrap();
326        let mut reader = reader.stream_reader().into_stream();
327
328        // Update
329        let update_executor = Box::new(UpdateExecutor::new(
330            table_id,
331            INITIAL_TABLE_VERSION_ID,
332            dml_manager,
333            Box::new(mock_executor),
334            exprs,
335            5,
336            "UpdateExecutor".to_string(),
337            false,
338            vec![0, 1],
339            0,
340        ));
341
342        let handle = tokio::spawn(async move {
343            let fields = &update_executor.schema().fields;
344            assert_eq!(fields[0].data_type, DataType::Int64);
345
346            let mut stream = update_executor.execute();
347            let result = stream.next().await.unwrap().unwrap();
348
349            assert_eq!(
350                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
351                vec![Some(5)] // updated rows
352            );
353        });
354
355        reader.next().await.unwrap()?.into_begin().unwrap();
356
357        // Read
358        // As we set the chunk size to 5, we'll get 2 chunks. Note that the update records for one
359        // row cannot be cut into two chunks, so the first chunk will actually have 6 rows.
360        for updated_rows in [1..=3, 4..=5] {
361            let txn_msg = reader.next().await.unwrap()?;
362            let chunk = txn_msg.as_stream_chunk().unwrap();
363            assert_eq!(
364                chunk.ops().chunks(2).collect_vec(),
365                vec![&[Op::UpdateDelete, Op::UpdateInsert]; updated_rows.clone().count()]
366            );
367
368            assert_eq!(
369                chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
370                updated_rows
371                    .clone()
372                    .flat_map(|i| [i * 2 - 1, i * 2]) // -1, +2, -3, +4, ...
373                    .map(Some)
374                    .collect_vec()
375            );
376
377            assert_eq!(
378                chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
379                updated_rows
380                    .clone()
381                    .flat_map(|i| [i * 2, i * 2 - 1]) // -2, +1, -4, +3, ...
382                    .map(Some)
383                    .collect_vec()
384            );
385        }
386
387        reader.next().await.unwrap()?.into_end().unwrap();
388
389        handle.await.unwrap();
390
391        Ok(())
392    }
393}