risingwave_frontend/binder/
values.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::bail_not_implemented;
17use risingwave_common::catalog::{Field, Schema};
18use risingwave_common::types::DataType;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_sqlparser::ast::Values;
21
22use super::bind_context::Clause;
23use super::statement::RewriteExprsRecursive;
24use crate::binder::Binder;
25use crate::error::{ErrorCode, Result};
26use crate::expr::{CorrelatedId, Depth, ExprImpl, align_types};
27
28#[derive(Debug, Clone)]
29pub struct BoundValues {
30    pub rows: Vec<Vec<ExprImpl>>,
31    pub schema: Schema,
32}
33
34impl RewriteExprsRecursive for BoundValues {
35    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
36        let new_rows = std::mem::take(&mut self.rows)
37            .into_iter()
38            .map(|exprs| {
39                exprs
40                    .into_iter()
41                    .map(|expr| rewriter.rewrite_expr(expr))
42                    .collect::<Vec<_>>()
43            })
44            .collect::<Vec<_>>();
45        self.rows = new_rows;
46    }
47}
48
49impl BoundValues {
50    /// The schema returned of this [`BoundValues`].
51    pub fn schema(&self) -> &Schema {
52        &self.schema
53    }
54
55    pub fn exprs(&self) -> impl Iterator<Item = &ExprImpl> {
56        self.rows.iter().flatten()
57    }
58
59    pub fn exprs_mut(&mut self) -> impl Iterator<Item = &mut ExprImpl> {
60        self.rows.iter_mut().flatten()
61    }
62
63    pub fn is_correlated_by_depth(&self, depth: Depth) -> bool {
64        self.exprs()
65            .any(|expr| expr.has_correlated_input_ref_by_depth(depth))
66    }
67
68    pub fn is_correlated_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
69        self.exprs()
70            .any(|expr| expr.has_correlated_input_ref_by_correlated_id(correlated_id))
71    }
72
73    pub fn collect_correlated_indices_by_depth_and_assign_id(
74        &mut self,
75        depth: Depth,
76        correlated_id: CorrelatedId,
77    ) -> Vec<usize> {
78        self.exprs_mut()
79            .flat_map(|expr| {
80                expr.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
81            })
82            .collect()
83    }
84}
85
86fn values_column_name(values_id: usize, col_id: usize) -> String {
87    format!("*VALUES*_{}.column_{}", values_id, col_id)
88}
89
90impl Binder {
91    /// Bind [`Values`] with given `expected_types`. If no types are expected, a compatible type for
92    /// all rows will be used.
93    /// If values are shorter than expected, `NULL`s will be filled.
94    pub(super) fn bind_values(
95        &mut self,
96        values: Values,
97        expected_types: Option<Vec<DataType>>,
98    ) -> Result<BoundValues> {
99        assert!(!values.0.is_empty());
100
101        self.context.clause = Some(Clause::Values);
102        let vec2d = values.0;
103        let mut bound = vec2d
104            .into_iter()
105            .map(|vec| vec.into_iter().map(|expr| self.bind_expr(expr)).collect())
106            .collect::<Result<Vec<Vec<_>>>>()?;
107        self.context.clause = None;
108
109        let num_columns = bound[0].len();
110        if bound.iter().any(|row| row.len() != num_columns) {
111            return Err(
112                ErrorCode::BindError("VALUES lists must all be the same length".into()).into(),
113            );
114        }
115
116        // Calculate column types.
117        let types = match expected_types {
118            Some(types) => {
119                bound = bound
120                    .into_iter()
121                    .map(|vec| Self::cast_on_insert(&types.clone(), vec))
122                    .try_collect()?;
123
124                types
125            }
126            None => (0..num_columns)
127                .map(|col_index| align_types(bound.iter_mut().map(|row| &mut row[col_index])))
128                .try_collect()?,
129        };
130
131        let values_id = self.next_values_id();
132        let schema = Schema::new(
133            types
134                .into_iter()
135                .take(num_columns)
136                .zip_eq_fast(0..num_columns)
137                .map(|(ty, col_id)| Field::with_name(ty, values_column_name(values_id, col_id)))
138                .collect(),
139        );
140
141        let bound_values = BoundValues {
142            rows: bound,
143            schema,
144        };
145        if bound_values
146            .rows
147            .iter()
148            .flatten()
149            .any(|expr| expr.has_subquery())
150        {
151            bail_not_implemented!("Subquery in VALUES");
152        }
153        if bound_values.is_correlated_by_depth(1) {
154            bail_not_implemented!("CorrelatedInputRef in VALUES");
155        }
156        Ok(bound_values)
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use risingwave_common::util::iter_util::zip_eq_fast;
163    use risingwave_sqlparser::ast::{Expr, Value};
164
165    use super::*;
166    use crate::binder::test_utils::mock_binder;
167    use crate::expr::Expr as _;
168
169    #[tokio::test]
170    async fn test_bind_values() {
171        let mut binder = mock_binder();
172
173        // Test i32 -> decimal.
174        let expr1 = Expr::Value(Value::Number("1".to_owned()));
175        let expr2 = Expr::Value(Value::Number("1.1".to_owned()));
176        let values = Values(vec![vec![expr1], vec![expr2]]);
177        let res = binder.bind_values(values, None).unwrap();
178
179        let types = vec![DataType::Decimal];
180        let n_cols = types.len();
181        let schema = Schema::new(
182            types
183                .into_iter()
184                .zip_eq_fast(0..n_cols)
185                .map(|(ty, col_id)| Field::with_name(ty, values_column_name(0, col_id)))
186                .collect(),
187        );
188
189        assert_eq!(res.schema, schema);
190        for vec in res.rows {
191            for (expr, ty) in zip_eq_fast(vec, schema.data_types()) {
192                assert_eq!(expr.return_type(), ty);
193            }
194        }
195    }
196}