risingwave_frontend/binder/
insert.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::expr::expr_node::Type as ExprType;
24use risingwave_pb::plan_common::DefaultColumnDesc;
25use risingwave_pb::plan_common::column_desc::GeneratedOrDefaultColumn;
26use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};
27
28use super::BoundQuery;
29use super::statement::RewriteExprsRecursive;
30use crate::binder::{Binder, Clause};
31use crate::catalog::TableId;
32use crate::error::{ErrorCode, Result, RwError};
33use crate::expr::{Expr, ExprImpl, FunctionCall, InputRef};
34use crate::handler::privilege::ObjectCheckItem;
35use crate::user::UserId;
36use crate::utils::ordinal;
37
38#[derive(Debug, Clone)]
39pub struct BoundInsert {
40    /// Id of the table to perform inserting.
41    pub table_id: TableId,
42
43    /// Version id of the table.
44    pub table_version_id: TableVersionId,
45
46    /// Name of the table to perform inserting.
47    pub table_name: String,
48
49    /// All visible columns of the table, used as the output schema of `Insert` plan node if
50    /// `RETURNING` is specified.
51    pub table_visible_columns: Vec<ColumnCatalog>,
52
53    /// Owner of the table to perform inserting.
54    pub owner: UserId,
55
56    // An optional column index of row ID. If the primary key is specified by the user,
57    // this will be `None`.
58    pub row_id_index: Option<usize>,
59
60    /// User defined columns in which to insert
61    /// Is equal to [0, 2, 1] for insert statement
62    /// create table t1 (v1 int, v2 int, v3 int); insert into t1 (v1, v3, v2) values (5, 6, 7);
63    /// Empty if user does not define insert columns
64    pub column_indices: Vec<usize>,
65
66    /// Columns that user fails to specify
67    /// Will set to default value (current null)
68    pub default_columns: Vec<(usize, ExprImpl)>,
69
70    pub source: BoundQuery,
71
72    /// Used as part of an extra `Project` when the column types of the query does not match
73    /// those of the table. This does not include a simple `VALUE`. See comments in code for
74    /// details.
75    pub cast_exprs: Vec<ExprImpl>,
76
77    // used for the 'RETURNING" keyword to indicate the returning items and schema
78    // if the list is empty and the schema is None, the output schema will be a INT64 as the
79    // affected row cnt
80    pub returning_list: Vec<ExprImpl>,
81
82    pub returning_schema: Option<Schema>,
83}
84
85impl RewriteExprsRecursive for BoundInsert {
86    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
87        self.source.rewrite_exprs_recursive(rewriter);
88
89        let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
90            .into_iter()
91            .map(|expr| rewriter.rewrite_expr(expr))
92            .collect::<Vec<_>>();
93        self.cast_exprs = new_cast_exprs;
94
95        let new_returning_list = std::mem::take(&mut self.returning_list)
96            .into_iter()
97            .map(|expr| rewriter.rewrite_expr(expr))
98            .collect::<Vec<_>>();
99        self.returning_list = new_returning_list;
100    }
101}
102
103impl Binder {
104    pub(super) fn bind_insert(
105        &mut self,
106        name: ObjectName,
107        cols_to_insert_by_user: Vec<Ident>,
108        source: Query,
109        returning_items: Vec<SelectItem>,
110    ) -> Result<BoundInsert> {
111        let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, &name)?;
112        // bind insert table
113        self.context.clause = Some(Clause::Insert);
114        let bound_table = self.bind_table(schema_name.as_deref(), &table_name)?;
115        let table_catalog = &bound_table.table_catalog;
116        Self::check_for_dml(table_catalog, true)?;
117        self.check_privilege(
118            ObjectCheckItem::new(
119                table_catalog.owner,
120                AclMode::Insert,
121                table_name.clone(),
122                table_catalog.id,
123            ),
124            table_catalog.database_id,
125        )?;
126
127        let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
128        let table_id = table_catalog.id;
129        let owner = table_catalog.owner;
130        let table_version_id = table_catalog.version_id().expect("table must be versioned");
131        let table_visible_columns = table_catalog
132            .columns()
133            .iter()
134            .filter(|c| !c.is_hidden())
135            .cloned()
136            .collect_vec();
137        let (cols_to_insert_in_table, row_id_index) = table_catalog.columns_to_insert();
138        let cols_to_insert_in_table = cols_to_insert_in_table
139            .map(|(column, _)| column.clone())
140            .collect_vec();
141        // Reorder default columns based on `cols_to_insert_in_table`.
142        let default_columns_from_catalog = cols_to_insert_in_table
143            .iter()
144            .enumerate()
145            .filter_map(|(idx, col)| {
146                if let Some(GeneratedOrDefaultColumn::DefaultColumn(DefaultColumnDesc {
147                    expr,
148                    ..
149                })) = col.column_desc.generated_or_default_column.as_ref()
150                {
151                    Some((
152                        idx,
153                        ExprImpl::from_expr_proto(expr.as_ref().unwrap())
154                            .expect("expr in default columns corrupted"),
155                    ))
156                } else {
157                    None
158                }
159            })
160            .collect::<HashMap<_, _>>();
161
162        let generated_column_names = table_catalog
163            .generated_column_names()
164            .collect::<HashSet<_>>();
165
166        let check_generated_insert_violation = |bound_column_nums: Option<usize>| -> Result<()> {
167            let generated_column_name = if let Some(column_num) = bound_column_nums {
168                table_catalog
169                    .first_generated_column()
170                    .and_then(|(index, column_name)| {
171                        (column_num > index).then_some(column_name.to_owned())
172                    })
173            } else {
174                cols_to_insert_by_user
175                    .iter()
176                    .map(|col| col.real_value())
177                    .find(|col_name| generated_column_names.contains(col_name.as_str()))
178            };
179
180            if let Some(column_name) = generated_column_name {
181                return Err(ErrorCode::InsertViolation(format!(
182                    "cannot insert a non-DEFAULT value into column \"{0}\"\n  DETAIL: Column \"{0}\" is a generated column.",
183                    column_name
184                ))
185                .into());
186            }
187
188            Ok(())
189        };
190
191        if has_user_specified_columns {
192            check_generated_insert_violation(None)?;
193        }
194
195        if !generated_column_names.is_empty() && !returning_items.is_empty() {
196            return Err(RwError::from(ErrorCode::BindError(
197                "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
198            )));
199        }
200
201        let (returning_list, fields) = self.bind_returning_list(returning_items)?;
202        let is_returning = !returning_list.is_empty();
203
204        let (mut col_indices_to_insert, default_column_indices) = get_col_indices_to_insert(
205            &cols_to_insert_in_table,
206            &cols_to_insert_by_user,
207            &table_name,
208        )?;
209        // Collect the types of columns explicitly specified by the user.
210        let expected_types: Vec<DataType> = col_indices_to_insert
211            .iter()
212            .map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
213            .collect();
214
215        // Collect the nullable of columns explicitly specified by the user.
216        let nullables: Vec<(bool, &str)> = col_indices_to_insert
217            .iter()
218            .map(|idx| {
219                (
220                    cols_to_insert_in_table[*idx].nullable(),
221                    cols_to_insert_in_table[*idx].name(),
222                )
223            })
224            .collect();
225
226        // When the column types of `source` query do not match `expected_types`,
227        // casting is needed.
228        //
229        // In PG, when the `source` is a `VALUES` without order / limit / offset, special treatment
230        // is given and it is NOT equivalent to assignment cast over potential implicit cast inside.
231        // For example, the following is valid:
232        //
233        // ```sql
234        //   create table t (v1 time);
235        //   insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
236        // ```
237        //
238        // But the followings are not:
239        //
240        // ```sql
241        //   values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
242        //   insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05') limit 1;
243        // ```
244        //
245        // Because `timestamp` can cast to `time` in assignment context, but no casting between them
246        // is allowed implicitly.
247        //
248        // In this case, assignment cast should be used directly in `VALUES`, suppressing its
249        // internal implicit cast.
250        // In other cases, the `source` query is handled on its own and assignment cast is done
251        // afterwards.
252        let bound_query;
253        let cast_exprs;
254        let all_nullable = nullables.iter().all(|(nullable, _)| *nullable);
255
256        let bound_column_nums = match source.as_simple_values() {
257            None => {
258                bound_query = self.bind_query(&source)?;
259                if !has_user_specified_columns {
260                    check_generated_insert_violation(Some(bound_query.schema().len()))?;
261                }
262                let actual_types = bound_query.data_types();
263                let type_match = expected_types == actual_types;
264                cast_exprs = if all_nullable && type_match {
265                    vec![]
266                } else {
267                    let mut cast_exprs = actual_types
268                        .into_iter()
269                        .enumerate()
270                        .map(|(i, t)| InputRef::new(i, t).into())
271                        .collect();
272                    if !type_match {
273                        cast_exprs = Self::cast_on_insert(&expected_types, cast_exprs)?
274                    }
275                    if !all_nullable {
276                        cast_exprs =
277                            Self::check_not_null(&nullables, cast_exprs, table_name.as_str())?
278                    }
279                    cast_exprs
280                };
281                bound_query.schema().len()
282            }
283            Some(values) => {
284                let values_len = values
285                    .0
286                    .first()
287                    .expect("values list should not be empty")
288                    .len();
289                if !has_user_specified_columns {
290                    check_generated_insert_violation(Some(values_len))?;
291                }
292                let mut values = self.bind_values(values, Some(&expected_types))?;
293                // let mut bound_values = values.clone();
294
295                if !all_nullable {
296                    values.rows = values
297                        .rows
298                        .into_iter()
299                        .map(|vec| Self::check_not_null(&nullables, vec, table_name.as_str()))
300                        .try_collect()?;
301                }
302
303                bound_query = BoundQuery::with_values(values);
304                cast_exprs = vec![];
305                values_len
306            }
307        };
308
309        let num_target_cols = if has_user_specified_columns {
310            cols_to_insert_by_user.len()
311        } else {
312            cols_to_insert_in_table.len()
313        };
314
315        let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
316            std::cmp::Ordering::Equal => (None, default_column_indices),
317            std::cmp::Ordering::Greater => {
318                if has_user_specified_columns {
319                    // e.g. insert into t (v1, v2) values (7)
320                    (
321                        Some("INSERT has more target columns than expressions"),
322                        vec![],
323                    )
324                } else {
325                    // e.g. create table t (a int, b real)
326                    //      insert into t values (7)
327                    // this kind of usage is fine, null values will be provided
328                    // implicitly.
329                    (None, col_indices_to_insert.split_off(bound_column_nums))
330                }
331            }
332            std::cmp::Ordering::Less => {
333                // e.g. create table t (a int, b real)
334                //      insert into t (v1) values (7, 13)
335                // or   insert into t values (7, 13, 17)
336                (
337                    Some("INSERT has more expressions than target columns"),
338                    vec![],
339                )
340            }
341        };
342        if let Some(msg) = err_msg {
343            return Err(RwError::from(ErrorCode::BindError(msg.to_owned())));
344        }
345
346        let default_columns = default_column_indices
347            .into_iter()
348            .map(|i| {
349                let column = &cols_to_insert_in_table[i];
350                let expr = default_columns_from_catalog
351                    .get(&i)
352                    .cloned()
353                    .unwrap_or_else(|| ExprImpl::literal_null(column.data_type().clone()));
354
355                let expr = if column.nullable() {
356                    expr
357                } else {
358                    FunctionCall::new_unchecked(
359                        ExprType::CheckNotNull,
360                        vec![
361                            expr,
362                            ExprImpl::literal_varchar(column.name().to_owned()),
363                            ExprImpl::literal_varchar(table_name.clone()),
364                        ],
365                        column.data_type().clone(),
366                    )
367                    .into()
368                };
369
370                (i, expr)
371            })
372            .collect_vec();
373
374        let insert = BoundInsert {
375            table_id,
376            table_version_id,
377            table_name,
378            table_visible_columns,
379            owner,
380            row_id_index,
381            column_indices: col_indices_to_insert,
382            default_columns,
383            source: bound_query,
384            cast_exprs,
385            returning_list,
386            returning_schema: if is_returning {
387                Some(Schema { fields })
388            } else {
389                None
390            },
391        };
392        Ok(insert)
393    }
394
395    /// Cast a list of `exprs` to corresponding `expected_types` IN ASSIGNMENT CONTEXT. Make sure
396    /// you understand the difference of implicit, assignment and explicit cast before reusing it.
397    pub(super) fn cast_on_insert(
398        expected_types: &[DataType],
399        exprs: Vec<ExprImpl>,
400    ) -> Result<Vec<ExprImpl>> {
401        let msg = match expected_types.len().cmp(&exprs.len()) {
402            std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
403            _ => {
404                let expr_len = exprs.len();
405                return exprs
406                    .into_iter()
407                    .zip_eq_fast(expected_types.iter().take(expr_len))
408                    .enumerate()
409                    .map(|(i, (e, t))| {
410                        let res = e.cast_assign(t);
411                        if expr_len > 1 {
412                            res.with_context(|| {
413                                format!("failed to cast the {} column", ordinal(i + 1))
414                            })
415                            .map_err(Into::into)
416                        } else {
417                            res.map_err(Into::into)
418                        }
419                    })
420                    .try_collect();
421            }
422        };
423        Err(ErrorCode::BindError(msg.into()).into())
424    }
425
426    /// Add not null check for the columns that are not nullable.
427    pub(super) fn check_not_null(
428        nullables: &Vec<(bool, &str)>,
429        exprs: Vec<ExprImpl>,
430        table_name: &str,
431    ) -> Result<Vec<ExprImpl>> {
432        let msg = match nullables.len().cmp(&exprs.len()) {
433            std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
434            _ => {
435                let expr_len = exprs.len();
436                return exprs
437                    .into_iter()
438                    .zip_eq_fast(nullables.iter().take(expr_len))
439                    .map(|(expr, (nullable, col_name))| {
440                        if !nullable {
441                            let return_type = expr.return_type();
442                            let check_not_null = FunctionCall::new_unchecked(
443                                ExprType::CheckNotNull,
444                                vec![
445                                    expr,
446                                    ExprImpl::literal_varchar((*col_name).to_owned()),
447                                    ExprImpl::literal_varchar(table_name.to_owned()),
448                                ],
449                                return_type,
450                            );
451                            // let res = expr.cast_assign(t.clone());
452                            Ok(check_not_null.into())
453                        } else {
454                            Ok(expr)
455                        }
456                    })
457                    .try_collect();
458            }
459        };
460        Err(ErrorCode::BindError(msg.into()).into())
461    }
462}
463
464/// # Parameters
465/// - `cols_to_insert_in_table`: the list of columns that are visible and non-generated
466///
467/// - `cols_to_insert_by_user`:
468///   The list of column identifiers explicitly specified by the user
469///   in the INSERT statement.
470///
471/// - `table_name`: The name of the target table.
472///
473/// # Return
474/// - `(col_indices_to_insert, default_column_indices)`
475///
476///   - `col_indices_to_insert`:
477///     Column indices corresponding to user-specified columns in the INSERT statement,
478///     preserving the user-defined order.
479///
480///   - `default_column_indices`:
481///     Column indices of remaining columns in the target table that are not explicitly
482///     provided by the user and should be filled with DEFAULT values.
483fn get_col_indices_to_insert(
484    cols_to_insert_in_table: &[ColumnCatalog],
485    cols_to_insert_by_user: &[Ident],
486    table_name: &str,
487) -> Result<(Vec<usize>, Vec<usize>)> {
488    if cols_to_insert_by_user.is_empty() {
489        return Ok(((0..cols_to_insert_in_table.len()).collect(), vec![]));
490    }
491
492    let mut col_indices_to_insert: Vec<usize> = Vec::new();
493
494    // Build a map from column name to column index in the table catalog
495    let col_name_to_idx: HashMap<String, usize> = cols_to_insert_in_table
496        .iter()
497        .enumerate()
498        .map(|(idx, col)| (col.name().to_owned(), idx))
499        .collect();
500
501    let mut seen = HashSet::with_capacity(cols_to_insert_by_user.len());
502    for col_name in cols_to_insert_by_user {
503        let col_name = col_name.real_value();
504        if !seen.insert(col_name.clone()) {
505            return Err(RwError::from(ErrorCode::BindError(
506                "Column specified more than once".to_owned(),
507            )));
508        }
509
510        let idx = col_name_to_idx.get(&col_name).ok_or_else(|| {
511            RwError::from(ErrorCode::BindError(format!(
512                "Column {} not found in table {}",
513                col_name, table_name
514            )))
515        })?;
516
517        col_indices_to_insert.push(*idx);
518    }
519
520    // columns that are in the target table but not in the provided target columns
521    let default_column_indices = if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
522        cols_to_insert_in_table
523            .iter()
524            .enumerate()
525            .filter_map(|(idx, col)| {
526                let column_name = col.name();
527                (!seen.contains(column_name)).then_some(idx)
528            })
529            .collect()
530    } else {
531        vec![]
532    };
533
534    Ok((col_indices_to_insert, default_column_indices))
535}