risingwave_frontend/binder/
set_expr.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::collections::HashMap;
17
18use risingwave_common::bail_not_implemented;
19use risingwave_common::catalog::Schema;
20use risingwave_common::util::column_index_mapping::ColIndexMapping;
21use risingwave_common::util::iter_util::ZipEqFast;
22use risingwave_sqlparser::ast::{Corresponding, SetExpr, SetOperator};
23
24use super::UNNAMED_COLUMN;
25use super::statement::RewriteExprsRecursive;
26use crate::binder::{BindContext, Binder, BoundQuery, BoundSelect, BoundValues};
27use crate::error::{ErrorCode, Result};
28use crate::expr::{CorrelatedId, Depth, align_types};
29
30/// Part of a validated query, without order or limit clause. It may be composed of smaller
31/// `BoundSetExpr`(s) via set operators (e.g., union).
32#[derive(Debug, Clone)]
33pub enum BoundSetExpr {
34    Select(Box<BoundSelect>),
35    Query(Box<BoundQuery>),
36    Values(Box<BoundValues>),
37    /// UNION/EXCEPT/INTERSECT of two queries
38    SetOperation {
39        op: BoundSetOperation,
40        all: bool,
41        // Corresponding columns of the left and right side.
42        corresponding_col_indices: Option<(ColIndexMapping, ColIndexMapping)>,
43        left: Box<BoundSetExpr>,
44        right: Box<BoundSetExpr>,
45    },
46}
47
48impl RewriteExprsRecursive for BoundSetExpr {
49    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
50        match self {
51            BoundSetExpr::Select(inner) => inner.rewrite_exprs_recursive(rewriter),
52            BoundSetExpr::Query(inner) => inner.rewrite_exprs_recursive(rewriter),
53            BoundSetExpr::Values(inner) => inner.rewrite_exprs_recursive(rewriter),
54            BoundSetExpr::SetOperation { left, right, .. } => {
55                left.rewrite_exprs_recursive(rewriter);
56                right.rewrite_exprs_recursive(rewriter);
57            }
58        }
59    }
60}
61
62#[derive(Debug, Clone)]
63pub enum BoundSetOperation {
64    Union,
65    Except,
66    Intersect,
67}
68
69impl From<SetOperator> for BoundSetOperation {
70    fn from(value: SetOperator) -> Self {
71        match value {
72            SetOperator::Union => BoundSetOperation::Union,
73            SetOperator::Intersect => BoundSetOperation::Intersect,
74            SetOperator::Except => BoundSetOperation::Except,
75        }
76    }
77}
78
79impl BoundSetExpr {
80    /// The schema returned by this [`BoundSetExpr`].
81    pub fn schema(&self) -> Cow<'_, Schema> {
82        match self {
83            BoundSetExpr::Select(s) => Cow::Borrowed(s.schema()),
84            BoundSetExpr::Values(v) => Cow::Borrowed(v.schema()),
85            BoundSetExpr::Query(q) => q.schema(),
86            BoundSetExpr::SetOperation {
87                left,
88                corresponding_col_indices,
89                ..
90            } => {
91                if let Some((mapping_l, _)) = corresponding_col_indices {
92                    let mut schema = vec![None; mapping_l.target_size()];
93                    for (src, tar) in mapping_l.mapping_pairs() {
94                        assert_eq!(schema[tar], None);
95                        schema[tar] = Some(left.schema().fields[src].clone());
96                    }
97                    Cow::Owned(Schema::new(
98                        schema.into_iter().map(|x| x.unwrap()).collect(),
99                    ))
100                } else {
101                    left.schema()
102                }
103            }
104        }
105    }
106
107    pub fn is_correlated(&self, depth: Depth) -> bool {
108        match self {
109            BoundSetExpr::Select(s) => s.is_correlated(depth),
110            BoundSetExpr::Values(v) => v.is_correlated(depth),
111            BoundSetExpr::Query(q) => q.is_correlated(depth),
112            BoundSetExpr::SetOperation { left, right, .. } => {
113                left.is_correlated(depth) || right.is_correlated(depth)
114            }
115        }
116    }
117
118    pub fn collect_correlated_indices_by_depth_and_assign_id(
119        &mut self,
120        depth: Depth,
121        correlated_id: CorrelatedId,
122    ) -> Vec<usize> {
123        match self {
124            BoundSetExpr::Select(s) => {
125                s.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
126            }
127            BoundSetExpr::Values(v) => {
128                v.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
129            }
130            BoundSetExpr::Query(q) => {
131                q.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
132            }
133            BoundSetExpr::SetOperation { left, right, .. } => {
134                let mut correlated_indices = vec![];
135                correlated_indices.extend(
136                    left.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
137                );
138                correlated_indices.extend(
139                    right.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
140                );
141                correlated_indices
142            }
143        }
144    }
145}
146
147impl Binder {
148    /// note: `align_schema` only works when the `left` and `right`
149    /// are both select expression(s).
150    pub(crate) fn align_schema(
151        mut left: &mut BoundSetExpr,
152        mut right: &mut BoundSetExpr,
153        op: SetOperator,
154    ) -> Result<()> {
155        if left.schema().fields.len() != right.schema().fields.len() {
156            return Err(ErrorCode::InvalidInputSyntax(format!(
157                "each {} query must have the same number of columns",
158                op
159            ))
160            .into());
161        }
162
163        // handle type alignment for select union select
164        // e.g., select 1 UNION ALL select NULL
165        if let (BoundSetExpr::Select(l_select), BoundSetExpr::Select(r_select)) =
166            (&mut left, &mut right)
167        {
168            for (i, (l, r)) in l_select
169                .select_items
170                .iter_mut()
171                .zip_eq_fast(r_select.select_items.iter_mut())
172                .enumerate()
173            {
174                let Ok(column_type) = align_types(vec![l, r].into_iter()) else {
175                    return Err(ErrorCode::InvalidInputSyntax(format!(
176                        "{} types {} and {} cannot be matched. Columns' name are `{}` and `{}`.",
177                        op,
178                        l_select.schema.fields[i].data_type,
179                        r_select.schema.fields[i].data_type,
180                        l_select.schema.fields[i].name,
181                        r_select.schema.fields[i].name,
182                    ))
183                    .into());
184                };
185                l_select.schema.fields[i].data_type = column_type.clone();
186                r_select.schema.fields[i].data_type = column_type;
187            }
188        }
189
190        Self::validate(left, right, op)
191    }
192
193    /// validate the schema, should be called after aligning.
194    pub(crate) fn validate(
195        left: &BoundSetExpr,
196        right: &BoundSetExpr,
197        op: SetOperator,
198    ) -> Result<()> {
199        for (a, b) in left
200            .schema()
201            .fields
202            .iter()
203            .zip_eq_fast(right.schema().fields.iter())
204        {
205            if a.data_type != b.data_type {
206                return Err(ErrorCode::InvalidInputSyntax(format!(
207                    "{} types {} and {} cannot be matched. Columns' name are {} and {}.",
208                    op,
209                    a.data_type.prost_type_name().as_str_name(),
210                    b.data_type.prost_type_name().as_str_name(),
211                    a.name,
212                    b.name,
213                ))
214                .into());
215            }
216        }
217        Ok(())
218    }
219
220    /// Check the corresponding specification of the set operation.
221    /// Returns the corresponding column index of the left and right side.
222    fn corresponding(
223        &self,
224        left: &BoundSetExpr,
225        right: &BoundSetExpr,
226        corresponding: Corresponding,
227        op: &SetOperator,
228    ) -> Result<(ColIndexMapping, ColIndexMapping)> {
229        let check_duplicate_name = |set_expr: &BoundSetExpr| {
230            let mut name2idx = HashMap::new();
231            for (idx, field) in set_expr.schema().fields.iter().enumerate() {
232                if name2idx.insert(field.name.clone(), idx).is_some() {
233                    return Err(ErrorCode::InvalidInputSyntax(format!(
234                        "Duplicated column name `{}` in a column list of the query in a {} operation. Column list of the query: ({}).",
235                        field.name,
236                        op,
237                        set_expr.schema().formatted_col_names(),
238                    )));
239                }
240            }
241            Ok(name2idx)
242        };
243
244        // Within the columns of both side, the same <column name> shall not
245        // be specified more than once.
246        let name2idx_l = check_duplicate_name(left)?;
247        let name2idx_r = check_duplicate_name(right)?;
248
249        let mut corresponding_col_idx_l = vec![];
250        let mut corresponding_col_idx_r = vec![];
251
252        if let Some(column_list) = corresponding.column_list() {
253            // The select list of the corresponding set operation should be in the order of <corresponding column list>
254            for column in column_list {
255                let col_name = column.real_value();
256                if let Some(idx_l) = name2idx_l.get(&col_name)
257                    && let Some(idx_r) = name2idx_l.get(&col_name)
258                {
259                    corresponding_col_idx_l.push(*idx_l);
260                    corresponding_col_idx_r.push(*idx_r);
261                } else {
262                    return Err(ErrorCode::InvalidInputSyntax(format!(
263                        "Column name `{}` in CORRESPONDING BY is not found in a side of the {} operation. \
264                        It shall be included in both sides.",
265                        col_name,
266                        op,
267                    )).into());
268                }
269            }
270        } else {
271            // The select list of the corresponding set operation should be
272            // in the order that appears in the <column name>s of the left side.
273            for field in &left.schema().fields {
274                let col_name = &field.name;
275                if col_name != UNNAMED_COLUMN
276                    && let Some(idx_l) = name2idx_l.get(col_name)
277                    && let Some(idx_r) = name2idx_r.get(col_name)
278                {
279                    corresponding_col_idx_l.push(*idx_l);
280                    corresponding_col_idx_r.push(*idx_r);
281                }
282            }
283
284            if corresponding_col_idx_l.is_empty() {
285                return Err(ErrorCode::InvalidInputSyntax(
286                    format!(
287                        "When CORRESPONDING is specified, at least one column of the left side \
288                        shall have a column name that is the column name of some column of the right side in a {} operation. \
289                        Left side query column list: ({}). \
290                        Right side query column list: ({}).",
291                        op,
292                        left.schema().formatted_col_names(),
293                        right.schema().formatted_col_names(),
294                    )
295                )
296                .into());
297            }
298        }
299
300        let corresponding_mapping_l =
301            ColIndexMapping::with_remaining_columns(&corresponding_col_idx_l, left.schema().len());
302        let corresponding_mapping_r =
303            ColIndexMapping::with_remaining_columns(&corresponding_col_idx_r, right.schema().len());
304
305        Ok((corresponding_mapping_l, corresponding_mapping_r))
306    }
307
308    pub(super) fn bind_set_expr(&mut self, set_expr: SetExpr) -> Result<BoundSetExpr> {
309        match set_expr {
310            SetExpr::Select(s) => Ok(BoundSetExpr::Select(Box::new(self.bind_select(*s)?))),
311            SetExpr::Values(v) => Ok(BoundSetExpr::Values(Box::new(self.bind_values(v, None)?))),
312            SetExpr::Query(q) => Ok(BoundSetExpr::Query(Box::new(self.bind_query(*q)?))),
313            SetExpr::SetOperation {
314                op,
315                all,
316                corresponding,
317                left,
318                right,
319            } => {
320                match op.clone() {
321                    SetOperator::Union | SetOperator::Intersect | SetOperator::Except => {
322                        let mut left = self.bind_set_expr(*left)?;
323                        // Reset context for right side, but keep `cte_to_relation`.
324                        let new_context = std::mem::take(&mut self.context);
325                        self.context
326                            .cte_to_relation
327                            .clone_from(&new_context.cte_to_relation);
328                        self.context.disable_security_invoker =
329                            new_context.disable_security_invoker;
330                        let mut right = self.bind_set_expr(*right)?;
331
332                        let corresponding_col_indices = if corresponding.is_corresponding() {
333                            Some(Self::corresponding(
334                                self,
335                                &left,
336                                &right,
337                                corresponding,
338                                &op,
339                            )?)
340                            // TODO: Align schema
341                        } else {
342                            Self::align_schema(&mut left, &mut right, op.clone())?;
343                            None
344                        };
345
346                        if all {
347                            match op {
348                                SetOperator::Union => {}
349                                SetOperator::Intersect | SetOperator::Except => {
350                                    bail_not_implemented!("{} all", op);
351                                }
352                            }
353                        }
354
355                        // Reset context for the set operation.
356                        // Consider this case:
357                        // select a from t2 union all select b from t2 order by a+1; should throw an
358                        // error.
359                        self.context = BindContext::default();
360                        self.context.cte_to_relation = new_context.cte_to_relation;
361                        Ok(BoundSetExpr::SetOperation {
362                            op: op.into(),
363                            all,
364                            corresponding_col_indices,
365                            left: Box::new(left),
366                            right: Box::new(right),
367                        })
368                    }
369                }
370            }
371        }
372    }
373}