risingwave_frontend/binder/
set_expr.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::collections::HashMap;
17
18use risingwave_common::bail_not_implemented;
19use risingwave_common::catalog::Schema;
20use risingwave_common::util::column_index_mapping::ColIndexMapping;
21use risingwave_common::util::iter_util::ZipEqFast;
22use risingwave_sqlparser::ast::{Corresponding, SetExpr, SetOperator};
23
24use super::UNNAMED_COLUMN;
25use super::statement::RewriteExprsRecursive;
26use crate::binder::{BindContext, Binder, BoundQuery, BoundSelect, BoundValues};
27use crate::error::{ErrorCode, Result};
28use crate::expr::{CorrelatedId, Depth, align_types};
29
30/// Part of a validated query, without order or limit clause. It may be composed of smaller
31/// `BoundSetExpr`(s) via set operators (e.g., union).
32#[derive(Debug, Clone)]
33pub enum BoundSetExpr {
34    Select(Box<BoundSelect>),
35    Query(Box<BoundQuery>),
36    Values(Box<BoundValues>),
37    /// UNION/EXCEPT/INTERSECT of two queries
38    SetOperation {
39        op: BoundSetOperation,
40        all: bool,
41        // Corresponding columns of the left and right side.
42        corresponding_col_indices: Option<(ColIndexMapping, ColIndexMapping)>,
43        left: Box<BoundSetExpr>,
44        right: Box<BoundSetExpr>,
45    },
46}
47
48impl RewriteExprsRecursive for BoundSetExpr {
49    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
50        match self {
51            BoundSetExpr::Select(inner) => inner.rewrite_exprs_recursive(rewriter),
52            BoundSetExpr::Query(inner) => inner.rewrite_exprs_recursive(rewriter),
53            BoundSetExpr::Values(inner) => inner.rewrite_exprs_recursive(rewriter),
54            BoundSetExpr::SetOperation { left, right, .. } => {
55                left.rewrite_exprs_recursive(rewriter);
56                right.rewrite_exprs_recursive(rewriter);
57            }
58        }
59    }
60}
61
62#[derive(Debug, Clone)]
63pub enum BoundSetOperation {
64    Union,
65    Except,
66    Intersect,
67}
68
69impl From<SetOperator> for BoundSetOperation {
70    fn from(value: SetOperator) -> Self {
71        match value {
72            SetOperator::Union => BoundSetOperation::Union,
73            SetOperator::Intersect => BoundSetOperation::Intersect,
74            SetOperator::Except => BoundSetOperation::Except,
75        }
76    }
77}
78
79impl BoundSetExpr {
80    /// The schema returned by this [`BoundSetExpr`].
81    pub fn schema(&self) -> Cow<'_, Schema> {
82        match self {
83            BoundSetExpr::Select(s) => Cow::Borrowed(s.schema()),
84            BoundSetExpr::Values(v) => Cow::Borrowed(v.schema()),
85            BoundSetExpr::Query(q) => q.schema(),
86            BoundSetExpr::SetOperation {
87                left,
88                corresponding_col_indices,
89                ..
90            } => {
91                if let Some((mapping_l, _)) = corresponding_col_indices {
92                    let mut schema = vec![None; mapping_l.target_size()];
93                    for (src, tar) in mapping_l.mapping_pairs() {
94                        assert_eq!(schema[tar], None);
95                        schema[tar] = Some(left.schema().fields[src].clone());
96                    }
97                    Cow::Owned(Schema::new(
98                        schema.into_iter().map(|x| x.unwrap()).collect(),
99                    ))
100                } else {
101                    left.schema()
102                }
103            }
104        }
105    }
106
107    pub fn is_correlated_by_depth(&self, depth: Depth) -> bool {
108        match self {
109            BoundSetExpr::Select(s) => s.is_correlated_by_depth(depth),
110            BoundSetExpr::Values(v) => v.is_correlated_by_depth(depth),
111            BoundSetExpr::Query(q) => q.is_correlated_by_depth(depth),
112            BoundSetExpr::SetOperation { left, right, .. } => {
113                left.is_correlated_by_depth(depth) || right.is_correlated_by_depth(depth)
114            }
115        }
116    }
117
118    pub fn is_correlated_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
119        match self {
120            BoundSetExpr::Select(s) => s.is_correlated_by_correlated_id(correlated_id),
121            BoundSetExpr::Values(v) => v.is_correlated_by_correlated_id(correlated_id),
122            BoundSetExpr::Query(q) => q.is_correlated_by_correlated_id(correlated_id),
123            BoundSetExpr::SetOperation { left, right, .. } => {
124                left.is_correlated_by_correlated_id(correlated_id)
125                    || right.is_correlated_by_correlated_id(correlated_id)
126            }
127        }
128    }
129
130    pub fn collect_correlated_indices_by_depth_and_assign_id(
131        &mut self,
132        depth: Depth,
133        correlated_id: CorrelatedId,
134    ) -> Vec<usize> {
135        match self {
136            BoundSetExpr::Select(s) => {
137                s.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
138            }
139            BoundSetExpr::Values(v) => {
140                v.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
141            }
142            BoundSetExpr::Query(q) => {
143                q.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
144            }
145            BoundSetExpr::SetOperation { left, right, .. } => {
146                let mut correlated_indices = vec![];
147                correlated_indices.extend(
148                    left.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
149                );
150                correlated_indices.extend(
151                    right.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
152                );
153                correlated_indices
154            }
155        }
156    }
157}
158
159impl Binder {
160    /// note: `align_schema` only works when the `left` and `right`
161    /// are both select expression(s).
162    pub(crate) fn align_schema(
163        mut left: &mut BoundSetExpr,
164        mut right: &mut BoundSetExpr,
165        op: SetOperator,
166    ) -> Result<()> {
167        if left.schema().fields.len() != right.schema().fields.len() {
168            return Err(ErrorCode::InvalidInputSyntax(format!(
169                "each {} query must have the same number of columns",
170                op
171            ))
172            .into());
173        }
174
175        // handle type alignment for select union select
176        // e.g., select 1 UNION ALL select NULL
177        if let (BoundSetExpr::Select(l_select), BoundSetExpr::Select(r_select)) =
178            (&mut left, &mut right)
179        {
180            for (i, (l, r)) in l_select
181                .select_items
182                .iter_mut()
183                .zip_eq_fast(r_select.select_items.iter_mut())
184                .enumerate()
185            {
186                let Ok(column_type) = align_types(vec![l, r].into_iter()) else {
187                    return Err(ErrorCode::InvalidInputSyntax(format!(
188                        "{} types {} and {} cannot be matched. Columns' name are `{}` and `{}`.",
189                        op,
190                        l_select.schema.fields[i].data_type,
191                        r_select.schema.fields[i].data_type,
192                        l_select.schema.fields[i].name,
193                        r_select.schema.fields[i].name,
194                    ))
195                    .into());
196                };
197                l_select.schema.fields[i].data_type = column_type.clone();
198                r_select.schema.fields[i].data_type = column_type;
199            }
200        }
201
202        Self::validate(left, right, op)
203    }
204
205    /// validate the schema, should be called after aligning.
206    pub(crate) fn validate(
207        left: &BoundSetExpr,
208        right: &BoundSetExpr,
209        op: SetOperator,
210    ) -> Result<()> {
211        for (a, b) in left
212            .schema()
213            .fields
214            .iter()
215            .zip_eq_fast(right.schema().fields.iter())
216        {
217            if a.data_type != b.data_type {
218                return Err(ErrorCode::InvalidInputSyntax(format!(
219                    "{} types {} and {} cannot be matched. Columns' name are {} and {}.",
220                    op,
221                    a.data_type.prost_type_name().as_str_name(),
222                    b.data_type.prost_type_name().as_str_name(),
223                    a.name,
224                    b.name,
225                ))
226                .into());
227            }
228        }
229        Ok(())
230    }
231
232    /// Check the corresponding specification of the set operation.
233    /// Returns the corresponding column index of the left and right side.
234    fn corresponding(
235        &self,
236        left: &BoundSetExpr,
237        right: &BoundSetExpr,
238        corresponding: Corresponding,
239        op: &SetOperator,
240    ) -> Result<(ColIndexMapping, ColIndexMapping)> {
241        let check_duplicate_name = |set_expr: &BoundSetExpr| {
242            let mut name2idx = HashMap::new();
243            for (idx, field) in set_expr.schema().fields.iter().enumerate() {
244                if name2idx.insert(field.name.clone(), idx).is_some() {
245                    return Err(ErrorCode::InvalidInputSyntax(format!(
246                        "Duplicated column name `{}` in a column list of the query in a {} operation. Column list of the query: ({}).",
247                        field.name,
248                        op,
249                        set_expr.schema().formatted_col_names(),
250                    )));
251                }
252            }
253            Ok(name2idx)
254        };
255
256        // Within the columns of both side, the same <column name> shall not
257        // be specified more than once.
258        let name2idx_l = check_duplicate_name(left)?;
259        let name2idx_r = check_duplicate_name(right)?;
260
261        let mut corresponding_col_idx_l = vec![];
262        let mut corresponding_col_idx_r = vec![];
263
264        if let Some(column_list) = corresponding.column_list() {
265            // The select list of the corresponding set operation should be in the order of <corresponding column list>
266            for column in column_list {
267                let col_name = column.real_value();
268                if let Some(idx_l) = name2idx_l.get(&col_name)
269                    && let Some(idx_r) = name2idx_l.get(&col_name)
270                {
271                    corresponding_col_idx_l.push(*idx_l);
272                    corresponding_col_idx_r.push(*idx_r);
273                } else {
274                    return Err(ErrorCode::InvalidInputSyntax(format!(
275                        "Column name `{}` in CORRESPONDING BY is not found in a side of the {} operation. \
276                        It shall be included in both sides.",
277                        col_name,
278                        op,
279                    )).into());
280                }
281            }
282        } else {
283            // The select list of the corresponding set operation should be
284            // in the order that appears in the <column name>s of the left side.
285            for field in &left.schema().fields {
286                let col_name = &field.name;
287                if col_name != UNNAMED_COLUMN
288                    && let Some(idx_l) = name2idx_l.get(col_name)
289                    && let Some(idx_r) = name2idx_r.get(col_name)
290                {
291                    corresponding_col_idx_l.push(*idx_l);
292                    corresponding_col_idx_r.push(*idx_r);
293                }
294            }
295
296            if corresponding_col_idx_l.is_empty() {
297                return Err(ErrorCode::InvalidInputSyntax(
298                    format!(
299                        "When CORRESPONDING is specified, at least one column of the left side \
300                        shall have a column name that is the column name of some column of the right side in a {} operation. \
301                        Left side query column list: ({}). \
302                        Right side query column list: ({}).",
303                        op,
304                        left.schema().formatted_col_names(),
305                        right.schema().formatted_col_names(),
306                    )
307                )
308                .into());
309            }
310        }
311
312        let corresponding_mapping_l =
313            ColIndexMapping::with_remaining_columns(&corresponding_col_idx_l, left.schema().len());
314        let corresponding_mapping_r =
315            ColIndexMapping::with_remaining_columns(&corresponding_col_idx_r, right.schema().len());
316
317        Ok((corresponding_mapping_l, corresponding_mapping_r))
318    }
319
320    pub(super) fn bind_set_expr(&mut self, set_expr: SetExpr) -> Result<BoundSetExpr> {
321        match set_expr {
322            SetExpr::Select(s) => Ok(BoundSetExpr::Select(Box::new(self.bind_select(*s)?))),
323            SetExpr::Values(v) => Ok(BoundSetExpr::Values(Box::new(self.bind_values(v, None)?))),
324            SetExpr::Query(q) => Ok(BoundSetExpr::Query(Box::new(self.bind_query(*q)?))),
325            SetExpr::SetOperation {
326                op,
327                all,
328                corresponding,
329                left,
330                right,
331            } => {
332                match op.clone() {
333                    SetOperator::Union | SetOperator::Intersect | SetOperator::Except => {
334                        let mut left = self.bind_set_expr(*left)?;
335                        // Reset context for right side, but keep `cte_to_relation`.
336                        let new_context = std::mem::take(&mut self.context);
337                        self.context
338                            .cte_to_relation
339                            .clone_from(&new_context.cte_to_relation);
340                        self.context.disable_security_invoker =
341                            new_context.disable_security_invoker;
342                        let mut right = self.bind_set_expr(*right)?;
343
344                        let corresponding_col_indices = if corresponding.is_corresponding() {
345                            Some(Self::corresponding(
346                                self,
347                                &left,
348                                &right,
349                                corresponding,
350                                &op,
351                            )?)
352                            // TODO: Align schema
353                        } else {
354                            Self::align_schema(&mut left, &mut right, op.clone())?;
355                            None
356                        };
357
358                        if all {
359                            match op {
360                                SetOperator::Union => {}
361                                SetOperator::Intersect | SetOperator::Except => {
362                                    bail_not_implemented!("{} all", op);
363                                }
364                            }
365                        }
366
367                        // Reset context for the set operation.
368                        // Consider this case:
369                        // select a from t2 union all select b from t2 order by a+1; should throw an
370                        // error.
371                        self.context = BindContext::default();
372                        self.context.cte_to_relation = new_context.cte_to_relation;
373                        Ok(BoundSetExpr::SetOperation {
374                            op: op.into(),
375                            all,
376                            corresponding_col_indices,
377                            left: Box::new(left),
378                            right: Box::new(right),
379                        })
380                    }
381                }
382            }
383        }
384    }
385}