risingwave_batch_executors/executor/
update.rs1use futures_async_stream::try_stream;
16use itertools::Itertools;
17use risingwave_common::array::{
18    Array, ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, StreamChunk,
19};
20use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
21use risingwave_common::transaction::transaction_id::TxnId;
22use risingwave_common::types::DataType;
23use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
24use risingwave_common::util::iter_util::ZipEqDebug;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_expr::expr::{BoxedExpression, build_from_prost};
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34pub struct UpdateExecutor {
39    table_id: TableId,
41    table_version_id: TableVersionId,
42    dml_manager: DmlManagerRef,
43    child: BoxedExecutor,
44    old_exprs: Vec<BoxedExpression>,
45    new_exprs: Vec<BoxedExpression>,
46    chunk_size: usize,
47    schema: Schema,
48    identity: String,
49    returning: bool,
50    txn_id: TxnId,
51    session_id: u32,
52}
53
54impl UpdateExecutor {
55    #[allow(clippy::too_many_arguments)]
56    pub fn new(
57        table_id: TableId,
58        table_version_id: TableVersionId,
59        dml_manager: DmlManagerRef,
60        child: BoxedExecutor,
61        old_exprs: Vec<BoxedExpression>,
62        new_exprs: Vec<BoxedExpression>,
63        chunk_size: usize,
64        identity: String,
65        returning: bool,
66        session_id: u32,
67    ) -> Self {
68        let chunk_size = chunk_size.next_multiple_of(2);
69        let table_schema = child.schema().clone();
70        let txn_id = dml_manager.gen_txn_id();
71
72        Self {
73            table_id,
74            table_version_id,
75            dml_manager,
76            child,
77            old_exprs,
78            new_exprs,
79            chunk_size,
80            schema: if returning {
81                table_schema
82            } else {
83                Schema {
84                    fields: vec![Field::unnamed(DataType::Int64)],
85                }
86            },
87            identity,
88            returning,
89            txn_id,
90            session_id,
91        }
92    }
93}
94
95impl Executor for UpdateExecutor {
96    fn schema(&self) -> &Schema {
97        &self.schema
98    }
99
100    fn identity(&self) -> &str {
101        &self.identity
102    }
103
104    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
105        self.do_execute()
106    }
107}
108
109impl UpdateExecutor {
110    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
111    async fn do_execute(self: Box<Self>) {
112        let table_dml_handle = self
113            .dml_manager
114            .table_dml_handle(self.table_id, self.table_version_id)?;
115
116        let data_types = table_dml_handle
117            .column_descs()
118            .iter()
119            .map(|c| c.data_type.clone())
120            .collect_vec();
121
122        assert_eq!(
123            data_types,
124            self.new_exprs.iter().map(|e| e.return_type()).collect_vec(),
125            "bad update schema"
126        );
127        assert_eq!(
128            data_types,
129            self.old_exprs.iter().map(|e| e.return_type()).collect_vec(),
130            "bad update schema"
131        );
132
133        let mut builder = DataChunkBuilder::new(data_types, self.chunk_size);
134
135        let mut write_handle: risingwave_dml::WriteHandle =
136            table_dml_handle.write_handle(self.session_id, self.txn_id)?;
137        write_handle.begin()?;
138
139        let write_txn_data = |chunk: DataChunk| async {
141            let ops = [Op::UpdateDelete, Op::UpdateInsert]
144                .into_iter()
145                .cycle()
146                .take(chunk.capacity())
147                .collect_vec();
148            let stream_chunk = StreamChunk::from_parts(ops, chunk);
149
150            #[cfg(debug_assertions)]
151            table_dml_handle.check_chunk_schema(&stream_chunk);
152
153            write_handle.write_chunk(stream_chunk).await
154        };
155
156        let mut rows_updated = 0;
157
158        #[for_await]
159        for input in self.child.execute() {
160            let input = input?;
161
162            let old_data_chunk = {
163                let mut columns = Vec::with_capacity(self.old_exprs.len());
164                for expr in &self.old_exprs {
165                    let column = expr.eval(&input).await?;
166                    columns.push(column);
167                }
168
169                DataChunk::new(columns, input.visibility().clone())
170            };
171
172            let updated_data_chunk = {
173                let mut columns = Vec::with_capacity(self.new_exprs.len());
174                for expr in &self.new_exprs {
175                    let column = expr.eval(&input).await?;
176                    columns.push(column);
177                }
178
179                DataChunk::new(columns, input.visibility().clone())
180            };
181
182            if self.returning {
183                yield updated_data_chunk.clone();
184            }
185
186            for (row_delete, row_insert) in
187                (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows())
188            {
189                rows_updated += 1;
190                if row_delete != row_insert {
192                    let None = builder.append_one_row(row_delete) else {
193                        unreachable!(
194                            "no chunk should be yielded when appending the deleted row as the chunk size is always even"
195                        );
196                    };
197                    if let Some(chunk) = builder.append_one_row(row_insert) {
198                        write_txn_data(chunk).await?;
199                    }
200                }
201            }
202        }
203
204        if let Some(chunk) = builder.consume_all() {
205            write_txn_data(chunk).await?;
206        }
207        write_handle.end().await?;
208
209        if !self.returning {
211            let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
212            array_builder.append(Some(rows_updated as i64));
213
214            let array = array_builder.finish();
215            let ret_chunk = DataChunk::new(vec![array.into_ref()], 1);
216
217            yield ret_chunk
218        }
219    }
220}
221
222impl BoxedExecutorBuilder for UpdateExecutor {
223    async fn new_boxed_executor(
224        source: &ExecutorBuilder<'_>,
225        inputs: Vec<BoxedExecutor>,
226    ) -> Result<BoxedExecutor> {
227        let [child]: [_; 1] = inputs.try_into().unwrap();
228
229        let update_node = try_match_expand!(
230            source.plan_node().get_node_body().unwrap(),
231            NodeBody::Update
232        )?;
233
234        let table_id = TableId::new(update_node.table_id);
235
236        let old_exprs: Vec<_> = update_node
237            .get_old_exprs()
238            .iter()
239            .map(build_from_prost)
240            .try_collect()?;
241
242        let new_exprs: Vec<_> = update_node
243            .get_new_exprs()
244            .iter()
245            .map(build_from_prost)
246            .try_collect()?;
247
248        Ok(Box::new(Self::new(
249            table_id,
250            update_node.table_version_id,
251            source.context().dml_manager(),
252            child,
253            old_exprs,
254            new_exprs,
255            source.context().get_config().developer.chunk_size,
256            source.plan_node().get_identity().clone(),
257            update_node.returning,
258            update_node.session_id,
259        )))
260    }
261}
262
263#[cfg(test)]
264#[cfg(any())]
265mod tests {
266    use std::sync::Arc;
267
268    use futures::StreamExt;
269    use risingwave_common::catalog::{
270        ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID, schema_test_utils,
271    };
272    use risingwave_common::test_prelude::DataChunkTestExt;
273    use risingwave_dml::dml_manager::DmlManager;
274    use risingwave_expr::expr::InputRefExpression;
275
276    use super::*;
277    use crate::executor::test_utils::MockExecutor;
278    use crate::*;
279
280    #[tokio::test]
281    async fn test_update_executor() -> Result<()> {
282        let dml_manager = Arc::new(DmlManager::for_test());
283
284        let schema = schema_test_utils::ii();
286        let mut mock_executor = MockExecutor::new(schema.clone());
287
288        let schema = schema_test_utils::ii();
290
291        mock_executor.add(DataChunk::from_pretty(
292            "i  i
293             1  2
294             3  4
295             5  6
296             7  8
297             9 10",
298        ));
299
300        let exprs = vec![
302            Box::new(InputRefExpression::new(DataType::Int32, 1)) as BoxedExpression,
303            Box::new(InputRefExpression::new(DataType::Int32, 0)),
304        ];
305
306        let table_id = TableId::new(0);
308
309        let column_descs = schema
311            .fields
312            .iter()
313            .enumerate()
314            .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
315            .collect_vec();
316        let reader = dml_manager
319            .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
320            .unwrap();
321        let mut reader = reader.stream_reader().into_stream();
322
323        let update_executor = Box::new(UpdateExecutor::new(
325            table_id,
326            INITIAL_TABLE_VERSION_ID,
327            dml_manager,
328            Box::new(mock_executor),
329            exprs,
330            5,
331            "UpdateExecutor".to_string(),
332            false,
333            vec![0, 1],
334            0,
335        ));
336
337        let handle = tokio::spawn(async move {
338            let fields = &update_executor.schema().fields;
339            assert_eq!(fields[0].data_type, DataType::Int64);
340
341            let mut stream = update_executor.execute();
342            let result = stream.next().await.unwrap().unwrap();
343
344            assert_eq!(
345                result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
346                vec![Some(5)] );
348        });
349
350        reader.next().await.unwrap()?.into_begin().unwrap();
351
352        for updated_rows in [1..=3, 4..=5] {
356            let txn_msg = reader.next().await.unwrap()?;
357            let chunk = txn_msg.as_stream_chunk().unwrap();
358            assert_eq!(
359                chunk.ops().chunks(2).collect_vec(),
360                vec![&[Op::UpdateDelete, Op::UpdateInsert]; updated_rows.clone().count()]
361            );
362
363            assert_eq!(
364                chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
365                updated_rows
366                    .clone()
367                    .flat_map(|i| [i * 2 - 1, i * 2]) .map(Some)
369                    .collect_vec()
370            );
371
372            assert_eq!(
373                chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
374                updated_rows
375                    .clone()
376                    .flat_map(|i| [i * 2, i * 2 - 1]) .map(Some)
378                    .collect_vec()
379            );
380        }
381
382        reader.next().await.unwrap()?.into_end().unwrap();
383
384        handle.await.unwrap();
385
386        Ok(())
387    }
388}