risingwave_frontend/binder/
insert.rs1use std::collections::{HashMap, HashSet};
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::expr::expr_node::Type as ExprType;
24use risingwave_pb::plan_common::DefaultColumnDesc;
25use risingwave_pb::plan_common::column_desc::GeneratedOrDefaultColumn;
26use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};
27
28use super::BoundQuery;
29use super::statement::RewriteExprsRecursive;
30use crate::binder::{Binder, Clause};
31use crate::catalog::TableId;
32use crate::error::{ErrorCode, Result, RwError};
33use crate::expr::{Expr, ExprImpl, FunctionCall, InputRef};
34use crate::handler::privilege::ObjectCheckItem;
35use crate::user::UserId;
36use crate::utils::ordinal;
37
38#[derive(Debug, Clone)]
39pub struct BoundInsert {
40 pub table_id: TableId,
42
43 pub table_version_id: TableVersionId,
45
46 pub table_name: String,
48
49 pub table_visible_columns: Vec<ColumnCatalog>,
52
53 pub owner: UserId,
55
56 pub row_id_index: Option<usize>,
59
60 pub column_indices: Vec<usize>,
65
66 pub default_columns: Vec<(usize, ExprImpl)>,
69
70 pub source: BoundQuery,
71
72 pub cast_exprs: Vec<ExprImpl>,
76
77 pub returning_list: Vec<ExprImpl>,
81
82 pub returning_schema: Option<Schema>,
83}
84
85impl RewriteExprsRecursive for BoundInsert {
86 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
87 self.source.rewrite_exprs_recursive(rewriter);
88
89 let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
90 .into_iter()
91 .map(|expr| rewriter.rewrite_expr(expr))
92 .collect::<Vec<_>>();
93 self.cast_exprs = new_cast_exprs;
94
95 let new_returning_list = std::mem::take(&mut self.returning_list)
96 .into_iter()
97 .map(|expr| rewriter.rewrite_expr(expr))
98 .collect::<Vec<_>>();
99 self.returning_list = new_returning_list;
100 }
101}
102
103impl Binder {
104 pub(super) fn bind_insert(
105 &mut self,
106 name: ObjectName,
107 cols_to_insert_by_user: Vec<Ident>,
108 source: Query,
109 returning_items: Vec<SelectItem>,
110 ) -> Result<BoundInsert> {
111 let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, &name)?;
112 self.context.clause = Some(Clause::Insert);
114 let bound_table = self.bind_table(schema_name.as_deref(), &table_name)?;
115 let table_catalog = &bound_table.table_catalog;
116 Self::check_for_dml(table_catalog, true)?;
117 self.check_privilege(
118 ObjectCheckItem::new(
119 table_catalog.owner,
120 AclMode::Insert,
121 table_name.clone(),
122 table_catalog.id,
123 ),
124 table_catalog.database_id,
125 )?;
126
127 let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
128 let table_id = table_catalog.id;
129 let owner = table_catalog.owner;
130 let table_version_id = table_catalog.version_id().expect("table must be versioned");
131 let table_visible_columns = table_catalog
132 .columns()
133 .iter()
134 .filter(|c| !c.is_hidden())
135 .cloned()
136 .collect_vec();
137 let (cols_to_insert_in_table, row_id_index) = table_catalog.columns_to_insert();
138 let cols_to_insert_in_table = cols_to_insert_in_table
139 .map(|(column, _)| column.clone())
140 .collect_vec();
141 let default_columns_from_catalog = cols_to_insert_in_table
143 .iter()
144 .enumerate()
145 .filter_map(|(idx, col)| {
146 if let Some(GeneratedOrDefaultColumn::DefaultColumn(DefaultColumnDesc {
147 expr,
148 ..
149 })) = col.column_desc.generated_or_default_column.as_ref()
150 {
151 Some((
152 idx,
153 ExprImpl::from_expr_proto(expr.as_ref().unwrap())
154 .expect("expr in default columns corrupted"),
155 ))
156 } else {
157 None
158 }
159 })
160 .collect::<HashMap<_, _>>();
161
162 let generated_column_names = table_catalog
163 .generated_column_names()
164 .collect::<HashSet<_>>();
165
166 let check_generated_insert_violation = |bound_column_nums: Option<usize>| -> Result<()> {
167 let generated_column_name = if let Some(column_num) = bound_column_nums {
168 table_catalog
169 .first_generated_column()
170 .and_then(|(index, column_name)| {
171 (column_num > index).then_some(column_name.to_owned())
172 })
173 } else {
174 cols_to_insert_by_user
175 .iter()
176 .map(|col| col.real_value())
177 .find(|col_name| generated_column_names.contains(col_name.as_str()))
178 };
179
180 if let Some(column_name) = generated_column_name {
181 return Err(ErrorCode::InsertViolation(format!(
182 "cannot insert a non-DEFAULT value into column \"{0}\"\n DETAIL: Column \"{0}\" is a generated column.",
183 column_name
184 ))
185 .into());
186 }
187
188 Ok(())
189 };
190
191 if has_user_specified_columns {
192 check_generated_insert_violation(None)?;
193 }
194
195 if !generated_column_names.is_empty() && !returning_items.is_empty() {
196 return Err(RwError::from(ErrorCode::BindError(
197 "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
198 )));
199 }
200
201 let (returning_list, fields) = self.bind_returning_list(returning_items)?;
202 let is_returning = !returning_list.is_empty();
203
204 let (mut col_indices_to_insert, default_column_indices) = get_col_indices_to_insert(
205 &cols_to_insert_in_table,
206 &cols_to_insert_by_user,
207 &table_name,
208 )?;
209 let expected_types: Vec<DataType> = col_indices_to_insert
211 .iter()
212 .map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
213 .collect();
214
215 let nullables: Vec<(bool, &str)> = col_indices_to_insert
217 .iter()
218 .map(|idx| {
219 (
220 cols_to_insert_in_table[*idx].nullable(),
221 cols_to_insert_in_table[*idx].name(),
222 )
223 })
224 .collect();
225
226 let bound_query;
253 let cast_exprs;
254 let all_nullable = nullables.iter().all(|(nullable, _)| *nullable);
255
256 let bound_column_nums = match source.as_simple_values() {
257 None => {
258 bound_query = self.bind_query(&source)?;
259 if !has_user_specified_columns {
260 check_generated_insert_violation(Some(bound_query.schema().len()))?;
261 }
262 let actual_types = bound_query.data_types();
263 let type_match = expected_types == actual_types;
264 cast_exprs = if all_nullable && type_match {
265 vec![]
266 } else {
267 let mut cast_exprs = actual_types
268 .into_iter()
269 .enumerate()
270 .map(|(i, t)| InputRef::new(i, t).into())
271 .collect();
272 if !type_match {
273 cast_exprs = Self::cast_on_insert(&expected_types, cast_exprs)?
274 }
275 if !all_nullable {
276 cast_exprs =
277 Self::check_not_null(&nullables, cast_exprs, table_name.as_str())?
278 }
279 cast_exprs
280 };
281 bound_query.schema().len()
282 }
283 Some(values) => {
284 let values_len = values
285 .0
286 .first()
287 .expect("values list should not be empty")
288 .len();
289 if !has_user_specified_columns {
290 check_generated_insert_violation(Some(values_len))?;
291 }
292 let mut values = self.bind_values(values, Some(&expected_types))?;
293 if !all_nullable {
296 values.rows = values
297 .rows
298 .into_iter()
299 .map(|vec| Self::check_not_null(&nullables, vec, table_name.as_str()))
300 .try_collect()?;
301 }
302
303 bound_query = BoundQuery::with_values(values);
304 cast_exprs = vec![];
305 values_len
306 }
307 };
308
309 let num_target_cols = if has_user_specified_columns {
310 cols_to_insert_by_user.len()
311 } else {
312 cols_to_insert_in_table.len()
313 };
314
315 let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
316 std::cmp::Ordering::Equal => (None, default_column_indices),
317 std::cmp::Ordering::Greater => {
318 if has_user_specified_columns {
319 (
321 Some("INSERT has more target columns than expressions"),
322 vec![],
323 )
324 } else {
325 (None, col_indices_to_insert.split_off(bound_column_nums))
330 }
331 }
332 std::cmp::Ordering::Less => {
333 (
337 Some("INSERT has more expressions than target columns"),
338 vec![],
339 )
340 }
341 };
342 if let Some(msg) = err_msg {
343 return Err(RwError::from(ErrorCode::BindError(msg.to_owned())));
344 }
345
346 let default_columns = default_column_indices
347 .into_iter()
348 .map(|i| {
349 let column = &cols_to_insert_in_table[i];
350 let expr = default_columns_from_catalog
351 .get(&i)
352 .cloned()
353 .unwrap_or_else(|| ExprImpl::literal_null(column.data_type().clone()));
354
355 let expr = if column.nullable() {
356 expr
357 } else {
358 FunctionCall::new_unchecked(
359 ExprType::CheckNotNull,
360 vec![
361 expr,
362 ExprImpl::literal_varchar(column.name().to_owned()),
363 ExprImpl::literal_varchar(table_name.clone()),
364 ],
365 column.data_type().clone(),
366 )
367 .into()
368 };
369
370 (i, expr)
371 })
372 .collect_vec();
373
374 let insert = BoundInsert {
375 table_id,
376 table_version_id,
377 table_name,
378 table_visible_columns,
379 owner,
380 row_id_index,
381 column_indices: col_indices_to_insert,
382 default_columns,
383 source: bound_query,
384 cast_exprs,
385 returning_list,
386 returning_schema: if is_returning {
387 Some(Schema { fields })
388 } else {
389 None
390 },
391 };
392 Ok(insert)
393 }
394
395 pub(super) fn cast_on_insert(
398 expected_types: &[DataType],
399 exprs: Vec<ExprImpl>,
400 ) -> Result<Vec<ExprImpl>> {
401 let msg = match expected_types.len().cmp(&exprs.len()) {
402 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
403 _ => {
404 let expr_len = exprs.len();
405 return exprs
406 .into_iter()
407 .zip_eq_fast(expected_types.iter().take(expr_len))
408 .enumerate()
409 .map(|(i, (e, t))| {
410 let res = e.cast_assign(t);
411 if expr_len > 1 {
412 res.with_context(|| {
413 format!("failed to cast the {} column", ordinal(i + 1))
414 })
415 .map_err(Into::into)
416 } else {
417 res.map_err(Into::into)
418 }
419 })
420 .try_collect();
421 }
422 };
423 Err(ErrorCode::BindError(msg.into()).into())
424 }
425
426 pub(super) fn check_not_null(
428 nullables: &Vec<(bool, &str)>,
429 exprs: Vec<ExprImpl>,
430 table_name: &str,
431 ) -> Result<Vec<ExprImpl>> {
432 let msg = match nullables.len().cmp(&exprs.len()) {
433 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
434 _ => {
435 let expr_len = exprs.len();
436 return exprs
437 .into_iter()
438 .zip_eq_fast(nullables.iter().take(expr_len))
439 .map(|(expr, (nullable, col_name))| {
440 if !nullable {
441 let return_type = expr.return_type();
442 let check_not_null = FunctionCall::new_unchecked(
443 ExprType::CheckNotNull,
444 vec![
445 expr,
446 ExprImpl::literal_varchar((*col_name).to_owned()),
447 ExprImpl::literal_varchar(table_name.to_owned()),
448 ],
449 return_type,
450 );
451 Ok(check_not_null.into())
453 } else {
454 Ok(expr)
455 }
456 })
457 .try_collect();
458 }
459 };
460 Err(ErrorCode::BindError(msg.into()).into())
461 }
462}
463
464fn get_col_indices_to_insert(
484 cols_to_insert_in_table: &[ColumnCatalog],
485 cols_to_insert_by_user: &[Ident],
486 table_name: &str,
487) -> Result<(Vec<usize>, Vec<usize>)> {
488 if cols_to_insert_by_user.is_empty() {
489 return Ok(((0..cols_to_insert_in_table.len()).collect(), vec![]));
490 }
491
492 let mut col_indices_to_insert: Vec<usize> = Vec::new();
493
494 let col_name_to_idx: HashMap<String, usize> = cols_to_insert_in_table
496 .iter()
497 .enumerate()
498 .map(|(idx, col)| (col.name().to_owned(), idx))
499 .collect();
500
501 let mut seen = HashSet::with_capacity(cols_to_insert_by_user.len());
502 for col_name in cols_to_insert_by_user {
503 let col_name = col_name.real_value();
504 if !seen.insert(col_name.clone()) {
505 return Err(RwError::from(ErrorCode::BindError(
506 "Column specified more than once".to_owned(),
507 )));
508 }
509
510 let idx = col_name_to_idx.get(&col_name).ok_or_else(|| {
511 RwError::from(ErrorCode::BindError(format!(
512 "Column {} not found in table {}",
513 col_name, table_name
514 )))
515 })?;
516
517 col_indices_to_insert.push(*idx);
518 }
519
520 let default_column_indices = if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
522 cols_to_insert_in_table
523 .iter()
524 .enumerate()
525 .filter_map(|(idx, col)| {
526 let column_name = col.name();
527 (!seen.contains(column_name)).then_some(idx)
528 })
529 .collect()
530 } else {
531 vec![]
532 };
533
534 Ok((col_indices_to_insert, default_column_indices))
535}