risingwave_frontend/binder/
update.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{Schema, TableVersionId};
21use risingwave_common::types::StructType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem};
24
25use super::statement::RewriteExprsRecursive;
26use super::{Binder, BoundBaseTable};
27use crate::TableCatalog;
28use crate::catalog::TableId;
29use crate::error::{ErrorCode, Result, RwError, bail_bind_error, bind_error};
30use crate::expr::{Expr as _, ExprImpl, SubqueryKind};
31use crate::handler::privilege::ObjectCheckItem;
32use crate::user::UserId;
33
34/// Project into `exprs` in `BoundUpdate` to get the new values for updating.
35#[derive(Debug, Clone, Copy)]
36pub enum UpdateProject {
37    /// Use the expression at the given index in `exprs`.
38    Simple(usize),
39    /// Use the `i`-th field of the expression (returning a struct) at the given index in `exprs`.
40    Composite(usize, usize),
41}
42
43impl UpdateProject {
44    /// Offset the index by `i`.
45    pub fn offset(self, i: usize) -> Self {
46        match self {
47            UpdateProject::Simple(index) => UpdateProject::Simple(index + i),
48            UpdateProject::Composite(index, j) => UpdateProject::Composite(index + i, j),
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct BoundUpdate {
55    /// Id of the table to perform updating.
56    pub table_id: TableId,
57
58    /// Version id of the table.
59    pub table_version_id: TableVersionId,
60
61    /// Name of the table to perform updating.
62    pub table_name: String,
63
64    /// Owner of the table to perform updating.
65    pub owner: UserId,
66
67    /// Used for scanning the records to update with the `selection`.
68    pub table: BoundBaseTable,
69
70    pub selection: Option<ExprImpl>,
71
72    /// Expression used to evaluate the new values for the columns.
73    pub exprs: Vec<ExprImpl>,
74
75    /// Mapping from the index of the column to be updated, to the index of the expression in `exprs`.
76    ///
77    /// By constructing two `Project` nodes with `exprs` and `projects`, we can get the new values.
78    pub projects: HashMap<usize, UpdateProject>,
79
80    // used for the 'RETURNING" keyword to indicate the returning items and schema
81    // if the list is empty and the schema is None, the output schema will be a INT64 as the
82    // affected row cnt
83    pub returning_list: Vec<ExprImpl>,
84
85    pub returning_schema: Option<Schema>,
86}
87
88impl RewriteExprsRecursive for BoundUpdate {
89    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
90        self.selection =
91            std::mem::take(&mut self.selection).map(|expr| rewriter.rewrite_expr(expr));
92
93        let new_exprs = std::mem::take(&mut self.exprs)
94            .into_iter()
95            .map(|expr| rewriter.rewrite_expr(expr))
96            .collect::<Vec<_>>();
97        self.exprs = new_exprs;
98
99        let new_returning_list = std::mem::take(&mut self.returning_list)
100            .into_iter()
101            .map(|expr| rewriter.rewrite_expr(expr))
102            .collect::<Vec<_>>();
103        self.returning_list = new_returning_list;
104    }
105}
106
107fn get_col_referenced_by_generated_pk(table_catalog: &TableCatalog) -> Result<FixedBitSet> {
108    let column_num = table_catalog.columns().len();
109    let pk_col_id = table_catalog.pk_column_ids();
110    let mut bitset = FixedBitSet::with_capacity(column_num);
111
112    let generated_pk_col_exprs = table_catalog
113        .columns
114        .iter()
115        .filter(|c| pk_col_id.contains(&c.column_id()))
116        .flat_map(|c| c.generated_expr());
117    for expr_node in generated_pk_col_exprs {
118        let expr = ExprImpl::from_expr_proto(expr_node)?;
119        bitset.union_with(&expr.collect_input_refs(column_num));
120    }
121    Ok(bitset)
122}
123
124impl Binder {
125    pub(super) fn bind_update(
126        &mut self,
127        name: ObjectName,
128        assignments: Vec<Assignment>,
129        selection: Option<Expr>,
130        returning_items: Vec<SelectItem>,
131    ) -> Result<BoundUpdate> {
132        let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, &name)?;
133        let table = self.bind_table(schema_name.as_deref(), &table_name)?;
134
135        let table_catalog = &table.table_catalog;
136        Self::check_for_dml(table_catalog, false)?;
137        self.check_privilege(
138            ObjectCheckItem::new(
139                table_catalog.owner,
140                AclMode::Update,
141                table_name.clone(),
142                table_catalog.id,
143            ),
144            table_catalog.database_id,
145        )?;
146
147        let default_columns_from_catalog =
148            table_catalog.default_columns().collect::<BTreeMap<_, _>>();
149        if !returning_items.is_empty() && table_catalog.has_generated_column() {
150            return Err(RwError::from(ErrorCode::BindError(
151                "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
152            )));
153        }
154
155        let table_id = table_catalog.id;
156        let owner = table_catalog.owner;
157        let table_version_id = table_catalog.version_id().expect("table must be versioned");
158        let cols_refed_by_generated_pk = get_col_referenced_by_generated_pk(table_catalog)?;
159
160        let selection = selection.map(|expr| self.bind_expr(&expr)).transpose()?;
161
162        let mut exprs = Vec::new();
163        let mut projects = HashMap::new();
164
165        macro_rules! record {
166            ($id:expr, $project:expr) => {
167                let id_index = $id.as_input_ref().unwrap().index;
168                projects
169                    .try_insert(id_index, $project)
170                    .map_err(|_e| bind_error!("multiple assignments to the same column"))?;
171            };
172        }
173
174        for Assignment { id, value } in assignments {
175            let ids: Vec<_> = id
176                .iter()
177                .map(|id| self.bind_expr(&Expr::Identifier(id.clone())))
178                .try_collect()?;
179
180            match (ids.as_slice(), value) {
181                // `SET col1 = DEFAULT`, `SET (col1, col2, ...) = DEFAULT`
182                (ids, AssignmentValue::Default) => {
183                    for id in ids {
184                        let id_index = id.as_input_ref().unwrap().index;
185                        let expr = default_columns_from_catalog
186                            .get(&id_index)
187                            .cloned()
188                            .unwrap_or_else(|| ExprImpl::literal_null(id.return_type()));
189
190                        exprs.push(expr);
191                        record!(id, UpdateProject::Simple(exprs.len() - 1));
192                    }
193                }
194
195                // `SET col1 = expr`
196                ([id], AssignmentValue::Expr(expr)) => {
197                    let expr = self.bind_expr(&expr)?.cast_assign(&id.return_type())?;
198                    exprs.push(expr);
199                    record!(id, UpdateProject::Simple(exprs.len() - 1));
200                }
201                // `SET (col1, col2, ...) = (val1, val2, ...)`
202                (ids, AssignmentValue::Expr(Expr::Row(values))) => {
203                    if ids.len() != values.len() {
204                        bail_bind_error!("number of columns does not match number of values");
205                    }
206
207                    for (id, value) in ids.iter().zip_eq_fast(values) {
208                        let expr = self.bind_expr(&value)?.cast_assign(&id.return_type())?;
209                        exprs.push(expr);
210                        record!(id, UpdateProject::Simple(exprs.len() - 1));
211                    }
212                }
213                // `SET (col1, col2, ...) = (SELECT ...)`
214                (ids, AssignmentValue::Expr(Expr::Subquery(subquery))) => {
215                    let expr = self.bind_subquery_expr(&subquery, SubqueryKind::UpdateSet)?;
216
217                    if expr.return_type().as_struct().len() != ids.len() {
218                        bail_bind_error!("number of columns does not match number of values");
219                    }
220
221                    let target_type = StructType::new(
222                        id.iter()
223                            .zip_eq_fast(ids)
224                            .map(|(id, expr)| (id.real_value(), expr.return_type())),
225                    )
226                    .into();
227                    let expr = expr.cast_assign(&target_type)?;
228
229                    exprs.push(expr);
230
231                    for (i, id) in ids.iter().enumerate() {
232                        record!(id, UpdateProject::Composite(exprs.len() - 1, i));
233                    }
234                }
235
236                (_ids, AssignmentValue::Expr(_expr)) => {
237                    bail_bind_error!(
238                        "source for a multiple-column UPDATE item must be a sub-SELECT or ROW() expression"
239                    );
240                }
241            }
242        }
243
244        // Check whether updating these columns is allowed.
245        for &id_index in projects.keys() {
246            if (table.table_catalog.pk())
247                .iter()
248                .any(|k| k.column_index == id_index)
249            {
250                return Err(ErrorCode::BindError(
251                    "update modifying the PK column is unsupported".to_owned(),
252                )
253                .into());
254            }
255            if (table.table_catalog.generated_col_idxes()).contains(&id_index) {
256                return Err(ErrorCode::BindError(
257                    "update modifying the generated column is unsupported".to_owned(),
258                )
259                .into());
260            }
261            if cols_refed_by_generated_pk.contains(id_index) {
262                return Err(ErrorCode::BindError(
263                    "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(),
264                )
265                .into());
266            }
267
268            let col = &table.table_catalog.columns()[id_index];
269            if !col.can_dml() {
270                bail_bind_error!("update modifying column `{}` is unsupported", col.name());
271            }
272        }
273
274        let (returning_list, fields) = self.bind_returning_list(returning_items)?;
275        let returning = !returning_list.is_empty();
276
277        Ok(BoundUpdate {
278            table_id,
279            table_version_id,
280            table_name,
281            owner,
282            table,
283            selection,
284            projects,
285            exprs,
286            returning_list,
287            returning_schema: if returning {
288                Some(Schema { fields })
289            } else {
290                None
291            },
292        })
293    }
294}