risingwave_frontend/binder/
set_expr.rs1use std::borrow::Cow;
16use std::collections::HashMap;
17
18use risingwave_common::bail_not_implemented;
19use risingwave_common::catalog::Schema;
20use risingwave_common::util::column_index_mapping::ColIndexMapping;
21use risingwave_common::util::iter_util::ZipEqFast;
22use risingwave_sqlparser::ast::{Corresponding, SetExpr, SetOperator};
23
24use super::UNNAMED_COLUMN;
25use super::statement::RewriteExprsRecursive;
26use crate::binder::{BindContext, Binder, BoundQuery, BoundSelect, BoundValues};
27use crate::error::{ErrorCode, Result};
28use crate::expr::{CorrelatedId, Depth, align_types};
29
30#[derive(Debug, Clone)]
33pub enum BoundSetExpr {
34 Select(Box<BoundSelect>),
35 Query(Box<BoundQuery>),
36 Values(Box<BoundValues>),
37 SetOperation {
39 op: BoundSetOperation,
40 all: bool,
41 corresponding_col_indices: Option<(ColIndexMapping, ColIndexMapping)>,
43 left: Box<BoundSetExpr>,
44 right: Box<BoundSetExpr>,
45 },
46}
47
48impl RewriteExprsRecursive for BoundSetExpr {
49 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
50 match self {
51 BoundSetExpr::Select(inner) => inner.rewrite_exprs_recursive(rewriter),
52 BoundSetExpr::Query(inner) => inner.rewrite_exprs_recursive(rewriter),
53 BoundSetExpr::Values(inner) => inner.rewrite_exprs_recursive(rewriter),
54 BoundSetExpr::SetOperation { left, right, .. } => {
55 left.rewrite_exprs_recursive(rewriter);
56 right.rewrite_exprs_recursive(rewriter);
57 }
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub enum BoundSetOperation {
64 Union,
65 Except,
66 Intersect,
67}
68
69impl From<SetOperator> for BoundSetOperation {
70 fn from(value: SetOperator) -> Self {
71 match value {
72 SetOperator::Union => BoundSetOperation::Union,
73 SetOperator::Intersect => BoundSetOperation::Intersect,
74 SetOperator::Except => BoundSetOperation::Except,
75 }
76 }
77}
78
79impl BoundSetExpr {
80 pub fn schema(&self) -> Cow<'_, Schema> {
82 match self {
83 BoundSetExpr::Select(s) => Cow::Borrowed(s.schema()),
84 BoundSetExpr::Values(v) => Cow::Borrowed(v.schema()),
85 BoundSetExpr::Query(q) => q.schema(),
86 BoundSetExpr::SetOperation {
87 left,
88 corresponding_col_indices,
89 ..
90 } => {
91 if let Some((mapping_l, _)) = corresponding_col_indices {
92 let mut schema = vec![None; mapping_l.target_size()];
93 for (src, tar) in mapping_l.mapping_pairs() {
94 assert_eq!(schema[tar], None);
95 schema[tar] = Some(left.schema().fields[src].clone());
96 }
97 Cow::Owned(Schema::new(
98 schema.into_iter().map(|x| x.unwrap()).collect(),
99 ))
100 } else {
101 left.schema()
102 }
103 }
104 }
105 }
106
107 pub fn is_correlated(&self, depth: Depth) -> bool {
108 match self {
109 BoundSetExpr::Select(s) => s.is_correlated(depth),
110 BoundSetExpr::Values(v) => v.is_correlated(depth),
111 BoundSetExpr::Query(q) => q.is_correlated(depth),
112 BoundSetExpr::SetOperation { left, right, .. } => {
113 left.is_correlated(depth) || right.is_correlated(depth)
114 }
115 }
116 }
117
118 pub fn collect_correlated_indices_by_depth_and_assign_id(
119 &mut self,
120 depth: Depth,
121 correlated_id: CorrelatedId,
122 ) -> Vec<usize> {
123 match self {
124 BoundSetExpr::Select(s) => {
125 s.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
126 }
127 BoundSetExpr::Values(v) => {
128 v.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
129 }
130 BoundSetExpr::Query(q) => {
131 q.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
132 }
133 BoundSetExpr::SetOperation { left, right, .. } => {
134 let mut correlated_indices = vec![];
135 correlated_indices.extend(
136 left.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
137 );
138 correlated_indices.extend(
139 right.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
140 );
141 correlated_indices
142 }
143 }
144 }
145}
146
147impl Binder {
148 pub(crate) fn align_schema(
151 mut left: &mut BoundSetExpr,
152 mut right: &mut BoundSetExpr,
153 op: SetOperator,
154 ) -> Result<()> {
155 if left.schema().fields.len() != right.schema().fields.len() {
156 return Err(ErrorCode::InvalidInputSyntax(format!(
157 "each {} query must have the same number of columns",
158 op
159 ))
160 .into());
161 }
162
163 if let (BoundSetExpr::Select(l_select), BoundSetExpr::Select(r_select)) =
166 (&mut left, &mut right)
167 {
168 for (i, (l, r)) in l_select
169 .select_items
170 .iter_mut()
171 .zip_eq_fast(r_select.select_items.iter_mut())
172 .enumerate()
173 {
174 let Ok(column_type) = align_types(vec![l, r].into_iter()) else {
175 return Err(ErrorCode::InvalidInputSyntax(format!(
176 "{} types {} and {} cannot be matched. Columns' name are `{}` and `{}`.",
177 op,
178 l_select.schema.fields[i].data_type,
179 r_select.schema.fields[i].data_type,
180 l_select.schema.fields[i].name,
181 r_select.schema.fields[i].name,
182 ))
183 .into());
184 };
185 l_select.schema.fields[i].data_type = column_type.clone();
186 r_select.schema.fields[i].data_type = column_type;
187 }
188 }
189
190 Self::validate(left, right, op)
191 }
192
193 pub(crate) fn validate(
195 left: &BoundSetExpr,
196 right: &BoundSetExpr,
197 op: SetOperator,
198 ) -> Result<()> {
199 for (a, b) in left
200 .schema()
201 .fields
202 .iter()
203 .zip_eq_fast(right.schema().fields.iter())
204 {
205 if a.data_type != b.data_type {
206 return Err(ErrorCode::InvalidInputSyntax(format!(
207 "{} types {} and {} cannot be matched. Columns' name are {} and {}.",
208 op,
209 a.data_type.prost_type_name().as_str_name(),
210 b.data_type.prost_type_name().as_str_name(),
211 a.name,
212 b.name,
213 ))
214 .into());
215 }
216 }
217 Ok(())
218 }
219
220 fn corresponding(
223 &self,
224 left: &BoundSetExpr,
225 right: &BoundSetExpr,
226 corresponding: Corresponding,
227 op: &SetOperator,
228 ) -> Result<(ColIndexMapping, ColIndexMapping)> {
229 let check_duplicate_name = |set_expr: &BoundSetExpr| {
230 let mut name2idx = HashMap::new();
231 for (idx, field) in set_expr.schema().fields.iter().enumerate() {
232 if name2idx.insert(field.name.clone(), idx).is_some() {
233 return Err(ErrorCode::InvalidInputSyntax(format!(
234 "Duplicated column name `{}` in a column list of the query in a {} operation. Column list of the query: ({}).",
235 field.name,
236 op,
237 set_expr.schema().formatted_col_names(),
238 )));
239 }
240 }
241 Ok(name2idx)
242 };
243
244 let name2idx_l = check_duplicate_name(left)?;
247 let name2idx_r = check_duplicate_name(right)?;
248
249 let mut corresponding_col_idx_l = vec![];
250 let mut corresponding_col_idx_r = vec![];
251
252 if let Some(column_list) = corresponding.column_list() {
253 for column in column_list {
255 let col_name = column.real_value();
256 if let Some(idx_l) = name2idx_l.get(&col_name)
257 && let Some(idx_r) = name2idx_l.get(&col_name)
258 {
259 corresponding_col_idx_l.push(*idx_l);
260 corresponding_col_idx_r.push(*idx_r);
261 } else {
262 return Err(ErrorCode::InvalidInputSyntax(format!(
263 "Column name `{}` in CORRESPONDING BY is not found in a side of the {} operation. \
264 It shall be included in both sides.",
265 col_name,
266 op,
267 )).into());
268 }
269 }
270 } else {
271 for field in &left.schema().fields {
274 let col_name = &field.name;
275 if col_name != UNNAMED_COLUMN
276 && let Some(idx_l) = name2idx_l.get(col_name)
277 && let Some(idx_r) = name2idx_r.get(col_name)
278 {
279 corresponding_col_idx_l.push(*idx_l);
280 corresponding_col_idx_r.push(*idx_r);
281 }
282 }
283
284 if corresponding_col_idx_l.is_empty() {
285 return Err(ErrorCode::InvalidInputSyntax(
286 format!(
287 "When CORRESPONDING is specified, at least one column of the left side \
288 shall have a column name that is the column name of some column of the right side in a {} operation. \
289 Left side query column list: ({}). \
290 Right side query column list: ({}).",
291 op,
292 left.schema().formatted_col_names(),
293 right.schema().formatted_col_names(),
294 )
295 )
296 .into());
297 }
298 }
299
300 let corresponding_mapping_l =
301 ColIndexMapping::with_remaining_columns(&corresponding_col_idx_l, left.schema().len());
302 let corresponding_mapping_r =
303 ColIndexMapping::with_remaining_columns(&corresponding_col_idx_r, right.schema().len());
304
305 Ok((corresponding_mapping_l, corresponding_mapping_r))
306 }
307
308 pub(super) fn bind_set_expr(&mut self, set_expr: SetExpr) -> Result<BoundSetExpr> {
309 match set_expr {
310 SetExpr::Select(s) => Ok(BoundSetExpr::Select(Box::new(self.bind_select(*s)?))),
311 SetExpr::Values(v) => Ok(BoundSetExpr::Values(Box::new(self.bind_values(v, None)?))),
312 SetExpr::Query(q) => Ok(BoundSetExpr::Query(Box::new(self.bind_query(*q)?))),
313 SetExpr::SetOperation {
314 op,
315 all,
316 corresponding,
317 left,
318 right,
319 } => {
320 match op.clone() {
321 SetOperator::Union | SetOperator::Intersect | SetOperator::Except => {
322 let mut left = self.bind_set_expr(*left)?;
323 let new_context = std::mem::take(&mut self.context);
325 self.context
326 .cte_to_relation
327 .clone_from(&new_context.cte_to_relation);
328 self.context.disable_security_invoker =
329 new_context.disable_security_invoker;
330 let mut right = self.bind_set_expr(*right)?;
331
332 let corresponding_col_indices = if corresponding.is_corresponding() {
333 Some(Self::corresponding(
334 self,
335 &left,
336 &right,
337 corresponding,
338 &op,
339 )?)
340 } else {
342 Self::align_schema(&mut left, &mut right, op.clone())?;
343 None
344 };
345
346 if all {
347 match op {
348 SetOperator::Union => {}
349 SetOperator::Intersect | SetOperator::Except => {
350 bail_not_implemented!("{} all", op);
351 }
352 }
353 }
354
355 self.context = BindContext::default();
360 self.context.cte_to_relation = new_context.cte_to_relation;
361 Ok(BoundSetExpr::SetOperation {
362 op: op.into(),
363 all,
364 corresponding_col_indices,
365 left: Box::new(left),
366 right: Box::new(right),
367 })
368 }
369 }
370 }
371 }
372 }
373}