risingwave_frontend/binder/
update.rs1use std::collections::{BTreeMap, HashMap};
16
17use fixedbitset::FixedBitSet;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{Schema, TableVersionId};
21use risingwave_common::types::StructType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::user::grant_privilege::PbObject;
24use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem};
25
26use super::statement::RewriteExprsRecursive;
27use super::{Binder, BoundBaseTable};
28use crate::TableCatalog;
29use crate::catalog::TableId;
30use crate::error::{ErrorCode, Result, RwError, bail_bind_error, bind_error};
31use crate::expr::{Expr as _, ExprImpl, SubqueryKind};
32use crate::handler::privilege::ObjectCheckItem;
33use crate::user::UserId;
34
35#[derive(Debug, Clone, Copy)]
37pub enum UpdateProject {
38 Simple(usize),
40 Composite(usize, usize),
42}
43
44impl UpdateProject {
45 pub fn offset(self, i: usize) -> Self {
47 match self {
48 UpdateProject::Simple(index) => UpdateProject::Simple(index + i),
49 UpdateProject::Composite(index, j) => UpdateProject::Composite(index + i, j),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct BoundUpdate {
56 pub table_id: TableId,
58
59 pub table_version_id: TableVersionId,
61
62 pub table_name: String,
64
65 pub owner: UserId,
67
68 pub table: BoundBaseTable,
70
71 pub selection: Option<ExprImpl>,
72
73 pub exprs: Vec<ExprImpl>,
75
76 pub projects: HashMap<usize, UpdateProject>,
80
81 pub returning_list: Vec<ExprImpl>,
85
86 pub returning_schema: Option<Schema>,
87}
88
89impl RewriteExprsRecursive for BoundUpdate {
90 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
91 self.selection =
92 std::mem::take(&mut self.selection).map(|expr| rewriter.rewrite_expr(expr));
93
94 let new_exprs = std::mem::take(&mut self.exprs)
95 .into_iter()
96 .map(|expr| rewriter.rewrite_expr(expr))
97 .collect::<Vec<_>>();
98 self.exprs = new_exprs;
99
100 let new_returning_list = std::mem::take(&mut self.returning_list)
101 .into_iter()
102 .map(|expr| rewriter.rewrite_expr(expr))
103 .collect::<Vec<_>>();
104 self.returning_list = new_returning_list;
105 }
106}
107
108fn get_col_referenced_by_generated_pk(table_catalog: &TableCatalog) -> Result<FixedBitSet> {
109 let column_num = table_catalog.columns().len();
110 let pk_col_id = table_catalog.pk_column_ids();
111 let mut bitset = FixedBitSet::with_capacity(column_num);
112
113 let generated_pk_col_exprs = table_catalog
114 .columns
115 .iter()
116 .filter(|c| pk_col_id.contains(&c.column_id()))
117 .flat_map(|c| c.generated_expr());
118 for expr_node in generated_pk_col_exprs {
119 let expr = ExprImpl::from_expr_proto(expr_node)?;
120 bitset.union_with(&expr.collect_input_refs(column_num));
121 }
122 Ok(bitset)
123}
124
125impl Binder {
126 pub(super) fn bind_update(
127 &mut self,
128 name: ObjectName,
129 assignments: Vec<Assignment>,
130 selection: Option<Expr>,
131 returning_items: Vec<SelectItem>,
132 ) -> Result<BoundUpdate> {
133 let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;
134 let table = self.bind_table(schema_name.as_deref(), &table_name)?;
135
136 let table_catalog = &table.table_catalog;
137 Self::check_for_dml(table_catalog, false)?;
138 self.check_privilege(
139 ObjectCheckItem::new(
140 table_catalog.owner,
141 AclMode::Update,
142 table_name.clone(),
143 PbObject::TableId(table_catalog.id.table_id),
144 ),
145 table_catalog.database_id,
146 )?;
147
148 let default_columns_from_catalog =
149 table_catalog.default_columns().collect::<BTreeMap<_, _>>();
150 if !returning_items.is_empty() && table_catalog.has_generated_column() {
151 return Err(RwError::from(ErrorCode::BindError(
152 "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
153 )));
154 }
155
156 let table_id = table_catalog.id;
157 let owner = table_catalog.owner;
158 let table_version_id = table_catalog.version_id().expect("table must be versioned");
159 let cols_refed_by_generated_pk = get_col_referenced_by_generated_pk(table_catalog)?;
160
161 let selection = selection.map(|expr| self.bind_expr(expr)).transpose()?;
162
163 let mut exprs = Vec::new();
164 let mut projects = HashMap::new();
165
166 macro_rules! record {
167 ($id:expr, $project:expr) => {
168 let id_index = $id.as_input_ref().unwrap().index;
169 projects
170 .try_insert(id_index, $project)
171 .map_err(|_e| bind_error!("multiple assignments to the same column"))?;
172 };
173 }
174
175 for Assignment { id, value } in assignments {
176 let ids: Vec<_> = id
177 .iter()
178 .map(|id| self.bind_expr(Expr::Identifier(id.clone())))
179 .try_collect()?;
180
181 match (ids.as_slice(), value) {
182 (ids, AssignmentValue::Default) => {
184 for id in ids {
185 let id_index = id.as_input_ref().unwrap().index;
186 let expr = default_columns_from_catalog
187 .get(&id_index)
188 .cloned()
189 .unwrap_or_else(|| ExprImpl::literal_null(id.return_type()));
190
191 exprs.push(expr);
192 record!(id, UpdateProject::Simple(exprs.len() - 1));
193 }
194 }
195
196 ([id], AssignmentValue::Expr(expr)) => {
198 let expr = self.bind_expr(expr)?.cast_assign(id.return_type())?;
199 exprs.push(expr);
200 record!(id, UpdateProject::Simple(exprs.len() - 1));
201 }
202 (ids, AssignmentValue::Expr(Expr::Row(values))) => {
204 if ids.len() != values.len() {
205 bail_bind_error!("number of columns does not match number of values");
206 }
207
208 for (id, value) in ids.iter().zip_eq_fast(values) {
209 let expr = self.bind_expr(value)?.cast_assign(id.return_type())?;
210 exprs.push(expr);
211 record!(id, UpdateProject::Simple(exprs.len() - 1));
212 }
213 }
214 (ids, AssignmentValue::Expr(Expr::Subquery(subquery))) => {
216 let expr = self.bind_subquery_expr(*subquery, SubqueryKind::UpdateSet)?;
217
218 if expr.return_type().as_struct().len() != ids.len() {
219 bail_bind_error!("number of columns does not match number of values");
220 }
221
222 let target_type = StructType::new(
223 id.iter()
224 .zip_eq_fast(ids)
225 .map(|(id, expr)| (id.real_value(), expr.return_type())),
226 )
227 .into();
228 let expr = expr.cast_assign(target_type)?;
229
230 exprs.push(expr);
231
232 for (i, id) in ids.iter().enumerate() {
233 record!(id, UpdateProject::Composite(exprs.len() - 1, i));
234 }
235 }
236
237 (_ids, AssignmentValue::Expr(_expr)) => {
238 bail_bind_error!(
239 "source for a multiple-column UPDATE item must be a sub-SELECT or ROW() expression"
240 );
241 }
242 }
243 }
244
245 for &id_index in projects.keys() {
247 if (table.table_catalog.pk())
248 .iter()
249 .any(|k| k.column_index == id_index)
250 {
251 return Err(ErrorCode::BindError(
252 "update modifying the PK column is unsupported".to_owned(),
253 )
254 .into());
255 }
256 if (table.table_catalog.generated_col_idxes()).contains(&id_index) {
257 return Err(ErrorCode::BindError(
258 "update modifying the generated column is unsupported".to_owned(),
259 )
260 .into());
261 }
262 if cols_refed_by_generated_pk.contains(id_index) {
263 return Err(ErrorCode::BindError(
264 "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(),
265 )
266 .into());
267 }
268
269 let col = &table.table_catalog.columns()[id_index];
270 if !col.can_dml() {
271 bail_bind_error!("update modifying column `{}` is unsupported", col.name());
272 }
273 }
274
275 let (returning_list, fields) = self.bind_returning_list(returning_items)?;
276 let returning = !returning_list.is_empty();
277
278 Ok(BoundUpdate {
279 table_id,
280 table_version_id,
281 table_name,
282 owner,
283 table,
284 selection,
285 projects,
286 exprs,
287 returning_list,
288 returning_schema: if returning {
289 Some(Schema { fields })
290 } else {
291 None
292 },
293 })
294 }
295}