risingwave_frontend/binder/
values.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::bail_not_implemented;
17use risingwave_common::catalog::{Field, Schema};
18use risingwave_common::types::DataType;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_sqlparser::ast::Values;
21
22use super::bind_context::Clause;
23use super::statement::RewriteExprsRecursive;
24use crate::binder::Binder;
25use crate::error::{ErrorCode, Result};
26use crate::expr::{CorrelatedId, Depth, ExprImpl, align_types};
27
28#[derive(Debug, Clone)]
29pub struct BoundValues {
30    pub rows: Vec<Vec<ExprImpl>>,
31    pub schema: Schema,
32}
33
34impl RewriteExprsRecursive for BoundValues {
35    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
36        let new_rows = std::mem::take(&mut self.rows)
37            .into_iter()
38            .map(|exprs| {
39                exprs
40                    .into_iter()
41                    .map(|expr| rewriter.rewrite_expr(expr))
42                    .collect::<Vec<_>>()
43            })
44            .collect::<Vec<_>>();
45        self.rows = new_rows;
46    }
47}
48
49impl BoundValues {
50    /// The schema returned of this [`BoundValues`].
51    pub fn schema(&self) -> &Schema {
52        &self.schema
53    }
54
55    pub fn exprs(&self) -> impl Iterator<Item = &ExprImpl> {
56        self.rows.iter().flatten()
57    }
58
59    pub fn exprs_mut(&mut self) -> impl Iterator<Item = &mut ExprImpl> {
60        self.rows.iter_mut().flatten()
61    }
62
63    pub fn is_correlated(&self, depth: Depth) -> bool {
64        self.exprs()
65            .any(|expr| expr.has_correlated_input_ref_by_depth(depth))
66    }
67
68    pub fn collect_correlated_indices_by_depth_and_assign_id(
69        &mut self,
70        depth: Depth,
71        correlated_id: CorrelatedId,
72    ) -> Vec<usize> {
73        self.exprs_mut()
74            .flat_map(|expr| {
75                expr.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
76            })
77            .collect()
78    }
79}
80
81fn values_column_name(values_id: usize, col_id: usize) -> String {
82    format!("*VALUES*_{}.column_{}", values_id, col_id)
83}
84
85impl Binder {
86    /// Bind [`Values`] with given `expected_types`. If no types are expected, a compatible type for
87    /// all rows will be used.
88    /// If values are shorter than expected, `NULL`s will be filled.
89    pub(super) fn bind_values(
90        &mut self,
91        values: Values,
92        expected_types: Option<Vec<DataType>>,
93    ) -> Result<BoundValues> {
94        assert!(!values.0.is_empty());
95
96        self.context.clause = Some(Clause::Values);
97        let vec2d = values.0;
98        let mut bound = vec2d
99            .into_iter()
100            .map(|vec| vec.into_iter().map(|expr| self.bind_expr(expr)).collect())
101            .collect::<Result<Vec<Vec<_>>>>()?;
102        self.context.clause = None;
103
104        let num_columns = bound[0].len();
105        if bound.iter().any(|row| row.len() != num_columns) {
106            return Err(
107                ErrorCode::BindError("VALUES lists must all be the same length".into()).into(),
108            );
109        }
110
111        // Calculate column types.
112        let types = match expected_types {
113            Some(types) => {
114                bound = bound
115                    .into_iter()
116                    .map(|vec| Self::cast_on_insert(&types.clone(), vec))
117                    .try_collect()?;
118
119                types
120            }
121            None => (0..num_columns)
122                .map(|col_index| align_types(bound.iter_mut().map(|row| &mut row[col_index])))
123                .try_collect()?,
124        };
125
126        let values_id = self.next_values_id();
127        let schema = Schema::new(
128            types
129                .into_iter()
130                .take(num_columns)
131                .zip_eq_fast(0..num_columns)
132                .map(|(ty, col_id)| Field::with_name(ty, values_column_name(values_id, col_id)))
133                .collect(),
134        );
135
136        let bound_values = BoundValues {
137            rows: bound,
138            schema,
139        };
140        if bound_values
141            .rows
142            .iter()
143            .flatten()
144            .any(|expr| expr.has_subquery())
145        {
146            bail_not_implemented!("Subquery in VALUES");
147        }
148        if bound_values.is_correlated(1) {
149            bail_not_implemented!("CorrelatedInputRef in VALUES");
150        }
151        Ok(bound_values)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use risingwave_common::util::iter_util::zip_eq_fast;
158    use risingwave_sqlparser::ast::{Expr, Value};
159
160    use super::*;
161    use crate::binder::test_utils::mock_binder;
162    use crate::expr::Expr as _;
163
164    #[tokio::test]
165    async fn test_bind_values() {
166        let mut binder = mock_binder();
167
168        // Test i32 -> decimal.
169        let expr1 = Expr::Value(Value::Number("1".to_owned()));
170        let expr2 = Expr::Value(Value::Number("1.1".to_owned()));
171        let values = Values(vec![vec![expr1], vec![expr2]]);
172        let res = binder.bind_values(values, None).unwrap();
173
174        let types = vec![DataType::Decimal];
175        let n_cols = types.len();
176        let schema = Schema::new(
177            types
178                .into_iter()
179                .zip_eq_fast(0..n_cols)
180                .map(|(ty, col_id)| Field::with_name(ty, values_column_name(0, col_id)))
181                .collect(),
182        );
183
184        assert_eq!(res.schema, schema);
185        for vec in res.rows {
186            for (expr, ty) in zip_eq_fast(vec, schema.data_types()) {
187                assert_eq!(expr.return_type(), ty);
188            }
189        }
190    }
191}