risingwave_frontend/binder/
insert.rs1use std::collections::{BTreeMap, HashMap, HashSet};
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::expr::expr_node::Type as ExprType;
24use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};
25
26use super::BoundQuery;
27use super::statement::RewriteExprsRecursive;
28use crate::binder::{Binder, Clause};
29use crate::catalog::TableId;
30use crate::error::{ErrorCode, Result, RwError};
31use crate::expr::{Expr, ExprImpl, FunctionCall, InputRef};
32use crate::handler::privilege::ObjectCheckItem;
33use crate::user::UserId;
34use crate::utils::ordinal;
35
36#[derive(Debug, Clone)]
37pub struct BoundInsert {
38 pub table_id: TableId,
40
41 pub table_version_id: TableVersionId,
43
44 pub table_name: String,
46
47 pub table_visible_columns: Vec<ColumnCatalog>,
50
51 pub owner: UserId,
53
54 pub row_id_index: Option<usize>,
57
58 pub column_indices: Vec<usize>,
63
64 pub default_columns: Vec<(usize, ExprImpl)>,
67
68 pub source: BoundQuery,
69
70 pub cast_exprs: Vec<ExprImpl>,
74
75 pub returning_list: Vec<ExprImpl>,
79
80 pub returning_schema: Option<Schema>,
81}
82
83impl RewriteExprsRecursive for BoundInsert {
84 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
85 self.source.rewrite_exprs_recursive(rewriter);
86
87 let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
88 .into_iter()
89 .map(|expr| rewriter.rewrite_expr(expr))
90 .collect::<Vec<_>>();
91 self.cast_exprs = new_cast_exprs;
92
93 let new_returning_list = std::mem::take(&mut self.returning_list)
94 .into_iter()
95 .map(|expr| rewriter.rewrite_expr(expr))
96 .collect::<Vec<_>>();
97 self.returning_list = new_returning_list;
98 }
99}
100
101impl Binder {
102 pub(super) fn bind_insert(
103 &mut self,
104 name: ObjectName,
105 cols_to_insert_by_user: Vec<Ident>,
106 source: Query,
107 returning_items: Vec<SelectItem>,
108 ) -> Result<BoundInsert> {
109 let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, &name)?;
110 self.context.clause = Some(Clause::Insert);
112 let bound_table = self.bind_table(schema_name.as_deref(), &table_name)?;
113 let table_catalog = &bound_table.table_catalog;
114 Self::check_for_dml(table_catalog, true)?;
115 self.check_privilege(
116 ObjectCheckItem::new(
117 table_catalog.owner,
118 AclMode::Insert,
119 table_name.clone(),
120 table_catalog.id,
121 ),
122 table_catalog.database_id,
123 )?;
124
125 let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
126 let default_columns_from_catalog =
127 table_catalog.default_columns().collect::<BTreeMap<_, _>>();
128 let table_id = table_catalog.id;
129 let owner = table_catalog.owner;
130 let table_version_id = table_catalog.version_id().expect("table must be versioned");
131 let table_visible_columns = table_catalog
132 .columns()
133 .iter()
134 .filter(|c| !c.is_hidden())
135 .cloned()
136 .collect_vec();
137 let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec();
138
139 let generated_column_names = table_catalog
140 .generated_column_names()
141 .collect::<HashSet<_>>();
142
143 let check_generated_insert_violation = |bound_column_nums: Option<usize>| -> Result<()> {
144 let generated_column_name = if let Some(column_num) = bound_column_nums {
145 table_catalog
146 .first_generated_column()
147 .and_then(|(index, column_name)| {
148 (column_num > index).then_some(column_name.to_owned())
149 })
150 } else {
151 cols_to_insert_by_user
152 .iter()
153 .map(|col| col.real_value())
154 .find(|col_name| generated_column_names.contains(col_name.as_str()))
155 };
156
157 if let Some(column_name) = generated_column_name {
158 return Err(ErrorCode::InsertViolation(format!(
159 "cannot insert a non-DEFAULT value into column \"{0}\"\n DETAIL: Column \"{0}\" is a generated column.",
160 column_name
161 ))
162 .into());
163 }
164
165 Ok(())
166 };
167
168 if has_user_specified_columns {
169 check_generated_insert_violation(None)?;
170 }
171
172 if !generated_column_names.is_empty() && !returning_items.is_empty() {
173 return Err(RwError::from(ErrorCode::BindError(
174 "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
175 )));
176 }
177
178 let row_id_index = {
182 if let Some(row_id_index) = table_catalog.row_id_index {
183 let mut cnt = 0;
184 for col in table_catalog.columns().iter().take(row_id_index + 1) {
185 if col.is_generated() {
186 cnt += 1;
187 }
188 }
189 Some(row_id_index - cnt)
190 } else {
191 None
192 }
193 };
194
195 let (returning_list, fields) = self.bind_returning_list(returning_items)?;
196 let is_returning = !returning_list.is_empty();
197
198 let (mut col_indices_to_insert, default_column_indices) = get_col_indices_to_insert(
199 &cols_to_insert_in_table,
200 &cols_to_insert_by_user,
201 &table_name,
202 )?;
203 let expected_types: Vec<DataType> = col_indices_to_insert
204 .iter()
205 .map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
206 .collect();
207
208 let nullables: Vec<(bool, &str)> = col_indices_to_insert
209 .iter()
210 .map(|idx| {
211 (
212 cols_to_insert_in_table[*idx].nullable(),
213 cols_to_insert_in_table[*idx].name(),
214 )
215 })
216 .collect();
217
218 let bound_query;
245 let cast_exprs;
246 let all_nullable = nullables.iter().all(|(nullable, _)| *nullable);
247
248 let bound_column_nums = match source.as_simple_values() {
249 None => {
250 bound_query = self.bind_query(&source)?;
251 if !has_user_specified_columns {
252 check_generated_insert_violation(Some(bound_query.schema().len()))?;
253 }
254 let actual_types = bound_query.data_types();
255 let type_match = expected_types == actual_types;
256 cast_exprs = if all_nullable && type_match {
257 vec![]
258 } else {
259 let mut cast_exprs = actual_types
260 .into_iter()
261 .enumerate()
262 .map(|(i, t)| InputRef::new(i, t).into())
263 .collect();
264 if !type_match {
265 cast_exprs = Self::cast_on_insert(&expected_types, cast_exprs)?
266 }
267 if !all_nullable {
268 cast_exprs =
269 Self::check_not_null(&nullables, cast_exprs, table_name.as_str())?
270 }
271 cast_exprs
272 };
273 bound_query.schema().len()
274 }
275 Some(values) => {
276 let values_len = values
277 .0
278 .first()
279 .expect("values list should not be empty")
280 .len();
281 if !has_user_specified_columns {
282 check_generated_insert_violation(Some(values_len))?;
283 }
284 let mut values = self.bind_values(values, Some(&expected_types))?;
285 if !all_nullable {
288 values.rows = values
289 .rows
290 .into_iter()
291 .map(|vec| Self::check_not_null(&nullables, vec, table_name.as_str()))
292 .try_collect()?;
293 }
294
295 bound_query = BoundQuery::with_values(values);
296 cast_exprs = vec![];
297 values_len
298 }
299 };
300
301 let num_target_cols = if has_user_specified_columns {
302 cols_to_insert_by_user.len()
303 } else {
304 cols_to_insert_in_table.len()
305 };
306
307 let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
308 std::cmp::Ordering::Equal => (None, default_column_indices),
309 std::cmp::Ordering::Greater => {
310 if has_user_specified_columns {
311 (
313 Some("INSERT has more target columns than expressions"),
314 vec![],
315 )
316 } else {
317 (None, col_indices_to_insert.split_off(bound_column_nums))
322 }
323 }
324 std::cmp::Ordering::Less => {
325 (
329 Some("INSERT has more expressions than target columns"),
330 vec![],
331 )
332 }
333 };
334 if let Some(msg) = err_msg {
335 return Err(RwError::from(ErrorCode::BindError(msg.to_owned())));
336 }
337
338 let default_columns = default_column_indices
339 .into_iter()
340 .map(|i| {
341 let column = &cols_to_insert_in_table[i];
342 let expr = default_columns_from_catalog
343 .get(&i)
344 .cloned()
345 .unwrap_or_else(|| ExprImpl::literal_null(column.data_type().clone()));
346
347 let expr = if column.nullable() {
348 expr
349 } else {
350 FunctionCall::new_unchecked(
351 ExprType::CheckNotNull,
352 vec![
353 expr,
354 ExprImpl::literal_varchar(column.name().to_owned()),
355 ExprImpl::literal_varchar(table_name.clone()),
356 ],
357 column.data_type().clone(),
358 )
359 .into()
360 };
361
362 (i, expr)
363 })
364 .collect_vec();
365
366 let insert = BoundInsert {
367 table_id,
368 table_version_id,
369 table_name,
370 table_visible_columns,
371 owner,
372 row_id_index,
373 column_indices: col_indices_to_insert,
374 default_columns,
375 source: bound_query,
376 cast_exprs,
377 returning_list,
378 returning_schema: if is_returning {
379 Some(Schema { fields })
380 } else {
381 None
382 },
383 };
384 Ok(insert)
385 }
386
387 pub(super) fn cast_on_insert(
390 expected_types: &[DataType],
391 exprs: Vec<ExprImpl>,
392 ) -> Result<Vec<ExprImpl>> {
393 let msg = match expected_types.len().cmp(&exprs.len()) {
394 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
395 _ => {
396 let expr_len = exprs.len();
397 return exprs
398 .into_iter()
399 .zip_eq_fast(expected_types.iter().take(expr_len))
400 .enumerate()
401 .map(|(i, (e, t))| {
402 let res = e.cast_assign(t);
403 if expr_len > 1 {
404 res.with_context(|| {
405 format!("failed to cast the {} column", ordinal(i + 1))
406 })
407 .map_err(Into::into)
408 } else {
409 res.map_err(Into::into)
410 }
411 })
412 .try_collect();
413 }
414 };
415 Err(ErrorCode::BindError(msg.into()).into())
416 }
417
418 pub(super) fn check_not_null(
420 nullables: &Vec<(bool, &str)>,
421 exprs: Vec<ExprImpl>,
422 table_name: &str,
423 ) -> Result<Vec<ExprImpl>> {
424 let msg = match nullables.len().cmp(&exprs.len()) {
425 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
426 _ => {
427 let expr_len = exprs.len();
428 return exprs
429 .into_iter()
430 .zip_eq_fast(nullables.iter().take(expr_len))
431 .map(|(expr, (nullable, col_name))| {
432 if !nullable {
433 let return_type = expr.return_type();
434 let check_not_null = FunctionCall::new_unchecked(
435 ExprType::CheckNotNull,
436 vec![
437 expr,
438 ExprImpl::literal_varchar((*col_name).to_owned()),
439 ExprImpl::literal_varchar(table_name.to_owned()),
440 ],
441 return_type,
442 );
443 Ok(check_not_null.into())
445 } else {
446 Ok(expr)
447 }
448 })
449 .try_collect();
450 }
451 };
452 Err(ErrorCode::BindError(msg.into()).into())
453 }
454}
455
456fn get_col_indices_to_insert(
462 cols_to_insert_in_table: &[ColumnCatalog],
463 cols_to_insert_by_user: &[Ident],
464 table_name: &str,
465) -> Result<(Vec<usize>, Vec<usize>)> {
466 if cols_to_insert_by_user.is_empty() {
467 return Ok(((0..cols_to_insert_in_table.len()).collect(), vec![]));
468 }
469
470 let mut col_indices_to_insert: Vec<usize> = Vec::new();
471
472 let mut col_name_to_idx: HashMap<String, usize> = HashMap::new();
473 for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() {
474 col_name_to_idx.insert(col.name().to_owned(), col_idx);
475 }
476
477 for col_name in cols_to_insert_by_user {
478 let col_name = &col_name.real_value();
479 match col_name_to_idx.get_mut(col_name) {
480 Some(value_ref) => {
481 if *value_ref == usize::MAX {
482 return Err(RwError::from(ErrorCode::BindError(
483 "Column specified more than once".to_owned(),
484 )));
485 }
486 col_indices_to_insert.push(*value_ref);
487 *value_ref = usize::MAX; }
490 None => {
491 return Err(RwError::from(ErrorCode::BindError(format!(
493 "Column {} not found in table {}",
494 col_name, table_name
495 ))));
496 }
497 }
498 }
499
500 let default_column_indices = if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
502 let mut cols = vec![];
503 for col in cols_to_insert_in_table {
504 if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) {
505 if *col_to_insert_idx != usize::MAX {
506 cols.push(*col_to_insert_idx);
507 }
508 } else {
509 unreachable!();
510 }
511 }
512 cols
513 } else {
514 vec![]
515 };
516
517 Ok((col_indices_to_insert, default_column_indices))
518}