risingwave_frontend/planner/
update.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use risingwave_common::types::{DataType, Scalar};
17use risingwave_pb::expr::expr_node::Type;
18
19use super::Planner;
20use crate::binder::{BoundUpdate, UpdateProject};
21use crate::error::Result;
22use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, Literal};
23use crate::optimizer::plan_node::generic::GenericPlanRef;
24use crate::optimizer::plan_node::{
25    LogicalPlanRef as PlanRef, LogicalProject, LogicalUpdate, generic,
26};
27use crate::optimizer::property::{Order, RequiredDist};
28use crate::optimizer::{LogicalPlanRoot, PlanRoot};
29
30impl Planner {
31    pub(super) fn plan_update(&mut self, update: BoundUpdate) -> Result<LogicalPlanRoot> {
32        let returning = !update.returning_list.is_empty();
33
34        let scan = self.plan_base_table(&update.table)?;
35        let input = if let Some(expr) = update.selection {
36            self.plan_where(scan, expr)?
37        } else {
38            scan
39        };
40        let old_schema_len = input.schema().len();
41
42        // Extend table scan with updated columns.
43        let with_new: PlanRef = {
44            let mut plan = input;
45
46            let mut exprs: Vec<ExprImpl> = plan
47                .schema()
48                .data_types()
49                .into_iter()
50                .enumerate()
51                .map(|(index, data_type)| InputRef::new(index, data_type).into())
52                .collect();
53
54            exprs.extend(update.exprs);
55
56            // Substitute subqueries into `LogicalApply`s.
57            if exprs.iter().any(|e| e.has_subquery()) {
58                (plan, exprs) = self.substitute_subqueries_in_cross_join_way(plan, exprs)?;
59            }
60
61            LogicalProject::new(plan, exprs).into()
62        };
63
64        let mut olds = Vec::new();
65        let mut news = Vec::new();
66
67        for (i, col) in update.table.table_catalog.columns().iter().enumerate() {
68            // Skip generated columns and system columns.
69            if !col.can_dml() {
70                continue;
71            }
72            let data_type = col.data_type();
73
74            let old: ExprImpl = InputRef::new(i, data_type.clone()).into();
75
76            let mut new: ExprImpl = match (update.projects.get(&i)).map(|p| p.offset(old_schema_len)) {
77                Some(UpdateProject::Simple(j)) => InputRef::new(j, data_type.clone()).into(),
78                Some(UpdateProject::Composite(j, field)) => FunctionCall::new_unchecked(
79                    Type::Field,
80                    vec![
81                        InputRef::new(j, with_new.schema().data_types()[j].clone()).into(), // struct
82                        Literal::new(Some((field as i32).to_scalar_value()), DataType::Int32)
83                            .into(),
84                    ],
85                    data_type.clone(),
86                )
87                .into(),
88
89                None => old.clone(),
90            };
91            if !col.nullable() {
92                new = FunctionCall::new_unchecked(
93                    ExprType::CheckNotNull,
94                    vec![
95                        new,
96                        ExprImpl::literal_varchar(col.name().to_owned()),
97                        ExprImpl::literal_varchar(update.table_name.clone()),
98                    ],
99                    data_type.clone(),
100                )
101                .into();
102            }
103
104            olds.push(old);
105            news.push(new);
106        }
107
108        let mut plan: PlanRef = LogicalUpdate::from(generic::Update::new(
109            with_new,
110            update.table_name.clone(),
111            update.table_id,
112            update.table_version_id,
113            olds,
114            news,
115            returning,
116        ))
117        .into();
118
119        if returning {
120            plan = LogicalProject::create(plan, update.returning_list);
121        }
122
123        // For update, frontend will only schedule one task so do not need this to be single.
124        let dist = RequiredDist::Any;
125        let mut out_fields = FixedBitSet::with_capacity(plan.schema().len());
126        out_fields.insert_range(..);
127        let out_names = if returning {
128            update.returning_schema.expect("If returning list is not empty, should provide returning schema in BoundDelete.").names()
129        } else {
130            plan.schema().names()
131        };
132
133        let root = PlanRoot::new_with_logical_plan(plan, dist, Order::any(), out_fields, out_names);
134        Ok(root)
135    }
136}