risingwave_batch_executors/executor/
update.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use futures_async_stream::try_stream;
16use itertools::Itertools;
17use risingwave_common::array::{
18    Array, ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, StreamChunk,
19};
20use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
21use risingwave_common::transaction::transaction_id::TxnId;
22use risingwave_common::types::DataType;
23use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
24use risingwave_common::util::iter_util::ZipEqDebug;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_expr::expr::{BoxedExpression, build_from_prost};
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34/// [`UpdateExecutor`] implements table update with values from its child executor and given
35/// expressions.
36// Note: multiple `UPDATE`s in a single epoch, or concurrent `UPDATE`s may lead to conflicting
37// records. This is validated and filtered on the first `Materialize`.
38pub struct UpdateExecutor {
39    /// Target table id.
40    table_id: TableId,
41    table_version_id: TableVersionId,
42    dml_manager: DmlManagerRef,
43    child: BoxedExecutor,
44    old_exprs: Vec<BoxedExpression>,
45    new_exprs: Vec<BoxedExpression>,
46    chunk_size: usize,
47    schema: Schema,
48    identity: String,
49    returning: bool,
50    txn_id: TxnId,
51    session_id: u32,
52}
53
54impl UpdateExecutor {
55    #[allow(clippy::too_many_arguments)]
56    pub fn new(
57        table_id: TableId,
58        table_version_id: TableVersionId,
59        dml_manager: DmlManagerRef,
60        child: BoxedExecutor,
61        old_exprs: Vec<BoxedExpression>,
62        new_exprs: Vec<BoxedExpression>,
63        chunk_size: usize,
64        identity: String,
65        returning: bool,
66        session_id: u32,
67    ) -> Self {
68        let chunk_size = chunk_size.next_multiple_of(2);
69        let table_schema = child.schema().clone();
70        let txn_id = dml_manager.gen_txn_id();
71
72        Self {
73            table_id,
74            table_version_id,
75            dml_manager,
76            child,
77            old_exprs,
78            new_exprs,
79            chunk_size,
80            schema: if returning {
81                table_schema
82            } else {
83                Schema {
84                    fields: vec![Field::unnamed(DataType::Int64)],
85                }
86            },
87            identity,
88            returning,
89            txn_id,
90            session_id,
91        }
92    }
93}
94
95impl Executor for UpdateExecutor {
96    fn schema(&self) -> &Schema {
97        &self.schema
98    }
99
100    fn identity(&self) -> &str {
101        &self.identity
102    }
103
104    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
105        self.do_execute()
106    }
107}
108
109impl UpdateExecutor {
110    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
111    async fn do_execute(self: Box<Self>) {
112        let table_dml_handle = self
113            .dml_manager
114            .table_dml_handle(self.table_id, self.table_version_id)?;
115
116        let data_types = table_dml_handle
117            .column_descs()
118            .iter()
119            .map(|c| c.data_type.clone())
120            .collect_vec();
121
122        assert_eq!(
123            data_types,
124            self.new_exprs.iter().map(|e| e.return_type()).collect_vec(),
125            "bad update schema"
126        );
127        assert_eq!(
128            data_types,
129            self.old_exprs.iter().map(|e| e.return_type()).collect_vec(),
130            "bad update schema"
131        );
132
133        let mut builder = DataChunkBuilder::new(data_types, self.chunk_size);
134
135        let mut write_handle: risingwave_dml::WriteHandle =
136            table_dml_handle.write_handle(self.session_id, self.txn_id)?;
137        write_handle.begin()?;
138
139        // Transform the data chunk to a stream chunk, then write to the source.
140        let write_txn_data = |chunk: DataChunk| async {
141            // TODO: if the primary key is updated, we should use plain `+,-` instead of `U+,U-`.
142            let ops = [Op::UpdateDelete, Op::UpdateInsert]
143                .into_iter()
144                .cycle()
145                .take(chunk.capacity())
146                .collect_vec();
147            let stream_chunk = StreamChunk::from_parts(ops, chunk);
148
149            #[cfg(debug_assertions)]
150            table_dml_handle.check_chunk_schema(&stream_chunk);
151
152            write_handle.write_chunk(stream_chunk).await
153        };
154
155        let mut rows_updated = 0;
156
157        #[for_await]
158        for input in self.child.execute() {
159            let input = input?;
160
161            let old_data_chunk = {
162                let mut columns = Vec::with_capacity(self.old_exprs.len());
163                for expr in &self.old_exprs {
164                    let column = expr.eval(&input).await?;
165                    columns.push(column);
166                }
167
168                DataChunk::new(columns, input.visibility().clone())
169            };
170
171            let updated_data_chunk = {
172                let mut columns = Vec::with_capacity(self.new_exprs.len());
173                for expr in &self.new_exprs {
174                    let column = expr.eval(&input).await?;
175                    columns.push(column);
176                }
177
178                DataChunk::new(columns, input.visibility().clone())
179            };
180
181            if self.returning {
182                yield updated_data_chunk.clone();
183            }
184
185            for (row_delete, row_insert) in
186                (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows())
187            {
188                rows_updated += 1;
189                // If row_delete == row_insert, we don't need to do a actual update
190                if row_delete != row_insert {
191                    let None = builder.append_one_row(row_delete) else {
192                        unreachable!(
193                            "no chunk should be yielded when appending the deleted row as the chunk size is always even"
194                        );
195                    };
196                    if let Some(chunk) = builder.append_one_row(row_insert) {
197                        write_txn_data(chunk).await?;
198                    }
199                }
200            }
201        }
202
203        if let Some(chunk) = builder.consume_all() {
204            write_txn_data(chunk).await?;
205        }
206        write_handle.end().await?;
207
208        // Create ret value
209        if !self.returning {
210            let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
211            array_builder.append(Some(rows_updated as i64));
212
213            let array = array_builder.finish();
214            let ret_chunk = DataChunk::new(vec![array.into_ref()], 1);
215
216            yield ret_chunk
217        }
218    }
219}
220
221impl BoxedExecutorBuilder for UpdateExecutor {
222    async fn new_boxed_executor(
223        source: &ExecutorBuilder<'_>,
224        inputs: Vec<BoxedExecutor>,
225    ) -> Result<BoxedExecutor> {
226        let [child]: [_; 1] = inputs.try_into().unwrap();
227
228        let update_node = try_match_expand!(
229            source.plan_node().get_node_body().unwrap(),
230            NodeBody::Update
231        )?;
232
233        let table_id = TableId::new(update_node.table_id);
234
235        let old_exprs: Vec<_> = update_node
236            .get_old_exprs()
237            .iter()
238            .map(build_from_prost)
239            .try_collect()?;
240
241        let new_exprs: Vec<_> = update_node
242            .get_new_exprs()
243            .iter()
244            .map(build_from_prost)
245            .try_collect()?;
246
247        Ok(Box::new(Self::new(
248            table_id,
249            update_node.table_version_id,
250            source.context().dml_manager(),
251            child,
252            old_exprs,
253            new_exprs,
254            source.context().get_config().developer.chunk_size,
255            source.plan_node().get_identity().clone(),
256            update_node.returning,
257            update_node.session_id,
258        )))
259    }
260}
261
262#[cfg(test)]
263#[cfg(any())]
264mod tests {
265    use std::sync::Arc;
266
267    use futures::StreamExt;
268    use risingwave_common::catalog::{
269        ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID, schema_test_utils,
270    };
271    use risingwave_common::test_prelude::DataChunkTestExt;
272    use risingwave_dml::dml_manager::DmlManager;
273    use risingwave_expr::expr::InputRefExpression;
274
275    use super::*;
276    use crate::executor::test_utils::MockExecutor;
277    use crate::*;
278
279    #[tokio::test]
280    async fn test_update_executor() -> Result<()> {
281        let dml_manager = Arc::new(DmlManager::for_test());
282
283        // Schema for mock executor.
284        let schema = schema_test_utils::ii();
285        let mut mock_executor = MockExecutor::new(schema.clone());
286
287        // Schema of the table
288        let schema = schema_test_utils::ii();
289
290        mock_executor.add(DataChunk::from_pretty(
291            "i  i
292             1  2
293             3  4
294             5  6
295             7  8
296             9 10",
297        ));
298
299        // Update expressions, will swap two columns.
300        let exprs = vec![
301            Box::new(InputRefExpression::new(DataType::Int32, 1)) as BoxedExpression,
302            Box::new(InputRefExpression::new(DataType::Int32, 0)),
303        ];
304
305        // Create the table.
306        let table_id = TableId::new(0);
307
308        // Create reader
309        let column_descs = schema
310            .fields
311            .iter()
312            .enumerate()
313            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
314            .collect_vec();
315        // We must create a variable to hold this `Arc<TableDmlHandle>` here, or it will be dropped
316        // due to the `Weak` reference in `DmlManager`.
317        let reader = dml_manager
318            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
319            .unwrap();
320        let mut reader = reader.stream_reader().into_stream();
321
322        // Update
323        let update_executor = Box::new(UpdateExecutor::new(
324            table_id,
325            INITIAL_TABLE_VERSION_ID,
326            dml_manager,
327            Box::new(mock_executor),
328            exprs,
329            5,
330            "UpdateExecutor".to_string(),
331            false,
332            vec![0, 1],
333            0,
334        ));
335
336        let handle = tokio::spawn(async move {
337            let fields = &update_executor.schema().fields;
338            assert_eq!(fields[0].data_type, DataType::Int64);
339
340            let mut stream = update_executor.execute();
341            let result = stream.next().await.unwrap().unwrap();
342
343            assert_eq!(
344                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
345                vec![Some(5)] // updated rows
346            );
347        });
348
349        reader.next().await.unwrap()?.into_begin().unwrap();
350
351        // Read
352        // As we set the chunk size to 5, we'll get 2 chunks. Note that the update records for one
353        // row cannot be cut into two chunks, so the first chunk will actually have 6 rows.
354        for updated_rows in [1..=3, 4..=5] {
355            let txn_msg = reader.next().await.unwrap()?;
356            let chunk = txn_msg.as_stream_chunk().unwrap();
357            assert_eq!(
358                chunk.ops().chunks(2).collect_vec(),
359                vec![&[Op::UpdateDelete, Op::UpdateInsert]; updated_rows.clone().count()]
360            );
361
362            assert_eq!(
363                chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
364                updated_rows
365                    .clone()
366                    .flat_map(|i| [i * 2 - 1, i * 2]) // -1, +2, -3, +4, ...
367                    .map(Some)
368                    .collect_vec()
369            );
370
371            assert_eq!(
372                chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
373                updated_rows
374                    .clone()
375                    .flat_map(|i| [i * 2, i * 2 - 1]) // -2, +1, -4, +3, ...
376                    .map(Some)
377                    .collect_vec()
378            );
379        }
380
381        reader.next().await.unwrap()?.into_end().unwrap();
382
383        handle.await.unwrap();
384
385        Ok(())
386    }
387}