risingwave_frontend/binder/
insert.rs1use std::collections::{BTreeMap, HashMap, HashSet};
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::expr::expr_node::Type as ExprType;
24use risingwave_pb::user::grant_privilege::PbObject;
25use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};
26
27use super::BoundQuery;
28use super::statement::RewriteExprsRecursive;
29use crate::binder::{Binder, Clause};
30use crate::catalog::TableId;
31use crate::error::{ErrorCode, Result, RwError};
32use crate::expr::{Expr, ExprImpl, FunctionCall, InputRef};
33use crate::user::UserId;
34use crate::utils::ordinal;
35
36#[derive(Debug, Clone)]
37pub struct BoundInsert {
38 pub table_id: TableId,
40
41 pub table_version_id: TableVersionId,
43
44 pub table_name: String,
46
47 pub table_visible_columns: Vec<ColumnCatalog>,
50
51 pub owner: UserId,
53
54 pub row_id_index: Option<usize>,
57
58 pub column_indices: Vec<usize>,
63
64 pub default_columns: Vec<(usize, ExprImpl)>,
67
68 pub source: BoundQuery,
69
70 pub cast_exprs: Vec<ExprImpl>,
74
75 pub returning_list: Vec<ExprImpl>,
79
80 pub returning_schema: Option<Schema>,
81}
82
83impl RewriteExprsRecursive for BoundInsert {
84 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
85 self.source.rewrite_exprs_recursive(rewriter);
86
87 let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
88 .into_iter()
89 .map(|expr| rewriter.rewrite_expr(expr))
90 .collect::<Vec<_>>();
91 self.cast_exprs = new_cast_exprs;
92
93 let new_returning_list = std::mem::take(&mut self.returning_list)
94 .into_iter()
95 .map(|expr| rewriter.rewrite_expr(expr))
96 .collect::<Vec<_>>();
97 self.returning_list = new_returning_list;
98 }
99}
100
101impl Binder {
102 pub(super) fn bind_insert(
103 &mut self,
104 name: ObjectName,
105 cols_to_insert_by_user: Vec<Ident>,
106 source: Query,
107 returning_items: Vec<SelectItem>,
108 ) -> Result<BoundInsert> {
109 let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;
110 self.context.clause = Some(Clause::Insert);
112 let bound_table = self.bind_table(schema_name.as_deref(), &table_name)?;
113 let table_catalog = &bound_table.table_catalog;
114 Self::check_for_dml(table_catalog, true)?;
115 self.check_privilege(
116 PbObject::TableId(table_catalog.id.table_id),
117 table_catalog.database_id,
118 AclMode::Insert,
119 table_catalog.owner,
120 )?;
121
122 let default_columns_from_catalog =
123 table_catalog.default_columns().collect::<BTreeMap<_, _>>();
124 let table_id = table_catalog.id;
125 let owner = table_catalog.owner;
126 let table_version_id = table_catalog.version_id().expect("table must be versioned");
127 let table_visible_columns = table_catalog
128 .columns()
129 .iter()
130 .filter(|c| !c.is_hidden())
131 .cloned()
132 .collect_vec();
133 let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec();
134
135 let generated_column_names = table_catalog
136 .generated_column_names()
137 .collect::<HashSet<_>>();
138 for col in &cols_to_insert_by_user {
139 let query_col_name = col.real_value();
140 if generated_column_names.contains(query_col_name.as_str()) {
141 return Err(RwError::from(ErrorCode::BindError(format!(
142 "cannot insert a non-DEFAULT value into column \"{0}\". Column \"{0}\" is a generated column.",
143 &query_col_name
144 ))));
145 }
146 }
147 if !generated_column_names.is_empty() && !returning_items.is_empty() {
148 return Err(RwError::from(ErrorCode::BindError(
149 "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
150 )));
151 }
152
153 let row_id_index = {
157 if let Some(row_id_index) = table_catalog.row_id_index {
158 let mut cnt = 0;
159 for col in table_catalog.columns().iter().take(row_id_index + 1) {
160 if col.is_generated() {
161 cnt += 1;
162 }
163 }
164 Some(row_id_index - cnt)
165 } else {
166 None
167 }
168 };
169
170 let (returning_list, fields) = self.bind_returning_list(returning_items)?;
171 let is_returning = !returning_list.is_empty();
172
173 let (mut col_indices_to_insert, default_column_indices) = get_col_indices_to_insert(
174 &cols_to_insert_in_table,
175 &cols_to_insert_by_user,
176 &table_name,
177 )?;
178 let expected_types: Vec<DataType> = col_indices_to_insert
179 .iter()
180 .map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
181 .collect();
182
183 let nullables: Vec<(bool, &str)> = col_indices_to_insert
184 .iter()
185 .map(|idx| {
186 (
187 cols_to_insert_in_table[*idx].nullable(),
188 cols_to_insert_in_table[*idx].name(),
189 )
190 })
191 .collect();
192
193 let bound_query;
220 let cast_exprs;
221 let all_nullable = nullables.iter().all(|(nullable, _)| *nullable);
222
223 let bound_column_nums = match source.as_simple_values() {
224 None => {
225 bound_query = self.bind_query(source)?;
226 let actual_types = bound_query.data_types();
227 let type_match = expected_types == actual_types;
228 cast_exprs = if all_nullable && type_match {
229 vec![]
230 } else {
231 let mut cast_exprs = actual_types
232 .into_iter()
233 .enumerate()
234 .map(|(i, t)| InputRef::new(i, t).into())
235 .collect();
236 if !type_match {
237 cast_exprs = Self::cast_on_insert(&expected_types, cast_exprs)?
238 }
239 if !all_nullable {
240 cast_exprs =
241 Self::check_not_null(&nullables, cast_exprs, table_name.as_str())?
242 }
243 cast_exprs
244 };
245 bound_query.schema().len()
246 }
247 Some(values) => {
248 let values_len = values
249 .0
250 .first()
251 .expect("values list should not be empty")
252 .len();
253 let mut values = self.bind_values(values.clone(), Some(expected_types))?;
254 if !all_nullable {
257 values.rows = values
258 .rows
259 .into_iter()
260 .map(|vec| Self::check_not_null(&nullables, vec, table_name.as_str()))
261 .try_collect()?;
262 }
263
264 bound_query = BoundQuery::with_values(values);
265 cast_exprs = vec![];
266 values_len
267 }
268 };
269
270 let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
271 let num_target_cols = if has_user_specified_columns {
272 cols_to_insert_by_user.len()
273 } else {
274 cols_to_insert_in_table.len()
275 };
276
277 let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
278 std::cmp::Ordering::Equal => (None, default_column_indices),
279 std::cmp::Ordering::Greater => {
280 if has_user_specified_columns {
281 (
283 Some("INSERT has more target columns than expressions"),
284 vec![],
285 )
286 } else {
287 (None, col_indices_to_insert.split_off(bound_column_nums))
292 }
293 }
294 std::cmp::Ordering::Less => {
295 (
299 Some("INSERT has more expressions than target columns"),
300 vec![],
301 )
302 }
303 };
304 if let Some(msg) = err_msg {
305 return Err(RwError::from(ErrorCode::BindError(msg.to_owned())));
306 }
307
308 let default_columns = default_column_indices
309 .into_iter()
310 .map(|i| {
311 (
312 i,
313 default_columns_from_catalog
314 .get(&i)
315 .cloned()
316 .unwrap_or_else(|| {
317 ExprImpl::literal_null(cols_to_insert_in_table[i].data_type().clone())
318 }),
319 )
320 })
321 .collect_vec();
322
323 let insert = BoundInsert {
324 table_id,
325 table_version_id,
326 table_name,
327 table_visible_columns,
328 owner,
329 row_id_index,
330 column_indices: col_indices_to_insert,
331 default_columns,
332 source: bound_query,
333 cast_exprs,
334 returning_list,
335 returning_schema: if is_returning {
336 Some(Schema { fields })
337 } else {
338 None
339 },
340 };
341 Ok(insert)
342 }
343
344 pub(super) fn cast_on_insert(
347 expected_types: &Vec<DataType>,
348 exprs: Vec<ExprImpl>,
349 ) -> Result<Vec<ExprImpl>> {
350 let msg = match expected_types.len().cmp(&exprs.len()) {
351 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
352 _ => {
353 let expr_len = exprs.len();
354 return exprs
355 .into_iter()
356 .zip_eq_fast(expected_types.iter().take(expr_len))
357 .enumerate()
358 .map(|(i, (e, t))| {
359 let res = e.cast_assign(t.clone());
360 if expr_len > 1 {
361 res.with_context(|| {
362 format!("failed to cast the {} column", ordinal(i + 1))
363 })
364 .map_err(Into::into)
365 } else {
366 res.map_err(Into::into)
367 }
368 })
369 .try_collect();
370 }
371 };
372 Err(ErrorCode::BindError(msg.into()).into())
373 }
374
375 pub(super) fn check_not_null(
377 nullables: &Vec<(bool, &str)>,
378 exprs: Vec<ExprImpl>,
379 table_name: &str,
380 ) -> Result<Vec<ExprImpl>> {
381 let msg = match nullables.len().cmp(&exprs.len()) {
382 std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
383 _ => {
384 let expr_len = exprs.len();
385 return exprs
386 .into_iter()
387 .zip_eq_fast(nullables.iter().take(expr_len))
388 .map(|(expr, (nullable, col_name))| {
389 if !nullable {
390 let return_type = expr.return_type();
391 let check_not_null = FunctionCall::new_unchecked(
392 ExprType::CheckNotNull,
393 vec![
394 expr,
395 ExprImpl::literal_varchar((*col_name).to_owned()),
396 ExprImpl::literal_varchar(table_name.to_owned()),
397 ],
398 return_type,
399 );
400 Ok(check_not_null.into())
402 } else {
403 Ok(expr)
404 }
405 })
406 .try_collect();
407 }
408 };
409 Err(ErrorCode::BindError(msg.into()).into())
410 }
411}
412
413fn get_col_indices_to_insert(
419 cols_to_insert_in_table: &[ColumnCatalog],
420 cols_to_insert_by_user: &[Ident],
421 table_name: &str,
422) -> Result<(Vec<usize>, Vec<usize>)> {
423 if cols_to_insert_by_user.is_empty() {
424 return Ok(((0..cols_to_insert_in_table.len()).collect(), vec![]));
425 }
426
427 let mut col_indices_to_insert: Vec<usize> = Vec::new();
428
429 let mut col_name_to_idx: HashMap<String, usize> = HashMap::new();
430 for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() {
431 col_name_to_idx.insert(col.name().to_owned(), col_idx);
432 }
433
434 for col_name in cols_to_insert_by_user {
435 let col_name = &col_name.real_value();
436 match col_name_to_idx.get_mut(col_name) {
437 Some(value_ref) => {
438 if *value_ref == usize::MAX {
439 return Err(RwError::from(ErrorCode::BindError(
440 "Column specified more than once".to_owned(),
441 )));
442 }
443 col_indices_to_insert.push(*value_ref);
444 *value_ref = usize::MAX; }
447 None => {
448 return Err(RwError::from(ErrorCode::BindError(format!(
450 "Column {} not found in table {}",
451 col_name, table_name
452 ))));
453 }
454 }
455 }
456
457 let default_column_indices = if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
459 let mut cols = vec![];
460 for col in cols_to_insert_in_table {
461 if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) {
462 if *col_to_insert_idx != usize::MAX {
463 cols.push(*col_to_insert_idx);
464 }
465 } else {
466 unreachable!();
467 }
468 }
469 cols
470 } else {
471 vec![]
472 };
473
474 Ok((col_indices_to_insert, default_column_indices))
475}