risingwave_frontend/binder/
insert.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::acl::AclMode;
20use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_pb::expr::expr_node::Type as ExprType;
24use risingwave_pb::user::grant_privilege::PbObject;
25use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};
26
27use super::BoundQuery;
28use super::statement::RewriteExprsRecursive;
29use crate::binder::{Binder, Clause};
30use crate::catalog::TableId;
31use crate::error::{ErrorCode, Result, RwError};
32use crate::expr::{Expr, ExprImpl, FunctionCall, InputRef};
33use crate::handler::privilege::ObjectCheckItem;
34use crate::user::UserId;
35use crate::utils::ordinal;
36
37#[derive(Debug, Clone)]
38pub struct BoundInsert {
39    /// Id of the table to perform inserting.
40    pub table_id: TableId,
41
42    /// Version id of the table.
43    pub table_version_id: TableVersionId,
44
45    /// Name of the table to perform inserting.
46    pub table_name: String,
47
48    /// All visible columns of the table, used as the output schema of `Insert` plan node if
49    /// `RETURNING` is specified.
50    pub table_visible_columns: Vec<ColumnCatalog>,
51
52    /// Owner of the table to perform inserting.
53    pub owner: UserId,
54
55    // An optional column index of row ID. If the primary key is specified by the user,
56    // this will be `None`.
57    pub row_id_index: Option<usize>,
58
59    /// User defined columns in which to insert
60    /// Is equal to [0, 2, 1] for insert statement
61    /// create table t1 (v1 int, v2 int, v3 int); insert into t1 (v1, v3, v2) values (5, 6, 7);
62    /// Empty if user does not define insert columns
63    pub column_indices: Vec<usize>,
64
65    /// Columns that user fails to specify
66    /// Will set to default value (current null)
67    pub default_columns: Vec<(usize, ExprImpl)>,
68
69    pub source: BoundQuery,
70
71    /// Used as part of an extra `Project` when the column types of the query does not match
72    /// those of the table. This does not include a simple `VALUE`. See comments in code for
73    /// details.
74    pub cast_exprs: Vec<ExprImpl>,
75
76    // used for the 'RETURNING" keyword to indicate the returning items and schema
77    // if the list is empty and the schema is None, the output schema will be a INT64 as the
78    // affected row cnt
79    pub returning_list: Vec<ExprImpl>,
80
81    pub returning_schema: Option<Schema>,
82}
83
84impl RewriteExprsRecursive for BoundInsert {
85    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
86        self.source.rewrite_exprs_recursive(rewriter);
87
88        let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
89            .into_iter()
90            .map(|expr| rewriter.rewrite_expr(expr))
91            .collect::<Vec<_>>();
92        self.cast_exprs = new_cast_exprs;
93
94        let new_returning_list = std::mem::take(&mut self.returning_list)
95            .into_iter()
96            .map(|expr| rewriter.rewrite_expr(expr))
97            .collect::<Vec<_>>();
98        self.returning_list = new_returning_list;
99    }
100}
101
102impl Binder {
103    pub(super) fn bind_insert(
104        &mut self,
105        name: ObjectName,
106        cols_to_insert_by_user: Vec<Ident>,
107        source: Query,
108        returning_items: Vec<SelectItem>,
109    ) -> Result<BoundInsert> {
110        let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;
111        // bind insert table
112        self.context.clause = Some(Clause::Insert);
113        let bound_table = self.bind_table(schema_name.as_deref(), &table_name)?;
114        let table_catalog = &bound_table.table_catalog;
115        Self::check_for_dml(table_catalog, true)?;
116        self.check_privilege(
117            ObjectCheckItem::new(
118                table_catalog.owner,
119                AclMode::Insert,
120                table_name.clone(),
121                PbObject::TableId(table_catalog.id.table_id),
122            ),
123            table_catalog.database_id,
124        )?;
125
126        let default_columns_from_catalog =
127            table_catalog.default_columns().collect::<BTreeMap<_, _>>();
128        let table_id = table_catalog.id;
129        let owner = table_catalog.owner;
130        let table_version_id = table_catalog.version_id().expect("table must be versioned");
131        let table_visible_columns = table_catalog
132            .columns()
133            .iter()
134            .filter(|c| !c.is_hidden())
135            .cloned()
136            .collect_vec();
137        let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec();
138
139        let generated_column_names = table_catalog
140            .generated_column_names()
141            .collect::<HashSet<_>>();
142        for col in &cols_to_insert_by_user {
143            let query_col_name = col.real_value();
144            if generated_column_names.contains(query_col_name.as_str()) {
145                return Err(RwError::from(ErrorCode::BindError(format!(
146                    "cannot insert a non-DEFAULT value into column \"{0}\".  Column \"{0}\" is a generated column.",
147                    &query_col_name
148                ))));
149            }
150        }
151        if !generated_column_names.is_empty() && !returning_items.is_empty() {
152            return Err(RwError::from(ErrorCode::BindError(
153                "`RETURNING` clause is not supported for tables with generated columns".to_owned(),
154            )));
155        }
156
157        // TODO(yuhao): refine this if row_id is always the last column.
158        //
159        // `row_id_index` in insert operation should rule out generated column
160        let row_id_index = {
161            if let Some(row_id_index) = table_catalog.row_id_index {
162                let mut cnt = 0;
163                for col in table_catalog.columns().iter().take(row_id_index + 1) {
164                    if col.is_generated() {
165                        cnt += 1;
166                    }
167                }
168                Some(row_id_index - cnt)
169            } else {
170                None
171            }
172        };
173
174        let (returning_list, fields) = self.bind_returning_list(returning_items)?;
175        let is_returning = !returning_list.is_empty();
176
177        let (mut col_indices_to_insert, default_column_indices) = get_col_indices_to_insert(
178            &cols_to_insert_in_table,
179            &cols_to_insert_by_user,
180            &table_name,
181        )?;
182        let expected_types: Vec<DataType> = col_indices_to_insert
183            .iter()
184            .map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
185            .collect();
186
187        let nullables: Vec<(bool, &str)> = col_indices_to_insert
188            .iter()
189            .map(|idx| {
190                (
191                    cols_to_insert_in_table[*idx].nullable(),
192                    cols_to_insert_in_table[*idx].name(),
193                )
194            })
195            .collect();
196
197        // When the column types of `source` query do not match `expected_types`,
198        // casting is needed.
199        //
200        // In PG, when the `source` is a `VALUES` without order / limit / offset, special treatment
201        // is given and it is NOT equivalent to assignment cast over potential implicit cast inside.
202        // For example, the following is valid:
203        //
204        // ```sql
205        //   create table t (v1 time);
206        //   insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
207        // ```
208        //
209        // But the followings are not:
210        //
211        // ```sql
212        //   values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
213        //   insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05') limit 1;
214        // ```
215        //
216        // Because `timestamp` can cast to `time` in assignment context, but no casting between them
217        // is allowed implicitly.
218        //
219        // In this case, assignment cast should be used directly in `VALUES`, suppressing its
220        // internal implicit cast.
221        // In other cases, the `source` query is handled on its own and assignment cast is done
222        // afterwards.
223        let bound_query;
224        let cast_exprs;
225        let all_nullable = nullables.iter().all(|(nullable, _)| *nullable);
226
227        let bound_column_nums = match source.as_simple_values() {
228            None => {
229                bound_query = self.bind_query(source)?;
230                let actual_types = bound_query.data_types();
231                let type_match = expected_types == actual_types;
232                cast_exprs = if all_nullable && type_match {
233                    vec![]
234                } else {
235                    let mut cast_exprs = actual_types
236                        .into_iter()
237                        .enumerate()
238                        .map(|(i, t)| InputRef::new(i, t).into())
239                        .collect();
240                    if !type_match {
241                        cast_exprs = Self::cast_on_insert(&expected_types, cast_exprs)?
242                    }
243                    if !all_nullable {
244                        cast_exprs =
245                            Self::check_not_null(&nullables, cast_exprs, table_name.as_str())?
246                    }
247                    cast_exprs
248                };
249                bound_query.schema().len()
250            }
251            Some(values) => {
252                let values_len = values
253                    .0
254                    .first()
255                    .expect("values list should not be empty")
256                    .len();
257                let mut values = self.bind_values(values.clone(), Some(expected_types))?;
258                // let mut bound_values = values.clone();
259
260                if !all_nullable {
261                    values.rows = values
262                        .rows
263                        .into_iter()
264                        .map(|vec| Self::check_not_null(&nullables, vec, table_name.as_str()))
265                        .try_collect()?;
266                }
267
268                bound_query = BoundQuery::with_values(values);
269                cast_exprs = vec![];
270                values_len
271            }
272        };
273
274        let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
275        let num_target_cols = if has_user_specified_columns {
276            cols_to_insert_by_user.len()
277        } else {
278            cols_to_insert_in_table.len()
279        };
280
281        let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
282            std::cmp::Ordering::Equal => (None, default_column_indices),
283            std::cmp::Ordering::Greater => {
284                if has_user_specified_columns {
285                    // e.g. insert into t (v1, v2) values (7)
286                    (
287                        Some("INSERT has more target columns than expressions"),
288                        vec![],
289                    )
290                } else {
291                    // e.g. create table t (a int, b real)
292                    //      insert into t values (7)
293                    // this kind of usage is fine, null values will be provided
294                    // implicitly.
295                    (None, col_indices_to_insert.split_off(bound_column_nums))
296                }
297            }
298            std::cmp::Ordering::Less => {
299                // e.g. create table t (a int, b real)
300                //      insert into t (v1) values (7, 13)
301                // or   insert into t values (7, 13, 17)
302                (
303                    Some("INSERT has more expressions than target columns"),
304                    vec![],
305                )
306            }
307        };
308        if let Some(msg) = err_msg {
309            return Err(RwError::from(ErrorCode::BindError(msg.to_owned())));
310        }
311
312        let default_columns = default_column_indices
313            .into_iter()
314            .map(|i| {
315                (
316                    i,
317                    default_columns_from_catalog
318                        .get(&i)
319                        .cloned()
320                        .unwrap_or_else(|| {
321                            ExprImpl::literal_null(cols_to_insert_in_table[i].data_type().clone())
322                        }),
323                )
324            })
325            .collect_vec();
326
327        let insert = BoundInsert {
328            table_id,
329            table_version_id,
330            table_name,
331            table_visible_columns,
332            owner,
333            row_id_index,
334            column_indices: col_indices_to_insert,
335            default_columns,
336            source: bound_query,
337            cast_exprs,
338            returning_list,
339            returning_schema: if is_returning {
340                Some(Schema { fields })
341            } else {
342                None
343            },
344        };
345        Ok(insert)
346    }
347
348    /// Cast a list of `exprs` to corresponding `expected_types` IN ASSIGNMENT CONTEXT. Make sure
349    /// you understand the difference of implicit, assignment and explicit cast before reusing it.
350    pub(super) fn cast_on_insert(
351        expected_types: &Vec<DataType>,
352        exprs: Vec<ExprImpl>,
353    ) -> Result<Vec<ExprImpl>> {
354        let msg = match expected_types.len().cmp(&exprs.len()) {
355            std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
356            _ => {
357                let expr_len = exprs.len();
358                return exprs
359                    .into_iter()
360                    .zip_eq_fast(expected_types.iter().take(expr_len))
361                    .enumerate()
362                    .map(|(i, (e, t))| {
363                        let res = e.cast_assign(t.clone());
364                        if expr_len > 1 {
365                            res.with_context(|| {
366                                format!("failed to cast the {} column", ordinal(i + 1))
367                            })
368                            .map_err(Into::into)
369                        } else {
370                            res.map_err(Into::into)
371                        }
372                    })
373                    .try_collect();
374            }
375        };
376        Err(ErrorCode::BindError(msg.into()).into())
377    }
378
379    /// Add not null check for the columns that are not nullable.
380    pub(super) fn check_not_null(
381        nullables: &Vec<(bool, &str)>,
382        exprs: Vec<ExprImpl>,
383        table_name: &str,
384    ) -> Result<Vec<ExprImpl>> {
385        let msg = match nullables.len().cmp(&exprs.len()) {
386            std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
387            _ => {
388                let expr_len = exprs.len();
389                return exprs
390                    .into_iter()
391                    .zip_eq_fast(nullables.iter().take(expr_len))
392                    .map(|(expr, (nullable, col_name))| {
393                        if !nullable {
394                            let return_type = expr.return_type();
395                            let check_not_null = FunctionCall::new_unchecked(
396                                ExprType::CheckNotNull,
397                                vec![
398                                    expr,
399                                    ExprImpl::literal_varchar((*col_name).to_owned()),
400                                    ExprImpl::literal_varchar(table_name.to_owned()),
401                                ],
402                                return_type,
403                            );
404                            // let res = expr.cast_assign(t.clone());
405                            Ok(check_not_null.into())
406                        } else {
407                            Ok(expr)
408                        }
409                    })
410                    .try_collect();
411            }
412        };
413        Err(ErrorCode::BindError(msg.into()).into())
414    }
415}
416
417/// Returned indices have the same length as `cols_to_insert_in_table`.
418/// The first elements have the same order as `cols_to_insert_by_user`.
419/// The rest are what's not specified by the user.
420///
421/// Also checks there are no duplicate nor unknown columns provided by the user.
422fn get_col_indices_to_insert(
423    cols_to_insert_in_table: &[ColumnCatalog],
424    cols_to_insert_by_user: &[Ident],
425    table_name: &str,
426) -> Result<(Vec<usize>, Vec<usize>)> {
427    if cols_to_insert_by_user.is_empty() {
428        return Ok(((0..cols_to_insert_in_table.len()).collect(), vec![]));
429    }
430
431    let mut col_indices_to_insert: Vec<usize> = Vec::new();
432
433    let mut col_name_to_idx: HashMap<String, usize> = HashMap::new();
434    for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() {
435        col_name_to_idx.insert(col.name().to_owned(), col_idx);
436    }
437
438    for col_name in cols_to_insert_by_user {
439        let col_name = &col_name.real_value();
440        match col_name_to_idx.get_mut(col_name) {
441            Some(value_ref) => {
442                if *value_ref == usize::MAX {
443                    return Err(RwError::from(ErrorCode::BindError(
444                        "Column specified more than once".to_owned(),
445                    )));
446                }
447                col_indices_to_insert.push(*value_ref);
448                *value_ref = usize::MAX; // mark this column name, for duplicate
449                // detection
450            }
451            None => {
452                // Invalid column name found
453                return Err(RwError::from(ErrorCode::BindError(format!(
454                    "Column {} not found in table {}",
455                    col_name, table_name
456                ))));
457            }
458        }
459    }
460
461    // columns that are in the target table but not in the provided target columns
462    let default_column_indices = if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
463        let mut cols = vec![];
464        for col in cols_to_insert_in_table {
465            if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) {
466                if *col_to_insert_idx != usize::MAX {
467                    cols.push(*col_to_insert_idx);
468                }
469            } else {
470                unreachable!();
471            }
472        }
473        cols
474    } else {
475        vec![]
476    };
477
478    Ok((col_indices_to_insert, default_column_indices))
479}