risingwave_frontend/binder/
insert.rs1use std::collections::{BTreeMap, HashMap, HashSet};
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::expr::expr_node::Type as ExprType;
24use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};
25
26use super::BoundQuery;
27use super::statement::RewriteExprsRecursive;
28use crate::binder::{Binder, Clause};
29use crate::catalog::TableId;
30use crate::error::{ErrorCode, Result, RwError};
31use crate::expr::{Expr, ExprImpl, FunctionCall, InputRef};
32use crate::handler::privilege::ObjectCheckItem;
33use crate::user::UserId;
34use crate::utils::ordinal;
35
36#[derive(Debug, Clone)]
37pub struct BoundInsert {
38 pub table_id: TableId,
40
41 pub table_version_id: TableVersionId,
43
44 pub table_name: String,
46
47 pub table_visible_columns: Vec<ColumnCatalog>,
50
51 pub owner: UserId,
53
54 pub row_id_index: Option<usize>,
57
58 pub column_indices: Vec<usize>,
63
64 pub default_columns: Vec<(usize, ExprImpl)>,
67
68 pub source: BoundQuery,
69
70 pub cast_exprs: Vec<ExprImpl>,
74
75 pub returning_list: Vec<ExprImpl>,
79
80 pub returning_schema: Option<Schema>,
81}
82
83impl RewriteExprsRecursive for BoundInsert {
84 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
85 self.source.rewrite_exprs_recursive(rewriter);
86
87 let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
88 .into_iter()
89 .map(|expr| rewriter.rewrite_expr(expr))
90 .collect::<Vec<_>>();
91 self.cast_exprs = new_cast_exprs;
92
93 let new_returning_list = std::mem::take(&mut self.returning_list)
94 .into_iter()
95 .map(|expr| rewriter.rewrite_expr(expr))
96 .collect::<Vec<_>>();
97 self.returning_list = new_returning_list;
98 }
99}
100
101impl Binder {
102 pub(super) fn bind_insert(
103 &mut self,
104 name: ObjectName,
105 cols_to_insert_by_user: Vec<Ident>,
106 source: Query,
107 returning_items: Vec<SelectItem>,
108 ) -> Result<BoundInsert> {
109 let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, &name)?;
110 self.context.clause = Some(Clause::Insert);
112 let bound_table = self.bind_table(schema_name.as_deref(), &table_name)?;
113 let table_catalog = &bound_table.table_catalog;
114 Self::check_for_dml(table_catalog, true)?;
115 self.check_privilege(
116 ObjectCheckItem::new(
117 table_catalog.owner,
118 AclMode::Insert,
119 table_name.clone(),
120 table_catalog.id,
121 ),
122 table_catalog.database_id,
123 )?;
124
125 let default_columns_from_catalog =
126 table_catalog.default_columns().collect::<BTreeMap<_, _>>();
127 let table_id = table_catalog.id;
128 let owner = table_catalog.owner;
129 let table_version_id = table_catalog.version_id().expect("table must be versioned");
130 let table_visible_columns = table_catalog
131 .columns()
132 .iter()
133 .filter(|c| !c.is_hidden())
134 .cloned()
135 .collect_vec();
136 let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec();
137
138 let generated_column_names = table_catalog
139 .generated_column_names()
140 .collect::<HashSet<_>>();
141 for col in &cols_to_insert_by_user {
142 let query_col_name = col.real_value();
143 if generated_column_names.contains(query_col_name.as_str()) {
144 return Err(RwError::from(ErrorCode::BindError(format!(
145 "cannot insert a non-DEFAULT value into column \"{0}\". Column \"{0}\" is a generated column.",
146 &query_col_name
147 ))));
148 }
149 }
150 if !generated_column_names.is_empty() && !returning_items.is_empty() {
151 return Err(RwError::from(ErrorCode::BindError(
152 "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
153 )));
154 }
155
156 let row_id_index = {
160 if let Some(row_id_index) = table_catalog.row_id_index {
161 let mut cnt = 0;
162 for col in table_catalog.columns().iter().take(row_id_index + 1) {
163 if col.is_generated() {
164 cnt += 1;
165 }
166 }
167 Some(row_id_index - cnt)
168 } else {
169 None
170 }
171 };
172
173 let (returning_list, fields) = self.bind_returning_list(returning_items)?;
174 let is_returning = !returning_list.is_empty();
175
176 let (mut col_indices_to_insert, default_column_indices) = get_col_indices_to_insert(
177 &cols_to_insert_in_table,
178 &cols_to_insert_by_user,
179 &table_name,
180 )?;
181 let expected_types: Vec<DataType> = col_indices_to_insert
182 .iter()
183 .map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
184 .collect();
185
186 let nullables: Vec<(bool, &str)> = col_indices_to_insert
187 .iter()
188 .map(|idx| {
189 (
190 cols_to_insert_in_table[*idx].nullable(),
191 cols_to_insert_in_table[*idx].name(),
192 )
193 })
194 .collect();
195
196 let bound_query;
223 let cast_exprs;
224 let all_nullable = nullables.iter().all(|(nullable, _)| *nullable);
225
226 let bound_column_nums = match source.as_simple_values() {
227 None => {
228 bound_query = self.bind_query(&source)?;
229 let actual_types = bound_query.data_types();
230 let type_match = expected_types == actual_types;
231 cast_exprs = if all_nullable && type_match {
232 vec![]
233 } else {
234 let mut cast_exprs = actual_types
235 .into_iter()
236 .enumerate()
237 .map(|(i, t)| InputRef::new(i, t).into())
238 .collect();
239 if !type_match {
240 cast_exprs = Self::cast_on_insert(&expected_types, cast_exprs)?
241 }
242 if !all_nullable {
243 cast_exprs =
244 Self::check_not_null(&nullables, cast_exprs, table_name.as_str())?
245 }
246 cast_exprs
247 };
248 bound_query.schema().len()
249 }
250 Some(values) => {
251 let values_len = values
252 .0
253 .first()
254 .expect("values list should not be empty")
255 .len();
256 let mut values = self.bind_values(values, Some(&expected_types))?;
257 if !all_nullable {
260 values.rows = values
261 .rows
262 .into_iter()
263 .map(|vec| Self::check_not_null(&nullables, vec, table_name.as_str()))
264 .try_collect()?;
265 }
266
267 bound_query = BoundQuery::with_values(values);
268 cast_exprs = vec![];
269 values_len
270 }
271 };
272
273 let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
274 let num_target_cols = if has_user_specified_columns {
275 cols_to_insert_by_user.len()
276 } else {
277 cols_to_insert_in_table.len()
278 };
279
280 let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
281 std::cmp::Ordering::Equal => (None, default_column_indices),
282 std::cmp::Ordering::Greater => {
283 if has_user_specified_columns {
284 (
286 Some("INSERT has more target columns than expressions"),
287 vec![],
288 )
289 } else {
290 (None, col_indices_to_insert.split_off(bound_column_nums))
295 }
296 }
297 std::cmp::Ordering::Less => {
298 (
302 Some("INSERT has more expressions than target columns"),
303 vec![],
304 )
305 }
306 };
307 if let Some(msg) = err_msg {
308 return Err(RwError::from(ErrorCode::BindError(msg.to_owned())));
309 }
310
311 let default_columns = default_column_indices
312 .into_iter()
313 .map(|i| {
314 let column = &cols_to_insert_in_table[i];
315 let expr = default_columns_from_catalog
316 .get(&i)
317 .cloned()
318 .unwrap_or_else(|| ExprImpl::literal_null(column.data_type().clone()));
319
320 let expr = if column.nullable() {
321 expr
322 } else {
323 FunctionCall::new_unchecked(
324 ExprType::CheckNotNull,
325 vec![
326 expr,
327 ExprImpl::literal_varchar(column.name().to_owned()),
328 ExprImpl::literal_varchar(table_name.clone()),
329 ],
330 column.data_type().clone(),
331 )
332 .into()
333 };
334
335 (i, expr)
336 })
337 .collect_vec();
338
339 let insert = BoundInsert {
340 table_id,
341 table_version_id,
342 table_name,
343 table_visible_columns,
344 owner,
345 row_id_index,
346 column_indices: col_indices_to_insert,
347 default_columns,
348 source: bound_query,
349 cast_exprs,
350 returning_list,
351 returning_schema: if is_returning {
352 Some(Schema { fields })
353 } else {
354 None
355 },
356 };
357 Ok(insert)
358 }
359
360 pub(super) fn cast_on_insert(
363 expected_types: &[DataType],
364 exprs: Vec<ExprImpl>,
365 ) -> Result<Vec<ExprImpl>> {
366 let msg = match expected_types.len().cmp(&exprs.len()) {
367 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
368 _ => {
369 let expr_len = exprs.len();
370 return exprs
371 .into_iter()
372 .zip_eq_fast(expected_types.iter().take(expr_len))
373 .enumerate()
374 .map(|(i, (e, t))| {
375 let res = e.cast_assign(t);
376 if expr_len > 1 {
377 res.with_context(|| {
378 format!("failed to cast the {} column", ordinal(i + 1))
379 })
380 .map_err(Into::into)
381 } else {
382 res.map_err(Into::into)
383 }
384 })
385 .try_collect();
386 }
387 };
388 Err(ErrorCode::BindError(msg.into()).into())
389 }
390
391 pub(super) fn check_not_null(
393 nullables: &Vec<(bool, &str)>,
394 exprs: Vec<ExprImpl>,
395 table_name: &str,
396 ) -> Result<Vec<ExprImpl>> {
397 let msg = match nullables.len().cmp(&exprs.len()) {
398 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
399 _ => {
400 let expr_len = exprs.len();
401 return exprs
402 .into_iter()
403 .zip_eq_fast(nullables.iter().take(expr_len))
404 .map(|(expr, (nullable, col_name))| {
405 if !nullable {
406 let return_type = expr.return_type();
407 let check_not_null = FunctionCall::new_unchecked(
408 ExprType::CheckNotNull,
409 vec![
410 expr,
411 ExprImpl::literal_varchar((*col_name).to_owned()),
412 ExprImpl::literal_varchar(table_name.to_owned()),
413 ],
414 return_type,
415 );
416 Ok(check_not_null.into())
418 } else {
419 Ok(expr)
420 }
421 })
422 .try_collect();
423 }
424 };
425 Err(ErrorCode::BindError(msg.into()).into())
426 }
427}
428
429fn get_col_indices_to_insert(
435 cols_to_insert_in_table: &[ColumnCatalog],
436 cols_to_insert_by_user: &[Ident],
437 table_name: &str,
438) -> Result<(Vec<usize>, Vec<usize>)> {
439 if cols_to_insert_by_user.is_empty() {
440 return Ok(((0..cols_to_insert_in_table.len()).collect(), vec![]));
441 }
442
443 let mut col_indices_to_insert: Vec<usize> = Vec::new();
444
445 let mut col_name_to_idx: HashMap<String, usize> = HashMap::new();
446 for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() {
447 col_name_to_idx.insert(col.name().to_owned(), col_idx);
448 }
449
450 for col_name in cols_to_insert_by_user {
451 let col_name = &col_name.real_value();
452 match col_name_to_idx.get_mut(col_name) {
453 Some(value_ref) => {
454 if *value_ref == usize::MAX {
455 return Err(RwError::from(ErrorCode::BindError(
456 "Column specified more than once".to_owned(),
457 )));
458 }
459 col_indices_to_insert.push(*value_ref);
460 *value_ref = usize::MAX; }
463 None => {
464 return Err(RwError::from(ErrorCode::BindError(format!(
466 "Column {} not found in table {}",
467 col_name, table_name
468 ))));
469 }
470 }
471 }
472
473 let default_column_indices = if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
475 let mut cols = vec![];
476 for col in cols_to_insert_in_table {
477 if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) {
478 if *col_to_insert_idx != usize::MAX {
479 cols.push(*col_to_insert_idx);
480 }
481 } else {
482 unreachable!();
483 }
484 }
485 cols
486 } else {
487 vec![]
488 };
489
490 Ok((col_indices_to_insert, default_column_indices))
491}