risingwave_batch_executors/executor/
update.rs1use futures_async_stream::try_stream;
16use itertools::Itertools;
17use risingwave_common::array::{
18 Array, ArrayBuilder, DataChunk, Op, PrimitiveArrayBuilder, StreamChunk,
19};
20use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId};
21use risingwave_common::transaction::transaction_id::TxnId;
22use risingwave_common::types::DataType;
23use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
24use risingwave_common::util::iter_util::ZipEqDebug;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_expr::expr::{BoxedExpression, build_from_prost};
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31 BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34pub struct UpdateExecutor {
39 table_id: TableId,
41 table_version_id: TableVersionId,
42 dml_manager: DmlManagerRef,
43 child: BoxedExecutor,
44 old_exprs: Vec<BoxedExpression>,
45 new_exprs: Vec<BoxedExpression>,
46 chunk_size: usize,
47 schema: Schema,
48 identity: String,
49 returning: bool,
50 txn_id: TxnId,
51 session_id: u32,
52}
53
54impl UpdateExecutor {
55 #[allow(clippy::too_many_arguments)]
56 pub fn new(
57 table_id: TableId,
58 table_version_id: TableVersionId,
59 dml_manager: DmlManagerRef,
60 child: BoxedExecutor,
61 old_exprs: Vec<BoxedExpression>,
62 new_exprs: Vec<BoxedExpression>,
63 chunk_size: usize,
64 identity: String,
65 returning: bool,
66 session_id: u32,
67 ) -> Self {
68 let chunk_size = chunk_size.next_multiple_of(2);
69 let table_schema = child.schema().clone();
70 let txn_id = dml_manager.gen_txn_id();
71
72 Self {
73 table_id,
74 table_version_id,
75 dml_manager,
76 child,
77 old_exprs,
78 new_exprs,
79 chunk_size,
80 schema: if returning {
81 table_schema
82 } else {
83 Schema {
84 fields: vec![Field::unnamed(DataType::Int64)],
85 }
86 },
87 identity,
88 returning,
89 txn_id,
90 session_id,
91 }
92 }
93}
94
95impl Executor for UpdateExecutor {
96 fn schema(&self) -> &Schema {
97 &self.schema
98 }
99
100 fn identity(&self) -> &str {
101 &self.identity
102 }
103
104 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
105 self.do_execute()
106 }
107}
108
109impl UpdateExecutor {
110 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
111 async fn do_execute(self: Box<Self>) {
112 let table_dml_handle = self
113 .dml_manager
114 .table_dml_handle(self.table_id, self.table_version_id)?;
115
116 let data_types = table_dml_handle
117 .column_descs()
118 .iter()
119 .map(|c| c.data_type.clone())
120 .collect_vec();
121
122 assert_eq!(
123 data_types,
124 self.new_exprs.iter().map(|e| e.return_type()).collect_vec(),
125 "bad update schema"
126 );
127 assert_eq!(
128 data_types,
129 self.old_exprs.iter().map(|e| e.return_type()).collect_vec(),
130 "bad update schema"
131 );
132
133 let mut builder = DataChunkBuilder::new(data_types, self.chunk_size);
134
135 let mut write_handle: risingwave_dml::WriteHandle =
136 table_dml_handle.write_handle(self.session_id, self.txn_id)?;
137 write_handle.begin()?;
138
139 let write_txn_data = |chunk: DataChunk| async {
141 let ops = [Op::UpdateDelete, Op::UpdateInsert]
143 .into_iter()
144 .cycle()
145 .take(chunk.capacity())
146 .collect_vec();
147 let stream_chunk = StreamChunk::from_parts(ops, chunk);
148
149 #[cfg(debug_assertions)]
150 table_dml_handle.check_chunk_schema(&stream_chunk);
151
152 write_handle.write_chunk(stream_chunk).await
153 };
154
155 let mut rows_updated = 0;
156
157 #[for_await]
158 for input in self.child.execute() {
159 let input = input?;
160
161 let old_data_chunk = {
162 let mut columns = Vec::with_capacity(self.old_exprs.len());
163 for expr in &self.old_exprs {
164 let column = expr.eval(&input).await?;
165 columns.push(column);
166 }
167
168 DataChunk::new(columns, input.visibility().clone())
169 };
170
171 let updated_data_chunk = {
172 let mut columns = Vec::with_capacity(self.new_exprs.len());
173 for expr in &self.new_exprs {
174 let column = expr.eval(&input).await?;
175 columns.push(column);
176 }
177
178 DataChunk::new(columns, input.visibility().clone())
179 };
180
181 if self.returning {
182 yield updated_data_chunk.clone();
183 }
184
185 for (row_delete, row_insert) in
186 (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows())
187 {
188 rows_updated += 1;
189 if row_delete != row_insert {
191 let None = builder.append_one_row(row_delete) else {
192 unreachable!(
193 "no chunk should be yielded when appending the deleted row as the chunk size is always even"
194 );
195 };
196 if let Some(chunk) = builder.append_one_row(row_insert) {
197 write_txn_data(chunk).await?;
198 }
199 }
200 }
201 }
202
203 if let Some(chunk) = builder.consume_all() {
204 write_txn_data(chunk).await?;
205 }
206 write_handle.end().await?;
207
208 if !self.returning {
210 let mut array_builder = PrimitiveArrayBuilder::<i64>::new(1);
211 array_builder.append(Some(rows_updated as i64));
212
213 let array = array_builder.finish();
214 let ret_chunk = DataChunk::new(vec![array.into_ref()], 1);
215
216 yield ret_chunk
217 }
218 }
219}
220
221impl BoxedExecutorBuilder for UpdateExecutor {
222 async fn new_boxed_executor(
223 source: &ExecutorBuilder<'_>,
224 inputs: Vec<BoxedExecutor>,
225 ) -> Result<BoxedExecutor> {
226 let [child]: [_; 1] = inputs.try_into().unwrap();
227
228 let update_node = try_match_expand!(
229 source.plan_node().get_node_body().unwrap(),
230 NodeBody::Update
231 )?;
232
233 let table_id = TableId::new(update_node.table_id);
234
235 let old_exprs: Vec<_> = update_node
236 .get_old_exprs()
237 .iter()
238 .map(build_from_prost)
239 .try_collect()?;
240
241 let new_exprs: Vec<_> = update_node
242 .get_new_exprs()
243 .iter()
244 .map(build_from_prost)
245 .try_collect()?;
246
247 Ok(Box::new(Self::new(
248 table_id,
249 update_node.table_version_id,
250 source.context().dml_manager(),
251 child,
252 old_exprs,
253 new_exprs,
254 source.context().get_config().developer.chunk_size,
255 source.plan_node().get_identity().clone(),
256 update_node.returning,
257 update_node.session_id,
258 )))
259 }
260}
261
262#[cfg(test)]
263#[cfg(any())]
264mod tests {
265 use std::sync::Arc;
266
267 use futures::StreamExt;
268 use risingwave_common::catalog::{
269 ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID, schema_test_utils,
270 };
271 use risingwave_common::test_prelude::DataChunkTestExt;
272 use risingwave_dml::dml_manager::DmlManager;
273 use risingwave_expr::expr::InputRefExpression;
274
275 use super::*;
276 use crate::executor::test_utils::MockExecutor;
277 use crate::*;
278
279 #[tokio::test]
280 async fn test_update_executor() -> Result<()> {
281 let dml_manager = Arc::new(DmlManager::for_test());
282
283 let schema = schema_test_utils::ii();
285 let mut mock_executor = MockExecutor::new(schema.clone());
286
287 let schema = schema_test_utils::ii();
289
290 mock_executor.add(DataChunk::from_pretty(
291 "i i
292 1 2
293 3 4
294 5 6
295 7 8
296 9 10",
297 ));
298
299 let exprs = vec![
301 Box::new(InputRefExpression::new(DataType::Int32, 1)) as BoxedExpression,
302 Box::new(InputRefExpression::new(DataType::Int32, 0)),
303 ];
304
305 let table_id = TableId::new(0);
307
308 let column_descs = schema
310 .fields
311 .iter()
312 .enumerate()
313 .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone()))
314 .collect_vec();
315 let reader = dml_manager
318 .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs)
319 .unwrap();
320 let mut reader = reader.stream_reader().into_stream();
321
322 let update_executor = Box::new(UpdateExecutor::new(
324 table_id,
325 INITIAL_TABLE_VERSION_ID,
326 dml_manager,
327 Box::new(mock_executor),
328 exprs,
329 5,
330 "UpdateExecutor".to_string(),
331 false,
332 vec![0, 1],
333 0,
334 ));
335
336 let handle = tokio::spawn(async move {
337 let fields = &update_executor.schema().fields;
338 assert_eq!(fields[0].data_type, DataType::Int64);
339
340 let mut stream = update_executor.execute();
341 let result = stream.next().await.unwrap().unwrap();
342
343 assert_eq!(
344 result.column_at(0).as_int64().iter().collect::<Vec<_>>(),
345 vec![Some(5)] );
347 });
348
349 reader.next().await.unwrap()?.into_begin().unwrap();
350
351 for updated_rows in [1..=3, 4..=5] {
355 let txn_msg = reader.next().await.unwrap()?;
356 let chunk = txn_msg.as_stream_chunk().unwrap();
357 assert_eq!(
358 chunk.ops().chunks(2).collect_vec(),
359 vec![&[Op::UpdateDelete, Op::UpdateInsert]; updated_rows.clone().count()]
360 );
361
362 assert_eq!(
363 chunk.columns()[0].as_int32().iter().collect::<Vec<_>>(),
364 updated_rows
365 .clone()
366 .flat_map(|i| [i * 2 - 1, i * 2]) .map(Some)
368 .collect_vec()
369 );
370
371 assert_eq!(
372 chunk.columns()[1].as_int32().iter().collect::<Vec<_>>(),
373 updated_rows
374 .clone()
375 .flat_map(|i| [i * 2, i * 2 - 1]) .map(Some)
377 .collect_vec()
378 );
379 }
380
381 reader.next().await.unwrap()?.into_end().unwrap();
382
383 handle.await.unwrap();
384
385 Ok(())
386 }
387}