risingwave_batch_executors/executor/
update.rs1use futures_async_stream::try_stream;
16use itertools::Itertools;
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::array::{
19 Array, ArrayBuilder, DataChunk, PrimitiveArrayBuilder, StreamChunk, StreamChunkBuilder,
20};
21use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::TxnId;
23use risingwave_common::types::DataType;
24use risingwave_common::util::iter_util::ZipEqDebug;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_expr::expr::{BoxedExpression, build_from_prost};
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31 BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34pub struct UpdateExecutor {
39 table_id: TableId,
41 table_version_id: TableVersionId,
42 dml_manager: DmlManagerRef,
43 child: BoxedExecutor,
44 old_exprs: Vec<BoxedExpression>,
45 new_exprs: Vec<BoxedExpression>,
46 chunk_size: usize,
47 schema: Schema,
48 identity: String,
49 returning: bool,
50 txn_id: TxnId,
51 session_id: u32,
52 upsert: bool,
53 wait_for_persistence: bool,
54}
55
56impl UpdateExecutor {
57 #[expect(clippy::too_many_arguments)]
58 pub fn new(
59 table_id: TableId,
60 table_version_id: TableVersionId,
61 dml_manager: DmlManagerRef,
62 child: BoxedExecutor,
63 old_exprs: Vec<BoxedExpression>,
64 new_exprs: Vec<BoxedExpression>,
65 chunk_size: usize,
66 identity: String,
67 returning: bool,
68 session_id: u32,
69 upsert: bool,
70 wait_for_persistence: bool,
71 ) -> Self {
72 let chunk_size = chunk_size.next_multiple_of(2);
73 let table_schema = child.schema().clone();
74 let txn_id = dml_manager.gen_txn_id();
75
76 Self {
77 table_id,
78 table_version_id,
79 dml_manager,
80 child,
81 old_exprs,
82 new_exprs,
83 chunk_size,
84 schema: if returning {
85 table_schema
86 } else {
87 Schema {
88 fields: vec![Field::unnamed(DataType::Int64)],
89 }
90 },
91 identity,
92 returning,
93 txn_id,
94 session_id,
95 upsert,
96 wait_for_persistence,
97 }
98 }
99}
100
101impl Executor for UpdateExecutor {
102 fn schema(&self) -> &Schema {
103 &self.schema
104 }
105
106 fn identity(&self) -> &str {
107 &self.identity
108 }
109
110 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
111 self.do_execute()
112 }
113}
114
115impl UpdateExecutor {
116 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
117 async fn do_execute(self: Box<Self>) {
118 let table_dml_handle = self
119 .dml_manager
120 .table_dml_handle(self.table_id, self.table_version_id)?;
121
122 let data_types = table_dml_handle
123 .column_descs()
124 .iter()
125 .map(|c| c.data_type.clone())
126 .collect_vec();
127
128 assert_eq!(
129 data_types,
130 self.new_exprs.iter().map(|e| e.return_type()).collect_vec(),
131 "bad update schema"
132 );
133 assert_eq!(
134 data_types,
135 self.old_exprs.iter().map(|e| e.return_type()).collect_vec(),
136 "bad update schema"
137 );
138
139 let mut builder = StreamChunkBuilder::new(self.chunk_size, data_types);
140
141 let mut write_handle: risingwave_dml::WriteHandle =
142 table_dml_handle.write_handle(self.session_id, self.txn_id)?;
143 write_handle.begin()?;
144
145 let write_txn_data = |chunk: StreamChunk| async {
147 if cfg!(debug_assertions) {
148 table_dml_handle.check_chunk_schema(&chunk);
149 }
150 write_handle.write_chunk(chunk).await
151 };
152
153 let mut rows_updated = 0;
154
155 #[for_await]
156 for input in self.child.execute() {
157 let input = input?;
158
159 let old_data_chunk = {
160 let mut columns = Vec::with_capacity(self.old_exprs.len());
161 for expr in &self.old_exprs {
162 let column = expr.eval(&input).await?;
163 columns.push(column);
164 }
165
166 DataChunk::new(columns, input.visibility().clone())
167 };
168
169 let updated_data_chunk = {
170 let mut columns = Vec::with_capacity(self.new_exprs.len());
171 for expr in &self.new_exprs {
172 let column = expr.eval(&input).await?;
173 columns.push(column);
174 }
175
176 DataChunk::new(columns, input.visibility().clone())
177 };
178
179 if self.returning {
180 yield updated_data_chunk.clone();
181 }
182
183 for (row_delete, row_insert) in
184 (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows())
185 {
186 rows_updated += 1;
187 if row_delete == row_insert {
189 continue;
190 }
191 let chunk = if self.upsert {
192 builder.append_record(Record::Insert {
194 new_row: row_insert,
195 })
196 } else {
197 builder.append_record(Record::Update {
200 old_row: row_delete,
201 new_row: row_insert,
202 })
203 };
204
205 if let Some(chunk) = chunk {
206 write_txn_data(chunk).await?;
207 }
208 }
209 }
210
211 if let Some(chunk) = builder.take() {
212 write_txn_data(chunk).await?;
213 }
214 if self.wait_for_persistence {
215 write_handle.end_wait_persistence()?.await?;
216 } else {
217 write_handle.end().await?;
218 }
219
220 if !self.returning {
222 let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
223 array_builder.append(Some(rows_updated as i64));
224
225 let array = array_builder.finish();
226 let ret_chunk = DataChunk::new(vec![array.into_ref()], 1);
227
228 yield ret_chunk
229 }
230 }
231}
232
233impl BoxedExecutorBuilder for UpdateExecutor {
234 async fn new_boxed_executor(
235 source: &ExecutorBuilder<'_>,
236 inputs: Vec<BoxedExecutor>,
237 ) -> Result<BoxedExecutor> {
238 let [child]: [_; 1] = inputs.try_into().unwrap();
239
240 let update_node = try_match_expand!(
241 source.plan_node().get_node_body().unwrap(),
242 NodeBody::Update
243 )?;
244
245 let table_id = update_node.table_id;
246
247 let old_exprs: Vec<_> = update_node
248 .get_old_exprs()
249 .iter()
250 .map(build_from_prost)
251 .try_collect()?;
252
253 let new_exprs: Vec<_> = update_node
254 .get_new_exprs()
255 .iter()
256 .map(build_from_prost)
257 .try_collect()?;
258
259 Ok(Box::new(Self::new(
260 table_id,
261 update_node.table_version_id,
262 source.context().dml_manager(),
263 child,
264 old_exprs,
265 new_exprs,
266 source.context().get_config().developer.chunk_size,
267 source.plan_node().get_identity().clone(),
268 update_node.returning,
269 update_node.session_id,
270 update_node.upsert,
271 update_node.wait_for_persistence,
272 )))
273 }
274}
275
276#[cfg(test)]
277#[cfg(any())]
278mod tests {
279 use std::sync::Arc;
280
281 use futures::StreamExt;
282 use risingwave_common::catalog::{
283 ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID, schema_test_utils,
284 };
285 use risingwave_common::test_prelude::DataChunkTestExt;
286 use risingwave_dml::dml_manager::DmlManager;
287 use risingwave_expr::expr::InputRefExpression;
288
289 use super::*;
290 use crate::executor::test_utils::MockExecutor;
291 use crate::*;
292
293 #[tokio::test]
294 async fn test_update_executor() -> Result<()> {
295 let dml_manager = Arc::new(DmlManager::for_test());
296
297 let schema = schema_test_utils::ii();
299 let mut mock_executor = MockExecutor::new(schema.clone());
300
301 let schema = schema_test_utils::ii();
303
304 mock_executor.add(DataChunk::from_pretty(
305 "i i
306 1 2
307 3 4
308 5 6
309 7 8
310 9 10",
311 ));
312
313 let exprs = vec![
315 Box::new(InputRefExpression::new(DataType::Int32, 1)) as BoxedExpression,
316 Box::new(InputRefExpression::new(DataType::Int32, 0)),
317 ];
318
319 let table_id = TableId::new(0);
321
322 let column_descs = schema
324 .fields
325 .iter()
326 .enumerate()
327 .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
328 .collect_vec();
329 let reader = dml_manager
332 .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
333 .unwrap();
334 let mut reader = reader.stream_reader().into_stream();
335
336 let update_executor = Box::new(UpdateExecutor::new(
338 table_id,
339 INITIAL_TABLE_VERSION_ID,
340 dml_manager,
341 Box::new(mock_executor),
342 exprs,
343 5,
344 "UpdateExecutor".to_string(),
345 false,
346 vec![0, 1],
347 0,
348 ));
349
350 let handle = tokio::spawn(async move {
351 let fields = &update_executor.schema().fields;
352 assert_eq!(fields[0].data_type, DataType::Int64);
353
354 let mut stream = update_executor.execute();
355 let result = stream.next().await.unwrap().unwrap();
356
357 assert_eq!(
358 result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
359 vec![Some(5)] );
361 });
362
363 reader.next().await.unwrap()?.into_begin().unwrap();
364
365 for updated_rows in [1..=3, 4..=5] {
369 let txn_msg = reader.next().await.unwrap()?;
370 let chunk = txn_msg.as_stream_chunk().unwrap();
371 assert_eq!(
372 chunk.ops().chunks(2).collect_vec(),
373 vec![&[Op::UpdateDelete, Op::UpdateInsert]; updated_rows.clone().count()]
374 );
375
376 assert_eq!(
377 chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
378 updated_rows
379 .clone()
380 .flat_map(|i| [i * 2 - 1, i * 2]) .map(Some)
382 .collect_vec()
383 );
384
385 assert_eq!(
386 chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
387 updated_rows
388 .clone()
389 .flat_map(|i| [i * 2, i * 2 - 1]) .map(Some)
391 .collect_vec()
392 );
393 }
394
395 reader.next().await.unwrap()?.into_end().unwrap();
396
397 handle.await.unwrap();
398
399 Ok(())
400 }
401}