risingwave_batch_executors/executor/
update.rs1use futures_async_stream::try_stream;
16use itertools::Itertools;
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::array::{
19 Array, ArrayBuilder, DataChunk, PrimitiveArrayBuilder, StreamChunk, StreamChunkBuilder,
20};
21use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
22use risingwave_common::transaction::transaction_id::TxnId;
23use risingwave_common::types::DataType;
24use risingwave_common::util::iter_util::ZipEqDebug;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_expr::expr::{BoxedExpression, build_from_prost};
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31 BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34pub struct UpdateExecutor {
39 table_id: TableId,
41 table_version_id: TableVersionId,
42 dml_manager: DmlManagerRef,
43 child: BoxedExecutor,
44 old_exprs: Vec<BoxedExpression>,
45 new_exprs: Vec<BoxedExpression>,
46 chunk_size: usize,
47 schema: Schema,
48 identity: String,
49 returning: bool,
50 txn_id: TxnId,
51 session_id: u32,
52 upsert: bool,
53}
54
55impl UpdateExecutor {
56 #[allow(clippy::too_many_arguments)]
57 pub fn new(
58 table_id: TableId,
59 table_version_id: TableVersionId,
60 dml_manager: DmlManagerRef,
61 child: BoxedExecutor,
62 old_exprs: Vec<BoxedExpression>,
63 new_exprs: Vec<BoxedExpression>,
64 chunk_size: usize,
65 identity: String,
66 returning: bool,
67 session_id: u32,
68 upsert: bool,
69 ) -> Self {
70 let chunk_size = chunk_size.next_multiple_of(2);
71 let table_schema = child.schema().clone();
72 let txn_id = dml_manager.gen_txn_id();
73
74 Self {
75 table_id,
76 table_version_id,
77 dml_manager,
78 child,
79 old_exprs,
80 new_exprs,
81 chunk_size,
82 schema: if returning {
83 table_schema
84 } else {
85 Schema {
86 fields: vec![Field::unnamed(DataType::Int64)],
87 }
88 },
89 identity,
90 returning,
91 txn_id,
92 session_id,
93 upsert,
94 }
95 }
96}
97
98impl Executor for UpdateExecutor {
99 fn schema(&self) -> &Schema {
100 &self.schema
101 }
102
103 fn identity(&self) -> &str {
104 &self.identity
105 }
106
107 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
108 self.do_execute()
109 }
110}
111
112impl UpdateExecutor {
113 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
114 async fn do_execute(self: Box<Self>) {
115 let table_dml_handle = self
116 .dml_manager
117 .table_dml_handle(self.table_id, self.table_version_id)?;
118
119 let data_types = table_dml_handle
120 .column_descs()
121 .iter()
122 .map(|c| c.data_type.clone())
123 .collect_vec();
124
125 assert_eq!(
126 data_types,
127 self.new_exprs.iter().map(|e| e.return_type()).collect_vec(),
128 "bad update schema"
129 );
130 assert_eq!(
131 data_types,
132 self.old_exprs.iter().map(|e| e.return_type()).collect_vec(),
133 "bad update schema"
134 );
135
136 let mut builder = StreamChunkBuilder::new(self.chunk_size, data_types);
137
138 let mut write_handle: risingwave_dml::WriteHandle =
139 table_dml_handle.write_handle(self.session_id, self.txn_id)?;
140 write_handle.begin()?;
141
142 let write_txn_data = |chunk: StreamChunk| async {
144 if cfg!(debug_assertions) {
145 table_dml_handle.check_chunk_schema(&chunk);
146 }
147 write_handle.write_chunk(chunk).await
148 };
149
150 let mut rows_updated = 0;
151
152 #[for_await]
153 for input in self.child.execute() {
154 let input = input?;
155
156 let old_data_chunk = {
157 let mut columns = Vec::with_capacity(self.old_exprs.len());
158 for expr in &self.old_exprs {
159 let column = expr.eval(&input).await?;
160 columns.push(column);
161 }
162
163 DataChunk::new(columns, input.visibility().clone())
164 };
165
166 let updated_data_chunk = {
167 let mut columns = Vec::with_capacity(self.new_exprs.len());
168 for expr in &self.new_exprs {
169 let column = expr.eval(&input).await?;
170 columns.push(column);
171 }
172
173 DataChunk::new(columns, input.visibility().clone())
174 };
175
176 if self.returning {
177 yield updated_data_chunk.clone();
178 }
179
180 for (row_delete, row_insert) in
181 (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows())
182 {
183 rows_updated += 1;
184 if row_delete == row_insert {
186 continue;
187 }
188 let chunk = if self.upsert {
189 builder.append_record(Record::Insert {
191 new_row: row_insert,
192 })
193 } else {
194 builder.append_record(Record::Update {
197 old_row: row_delete,
198 new_row: row_insert,
199 })
200 };
201
202 if let Some(chunk) = chunk {
203 write_txn_data(chunk).await?;
204 }
205 }
206 }
207
208 if let Some(chunk) = builder.take() {
209 write_txn_data(chunk).await?;
210 }
211 write_handle.end().await?;
212
213 if !self.returning {
215 let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
216 array_builder.append(Some(rows_updated as i64));
217
218 let array = array_builder.finish();
219 let ret_chunk = DataChunk::new(vec![array.into_ref()], 1);
220
221 yield ret_chunk
222 }
223 }
224}
225
226impl BoxedExecutorBuilder for UpdateExecutor {
227 async fn new_boxed_executor(
228 source: &ExecutorBuilder<'_>,
229 inputs: Vec<BoxedExecutor>,
230 ) -> Result<BoxedExecutor> {
231 let [child]: [_; 1] = inputs.try_into().unwrap();
232
233 let update_node = try_match_expand!(
234 source.plan_node().get_node_body().unwrap(),
235 NodeBody::Update
236 )?;
237
238 let table_id = update_node.table_id;
239
240 let old_exprs: Vec<_> = update_node
241 .get_old_exprs()
242 .iter()
243 .map(build_from_prost)
244 .try_collect()?;
245
246 let new_exprs: Vec<_> = update_node
247 .get_new_exprs()
248 .iter()
249 .map(build_from_prost)
250 .try_collect()?;
251
252 Ok(Box::new(Self::new(
253 table_id,
254 update_node.table_version_id,
255 source.context().dml_manager(),
256 child,
257 old_exprs,
258 new_exprs,
259 source.context().get_config().developer.chunk_size,
260 source.plan_node().get_identity().clone(),
261 update_node.returning,
262 update_node.session_id,
263 update_node.upsert,
264 )))
265 }
266}
267
268#[cfg(test)]
269#[cfg(any())]
270mod tests {
271 use std::sync::Arc;
272
273 use futures::StreamExt;
274 use risingwave_common::catalog::{
275 ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID, schema_test_utils,
276 };
277 use risingwave_common::test_prelude::DataChunkTestExt;
278 use risingwave_dml::dml_manager::DmlManager;
279 use risingwave_expr::expr::InputRefExpression;
280
281 use super::*;
282 use crate::executor::test_utils::MockExecutor;
283 use crate::*;
284
285 #[tokio::test]
286 async fn test_update_executor() -> Result<()> {
287 let dml_manager = Arc::new(DmlManager::for_test());
288
289 let schema = schema_test_utils::ii();
291 let mut mock_executor = MockExecutor::new(schema.clone());
292
293 let schema = schema_test_utils::ii();
295
296 mock_executor.add(DataChunk::from_pretty(
297 "i i
298 1 2
299 3 4
300 5 6
301 7 8
302 9 10",
303 ));
304
305 let exprs = vec![
307 Box::new(InputRefExpression::new(DataType::Int32, 1)) as BoxedExpression,
308 Box::new(InputRefExpression::new(DataType::Int32, 0)),
309 ];
310
311 let table_id = TableId::new(0);
313
314 let column_descs = schema
316 .fields
317 .iter()
318 .enumerate()
319 .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
320 .collect_vec();
321 let reader = dml_manager
324 .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
325 .unwrap();
326 let mut reader = reader.stream_reader().into_stream();
327
328 let update_executor = Box::new(UpdateExecutor::new(
330 table_id,
331 INITIAL_TABLE_VERSION_ID,
332 dml_manager,
333 Box::new(mock_executor),
334 exprs,
335 5,
336 "UpdateExecutor".to_string(),
337 false,
338 vec![0, 1],
339 0,
340 ));
341
342 let handle = tokio::spawn(async move {
343 let fields = &update_executor.schema().fields;
344 assert_eq!(fields[0].data_type, DataType::Int64);
345
346 let mut stream = update_executor.execute();
347 let result = stream.next().await.unwrap().unwrap();
348
349 assert_eq!(
350 result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
351 vec![Some(5)] );
353 });
354
355 reader.next().await.unwrap()?.into_begin().unwrap();
356
357 for updated_rows in [1..=3, 4..=5] {
361 let txn_msg = reader.next().await.unwrap()?;
362 let chunk = txn_msg.as_stream_chunk().unwrap();
363 assert_eq!(
364 chunk.ops().chunks(2).collect_vec(),
365 vec![&[Op::UpdateDelete, Op::UpdateInsert]; updated_rows.clone().count()]
366 );
367
368 assert_eq!(
369 chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
370 updated_rows
371 .clone()
372 .flat_map(|i| [i * 2 - 1, i * 2]) .map(Some)
374 .collect_vec()
375 );
376
377 assert_eq!(
378 chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
379 updated_rows
380 .clone()
381 .flat_map(|i| [i * 2, i * 2 - 1]) .map(Some)
383 .collect_vec()
384 );
385 }
386
387 reader.next().await.unwrap()?.into_end().unwrap();
388
389 handle.await.unwrap();
390
391 Ok(())
392 }
393}