risingwave_frontend/binder/
update.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{Schema, TableVersionId};
21use risingwave_common::types::StructType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::user::grant_privilege::PbObject;
24use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem};
25
26use super::statement::RewriteExprsRecursive;
27use super::{Binder, BoundBaseTable};
28use crate::TableCatalog;
29use crate::catalog::TableId;
30use crate::error::{ErrorCode, Result, RwError, bail_bind_error, bind_error};
31use crate::expr::{Expr as _, ExprImpl, SubqueryKind};
32use crate::handler::privilege::ObjectCheckItem;
33use crate::user::UserId;
34
35/// Project into `exprs` in `BoundUpdate` to get the new values for updating.
36#[derive(Debug, Clone, Copy)]
37pub enum UpdateProject {
38    /// Use the expression at the given index in `exprs`.
39    Simple(usize),
40    /// Use the `i`-th field of the expression (returning a struct) at the given index in `exprs`.
41    Composite(usize, usize),
42}
43
44impl UpdateProject {
45    /// Offset the index by `i`.
46    pub fn offset(self, i: usize) -> Self {
47        match self {
48            UpdateProject::Simple(index) => UpdateProject::Simple(index + i),
49            UpdateProject::Composite(index, j) => UpdateProject::Composite(index + i, j),
50        }
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct BoundUpdate {
56    /// Id of the table to perform updating.
57    pub table_id: TableId,
58
59    /// Version id of the table.
60    pub table_version_id: TableVersionId,
61
62    /// Name of the table to perform updating.
63    pub table_name: String,
64
65    /// Owner of the table to perform updating.
66    pub owner: UserId,
67
68    /// Used for scanning the records to update with the `selection`.
69    pub table: BoundBaseTable,
70
71    pub selection: Option<ExprImpl>,
72
73    /// Expression used to evaluate the new values for the columns.
74    pub exprs: Vec<ExprImpl>,
75
76    /// Mapping from the index of the column to be updated, to the index of the expression in `exprs`.
77    ///
78    /// By constructing two `Project` nodes with `exprs` and `projects`, we can get the new values.
79    pub projects: HashMap<usize, UpdateProject>,
80
81    // used for the 'RETURNING" keyword to indicate the returning items and schema
82    // if the list is empty and the schema is None, the output schema will be a INT64 as the
83    // affected row cnt
84    pub returning_list: Vec<ExprImpl>,
85
86    pub returning_schema: Option<Schema>,
87}
88
89impl RewriteExprsRecursive for BoundUpdate {
90    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
91        self.selection =
92            std::mem::take(&mut self.selection).map(|expr| rewriter.rewrite_expr(expr));
93
94        let new_exprs = std::mem::take(&mut self.exprs)
95            .into_iter()
96            .map(|expr| rewriter.rewrite_expr(expr))
97            .collect::<Vec<_>>();
98        self.exprs = new_exprs;
99
100        let new_returning_list = std::mem::take(&mut self.returning_list)
101            .into_iter()
102            .map(|expr| rewriter.rewrite_expr(expr))
103            .collect::<Vec<_>>();
104        self.returning_list = new_returning_list;
105    }
106}
107
108fn get_col_referenced_by_generated_pk(table_catalog: &TableCatalog) -> Result<FixedBitSet> {
109    let column_num = table_catalog.columns().len();
110    let pk_col_id = table_catalog.pk_column_ids();
111    let mut bitset = FixedBitSet::with_capacity(column_num);
112
113    let generated_pk_col_exprs = table_catalog
114        .columns
115        .iter()
116        .filter(|c| pk_col_id.contains(&c.column_id()))
117        .flat_map(|c| c.generated_expr());
118    for expr_node in generated_pk_col_exprs {
119        let expr = ExprImpl::from_expr_proto(expr_node)?;
120        bitset.union_with(&expr.collect_input_refs(column_num));
121    }
122    Ok(bitset)
123}
124
125impl Binder {
126    pub(super) fn bind_update(
127        &mut self,
128        name: ObjectName,
129        assignments: Vec<Assignment>,
130        selection: Option<Expr>,
131        returning_items: Vec<SelectItem>,
132    ) -> Result<BoundUpdate> {
133        let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;
134        let table = self.bind_table(schema_name.as_deref(), &table_name)?;
135
136        let table_catalog = &table.table_catalog;
137        Self::check_for_dml(table_catalog, false)?;
138        self.check_privilege(
139            ObjectCheckItem::new(
140                table_catalog.owner,
141                AclMode::Update,
142                table_name.clone(),
143                PbObject::TableId(table_catalog.id.table_id),
144            ),
145            table_catalog.database_id,
146        )?;
147
148        let default_columns_from_catalog =
149            table_catalog.default_columns().collect::<BTreeMap<_, _>>();
150        if !returning_items.is_empty() && table_catalog.has_generated_column() {
151            return Err(RwError::from(ErrorCode::BindError(
152                "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
153            )));
154        }
155
156        let table_id = table_catalog.id;
157        let owner = table_catalog.owner;
158        let table_version_id = table_catalog.version_id().expect("table must be versioned");
159        let cols_refed_by_generated_pk = get_col_referenced_by_generated_pk(table_catalog)?;
160
161        let selection = selection.map(|expr| self.bind_expr(expr)).transpose()?;
162
163        let mut exprs = Vec::new();
164        let mut projects = HashMap::new();
165
166        macro_rules! record {
167            ($id:expr, $project:expr) => {
168                let id_index = $id.as_input_ref().unwrap().index;
169                projects
170                    .try_insert(id_index, $project)
171                    .map_err(|_e| bind_error!("multiple assignments to the same column"))?;
172            };
173        }
174
175        for Assignment { id, value } in assignments {
176            let ids: Vec<_> = id
177                .iter()
178                .map(|id| self.bind_expr(Expr::Identifier(id.clone())))
179                .try_collect()?;
180
181            match (ids.as_slice(), value) {
182                // `SET col1 = DEFAULT`, `SET (col1, col2, ...) = DEFAULT`
183                (ids, AssignmentValue::Default) => {
184                    for id in ids {
185                        let id_index = id.as_input_ref().unwrap().index;
186                        let expr = default_columns_from_catalog
187                            .get(&id_index)
188                            .cloned()
189                            .unwrap_or_else(|| ExprImpl::literal_null(id.return_type()));
190
191                        exprs.push(expr);
192                        record!(id, UpdateProject::Simple(exprs.len() - 1));
193                    }
194                }
195
196                // `SET col1 = expr`
197                ([id], AssignmentValue::Expr(expr)) => {
198                    let expr = self.bind_expr(expr)?.cast_assign(id.return_type())?;
199                    exprs.push(expr);
200                    record!(id, UpdateProject::Simple(exprs.len() - 1));
201                }
202                // `SET (col1, col2, ...) = (val1, val2, ...)`
203                (ids, AssignmentValue::Expr(Expr::Row(values))) => {
204                    if ids.len() != values.len() {
205                        bail_bind_error!("number of columns does not match number of values");
206                    }
207
208                    for (id, value) in ids.iter().zip_eq_fast(values) {
209                        let expr = self.bind_expr(value)?.cast_assign(id.return_type())?;
210                        exprs.push(expr);
211                        record!(id, UpdateProject::Simple(exprs.len() - 1));
212                    }
213                }
214                // `SET (col1, col2, ...) = (SELECT ...)`
215                (ids, AssignmentValue::Expr(Expr::Subquery(subquery))) => {
216                    let expr = self.bind_subquery_expr(*subquery, SubqueryKind::UpdateSet)?;
217
218                    if expr.return_type().as_struct().len() != ids.len() {
219                        bail_bind_error!("number of columns does not match number of values");
220                    }
221
222                    let target_type = StructType::new(
223                        id.iter()
224                            .zip_eq_fast(ids)
225                            .map(|(id, expr)| (id.real_value(), expr.return_type())),
226                    )
227                    .into();
228                    let expr = expr.cast_assign(target_type)?;
229
230                    exprs.push(expr);
231
232                    for (i, id) in ids.iter().enumerate() {
233                        record!(id, UpdateProject::Composite(exprs.len() - 1, i));
234                    }
235                }
236
237                (_ids, AssignmentValue::Expr(_expr)) => {
238                    bail_bind_error!(
239                        "source for a multiple-column UPDATE item must be a sub-SELECT or ROW() expression"
240                    );
241                }
242            }
243        }
244
245        // Check whether updating these columns is allowed.
246        for &id_index in projects.keys() {
247            if (table.table_catalog.pk())
248                .iter()
249                .any(|k| k.column_index == id_index)
250            {
251                return Err(ErrorCode::BindError(
252                    "update modifying the PK column is unsupported".to_owned(),
253                )
254                .into());
255            }
256            if (table.table_catalog.generated_col_idxes()).contains(&id_index) {
257                return Err(ErrorCode::BindError(
258                    "update modifying the generated column is unsupported".to_owned(),
259                )
260                .into());
261            }
262            if cols_refed_by_generated_pk.contains(id_index) {
263                return Err(ErrorCode::BindError(
264                    "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(),
265                )
266                .into());
267            }
268
269            let col = &table.table_catalog.columns()[id_index];
270            if !col.can_dml() {
271                bail_bind_error!("update modifying column `{}` is unsupported", col.name());
272            }
273        }
274
275        let (returning_list, fields) = self.bind_returning_list(returning_items)?;
276        let returning = !returning_list.is_empty();
277
278        Ok(BoundUpdate {
279            table_id,
280            table_version_id,
281            table_name,
282            owner,
283            table,
284            selection,
285            projects,
286            exprs,
287            returning_list,
288            returning_schema: if returning {
289                Some(Schema { fields })
290            } else {
291                None
292            },
293        })
294    }
295}