risingwave_frontend/binder/
bind_context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use either::Either;
21use parse_display::Display;
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::types::DataType;
24use risingwave_sqlparser::ast::TableAlias;
25
26use crate::binder::Relation;
27use crate::error::{ErrorCode, Result};
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use super::BoundSetExpr;
32use super::statement::RewriteExprsRecursive;
33use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
34
35#[derive(Debug, Clone)]
36pub struct ColumnBinding {
37    pub table_name: String,
38    pub index: usize,
39    pub is_hidden: bool,
40    pub field: Field,
41}
42
43impl ColumnBinding {
44    pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
45        ColumnBinding {
46            table_name,
47            index,
48            is_hidden,
49            field,
50        }
51    }
52}
53
54#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
55#[display(style = "TITLE CASE")]
56pub enum Clause {
57    Where,
58    Values,
59    GroupBy,
60    JoinOn,
61    Having,
62    Filter,
63    From,
64    GeneratedColumn,
65    Insert,
66}
67
68/// A `BindContext` that is only visible if the `LATERAL` keyword
69/// is provided.
70pub struct LateralBindContext {
71    pub is_visible: bool,
72    pub context: BindContext,
73}
74
75/// For recursive CTE, we may need to store it in `cte_to_relation` first,
76/// and then bind it *step by step*.
77///
78/// note: the below sql example is to illustrate when we get the
79/// corresponding binding state when handling a recursive CTE like this.
80///
81/// ```sql
82/// WITH RECURSIVE t(n) AS (
83/// # -------------^ => Init
84///     VALUES (1)
85///   UNION ALL
86///     SELECT n + 1 FROM t WHERE n < 100
87/// # --------------------^ => BaseResolved (after binding the base term, this relation will be bound to `Relation::BackCteRef`)
88/// )
89/// SELECT sum(n) FROM t;
90/// # -----------------^ => Bound (we know exactly what the entire `RecursiveUnion` looks like, and this relation will be bound to `Relation::Share`)
91/// ```
92#[derive(Default, Debug, Clone)]
93pub enum BindingCteState {
94    /// We know nothing about the CTE before resolving the body.
95    #[default]
96    Init,
97    /// We know the schema form after the base term resolved.
98    BaseResolved {
99        base: BoundSetExpr,
100    },
101    /// We get the whole bound result of the (recursive) CTE.
102    Bound {
103        query: Either<BoundQuery, RecursiveUnion>,
104    },
105
106    ChangeLog {
107        table: Relation,
108    },
109}
110
111/// the entire `RecursiveUnion` represents a *bound* recursive cte.
112/// reference: <https://github.com/risingwavelabs/risingwave/pull/15522/files#r1524367781>
113#[derive(Debug, Clone)]
114pub struct RecursiveUnion {
115    /// currently this *must* be true,
116    /// otherwise binding will fail.
117    #[allow(dead_code)]
118    pub all: bool,
119    /// lhs part of the `UNION ALL` operator
120    pub base: Box<BoundSetExpr>,
121    /// rhs part of the `UNION ALL` operator
122    pub recursive: Box<BoundSetExpr>,
123    /// the aligned schema for this union
124    /// will be the *same* schema as recursive's
125    /// this is just for a better readability
126    pub schema: Schema,
127}
128
129impl RewriteExprsRecursive for RecursiveUnion {
130    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
131        // rewrite `base` and `recursive` separately
132        self.base.rewrite_exprs_recursive(rewriter);
133        self.recursive.rewrite_exprs_recursive(rewriter);
134    }
135}
136
137#[derive(Clone, Debug)]
138pub struct BindingCte {
139    pub share_id: ShareId,
140    pub state: BindingCteState,
141    pub alias: TableAlias,
142}
143
144#[derive(Default, Debug, Clone)]
145pub struct BindContext {
146    // Columns of all tables.
147    pub columns: Vec<ColumnBinding>,
148    // Mapping column name to indices in `columns`.
149    pub indices_of: HashMap<String, Vec<usize>>,
150    // Mapping table name to [begin, end) of its columns.
151    pub range_of: HashMap<String, (usize, usize)>,
152    // `clause` identifies in what clause we are binding.
153    pub clause: Option<Clause>,
154    // The `BindContext`'s data on its column groups
155    pub column_group_context: ColumnGroupContext,
156    /// Map the cte's name to its binding state.
157    /// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
158    pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
159    /// Current lambda functions's arguments
160    pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
161    /// Whether the security invoker is set, currently only used for views.
162    pub disable_security_invoker: bool,
163}
164
165/// Holds the context for the `BindContext`'s `ColumnGroup`s.
166#[derive(Default, Debug, Clone)]
167pub struct ColumnGroupContext {
168    // Maps naturally-joined/USING columns to their column group id
169    pub mapping: HashMap<usize, u32>,
170    // Maps column group ids to their column group data
171    // We use a BTreeMap to ensure that iteration over the groups is ordered.
172    pub groups: BTreeMap<u32, ColumnGroup>,
173
174    next_group_id: u32,
175}
176
177/// When binding a natural join or a join with USING, a `ColumnGroup` contains the columns with the
178/// same name.
179#[derive(Default, Debug, Clone)]
180pub struct ColumnGroup {
181    /// Indices of the columns in the group
182    pub indices: HashSet<usize>,
183    /// A non-nullable column is never NULL.
184    /// If `None`, ambiguous references to the column name will be resolved to a `COALESCE(col1,
185    /// col2, ..., coln)` over each column in the group
186    pub non_nullable_column: Option<usize>,
187
188    pub column_name: Option<String>,
189}
190
191impl BindContext {
192    pub fn get_column_binding_index(
193        &self,
194        table_name: &Option<String>,
195        column_name: &String,
196    ) -> LiteResult<usize> {
197        match &self.get_column_binding_indices(table_name, column_name)?[..] {
198            [] => unreachable!(),
199            [idx] => Ok(*idx),
200            _ => Err(ErrorCode::InternalError(format!(
201                "Ambiguous column name: {}",
202                column_name
203            ))),
204        }
205    }
206
207    /// If return Vec has len > 1, it means we have an unqualified reference to a column which has
208    /// been naturally joined upon, wherein none of the columns are min-nullable. This will be
209    /// handled in downstream as a `COALESCE` expression
210    pub fn get_column_binding_indices(
211        &self,
212        table_name: &Option<String>,
213        column_name: &String,
214    ) -> LiteResult<Vec<usize>> {
215        match table_name {
216            Some(table_name) => {
217                if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
218                    let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
219                        format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
220                    self.get_indices_with_group_id(group_id, column_name)
221                } else {
222                    Ok(vec![
223                        self.get_index_with_table_name(column_name, table_name)?,
224                    ])
225                }
226            }
227            None => self.get_unqualified_indices(column_name),
228        }
229    }
230
231    fn get_indices_with_group_id(
232        &self,
233        group_id: u32,
234        column_name: &String,
235    ) -> LiteResult<Vec<usize>> {
236        let group = self.column_group_context.groups.get(&group_id).unwrap();
237        if let Some(name) = &group.column_name {
238            debug_assert_eq!(name, column_name);
239        }
240        if let Some(non_nullable) = &group.non_nullable_column {
241            Ok(vec![*non_nullable])
242        } else {
243            // These will be converted to a `COALESCE(col1, col2, ..., coln)`
244            let mut indices: Vec<_> = group.indices.iter().copied().collect();
245            indices.sort(); // ensure a deterministic result
246            Ok(indices)
247        }
248    }
249
250    pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
251        let columns = self
252            .indices_of
253            .get(column_name)
254            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
255        if columns.len() > 1 {
256            // If there is some group containing the columns and the ambiguous columns are all in
257            // the group
258            if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
259                let group = self.column_group_context.groups.get(group_id).unwrap();
260                if columns.iter().all(|idx| group.indices.contains(idx)) {
261                    if let Some(non_nullable) = &group.non_nullable_column {
262                        return Ok(vec![*non_nullable]);
263                    } else {
264                        // These will be converted to a `COALESCE(col1, col2, ..., coln)`
265                        return Ok(columns.to_vec());
266                    }
267                }
268            }
269            Err(ErrorCode::InternalError(format!(
270                "Ambiguous column name: {}",
271                column_name
272            )))
273        } else {
274            Ok(columns.to_vec())
275        }
276    }
277
278    /// Identifies two columns as being in the same group. Additionally, possibly provides one of
279    /// the columns as being `non_nullable`
280    pub fn add_natural_columns(
281        &mut self,
282        left: usize,
283        right: usize,
284        non_nullable_column: Option<usize>,
285    ) {
286        match (
287            self.column_group_context.mapping.get(&left).copied(),
288            self.column_group_context.mapping.get(&right).copied(),
289        ) {
290            (None, None) => {
291                let group_id = self.column_group_context.next_group_id;
292                self.column_group_context.next_group_id += 1;
293
294                let group = ColumnGroup {
295                    indices: HashSet::from([left, right]),
296                    non_nullable_column,
297                    column_name: Some(self.columns[left].field.name.clone()),
298                };
299                self.column_group_context.groups.insert(group_id, group);
300                self.column_group_context.mapping.insert(left, group_id);
301                self.column_group_context.mapping.insert(right, group_id);
302            }
303            (Some(group_id), None) => {
304                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
305                group.indices.insert(right);
306                if group.non_nullable_column.is_none() {
307                    group.non_nullable_column = non_nullable_column;
308                }
309                self.column_group_context.mapping.insert(right, group_id);
310            }
311            (None, Some(group_id)) => {
312                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
313                group.indices.insert(left);
314                if group.non_nullable_column.is_none() {
315                    group.non_nullable_column = non_nullable_column;
316                }
317                self.column_group_context.mapping.insert(left, group_id);
318            }
319            (Some(l_group_id), Some(r_group_id)) => {
320                if r_group_id == l_group_id {
321                    return;
322                }
323
324                let r_group = self
325                    .column_group_context
326                    .groups
327                    .remove(&r_group_id)
328                    .unwrap();
329                let l_group = self
330                    .column_group_context
331                    .groups
332                    .get_mut(&l_group_id)
333                    .unwrap();
334
335                for idx in &r_group.indices {
336                    *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
337                    l_group.indices.insert(*idx);
338                }
339                if l_group.non_nullable_column.is_none() {
340                    l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
341                        non_nullable_column
342                    } else {
343                        r_group.non_nullable_column
344                    };
345                }
346            }
347        }
348    }
349
350    fn get_index_with_table_name(
351        &self,
352        column_name: &String,
353        table_name: &String,
354    ) -> LiteResult<usize> {
355        let column_indexes = self
356            .indices_of
357            .get(column_name)
358            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
359        match column_indexes
360            .iter()
361            .find(|column_index| self.columns[**column_index].table_name == *table_name)
362        {
363            Some(column_index) => Ok(*column_index),
364            None => Err(ErrorCode::ItemNotFound(format!(
365                "missing FROM-clause entry for table \"{}\"",
366                table_name
367            ))),
368        }
369    }
370
371    /// Merges two `BindContext`s which are adjacent. For instance, the `BindContext` of two
372    /// adjacent cross-joined tables.
373    pub fn merge_context(&mut self, other: Self) -> Result<()> {
374        let begin = self.columns.len();
375        self.columns.extend(other.columns.into_iter().map(|mut c| {
376            c.index += begin;
377            c
378        }));
379        for (k, v) in other.indices_of {
380            let entry = self.indices_of.entry(k).or_default();
381            entry.extend(v.into_iter().map(|x| x + begin));
382        }
383        for (k, (x, y)) in other.range_of {
384            match self.range_of.entry(k) {
385                Entry::Occupied(e) => {
386                    return Err(ErrorCode::InternalError(format!(
387                        "Duplicated table name while merging adjacent contexts: {}",
388                        e.key()
389                    ))
390                    .into());
391                }
392                Entry::Vacant(entry) => {
393                    entry.insert((begin + x, begin + y));
394                }
395            }
396        }
397        // To merge the column_group_contexts, we just need to offset RHS
398        // with the next_group_id of LHS.
399        let ColumnGroupContext {
400            mapping,
401            groups,
402            next_group_id,
403        } = other.column_group_context;
404
405        let offset = self.column_group_context.next_group_id;
406        for (idx, group_id) in mapping {
407            self.column_group_context
408                .mapping
409                .insert(begin + idx, offset + group_id);
410        }
411        for (group_id, mut group) in groups {
412            group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
413            if let Some(col) = &mut group.non_nullable_column {
414                *col += begin;
415            }
416            self.column_group_context
417                .groups
418                .insert(offset + group_id, group);
419        }
420        self.column_group_context.next_group_id += next_group_id;
421
422        // we assume that the clause is contained in the outer-level context
423        Ok(())
424    }
425}
426
427impl BindContext {
428    pub fn new() -> Self {
429        Self::default()
430    }
431}