risingwave_frontend/binder/
insert.rs1use std::collections::{BTreeMap, HashMap, HashSet};
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::expr::expr_node::Type as ExprType;
24use risingwave_pb::user::grant_privilege::PbObject;
25use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};
26
27use super::BoundQuery;
28use super::statement::RewriteExprsRecursive;
29use crate::binder::{Binder, Clause};
30use crate::catalog::TableId;
31use crate::error::{ErrorCode, Result, RwError};
32use crate::expr::{Expr, ExprImpl, FunctionCall, InputRef};
33use crate::handler::privilege::ObjectCheckItem;
34use crate::user::UserId;
35use crate::utils::ordinal;
36
37#[derive(Debug, Clone)]
38pub struct BoundInsert {
39 pub table_id: TableId,
41
42 pub table_version_id: TableVersionId,
44
45 pub table_name: String,
47
48 pub table_visible_columns: Vec<ColumnCatalog>,
51
52 pub owner: UserId,
54
55 pub row_id_index: Option<usize>,
58
59 pub column_indices: Vec<usize>,
64
65 pub default_columns: Vec<(usize, ExprImpl)>,
68
69 pub source: BoundQuery,
70
71 pub cast_exprs: Vec<ExprImpl>,
75
76 pub returning_list: Vec<ExprImpl>,
80
81 pub returning_schema: Option<Schema>,
82}
83
84impl RewriteExprsRecursive for BoundInsert {
85 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
86 self.source.rewrite_exprs_recursive(rewriter);
87
88 let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
89 .into_iter()
90 .map(|expr| rewriter.rewrite_expr(expr))
91 .collect::<Vec<_>>();
92 self.cast_exprs = new_cast_exprs;
93
94 let new_returning_list = std::mem::take(&mut self.returning_list)
95 .into_iter()
96 .map(|expr| rewriter.rewrite_expr(expr))
97 .collect::<Vec<_>>();
98 self.returning_list = new_returning_list;
99 }
100}
101
102impl Binder {
103 pub(super) fn bind_insert(
104 &mut self,
105 name: ObjectName,
106 cols_to_insert_by_user: Vec<Ident>,
107 source: Query,
108 returning_items: Vec<SelectItem>,
109 ) -> Result<BoundInsert> {
110 let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;
111 self.context.clause = Some(Clause::Insert);
113 let bound_table = self.bind_table(schema_name.as_deref(), &table_name)?;
114 let table_catalog = &bound_table.table_catalog;
115 Self::check_for_dml(table_catalog, true)?;
116 self.check_privilege(
117 ObjectCheckItem::new(
118 table_catalog.owner,
119 AclMode::Insert,
120 table_name.clone(),
121 PbObject::TableId(table_catalog.id.table_id),
122 ),
123 table_catalog.database_id,
124 )?;
125
126 let default_columns_from_catalog =
127 table_catalog.default_columns().collect::<BTreeMap<_, _>>();
128 let table_id = table_catalog.id;
129 let owner = table_catalog.owner;
130 let table_version_id = table_catalog.version_id().expect("table must be versioned");
131 let table_visible_columns = table_catalog
132 .columns()
133 .iter()
134 .filter(|c| !c.is_hidden())
135 .cloned()
136 .collect_vec();
137 let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec();
138
139 let generated_column_names = table_catalog
140 .generated_column_names()
141 .collect::<HashSet<_>>();
142 for col in &cols_to_insert_by_user {
143 let query_col_name = col.real_value();
144 if generated_column_names.contains(query_col_name.as_str()) {
145 return Err(RwError::from(ErrorCode::BindError(format!(
146 "cannot insert a non-DEFAULT value into column \"{0}\". Column \"{0}\" is a generated column.",
147 &query_col_name
148 ))));
149 }
150 }
151 if !generated_column_names.is_empty() && !returning_items.is_empty() {
152 return Err(RwError::from(ErrorCode::BindError(
153 "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
154 )));
155 }
156
157 let row_id_index = {
161 if let Some(row_id_index) = table_catalog.row_id_index {
162 let mut cnt = 0;
163 for col in table_catalog.columns().iter().take(row_id_index + 1) {
164 if col.is_generated() {
165 cnt += 1;
166 }
167 }
168 Some(row_id_index - cnt)
169 } else {
170 None
171 }
172 };
173
174 let (returning_list, fields) = self.bind_returning_list(returning_items)?;
175 let is_returning = !returning_list.is_empty();
176
177 let (mut col_indices_to_insert, default_column_indices) = get_col_indices_to_insert(
178 &cols_to_insert_in_table,
179 &cols_to_insert_by_user,
180 &table_name,
181 )?;
182 let expected_types: Vec<DataType> = col_indices_to_insert
183 .iter()
184 .map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
185 .collect();
186
187 let nullables: Vec<(bool, &str)> = col_indices_to_insert
188 .iter()
189 .map(|idx| {
190 (
191 cols_to_insert_in_table[*idx].nullable(),
192 cols_to_insert_in_table[*idx].name(),
193 )
194 })
195 .collect();
196
197 let bound_query;
224 let cast_exprs;
225 let all_nullable = nullables.iter().all(|(nullable, _)| *nullable);
226
227 let bound_column_nums = match source.as_simple_values() {
228 None => {
229 bound_query = self.bind_query(source)?;
230 let actual_types = bound_query.data_types();
231 let type_match = expected_types == actual_types;
232 cast_exprs = if all_nullable && type_match {
233 vec![]
234 } else {
235 let mut cast_exprs = actual_types
236 .into_iter()
237 .enumerate()
238 .map(|(i, t)| InputRef::new(i, t).into())
239 .collect();
240 if !type_match {
241 cast_exprs = Self::cast_on_insert(&expected_types, cast_exprs)?
242 }
243 if !all_nullable {
244 cast_exprs =
245 Self::check_not_null(&nullables, cast_exprs, table_name.as_str())?
246 }
247 cast_exprs
248 };
249 bound_query.schema().len()
250 }
251 Some(values) => {
252 let values_len = values
253 .0
254 .first()
255 .expect("values list should not be empty")
256 .len();
257 let mut values = self.bind_values(values.clone(), Some(expected_types))?;
258 if !all_nullable {
261 values.rows = values
262 .rows
263 .into_iter()
264 .map(|vec| Self::check_not_null(&nullables, vec, table_name.as_str()))
265 .try_collect()?;
266 }
267
268 bound_query = BoundQuery::with_values(values);
269 cast_exprs = vec![];
270 values_len
271 }
272 };
273
274 let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
275 let num_target_cols = if has_user_specified_columns {
276 cols_to_insert_by_user.len()
277 } else {
278 cols_to_insert_in_table.len()
279 };
280
281 let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
282 std::cmp::Ordering::Equal => (None, default_column_indices),
283 std::cmp::Ordering::Greater => {
284 if has_user_specified_columns {
285 (
287 Some("INSERT has more target columns than expressions"),
288 vec![],
289 )
290 } else {
291 (None, col_indices_to_insert.split_off(bound_column_nums))
296 }
297 }
298 std::cmp::Ordering::Less => {
299 (
303 Some("INSERT has more expressions than target columns"),
304 vec![],
305 )
306 }
307 };
308 if let Some(msg) = err_msg {
309 return Err(RwError::from(ErrorCode::BindError(msg.to_owned())));
310 }
311
312 let default_columns = default_column_indices
313 .into_iter()
314 .map(|i| {
315 (
316 i,
317 default_columns_from_catalog
318 .get(&i)
319 .cloned()
320 .unwrap_or_else(|| {
321 ExprImpl::literal_null(cols_to_insert_in_table[i].data_type().clone())
322 }),
323 )
324 })
325 .collect_vec();
326
327 let insert = BoundInsert {
328 table_id,
329 table_version_id,
330 table_name,
331 table_visible_columns,
332 owner,
333 row_id_index,
334 column_indices: col_indices_to_insert,
335 default_columns,
336 source: bound_query,
337 cast_exprs,
338 returning_list,
339 returning_schema: if is_returning {
340 Some(Schema { fields })
341 } else {
342 None
343 },
344 };
345 Ok(insert)
346 }
347
348 pub(super) fn cast_on_insert(
351 expected_types: &Vec<DataType>,
352 exprs: Vec<ExprImpl>,
353 ) -> Result<Vec<ExprImpl>> {
354 let msg = match expected_types.len().cmp(&exprs.len()) {
355 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
356 _ => {
357 let expr_len = exprs.len();
358 return exprs
359 .into_iter()
360 .zip_eq_fast(expected_types.iter().take(expr_len))
361 .enumerate()
362 .map(|(i, (e, t))| {
363 let res = e.cast_assign(t.clone());
364 if expr_len > 1 {
365 res.with_context(|| {
366 format!("failed to cast the {} column", ordinal(i + 1))
367 })
368 .map_err(Into::into)
369 } else {
370 res.map_err(Into::into)
371 }
372 })
373 .try_collect();
374 }
375 };
376 Err(ErrorCode::BindError(msg.into()).into())
377 }
378
379 pub(super) fn check_not_null(
381 nullables: &Vec<(bool, &str)>,
382 exprs: Vec<ExprImpl>,
383 table_name: &str,
384 ) -> Result<Vec<ExprImpl>> {
385 let msg = match nullables.len().cmp(&exprs.len()) {
386 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
387 _ => {
388 let expr_len = exprs.len();
389 return exprs
390 .into_iter()
391 .zip_eq_fast(nullables.iter().take(expr_len))
392 .map(|(expr, (nullable, col_name))| {
393 if !nullable {
394 let return_type = expr.return_type();
395 let check_not_null = FunctionCall::new_unchecked(
396 ExprType::CheckNotNull,
397 vec![
398 expr,
399 ExprImpl::literal_varchar((*col_name).to_owned()),
400 ExprImpl::literal_varchar(table_name.to_owned()),
401 ],
402 return_type,
403 );
404 Ok(check_not_null.into())
406 } else {
407 Ok(expr)
408 }
409 })
410 .try_collect();
411 }
412 };
413 Err(ErrorCode::BindError(msg.into()).into())
414 }
415}
416
417fn get_col_indices_to_insert(
423 cols_to_insert_in_table: &[ColumnCatalog],
424 cols_to_insert_by_user: &[Ident],
425 table_name: &str,
426) -> Result<(Vec<usize>, Vec<usize>)> {
427 if cols_to_insert_by_user.is_empty() {
428 return Ok(((0..cols_to_insert_in_table.len()).collect(), vec![]));
429 }
430
431 let mut col_indices_to_insert: Vec<usize> = Vec::new();
432
433 let mut col_name_to_idx: HashMap<String, usize> = HashMap::new();
434 for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() {
435 col_name_to_idx.insert(col.name().to_owned(), col_idx);
436 }
437
438 for col_name in cols_to_insert_by_user {
439 let col_name = &col_name.real_value();
440 match col_name_to_idx.get_mut(col_name) {
441 Some(value_ref) => {
442 if *value_ref == usize::MAX {
443 return Err(RwError::from(ErrorCode::BindError(
444 "Column specified more than once".to_owned(),
445 )));
446 }
447 col_indices_to_insert.push(*value_ref);
448 *value_ref = usize::MAX; }
451 None => {
452 return Err(RwError::from(ErrorCode::BindError(format!(
454 "Column {} not found in table {}",
455 col_name, table_name
456 ))));
457 }
458 }
459 }
460
461 let default_column_indices = if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
463 let mut cols = vec![];
464 for col in cols_to_insert_in_table {
465 if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) {
466 if *col_to_insert_idx != usize::MAX {
467 cols.push(*col_to_insert_idx);
468 }
469 } else {
470 unreachable!();
471 }
472 }
473 cols
474 } else {
475 vec![]
476 };
477
478 Ok((col_indices_to_insert, default_column_indices))
479}