risingwave_frontend/binder/
set_expr.rs1use std::borrow::Cow;
16use std::collections::HashMap;
17
18use risingwave_common::bail_not_implemented;
19use risingwave_common::catalog::Schema;
20use risingwave_common::util::column_index_mapping::ColIndexMapping;
21use risingwave_common::util::iter_util::ZipEqFast;
22use risingwave_sqlparser::ast::{Corresponding, SetExpr, SetOperator};
23
24use super::UNNAMED_COLUMN;
25use super::statement::RewriteExprsRecursive;
26use crate::binder::{BindContext, Binder, BoundQuery, BoundSelect, BoundValues};
27use crate::error::{ErrorCode, Result};
28use crate::expr::{CorrelatedId, Depth, align_types};
29
30#[derive(Debug, Clone)]
33pub enum BoundSetExpr {
34 Select(Box<BoundSelect>),
35 Query(Box<BoundQuery>),
36 Values(Box<BoundValues>),
37 SetOperation {
39 op: BoundSetOperation,
40 all: bool,
41 corresponding_col_indices: Option<(ColIndexMapping, ColIndexMapping)>,
43 left: Box<BoundSetExpr>,
44 right: Box<BoundSetExpr>,
45 },
46}
47
48impl RewriteExprsRecursive for BoundSetExpr {
49 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
50 match self {
51 BoundSetExpr::Select(inner) => inner.rewrite_exprs_recursive(rewriter),
52 BoundSetExpr::Query(inner) => inner.rewrite_exprs_recursive(rewriter),
53 BoundSetExpr::Values(inner) => inner.rewrite_exprs_recursive(rewriter),
54 BoundSetExpr::SetOperation { left, right, .. } => {
55 left.rewrite_exprs_recursive(rewriter);
56 right.rewrite_exprs_recursive(rewriter);
57 }
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub enum BoundSetOperation {
64 Union,
65 Except,
66 Intersect,
67}
68
69impl From<SetOperator> for BoundSetOperation {
70 fn from(value: SetOperator) -> Self {
71 match value {
72 SetOperator::Union => BoundSetOperation::Union,
73 SetOperator::Intersect => BoundSetOperation::Intersect,
74 SetOperator::Except => BoundSetOperation::Except,
75 }
76 }
77}
78
79impl BoundSetExpr {
80 pub fn schema(&self) -> Cow<'_, Schema> {
82 match self {
83 BoundSetExpr::Select(s) => Cow::Borrowed(s.schema()),
84 BoundSetExpr::Values(v) => Cow::Borrowed(v.schema()),
85 BoundSetExpr::Query(q) => q.schema(),
86 BoundSetExpr::SetOperation {
87 left,
88 corresponding_col_indices,
89 ..
90 } => {
91 if let Some((mapping_l, _)) = corresponding_col_indices {
92 let mut schema = vec![None; mapping_l.target_size()];
93 for (src, tar) in mapping_l.mapping_pairs() {
94 assert_eq!(schema[tar], None);
95 schema[tar] = Some(left.schema().fields[src].clone());
96 }
97 Cow::Owned(Schema::new(
98 schema.into_iter().map(|x| x.unwrap()).collect(),
99 ))
100 } else {
101 left.schema()
102 }
103 }
104 }
105 }
106
107 pub fn is_correlated_by_depth(&self, depth: Depth) -> bool {
108 match self {
109 BoundSetExpr::Select(s) => s.is_correlated_by_depth(depth),
110 BoundSetExpr::Values(v) => v.is_correlated_by_depth(depth),
111 BoundSetExpr::Query(q) => q.is_correlated_by_depth(depth),
112 BoundSetExpr::SetOperation { left, right, .. } => {
113 left.is_correlated_by_depth(depth) || right.is_correlated_by_depth(depth)
114 }
115 }
116 }
117
118 pub fn is_correlated_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
119 match self {
120 BoundSetExpr::Select(s) => s.is_correlated_by_correlated_id(correlated_id),
121 BoundSetExpr::Values(v) => v.is_correlated_by_correlated_id(correlated_id),
122 BoundSetExpr::Query(q) => q.is_correlated_by_correlated_id(correlated_id),
123 BoundSetExpr::SetOperation { left, right, .. } => {
124 left.is_correlated_by_correlated_id(correlated_id)
125 || right.is_correlated_by_correlated_id(correlated_id)
126 }
127 }
128 }
129
130 pub fn collect_correlated_indices_by_depth_and_assign_id(
131 &mut self,
132 depth: Depth,
133 correlated_id: CorrelatedId,
134 ) -> Vec<usize> {
135 match self {
136 BoundSetExpr::Select(s) => {
137 s.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
138 }
139 BoundSetExpr::Values(v) => {
140 v.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
141 }
142 BoundSetExpr::Query(q) => {
143 q.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
144 }
145 BoundSetExpr::SetOperation { left, right, .. } => {
146 let mut correlated_indices = vec![];
147 correlated_indices.extend(
148 left.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
149 );
150 correlated_indices.extend(
151 right.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
152 );
153 correlated_indices
154 }
155 }
156 }
157}
158
159impl Binder {
160 pub(crate) fn align_schema(
163 mut left: &mut BoundSetExpr,
164 mut right: &mut BoundSetExpr,
165 op: SetOperator,
166 ) -> Result<()> {
167 if left.schema().fields.len() != right.schema().fields.len() {
168 return Err(ErrorCode::InvalidInputSyntax(format!(
169 "each {} query must have the same number of columns",
170 op
171 ))
172 .into());
173 }
174
175 if let (BoundSetExpr::Select(l_select), BoundSetExpr::Select(r_select)) =
178 (&mut left, &mut right)
179 {
180 for (i, (l, r)) in l_select
181 .select_items
182 .iter_mut()
183 .zip_eq_fast(r_select.select_items.iter_mut())
184 .enumerate()
185 {
186 let Ok(column_type) = align_types(vec![l, r].into_iter()) else {
187 return Err(ErrorCode::InvalidInputSyntax(format!(
188 "{} types {} and {} cannot be matched. Columns' name are `{}` and `{}`.",
189 op,
190 l_select.schema.fields[i].data_type,
191 r_select.schema.fields[i].data_type,
192 l_select.schema.fields[i].name,
193 r_select.schema.fields[i].name,
194 ))
195 .into());
196 };
197 l_select.schema.fields[i].data_type = column_type.clone();
198 r_select.schema.fields[i].data_type = column_type;
199 }
200 }
201
202 Self::validate(left, right, op)
203 }
204
205 pub(crate) fn validate(
207 left: &BoundSetExpr,
208 right: &BoundSetExpr,
209 op: SetOperator,
210 ) -> Result<()> {
211 for (a, b) in left
212 .schema()
213 .fields
214 .iter()
215 .zip_eq_fast(right.schema().fields.iter())
216 {
217 if a.data_type != b.data_type {
218 return Err(ErrorCode::InvalidInputSyntax(format!(
219 "{} types {} and {} cannot be matched. Columns' name are {} and {}.",
220 op,
221 a.data_type.prost_type_name().as_str_name(),
222 b.data_type.prost_type_name().as_str_name(),
223 a.name,
224 b.name,
225 ))
226 .into());
227 }
228 }
229 Ok(())
230 }
231
232 fn corresponding(
235 &self,
236 left: &BoundSetExpr,
237 right: &BoundSetExpr,
238 corresponding: Corresponding,
239 op: &SetOperator,
240 ) -> Result<(ColIndexMapping, ColIndexMapping)> {
241 let check_duplicate_name = |set_expr: &BoundSetExpr| {
242 let mut name2idx = HashMap::new();
243 for (idx, field) in set_expr.schema().fields.iter().enumerate() {
244 if name2idx.insert(field.name.clone(), idx).is_some() {
245 return Err(ErrorCode::InvalidInputSyntax(format!(
246 "Duplicated column name `{}` in a column list of the query in a {} operation. Column list of the query: ({}).",
247 field.name,
248 op,
249 set_expr.schema().formatted_col_names(),
250 )));
251 }
252 }
253 Ok(name2idx)
254 };
255
256 let name2idx_l = check_duplicate_name(left)?;
259 let name2idx_r = check_duplicate_name(right)?;
260
261 let mut corresponding_col_idx_l = vec![];
262 let mut corresponding_col_idx_r = vec![];
263
264 if let Some(column_list) = corresponding.column_list() {
265 for column in column_list {
267 let col_name = column.real_value();
268 if let Some(idx_l) = name2idx_l.get(&col_name)
269 && let Some(idx_r) = name2idx_l.get(&col_name)
270 {
271 corresponding_col_idx_l.push(*idx_l);
272 corresponding_col_idx_r.push(*idx_r);
273 } else {
274 return Err(ErrorCode::InvalidInputSyntax(format!(
275 "Column name `{}` in CORRESPONDING BY is not found in a side of the {} operation. \
276 It shall be included in both sides.",
277 col_name,
278 op,
279 )).into());
280 }
281 }
282 } else {
283 for field in &left.schema().fields {
286 let col_name = &field.name;
287 if col_name != UNNAMED_COLUMN
288 && let Some(idx_l) = name2idx_l.get(col_name)
289 && let Some(idx_r) = name2idx_r.get(col_name)
290 {
291 corresponding_col_idx_l.push(*idx_l);
292 corresponding_col_idx_r.push(*idx_r);
293 }
294 }
295
296 if corresponding_col_idx_l.is_empty() {
297 return Err(ErrorCode::InvalidInputSyntax(
298 format!(
299 "When CORRESPONDING is specified, at least one column of the left side \
300 shall have a column name that is the column name of some column of the right side in a {} operation. \
301 Left side query column list: ({}). \
302 Right side query column list: ({}).",
303 op,
304 left.schema().formatted_col_names(),
305 right.schema().formatted_col_names(),
306 )
307 )
308 .into());
309 }
310 }
311
312 let corresponding_mapping_l =
313 ColIndexMapping::with_remaining_columns(&corresponding_col_idx_l, left.schema().len());
314 let corresponding_mapping_r =
315 ColIndexMapping::with_remaining_columns(&corresponding_col_idx_r, right.schema().len());
316
317 Ok((corresponding_mapping_l, corresponding_mapping_r))
318 }
319
320 pub(super) fn bind_set_expr(&mut self, set_expr: SetExpr) -> Result<BoundSetExpr> {
321 match set_expr {
322 SetExpr::Select(s) => Ok(BoundSetExpr::Select(Box::new(self.bind_select(*s)?))),
323 SetExpr::Values(v) => Ok(BoundSetExpr::Values(Box::new(self.bind_values(v, None)?))),
324 SetExpr::Query(q) => Ok(BoundSetExpr::Query(Box::new(self.bind_query(*q)?))),
325 SetExpr::SetOperation {
326 op,
327 all,
328 corresponding,
329 left,
330 right,
331 } => {
332 match op.clone() {
333 SetOperator::Union | SetOperator::Intersect | SetOperator::Except => {
334 let mut left = self.bind_set_expr(*left)?;
335 let new_context = std::mem::take(&mut self.context);
337 self.context
338 .cte_to_relation
339 .clone_from(&new_context.cte_to_relation);
340 self.context.disable_security_invoker =
341 new_context.disable_security_invoker;
342 let mut right = self.bind_set_expr(*right)?;
343
344 let corresponding_col_indices = if corresponding.is_corresponding() {
345 Some(Self::corresponding(
346 self,
347 &left,
348 &right,
349 corresponding,
350 &op,
351 )?)
352 } else {
354 Self::align_schema(&mut left, &mut right, op.clone())?;
355 None
356 };
357
358 if all {
359 match op {
360 SetOperator::Union => {}
361 SetOperator::Intersect | SetOperator::Except => {
362 bail_not_implemented!("{} all", op);
363 }
364 }
365 }
366
367 self.context = BindContext::default();
372 self.context.cte_to_relation = new_context.cte_to_relation;
373 Ok(BoundSetExpr::SetOperation {
374 op: op.into(),
375 all,
376 corresponding_col_indices,
377 left: Box::new(left),
378 right: Box::new(right),
379 })
380 }
381 }
382 }
383 }
384 }
385}