risingwave_frontend/binder/
update.rs1use std::collections::{BTreeMap, HashMap};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{Schema, TableVersionId};
21use risingwave_common::types::StructType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::user::grant_privilege::PbObject;
24use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem};
25
26use super::statement::RewriteExprsRecursive;
27use super::{Binder, BoundBaseTable};
28use crate::TableCatalog;
29use crate::catalog::TableId;
30use crate::error::{ErrorCode, Result, RwError, bail_bind_error, bind_error};
31use crate::expr::{Expr as _, ExprImpl, SubqueryKind};
32use crate::user::UserId;
33
34#[derive(Debug, Clone, Copy)]
36pub enum UpdateProject {
37 Simple(usize),
39 Composite(usize, usize),
41}
42
43impl UpdateProject {
44 pub fn offset(self, i: usize) -> Self {
46 match self {
47 UpdateProject::Simple(index) => UpdateProject::Simple(index + i),
48 UpdateProject::Composite(index, j) => UpdateProject::Composite(index + i, j),
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct BoundUpdate {
55 pub table_id: TableId,
57
58 pub table_version_id: TableVersionId,
60
61 pub table_name: String,
63
64 pub owner: UserId,
66
67 pub table: BoundBaseTable,
69
70 pub selection: Option<ExprImpl>,
71
72 pub exprs: Vec<ExprImpl>,
74
75 pub projects: HashMap<usize, UpdateProject>,
79
80 pub returning_list: Vec<ExprImpl>,
84
85 pub returning_schema: Option<Schema>,
86}
87
88impl RewriteExprsRecursive for BoundUpdate {
89 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
90 self.selection =
91 std::mem::take(&mut self.selection).map(|expr| rewriter.rewrite_expr(expr));
92
93 let new_exprs = std::mem::take(&mut self.exprs)
94 .into_iter()
95 .map(|expr| rewriter.rewrite_expr(expr))
96 .collect::<Vec<_>>();
97 self.exprs = new_exprs;
98
99 let new_returning_list = std::mem::take(&mut self.returning_list)
100 .into_iter()
101 .map(|expr| rewriter.rewrite_expr(expr))
102 .collect::<Vec<_>>();
103 self.returning_list = new_returning_list;
104 }
105}
106
107fn get_col_referenced_by_generated_pk(table_catalog: &TableCatalog) -> Result<FixedBitSet> {
108 let column_num = table_catalog.columns().len();
109 let pk_col_id = table_catalog.pk_column_ids();
110 let mut bitset = FixedBitSet::with_capacity(column_num);
111
112 let generated_pk_col_exprs = table_catalog
113 .columns
114 .iter()
115 .filter(|c| pk_col_id.contains(&c.column_id()))
116 .flat_map(|c| c.generated_expr());
117 for expr_node in generated_pk_col_exprs {
118 let expr = ExprImpl::from_expr_proto(expr_node)?;
119 bitset.union_with(&expr.collect_input_refs(column_num));
120 }
121 Ok(bitset)
122}
123
124impl Binder {
125 pub(super) fn bind_update(
126 &mut self,
127 name: ObjectName,
128 assignments: Vec<Assignment>,
129 selection: Option<Expr>,
130 returning_items: Vec<SelectItem>,
131 ) -> Result<BoundUpdate> {
132 let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;
133 let table = self.bind_table(schema_name.as_deref(), &table_name)?;
134
135 let table_catalog = &table.table_catalog;
136 Self::check_for_dml(table_catalog, false)?;
137 self.check_privilege(
138 PbObject::TableId(table_catalog.id.table_id),
139 table_catalog.database_id,
140 AclMode::Update,
141 table_catalog.owner,
142 )?;
143
144 let default_columns_from_catalog =
145 table_catalog.default_columns().collect::<BTreeMap<_, _>>();
146 if !returning_items.is_empty() && table_catalog.has_generated_column() {
147 return Err(RwError::from(ErrorCode::BindError(
148 "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
149 )));
150 }
151
152 let table_id = table_catalog.id;
153 let owner = table_catalog.owner;
154 let table_version_id = table_catalog.version_id().expect("table must be versioned");
155 let cols_refed_by_generated_pk = get_col_referenced_by_generated_pk(table_catalog)?;
156
157 let selection = selection.map(|expr| self.bind_expr(expr)).transpose()?;
158
159 let mut exprs = Vec::new();
160 let mut projects = HashMap::new();
161
162 macro_rules! record {
163 ($id:expr, $project:expr) => {
164 let id_index = $id.as_input_ref().unwrap().index;
165 projects
166 .try_insert(id_index, $project)
167 .map_err(|_e| bind_error!("multiple assignments to the same column"))?;
168 };
169 }
170
171 for Assignment { id, value } in assignments {
172 let ids: Vec<_> = id
173 .iter()
174 .map(|id| self.bind_expr(Expr::Identifier(id.clone())))
175 .try_collect()?;
176
177 match (ids.as_slice(), value) {
178 (ids, AssignmentValue::Default) => {
180 for id in ids {
181 let id_index = id.as_input_ref().unwrap().index;
182 let expr = default_columns_from_catalog
183 .get(&id_index)
184 .cloned()
185 .unwrap_or_else(|| ExprImpl::literal_null(id.return_type()));
186
187 exprs.push(expr);
188 record!(id, UpdateProject::Simple(exprs.len() - 1));
189 }
190 }
191
192 ([id], AssignmentValue::Expr(expr)) => {
194 let expr = self.bind_expr(expr)?.cast_assign(id.return_type())?;
195 exprs.push(expr);
196 record!(id, UpdateProject::Simple(exprs.len() - 1));
197 }
198 (ids, AssignmentValue::Expr(Expr::Row(values))) => {
200 if ids.len() != values.len() {
201 bail_bind_error!("number of columns does not match number of values");
202 }
203
204 for (id, value) in ids.iter().zip_eq_fast(values) {
205 let expr = self.bind_expr(value)?.cast_assign(id.return_type())?;
206 exprs.push(expr);
207 record!(id, UpdateProject::Simple(exprs.len() - 1));
208 }
209 }
210 (ids, AssignmentValue::Expr(Expr::Subquery(subquery))) => {
212 let expr = self.bind_subquery_expr(*subquery, SubqueryKind::UpdateSet)?;
213
214 if expr.return_type().as_struct().len() != ids.len() {
215 bail_bind_error!("number of columns does not match number of values");
216 }
217
218 let target_type = StructType::new(
219 id.iter()
220 .zip_eq_fast(ids)
221 .map(|(id, expr)| (id.real_value(), expr.return_type())),
222 )
223 .into();
224 let expr = expr.cast_assign(target_type)?;
225
226 exprs.push(expr);
227
228 for (i, id) in ids.iter().enumerate() {
229 record!(id, UpdateProject::Composite(exprs.len() - 1, i));
230 }
231 }
232
233 (_ids, AssignmentValue::Expr(_expr)) => {
234 bail_bind_error!(
235 "source for a multiple-column UPDATE item must be a sub-SELECT or ROW() expression"
236 );
237 }
238 }
239 }
240
241 for &id_index in projects.keys() {
243 if (table.table_catalog.pk())
244 .iter()
245 .any(|k| k.column_index == id_index)
246 {
247 return Err(ErrorCode::BindError(
248 "update modifying the PK column is unsupported".to_owned(),
249 )
250 .into());
251 }
252 if (table.table_catalog.generated_col_idxes()).contains(&id_index) {
253 return Err(ErrorCode::BindError(
254 "update modifying the generated column is unsupported".to_owned(),
255 )
256 .into());
257 }
258 if cols_refed_by_generated_pk.contains(id_index) {
259 return Err(ErrorCode::BindError(
260 "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(),
261 )
262 .into());
263 }
264
265 let col = &table.table_catalog.columns()[id_index];
266 if !col.can_dml() {
267 bail_bind_error!("update modifying column `{}` is unsupported", col.name());
268 }
269 }
270
271 let (returning_list, fields) = self.bind_returning_list(returning_items)?;
272 let returning = !returning_list.is_empty();
273
274 Ok(BoundUpdate {
275 table_id,
276 table_version_id,
277 table_name,
278 owner,
279 table,
280 selection,
281 projects,
282 exprs,
283 returning_list,
284 returning_schema: if returning {
285 Some(Schema { fields })
286 } else {
287 None
288 },
289 })
290 }
291}