risingwave_frontend/binder/
insert.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::expr::expr_node::Type as ExprType;
24use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};
25
26use super::BoundQuery;
27use super::statement::RewriteExprsRecursive;
28use crate::binder::{Binder, Clause};
29use crate::catalog::TableId;
30use crate::error::{ErrorCode, Result, RwError};
31use crate::expr::{Expr, ExprImpl, FunctionCall, InputRef};
32use crate::handler::privilege::ObjectCheckItem;
33use crate::user::UserId;
34use crate::utils::ordinal;
35
36#[derive(Debug, Clone)]
37pub struct BoundInsert {
38    /// Id of the table to perform inserting.
39    pub table_id: TableId,
40
41    /// Version id of the table.
42    pub table_version_id: TableVersionId,
43
44    /// Name of the table to perform inserting.
45    pub table_name: String,
46
47    /// All visible columns of the table, used as the output schema of `Insert` plan node if
48    /// `RETURNING` is specified.
49    pub table_visible_columns: Vec<ColumnCatalog>,
50
51    /// Owner of the table to perform inserting.
52    pub owner: UserId,
53
54    // An optional column index of row ID. If the primary key is specified by the user,
55    // this will be `None`.
56    pub row_id_index: Option<usize>,
57
58    /// User defined columns in which to insert
59    /// Is equal to [0, 2, 1] for insert statement
60    /// create table t1 (v1 int, v2 int, v3 int); insert into t1 (v1, v3, v2) values (5, 6, 7);
61    /// Empty if user does not define insert columns
62    pub column_indices: Vec<usize>,
63
64    /// Columns that user fails to specify
65    /// Will set to default value (current null)
66    pub default_columns: Vec<(usize, ExprImpl)>,
67
68    pub source: BoundQuery,
69
70    /// Used as part of an extra `Project` when the column types of the query does not match
71    /// those of the table. This does not include a simple `VALUE`. See comments in code for
72    /// details.
73    pub cast_exprs: Vec<ExprImpl>,
74
75    // used for the 'RETURNING" keyword to indicate the returning items and schema
76    // if the list is empty and the schema is None, the output schema will be a INT64 as the
77    // affected row cnt
78    pub returning_list: Vec<ExprImpl>,
79
80    pub returning_schema: Option<Schema>,
81}
82
83impl RewriteExprsRecursive for BoundInsert {
84    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
85        self.source.rewrite_exprs_recursive(rewriter);
86
87        let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
88            .into_iter()
89            .map(|expr| rewriter.rewrite_expr(expr))
90            .collect::<Vec<_>>();
91        self.cast_exprs = new_cast_exprs;
92
93        let new_returning_list = std::mem::take(&mut self.returning_list)
94            .into_iter()
95            .map(|expr| rewriter.rewrite_expr(expr))
96            .collect::<Vec<_>>();
97        self.returning_list = new_returning_list;
98    }
99}
100
101impl Binder {
102    pub(super) fn bind_insert(
103        &mut self,
104        name: ObjectName,
105        cols_to_insert_by_user: Vec<Ident>,
106        source: Query,
107        returning_items: Vec<SelectItem>,
108    ) -> Result<BoundInsert> {
109        let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, &name)?;
110        // bind insert table
111        self.context.clause = Some(Clause::Insert);
112        let bound_table = self.bind_table(schema_name.as_deref(), &table_name)?;
113        let table_catalog = &bound_table.table_catalog;
114        Self::check_for_dml(table_catalog, true)?;
115        self.check_privilege(
116            ObjectCheckItem::new(
117                table_catalog.owner,
118                AclMode::Insert,
119                table_name.clone(),
120                table_catalog.id,
121            ),
122            table_catalog.database_id,
123        )?;
124
125        let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
126        let default_columns_from_catalog =
127            table_catalog.default_columns().collect::<BTreeMap<_, _>>();
128        let table_id = table_catalog.id;
129        let owner = table_catalog.owner;
130        let table_version_id = table_catalog.version_id().expect("table must be versioned");
131        let table_visible_columns = table_catalog
132            .columns()
133            .iter()
134            .filter(|c| !c.is_hidden())
135            .cloned()
136            .collect_vec();
137        let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec();
138
139        let generated_column_names = table_catalog
140            .generated_column_names()
141            .collect::<HashSet<_>>();
142
143        let check_generated_insert_violation = |bound_column_nums: Option<usize>| -> Result<()> {
144            let generated_column_name = if let Some(column_num) = bound_column_nums {
145                table_catalog
146                    .first_generated_column()
147                    .and_then(|(index, column_name)| {
148                        (column_num > index).then_some(column_name.to_owned())
149                    })
150            } else {
151                cols_to_insert_by_user
152                    .iter()
153                    .map(|col| col.real_value())
154                    .find(|col_name| generated_column_names.contains(col_name.as_str()))
155            };
156
157            if let Some(column_name) = generated_column_name {
158                return Err(ErrorCode::InsertViolation(format!(
159                    "cannot insert a non-DEFAULT value into column \"{0}\"\n  DETAIL: Column \"{0}\" is a generated column.",
160                    column_name
161                ))
162                .into());
163            }
164
165            Ok(())
166        };
167
168        if has_user_specified_columns {
169            check_generated_insert_violation(None)?;
170        }
171
172        if !generated_column_names.is_empty() && !returning_items.is_empty() {
173            return Err(RwError::from(ErrorCode::BindError(
174                "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
175            )));
176        }
177
178        // TODO(yuhao): refine this if row_id is always the last column.
179        //
180        // `row_id_index` in insert operation should rule out generated column
181        let row_id_index = {
182            if let Some(row_id_index) = table_catalog.row_id_index {
183                let mut cnt = 0;
184                for col in table_catalog.columns().iter().take(row_id_index + 1) {
185                    if col.is_generated() {
186                        cnt += 1;
187                    }
188                }
189                Some(row_id_index - cnt)
190            } else {
191                None
192            }
193        };
194
195        let (returning_list, fields) = self.bind_returning_list(returning_items)?;
196        let is_returning = !returning_list.is_empty();
197
198        let (mut col_indices_to_insert, default_column_indices) = get_col_indices_to_insert(
199            &cols_to_insert_in_table,
200            &cols_to_insert_by_user,
201            &table_name,
202        )?;
203        let expected_types: Vec<DataType> = col_indices_to_insert
204            .iter()
205            .map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
206            .collect();
207
208        let nullables: Vec<(bool, &str)> = col_indices_to_insert
209            .iter()
210            .map(|idx| {
211                (
212                    cols_to_insert_in_table[*idx].nullable(),
213                    cols_to_insert_in_table[*idx].name(),
214                )
215            })
216            .collect();
217
218        // When the column types of `source` query do not match `expected_types`,
219        // casting is needed.
220        //
221        // In PG, when the `source` is a `VALUES` without order / limit / offset, special treatment
222        // is given and it is NOT equivalent to assignment cast over potential implicit cast inside.
223        // For example, the following is valid:
224        //
225        // ```sql
226        //   create table t (v1 time);
227        //   insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
228        // ```
229        //
230        // But the followings are not:
231        //
232        // ```sql
233        //   values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
234        //   insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05') limit 1;
235        // ```
236        //
237        // Because `timestamp` can cast to `time` in assignment context, but no casting between them
238        // is allowed implicitly.
239        //
240        // In this case, assignment cast should be used directly in `VALUES`, suppressing its
241        // internal implicit cast.
242        // In other cases, the `source` query is handled on its own and assignment cast is done
243        // afterwards.
244        let bound_query;
245        let cast_exprs;
246        let all_nullable = nullables.iter().all(|(nullable, _)| *nullable);
247
248        let bound_column_nums = match source.as_simple_values() {
249            None => {
250                bound_query = self.bind_query(&source)?;
251                if !has_user_specified_columns {
252                    check_generated_insert_violation(Some(bound_query.schema().len()))?;
253                }
254                let actual_types = bound_query.data_types();
255                let type_match = expected_types == actual_types;
256                cast_exprs = if all_nullable && type_match {
257                    vec![]
258                } else {
259                    let mut cast_exprs = actual_types
260                        .into_iter()
261                        .enumerate()
262                        .map(|(i, t)| InputRef::new(i, t).into())
263                        .collect();
264                    if !type_match {
265                        cast_exprs = Self::cast_on_insert(&expected_types, cast_exprs)?
266                    }
267                    if !all_nullable {
268                        cast_exprs =
269                            Self::check_not_null(&nullables, cast_exprs, table_name.as_str())?
270                    }
271                    cast_exprs
272                };
273                bound_query.schema().len()
274            }
275            Some(values) => {
276                let values_len = values
277                    .0
278                    .first()
279                    .expect("values list should not be empty")
280                    .len();
281                if !has_user_specified_columns {
282                    check_generated_insert_violation(Some(values_len))?;
283                }
284                let mut values = self.bind_values(values, Some(&expected_types))?;
285                // let mut bound_values = values.clone();
286
287                if !all_nullable {
288                    values.rows = values
289                        .rows
290                        .into_iter()
291                        .map(|vec| Self::check_not_null(&nullables, vec, table_name.as_str()))
292                        .try_collect()?;
293                }
294
295                bound_query = BoundQuery::with_values(values);
296                cast_exprs = vec![];
297                values_len
298            }
299        };
300
301        let num_target_cols = if has_user_specified_columns {
302            cols_to_insert_by_user.len()
303        } else {
304            cols_to_insert_in_table.len()
305        };
306
307        let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
308            std::cmp::Ordering::Equal => (None, default_column_indices),
309            std::cmp::Ordering::Greater => {
310                if has_user_specified_columns {
311                    // e.g. insert into t (v1, v2) values (7)
312                    (
313                        Some("INSERT has more target columns than expressions"),
314                        vec![],
315                    )
316                } else {
317                    // e.g. create table t (a int, b real)
318                    //      insert into t values (7)
319                    // this kind of usage is fine, null values will be provided
320                    // implicitly.
321                    (None, col_indices_to_insert.split_off(bound_column_nums))
322                }
323            }
324            std::cmp::Ordering::Less => {
325                // e.g. create table t (a int, b real)
326                //      insert into t (v1) values (7, 13)
327                // or   insert into t values (7, 13, 17)
328                (
329                    Some("INSERT has more expressions than target columns"),
330                    vec![],
331                )
332            }
333        };
334        if let Some(msg) = err_msg {
335            return Err(RwError::from(ErrorCode::BindError(msg.to_owned())));
336        }
337
338        let default_columns = default_column_indices
339            .into_iter()
340            .map(|i| {
341                let column = &cols_to_insert_in_table[i];
342                let expr = default_columns_from_catalog
343                    .get(&i)
344                    .cloned()
345                    .unwrap_or_else(|| ExprImpl::literal_null(column.data_type().clone()));
346
347                let expr = if column.nullable() {
348                    expr
349                } else {
350                    FunctionCall::new_unchecked(
351                        ExprType::CheckNotNull,
352                        vec![
353                            expr,
354                            ExprImpl::literal_varchar(column.name().to_owned()),
355                            ExprImpl::literal_varchar(table_name.clone()),
356                        ],
357                        column.data_type().clone(),
358                    )
359                    .into()
360                };
361
362                (i, expr)
363            })
364            .collect_vec();
365
366        let insert = BoundInsert {
367            table_id,
368            table_version_id,
369            table_name,
370            table_visible_columns,
371            owner,
372            row_id_index,
373            column_indices: col_indices_to_insert,
374            default_columns,
375            source: bound_query,
376            cast_exprs,
377            returning_list,
378            returning_schema: if is_returning {
379                Some(Schema { fields })
380            } else {
381                None
382            },
383        };
384        Ok(insert)
385    }
386
387    /// Cast a list of `exprs` to corresponding `expected_types` IN ASSIGNMENT CONTEXT. Make sure
388    /// you understand the difference of implicit, assignment and explicit cast before reusing it.
389    pub(super) fn cast_on_insert(
390        expected_types: &[DataType],
391        exprs: Vec<ExprImpl>,
392    ) -> Result<Vec<ExprImpl>> {
393        let msg = match expected_types.len().cmp(&exprs.len()) {
394            std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
395            _ => {
396                let expr_len = exprs.len();
397                return exprs
398                    .into_iter()
399                    .zip_eq_fast(expected_types.iter().take(expr_len))
400                    .enumerate()
401                    .map(|(i, (e, t))| {
402                        let res = e.cast_assign(t);
403                        if expr_len > 1 {
404                            res.with_context(|| {
405                                format!("failed to cast the {} column", ordinal(i + 1))
406                            })
407                            .map_err(Into::into)
408                        } else {
409                            res.map_err(Into::into)
410                        }
411                    })
412                    .try_collect();
413            }
414        };
415        Err(ErrorCode::BindError(msg.into()).into())
416    }
417
418    /// Add not null check for the columns that are not nullable.
419    pub(super) fn check_not_null(
420        nullables: &Vec<(bool, &str)>,
421        exprs: Vec<ExprImpl>,
422        table_name: &str,
423    ) -> Result<Vec<ExprImpl>> {
424        let msg = match nullables.len().cmp(&exprs.len()) {
425            std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
426            _ => {
427                let expr_len = exprs.len();
428                return exprs
429                    .into_iter()
430                    .zip_eq_fast(nullables.iter().take(expr_len))
431                    .map(|(expr, (nullable, col_name))| {
432                        if !nullable {
433                            let return_type = expr.return_type();
434                            let check_not_null = FunctionCall::new_unchecked(
435                                ExprType::CheckNotNull,
436                                vec![
437                                    expr,
438                                    ExprImpl::literal_varchar((*col_name).to_owned()),
439                                    ExprImpl::literal_varchar(table_name.to_owned()),
440                                ],
441                                return_type,
442                            );
443                            // let res = expr.cast_assign(t.clone());
444                            Ok(check_not_null.into())
445                        } else {
446                            Ok(expr)
447                        }
448                    })
449                    .try_collect();
450            }
451        };
452        Err(ErrorCode::BindError(msg.into()).into())
453    }
454}
455
456/// Returned indices have the same length as `cols_to_insert_in_table`.
457/// The first elements have the same order as `cols_to_insert_by_user`.
458/// The rest are what's not specified by the user.
459///
460/// Also checks there are no duplicate nor unknown columns provided by the user.
461fn get_col_indices_to_insert(
462    cols_to_insert_in_table: &[ColumnCatalog],
463    cols_to_insert_by_user: &[Ident],
464    table_name: &str,
465) -> Result<(Vec<usize>, Vec<usize>)> {
466    if cols_to_insert_by_user.is_empty() {
467        return Ok(((0..cols_to_insert_in_table.len()).collect(), vec![]));
468    }
469
470    let mut col_indices_to_insert: Vec<usize> = Vec::new();
471
472    let mut col_name_to_idx: HashMap<String, usize> = HashMap::new();
473    for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() {
474        col_name_to_idx.insert(col.name().to_owned(), col_idx);
475    }
476
477    for col_name in cols_to_insert_by_user {
478        let col_name = &col_name.real_value();
479        match col_name_to_idx.get_mut(col_name) {
480            Some(value_ref) => {
481                if *value_ref == usize::MAX {
482                    return Err(RwError::from(ErrorCode::BindError(
483                        "Column specified more than once".to_owned(),
484                    )));
485                }
486                col_indices_to_insert.push(*value_ref);
487                *value_ref = usize::MAX; // mark this column name, for duplicate
488                // detection
489            }
490            None => {
491                // Invalid column name found
492                return Err(RwError::from(ErrorCode::BindError(format!(
493                    "Column {} not found in table {}",
494                    col_name, table_name
495                ))));
496            }
497        }
498    }
499
500    // columns that are in the target table but not in the provided target columns
501    let default_column_indices = if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
502        let mut cols = vec![];
503        for col in cols_to_insert_in_table {
504            if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) {
505                if *col_to_insert_idx != usize::MAX {
506                    cols.push(*col_to_insert_idx);
507                }
508            } else {
509                unreachable!();
510            }
511        }
512        cols
513    } else {
514        vec![]
515    };
516
517    Ok((col_indices_to_insert, default_column_indices))
518}