risingwave_frontend/binder/
update.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{Schema, TableVersionId};
21use risingwave_common::types::StructType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::user::grant_privilege::PbObject;
24use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem};
25
26use super::statement::RewriteExprsRecursive;
27use super::{Binder, BoundBaseTable};
28use crate::TableCatalog;
29use crate::catalog::TableId;
30use crate::error::{ErrorCode, Result, RwError, bail_bind_error, bind_error};
31use crate::expr::{Expr as _, ExprImpl, SubqueryKind};
32use crate::user::UserId;
33
34/// Project into `exprs` in `BoundUpdate` to get the new values for updating.
35#[derive(Debug, Clone, Copy)]
36pub enum UpdateProject {
37    /// Use the expression at the given index in `exprs`.
38    Simple(usize),
39    /// Use the `i`-th field of the expression (returning a struct) at the given index in `exprs`.
40    Composite(usize, usize),
41}
42
43impl UpdateProject {
44    /// Offset the index by `i`.
45    pub fn offset(self, i: usize) -> Self {
46        match self {
47            UpdateProject::Simple(index) => UpdateProject::Simple(index + i),
48            UpdateProject::Composite(index, j) => UpdateProject::Composite(index + i, j),
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct BoundUpdate {
55    /// Id of the table to perform updating.
56    pub table_id: TableId,
57
58    /// Version id of the table.
59    pub table_version_id: TableVersionId,
60
61    /// Name of the table to perform updating.
62    pub table_name: String,
63
64    /// Owner of the table to perform updating.
65    pub owner: UserId,
66
67    /// Used for scanning the records to update with the `selection`.
68    pub table: BoundBaseTable,
69
70    pub selection: Option<ExprImpl>,
71
72    /// Expression used to evaluate the new values for the columns.
73    pub exprs: Vec<ExprImpl>,
74
75    /// Mapping from the index of the column to be updated, to the index of the expression in `exprs`.
76    ///
77    /// By constructing two `Project` nodes with `exprs` and `projects`, we can get the new values.
78    pub projects: HashMap<usize, UpdateProject>,
79
80    // used for the 'RETURNING" keyword to indicate the returning items and schema
81    // if the list is empty and the schema is None, the output schema will be a INT64 as the
82    // affected row cnt
83    pub returning_list: Vec<ExprImpl>,
84
85    pub returning_schema: Option<Schema>,
86}
87
88impl RewriteExprsRecursive for BoundUpdate {
89    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
90        self.selection =
91            std::mem::take(&mut self.selection).map(|expr| rewriter.rewrite_expr(expr));
92
93        let new_exprs = std::mem::take(&mut self.exprs)
94            .into_iter()
95            .map(|expr| rewriter.rewrite_expr(expr))
96            .collect::<Vec<_>>();
97        self.exprs = new_exprs;
98
99        let new_returning_list = std::mem::take(&mut self.returning_list)
100            .into_iter()
101            .map(|expr| rewriter.rewrite_expr(expr))
102            .collect::<Vec<_>>();
103        self.returning_list = new_returning_list;
104    }
105}
106
107fn get_col_referenced_by_generated_pk(table_catalog: &TableCatalog) -> Result<FixedBitSet> {
108    let column_num = table_catalog.columns().len();
109    let pk_col_id = table_catalog.pk_column_ids();
110    let mut bitset = FixedBitSet::with_capacity(column_num);
111
112    let generated_pk_col_exprs = table_catalog
113        .columns
114        .iter()
115        .filter(|c| pk_col_id.contains(&c.column_id()))
116        .flat_map(|c| c.generated_expr());
117    for expr_node in generated_pk_col_exprs {
118        let expr = ExprImpl::from_expr_proto(expr_node)?;
119        bitset.union_with(&expr.collect_input_refs(column_num));
120    }
121    Ok(bitset)
122}
123
124impl Binder {
125    pub(super) fn bind_update(
126        &mut self,
127        name: ObjectName,
128        assignments: Vec<Assignment>,
129        selection: Option<Expr>,
130        returning_items: Vec<SelectItem>,
131    ) -> Result<BoundUpdate> {
132        let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;
133        let table = self.bind_table(schema_name.as_deref(), &table_name)?;
134
135        let table_catalog = &table.table_catalog;
136        Self::check_for_dml(table_catalog, false)?;
137        self.check_privilege(
138            PbObject::TableId(table_catalog.id.table_id),
139            table_catalog.database_id,
140            AclMode::Update,
141            table_catalog.owner,
142        )?;
143
144        let default_columns_from_catalog =
145            table_catalog.default_columns().collect::<BTreeMap<_, _>>();
146        if !returning_items.is_empty() && table_catalog.has_generated_column() {
147            return Err(RwError::from(ErrorCode::BindError(
148                "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
149            )));
150        }
151
152        let table_id = table_catalog.id;
153        let owner = table_catalog.owner;
154        let table_version_id = table_catalog.version_id().expect("table must be versioned");
155        let cols_refed_by_generated_pk = get_col_referenced_by_generated_pk(table_catalog)?;
156
157        let selection = selection.map(|expr| self.bind_expr(expr)).transpose()?;
158
159        let mut exprs = Vec::new();
160        let mut projects = HashMap::new();
161
162        macro_rules! record {
163            ($id:expr, $project:expr) => {
164                let id_index = $id.as_input_ref().unwrap().index;
165                projects
166                    .try_insert(id_index, $project)
167                    .map_err(|_e| bind_error!("multiple assignments to the same column"))?;
168            };
169        }
170
171        for Assignment { id, value } in assignments {
172            let ids: Vec<_> = id
173                .iter()
174                .map(|id| self.bind_expr(Expr::Identifier(id.clone())))
175                .try_collect()?;
176
177            match (ids.as_slice(), value) {
178                // `SET col1 = DEFAULT`, `SET (col1, col2, ...) = DEFAULT`
179                (ids, AssignmentValue::Default) => {
180                    for id in ids {
181                        let id_index = id.as_input_ref().unwrap().index;
182                        let expr = default_columns_from_catalog
183                            .get(&id_index)
184                            .cloned()
185                            .unwrap_or_else(|| ExprImpl::literal_null(id.return_type()));
186
187                        exprs.push(expr);
188                        record!(id, UpdateProject::Simple(exprs.len() - 1));
189                    }
190                }
191
192                // `SET col1 = expr`
193                ([id], AssignmentValue::Expr(expr)) => {
194                    let expr = self.bind_expr(expr)?.cast_assign(id.return_type())?;
195                    exprs.push(expr);
196                    record!(id, UpdateProject::Simple(exprs.len() - 1));
197                }
198                // `SET (col1, col2, ...) = (val1, val2, ...)`
199                (ids, AssignmentValue::Expr(Expr::Row(values))) => {
200                    if ids.len() != values.len() {
201                        bail_bind_error!("number of columns does not match number of values");
202                    }
203
204                    for (id, value) in ids.iter().zip_eq_fast(values) {
205                        let expr = self.bind_expr(value)?.cast_assign(id.return_type())?;
206                        exprs.push(expr);
207                        record!(id, UpdateProject::Simple(exprs.len() - 1));
208                    }
209                }
210                // `SET (col1, col2, ...) = (SELECT ...)`
211                (ids, AssignmentValue::Expr(Expr::Subquery(subquery))) => {
212                    let expr = self.bind_subquery_expr(*subquery, SubqueryKind::UpdateSet)?;
213
214                    if expr.return_type().as_struct().len() != ids.len() {
215                        bail_bind_error!("number of columns does not match number of values");
216                    }
217
218                    let target_type = StructType::new(
219                        id.iter()
220                            .zip_eq_fast(ids)
221                            .map(|(id, expr)| (id.real_value(), expr.return_type())),
222                    )
223                    .into();
224                    let expr = expr.cast_assign(target_type)?;
225
226                    exprs.push(expr);
227
228                    for (i, id) in ids.iter().enumerate() {
229                        record!(id, UpdateProject::Composite(exprs.len() - 1, i));
230                    }
231                }
232
233                (_ids, AssignmentValue::Expr(_expr)) => {
234                    bail_bind_error!(
235                        "source for a multiple-column UPDATE item must be a sub-SELECT or ROW() expression"
236                    );
237                }
238            }
239        }
240
241        // Check whether updating these columns is allowed.
242        for &id_index in projects.keys() {
243            if (table.table_catalog.pk())
244                .iter()
245                .any(|k| k.column_index == id_index)
246            {
247                return Err(ErrorCode::BindError(
248                    "update modifying the PK column is unsupported".to_owned(),
249                )
250                .into());
251            }
252            if (table.table_catalog.generated_col_idxes()).contains(&id_index) {
253                return Err(ErrorCode::BindError(
254                    "update modifying the generated column is unsupported".to_owned(),
255                )
256                .into());
257            }
258            if cols_refed_by_generated_pk.contains(id_index) {
259                return Err(ErrorCode::BindError(
260                    "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(),
261                )
262                .into());
263            }
264
265            let col = &table.table_catalog.columns()[id_index];
266            if !col.can_dml() {
267                bail_bind_error!("update modifying column `{}` is unsupported", col.name());
268            }
269        }
270
271        let (returning_list, fields) = self.bind_returning_list(returning_items)?;
272        let returning = !returning_list.is_empty();
273
274        Ok(BoundUpdate {
275            table_id,
276            table_version_id,
277            table_name,
278            owner,
279            table,
280            selection,
281            projects,
282            exprs,
283            returning_list,
284            returning_schema: if returning {
285                Some(Schema { fields })
286            } else {
287                None
288            },
289        })
290    }
291}