risingwave_frontend/binder/
values.rs1use itertools::Itertools;
16use risingwave_common::bail_not_implemented;
17use risingwave_common::catalog::{Field, Schema};
18use risingwave_common::types::DataType;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_sqlparser::ast::Values;
21
22use super::bind_context::Clause;
23use super::statement::RewriteExprsRecursive;
24use crate::binder::Binder;
25use crate::error::{ErrorCode, Result};
26use crate::expr::{CorrelatedId, Depth, ExprImpl, align_types};
27
28#[derive(Debug, Clone)]
29pub struct BoundValues {
30 pub rows: Vec<Vec<ExprImpl>>,
31 pub schema: Schema,
32}
33
34impl RewriteExprsRecursive for BoundValues {
35 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
36 let new_rows = std::mem::take(&mut self.rows)
37 .into_iter()
38 .map(|exprs| {
39 exprs
40 .into_iter()
41 .map(|expr| rewriter.rewrite_expr(expr))
42 .collect::<Vec<_>>()
43 })
44 .collect::<Vec<_>>();
45 self.rows = new_rows;
46 }
47}
48
49impl BoundValues {
50 pub fn schema(&self) -> &Schema {
52 &self.schema
53 }
54
55 pub fn exprs(&self) -> impl Iterator<Item = &ExprImpl> {
56 self.rows.iter().flatten()
57 }
58
59 pub fn exprs_mut(&mut self) -> impl Iterator<Item = &mut ExprImpl> {
60 self.rows.iter_mut().flatten()
61 }
62
63 pub fn is_correlated(&self, depth: Depth) -> bool {
64 self.exprs()
65 .any(|expr| expr.has_correlated_input_ref_by_depth(depth))
66 }
67
68 pub fn collect_correlated_indices_by_depth_and_assign_id(
69 &mut self,
70 depth: Depth,
71 correlated_id: CorrelatedId,
72 ) -> Vec<usize> {
73 self.exprs_mut()
74 .flat_map(|expr| {
75 expr.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
76 })
77 .collect()
78 }
79}
80
81fn values_column_name(values_id: usize, col_id: usize) -> String {
82 format!("*VALUES*_{}.column_{}", values_id, col_id)
83}
84
85impl Binder {
86 pub(super) fn bind_values(
90 &mut self,
91 values: Values,
92 expected_types: Option<Vec<DataType>>,
93 ) -> Result<BoundValues> {
94 assert!(!values.0.is_empty());
95
96 self.context.clause = Some(Clause::Values);
97 let vec2d = values.0;
98 let mut bound = vec2d
99 .into_iter()
100 .map(|vec| vec.into_iter().map(|expr| self.bind_expr(expr)).collect())
101 .collect::<Result<Vec<Vec<_>>>>()?;
102 self.context.clause = None;
103
104 let num_columns = bound[0].len();
105 if bound.iter().any(|row| row.len() != num_columns) {
106 return Err(
107 ErrorCode::BindError("VALUES lists must all be the same length".into()).into(),
108 );
109 }
110
111 let types = match expected_types {
113 Some(types) => {
114 bound = bound
115 .into_iter()
116 .map(|vec| Self::cast_on_insert(&types.clone(), vec))
117 .try_collect()?;
118
119 types
120 }
121 None => (0..num_columns)
122 .map(|col_index| align_types(bound.iter_mut().map(|row| &mut row[col_index])))
123 .try_collect()?,
124 };
125
126 let values_id = self.next_values_id();
127 let schema = Schema::new(
128 types
129 .into_iter()
130 .take(num_columns)
131 .zip_eq_fast(0..num_columns)
132 .map(|(ty, col_id)| Field::with_name(ty, values_column_name(values_id, col_id)))
133 .collect(),
134 );
135
136 let bound_values = BoundValues {
137 rows: bound,
138 schema,
139 };
140 if bound_values
141 .rows
142 .iter()
143 .flatten()
144 .any(|expr| expr.has_subquery())
145 {
146 bail_not_implemented!("Subquery in VALUES");
147 }
148 if bound_values.is_correlated(1) {
149 bail_not_implemented!("CorrelatedInputRef in VALUES");
150 }
151 Ok(bound_values)
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use risingwave_common::util::iter_util::zip_eq_fast;
158 use risingwave_sqlparser::ast::{Expr, Value};
159
160 use super::*;
161 use crate::binder::test_utils::mock_binder;
162 use crate::expr::Expr as _;
163
164 #[tokio::test]
165 async fn test_bind_values() {
166 let mut binder = mock_binder();
167
168 let expr1 = Expr::Value(Value::Number("1".to_owned()));
170 let expr2 = Expr::Value(Value::Number("1.1".to_owned()));
171 let values = Values(vec![vec![expr1], vec![expr2]]);
172 let res = binder.bind_values(values, None).unwrap();
173
174 let types = vec![DataType::Decimal];
175 let n_cols = types.len();
176 let schema = Schema::new(
177 types
178 .into_iter()
179 .zip_eq_fast(0..n_cols)
180 .map(|(ty, col_id)| Field::with_name(ty, values_column_name(0, col_id)))
181 .collect(),
182 );
183
184 assert_eq!(res.schema, schema);
185 for vec in res.rows {
186 for (expr, ty) in zip_eq_fast(vec, schema.data_types()) {
187 assert_eq!(expr.return_type(), ty);
188 }
189 }
190 }
191}