risingwave_frontend/optimizer/plan_node/
batch_update.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::catalog::Schema;
16use risingwave_pb::batch_plan::UpdateNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18
19use super::batch::prelude::*;
20use super::utils::impl_distill_by_unit;
21use super::{
22    ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch, generic,
23};
24use crate::error::Result;
25use crate::expr::{Expr, ExprRewriter, ExprVisitor};
26use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
27use crate::optimizer::plan_node::{ToLocalBatch, utils};
28use crate::optimizer::plan_visitor::DistributedDmlVisitor;
29use crate::optimizer::property::{Distribution, Order, RequiredDist};
30
31/// `BatchUpdate` implements [`super::LogicalUpdate`]
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub struct BatchUpdate {
34    pub base: PlanBase<Batch>,
35    pub core: generic::Update<PlanRef>,
36}
37
38impl BatchUpdate {
39    pub fn new(core: generic::Update<PlanRef>, schema: Schema) -> Self {
40        let ctx = core.input.ctx();
41        let base =
42            PlanBase::new_batch(ctx, schema, core.input.distribution().clone(), Order::any());
43        Self { base, core }
44    }
45}
46
47impl PlanTreeNodeUnary for BatchUpdate {
48    fn input(&self) -> PlanRef {
49        self.core.input.clone()
50    }
51
52    fn clone_with_input(&self, input: PlanRef) -> Self {
53        let mut core = self.core.clone();
54        core.input = input;
55        Self::new(core, self.schema().clone())
56    }
57}
58
59impl_plan_tree_node_for_unary! { BatchUpdate }
60impl_distill_by_unit!(BatchUpdate, core, "BatchUpdate");
61
62impl ToDistributedBatch for BatchUpdate {
63    fn to_distributed(&self) -> Result<PlanRef> {
64        if DistributedDmlVisitor::dml_should_run_in_distributed(self.input()) {
65            // Add an hash shuffle between the update and its input.
66            let new_input = RequiredDist::PhysicalDist(Distribution::HashShard(
67                (0..self.input().schema().len()).collect(),
68            ))
69            .enforce_if_not_satisfies(self.input().to_distributed()?, &Order::any())?;
70            let new_update: PlanRef = self.clone_with_input(new_input).into();
71            if self.core.returning {
72                Ok(new_update)
73            } else {
74                utils::sum_affected_row(new_update)
75            }
76        } else {
77            let new_input = RequiredDist::single()
78                .enforce_if_not_satisfies(self.input().to_distributed()?, &Order::any())?;
79            Ok(self.clone_with_input(new_input).into())
80        }
81    }
82}
83
84impl ToBatchPb for BatchUpdate {
85    fn to_batch_prost_body(&self) -> NodeBody {
86        let old_exprs = (self.core.old_exprs)
87            .iter()
88            .map(|x| x.to_expr_proto())
89            .collect();
90        let new_exprs = (self.core.new_exprs)
91            .iter()
92            .map(|x| x.to_expr_proto())
93            .collect();
94
95        NodeBody::Update(UpdateNode {
96            table_id: self.core.table_id.table_id(),
97            table_version_id: self.core.table_version_id,
98            returning: self.core.returning,
99            old_exprs,
100            new_exprs,
101            session_id: self.base.ctx().session_ctx().session_id().0 as u32,
102        })
103    }
104}
105
106impl ToLocalBatch for BatchUpdate {
107    fn to_local(&self) -> Result<PlanRef> {
108        let new_input = RequiredDist::single()
109            .enforce_if_not_satisfies(self.input().to_local()?, &Order::any())?;
110        Ok(self.clone_with_input(new_input).into())
111    }
112}
113
114impl ExprRewritable for BatchUpdate {
115    fn has_rewritable_expr(&self) -> bool {
116        true
117    }
118
119    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
120        let mut core = self.core.clone();
121        core.rewrite_exprs(r);
122        Self::new(core, self.schema().clone()).into()
123    }
124}
125
126impl ExprVisitable for BatchUpdate {
127    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
128        self.core.visit_exprs(v);
129    }
130}