risingwave_frontend/planner/
update.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use risingwave_common::types::{DataType, Scalar};
17use risingwave_pb::expr::expr_node::Type;
18
19use super::Planner;
20use crate::binder::{BoundUpdate, UpdateProject};
21use crate::error::Result;
22use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, Literal};
23use crate::optimizer::plan_node::generic::GenericPlanRef;
24use crate::optimizer::plan_node::{LogicalProject, LogicalUpdate, generic};
25use crate::optimizer::property::{Order, RequiredDist};
26use crate::optimizer::{PlanRef, PlanRoot};
27
28impl Planner {
29    pub(super) fn plan_update(&mut self, update: BoundUpdate) -> Result<PlanRoot> {
30        let returning = !update.returning_list.is_empty();
31
32        let scan = self.plan_base_table(&update.table)?;
33        let input = if let Some(expr) = update.selection {
34            self.plan_where(scan, expr)?
35        } else {
36            scan
37        };
38        let old_schema_len = input.schema().len();
39
40        // Extend table scan with updated columns.
41        let with_new: PlanRef = {
42            let mut plan = input;
43
44            let mut exprs: Vec<ExprImpl> = plan
45                .schema()
46                .data_types()
47                .into_iter()
48                .enumerate()
49                .map(|(index, data_type)| InputRef::new(index, data_type).into())
50                .collect();
51
52            exprs.extend(update.exprs);
53
54            // Substitute subqueries into `LogicalApply`s.
55            if exprs.iter().any(|e| e.has_subquery()) {
56                (plan, exprs) = self.substitute_subqueries(plan, exprs)?;
57            }
58
59            LogicalProject::new(plan, exprs).into()
60        };
61
62        let mut olds = Vec::new();
63        let mut news = Vec::new();
64
65        for (i, col) in update.table.table_catalog.columns().iter().enumerate() {
66            // Skip generated columns and system columns.
67            if !col.can_dml() {
68                continue;
69            }
70            let data_type = col.data_type();
71
72            let old: ExprImpl = InputRef::new(i, data_type.clone()).into();
73
74            let mut new: ExprImpl = match (update.projects.get(&i)).map(|p| p.offset(old_schema_len)) {
75                Some(UpdateProject::Simple(j)) => InputRef::new(j, data_type.clone()).into(),
76                Some(UpdateProject::Composite(j, field)) => FunctionCall::new_unchecked(
77                    Type::Field,
78                    vec![
79                        InputRef::new(j, with_new.schema().data_types()[j].clone()).into(), // struct
80                        Literal::new(Some((field as i32).to_scalar_value()), DataType::Int32)
81                            .into(),
82                    ],
83                    data_type.clone(),
84                )
85                .into(),
86
87                None => old.clone(),
88            };
89            if !col.nullable() {
90                new = FunctionCall::new_unchecked(
91                    ExprType::CheckNotNull,
92                    vec![
93                        new,
94                        ExprImpl::literal_varchar(col.name().to_owned()),
95                        ExprImpl::literal_varchar(update.table_name.clone()),
96                    ],
97                    data_type.clone(),
98                )
99                .into();
100            }
101
102            olds.push(old);
103            news.push(new);
104        }
105
106        let mut plan: PlanRef = LogicalUpdate::from(generic::Update::new(
107            with_new,
108            update.table_name.clone(),
109            update.table_id,
110            update.table_version_id,
111            olds,
112            news,
113            returning,
114        ))
115        .into();
116
117        if returning {
118            plan = LogicalProject::create(plan, update.returning_list);
119        }
120
121        // For update, frontend will only schedule one task so do not need this to be single.
122        let dist = RequiredDist::Any;
123        let mut out_fields = FixedBitSet::with_capacity(plan.schema().len());
124        out_fields.insert_range(..);
125        let out_names = if returning {
126            update.returning_schema.expect("If returning list is not empty, should provide returning schema in BoundDelete.").names()
127        } else {
128            plan.schema().names()
129        };
130
131        let root = PlanRoot::new_with_logical_plan(plan, dist, Order::any(), out_fields, out_names);
132        Ok(root)
133    }
134}