risingwave_frontend/binder/
bind_context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use either::Either;
21use parse_display::Display;
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::types::DataType;
24use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
25
26use crate::binder::Relation;
27use crate::error::{ErrorCode, Result};
28use crate::expr::ExprImpl;
29
30type LiteResult<T> = std::result::Result<T, ErrorCode>;
31
32use super::BoundSetExpr;
33use super::statement::RewriteExprsRecursive;
34use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
35
36#[derive(Debug, Clone)]
37pub struct ColumnBinding {
38    pub table_name: String,
39    pub index: usize,
40    pub is_hidden: bool,
41    pub field: Field,
42}
43
44impl ColumnBinding {
45    pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
46        ColumnBinding {
47            table_name,
48            index,
49            is_hidden,
50            field,
51        }
52    }
53}
54
55#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
56#[display(style = "TITLE CASE")]
57pub enum Clause {
58    Where,
59    Values,
60    GroupBy,
61    JoinOn,
62    Having,
63    Filter,
64    From,
65    GeneratedColumn,
66    Insert,
67}
68
69/// A `BindContext` that is only visible if the `LATERAL` keyword
70/// is provided.
71pub struct LateralBindContext {
72    pub is_visible: bool,
73    pub context: BindContext,
74}
75
76/// For recursive CTE, we may need to store it in `cte_to_relation` first,
77/// and then bind it *step by step*.
78///
79/// note: the below sql example is to illustrate when we get the
80/// corresponding binding state when handling a recursive CTE like this.
81///
82/// ```sql
83/// WITH RECURSIVE t(n) AS (
84/// # -------------^ => Init
85///     VALUES (1)
86///   UNION ALL
87///     SELECT n + 1 FROM t WHERE n < 100
88/// # --------------------^ => BaseResolved (after binding the base term, this relation will be bound to `Relation::BackCteRef`)
89/// )
90/// SELECT sum(n) FROM t;
91/// # -----------------^ => Bound (we know exactly what the entire `RecursiveUnion` looks like, and this relation will be bound to `Relation::Share`)
92/// ```
93#[derive(Default, Debug, Clone)]
94pub enum BindingCteState {
95    /// We know nothing about the CTE before resolving the body.
96    #[default]
97    Init,
98    /// We know the schema form after the base term resolved.
99    BaseResolved {
100        base: BoundSetExpr,
101    },
102    /// We get the whole bound result of the (recursive) CTE.
103    Bound {
104        query: Either<BoundQuery, RecursiveUnion>,
105    },
106
107    ChangeLog {
108        table: Relation,
109    },
110}
111
112/// the entire `RecursiveUnion` represents a *bound* recursive cte.
113/// reference: <https://github.com/risingwavelabs/risingwave/pull/15522/files#r1524367781>
114#[derive(Debug, Clone)]
115pub struct RecursiveUnion {
116    /// currently this *must* be true,
117    /// otherwise binding will fail.
118    #[allow(dead_code)]
119    pub all: bool,
120    /// lhs part of the `UNION ALL` operator
121    pub base: Box<BoundSetExpr>,
122    /// rhs part of the `UNION ALL` operator
123    pub recursive: Box<BoundSetExpr>,
124    /// the aligned schema for this union
125    /// will be the *same* schema as recursive's
126    /// this is just for a better readability
127    pub schema: Schema,
128}
129
130impl RewriteExprsRecursive for RecursiveUnion {
131    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
132        // rewrite `base` and `recursive` separately
133        self.base.rewrite_exprs_recursive(rewriter);
134        self.recursive.rewrite_exprs_recursive(rewriter);
135    }
136}
137
138#[derive(Clone, Debug)]
139pub struct BindingCte {
140    pub share_id: ShareId,
141    pub state: BindingCteState,
142    pub alias: TableAlias,
143}
144
145#[derive(Default, Debug, Clone)]
146pub struct BindContext {
147    // Columns of all tables.
148    pub columns: Vec<ColumnBinding>,
149    // Mapping column name to indices in `columns`.
150    pub indices_of: HashMap<String, Vec<usize>>,
151    // Mapping table name to [begin, end) of its columns.
152    pub range_of: HashMap<String, (usize, usize)>,
153    // `clause` identifies in what clause we are binding.
154    pub clause: Option<Clause>,
155    // The `BindContext`'s data on its column groups
156    pub column_group_context: ColumnGroupContext,
157    /// Map the cte's name to its binding state.
158    /// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
159    pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
160    /// Current lambda functions's arguments
161    pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
162    /// Whether the security invoker is set, currently only used for views.
163    pub disable_security_invoker: bool,
164    /// Named window definitions from the `WINDOW` clause
165    pub named_windows: HashMap<String, WindowSpec>,
166    /// Bound arguments for the current SQL UDF, if any.
167    // TODO: use enum for named or positional arguments
168    pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
169}
170
171/// Holds the context for the `BindContext`'s `ColumnGroup`s.
172#[derive(Default, Debug, Clone)]
173pub struct ColumnGroupContext {
174    // Maps naturally-joined/USING columns to their column group id
175    pub mapping: HashMap<usize, u32>,
176    // Maps column group ids to their column group data
177    // We use a BTreeMap to ensure that iteration over the groups is ordered.
178    pub groups: BTreeMap<u32, ColumnGroup>,
179
180    next_group_id: u32,
181}
182
183/// When binding a natural join or a join with USING, a `ColumnGroup` contains the columns with the
184/// same name.
185#[derive(Default, Debug, Clone)]
186pub struct ColumnGroup {
187    /// Indices of the columns in the group
188    pub indices: HashSet<usize>,
189    /// A non-nullable column is never NULL.
190    /// If `None`, ambiguous references to the column name will be resolved to a `COALESCE(col1,
191    /// col2, ..., coln)` over each column in the group
192    pub non_nullable_column: Option<usize>,
193
194    pub column_name: Option<String>,
195}
196
197impl BindContext {
198    pub fn get_column_binding_index(
199        &self,
200        table_name: &Option<String>,
201        column_name: &String,
202    ) -> LiteResult<usize> {
203        match &self.get_column_binding_indices(table_name, column_name)?[..] {
204            [] => unreachable!(),
205            [idx] => Ok(*idx),
206            _ => Err(ErrorCode::InternalError(format!(
207                "Ambiguous column name: {}",
208                column_name
209            ))),
210        }
211    }
212
213    /// If return Vec has len > 1, it means we have an unqualified reference to a column which has
214    /// been naturally joined upon, wherein none of the columns are min-nullable. This will be
215    /// handled in downstream as a `COALESCE` expression
216    pub fn get_column_binding_indices(
217        &self,
218        table_name: &Option<String>,
219        column_name: &String,
220    ) -> LiteResult<Vec<usize>> {
221        match table_name {
222            Some(table_name) => {
223                if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
224                    let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
225                        format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
226                    self.get_indices_with_group_id(group_id, column_name)
227                } else {
228                    Ok(vec![
229                        self.get_index_with_table_name(column_name, table_name)?,
230                    ])
231                }
232            }
233            None => self.get_unqualified_indices(column_name),
234        }
235    }
236
237    fn get_indices_with_group_id(
238        &self,
239        group_id: u32,
240        column_name: &String,
241    ) -> LiteResult<Vec<usize>> {
242        let group = self.column_group_context.groups.get(&group_id).unwrap();
243        if let Some(name) = &group.column_name {
244            debug_assert_eq!(name, column_name);
245        }
246        if let Some(non_nullable) = &group.non_nullable_column {
247            Ok(vec![*non_nullable])
248        } else {
249            // These will be converted to a `COALESCE(col1, col2, ..., coln)`
250            let mut indices: Vec<_> = group.indices.iter().copied().collect();
251            indices.sort(); // ensure a deterministic result
252            Ok(indices)
253        }
254    }
255
256    pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
257        let columns = self
258            .indices_of
259            .get(column_name)
260            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
261        if columns.len() > 1 {
262            // If there is some group containing the columns and the ambiguous columns are all in
263            // the group
264            if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
265                let group = self.column_group_context.groups.get(group_id).unwrap();
266                if columns.iter().all(|idx| group.indices.contains(idx)) {
267                    if let Some(non_nullable) = &group.non_nullable_column {
268                        return Ok(vec![*non_nullable]);
269                    } else {
270                        // These will be converted to a `COALESCE(col1, col2, ..., coln)`
271                        return Ok(columns.to_vec());
272                    }
273                }
274            }
275            Err(ErrorCode::InternalError(format!(
276                "Ambiguous column name: {}",
277                column_name
278            )))
279        } else {
280            Ok(columns.to_vec())
281        }
282    }
283
284    /// Identifies two columns as being in the same group. Additionally, possibly provides one of
285    /// the columns as being `non_nullable`
286    pub fn add_natural_columns(
287        &mut self,
288        left: usize,
289        right: usize,
290        non_nullable_column: Option<usize>,
291    ) {
292        match (
293            self.column_group_context.mapping.get(&left).copied(),
294            self.column_group_context.mapping.get(&right).copied(),
295        ) {
296            (None, None) => {
297                let group_id = self.column_group_context.next_group_id;
298                self.column_group_context.next_group_id += 1;
299
300                let group = ColumnGroup {
301                    indices: HashSet::from([left, right]),
302                    non_nullable_column,
303                    column_name: Some(self.columns[left].field.name.clone()),
304                };
305                self.column_group_context.groups.insert(group_id, group);
306                self.column_group_context.mapping.insert(left, group_id);
307                self.column_group_context.mapping.insert(right, group_id);
308            }
309            (Some(group_id), None) => {
310                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
311                group.indices.insert(right);
312                if group.non_nullable_column.is_none() {
313                    group.non_nullable_column = non_nullable_column;
314                }
315                self.column_group_context.mapping.insert(right, group_id);
316            }
317            (None, Some(group_id)) => {
318                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
319                group.indices.insert(left);
320                if group.non_nullable_column.is_none() {
321                    group.non_nullable_column = non_nullable_column;
322                }
323                self.column_group_context.mapping.insert(left, group_id);
324            }
325            (Some(l_group_id), Some(r_group_id)) => {
326                if r_group_id == l_group_id {
327                    return;
328                }
329
330                let r_group = self
331                    .column_group_context
332                    .groups
333                    .remove(&r_group_id)
334                    .unwrap();
335                let l_group = self
336                    .column_group_context
337                    .groups
338                    .get_mut(&l_group_id)
339                    .unwrap();
340
341                for idx in &r_group.indices {
342                    *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
343                    l_group.indices.insert(*idx);
344                }
345                if l_group.non_nullable_column.is_none() {
346                    l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
347                        non_nullable_column
348                    } else {
349                        r_group.non_nullable_column
350                    };
351                }
352            }
353        }
354    }
355
356    fn get_index_with_table_name(
357        &self,
358        column_name: &String,
359        table_name: &String,
360    ) -> LiteResult<usize> {
361        let column_indexes = self
362            .indices_of
363            .get(column_name)
364            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
365        match column_indexes
366            .iter()
367            .find(|column_index| self.columns[**column_index].table_name == *table_name)
368        {
369            Some(column_index) => Ok(*column_index),
370            None => Err(ErrorCode::ItemNotFound(format!(
371                "missing FROM-clause entry for table \"{}\"",
372                table_name
373            ))),
374        }
375    }
376
377    /// Merges two `BindContext`s which are adjacent. For instance, the `BindContext` of two
378    /// adjacent cross-joined tables.
379    pub fn merge_context(&mut self, other: Self) -> Result<()> {
380        let begin = self.columns.len();
381        self.columns.extend(other.columns.into_iter().map(|mut c| {
382            c.index += begin;
383            c
384        }));
385        for (k, v) in other.indices_of {
386            let entry = self.indices_of.entry(k).or_default();
387            entry.extend(v.into_iter().map(|x| x + begin));
388        }
389        for (k, (x, y)) in other.range_of {
390            match self.range_of.entry(k) {
391                Entry::Occupied(e) => {
392                    return Err(ErrorCode::InternalError(format!(
393                        "Duplicated table name while merging adjacent contexts: {}",
394                        e.key()
395                    ))
396                    .into());
397                }
398                Entry::Vacant(entry) => {
399                    entry.insert((begin + x, begin + y));
400                }
401            }
402        }
403        // To merge the column_group_contexts, we just need to offset RHS
404        // with the next_group_id of LHS.
405        let ColumnGroupContext {
406            mapping,
407            groups,
408            next_group_id,
409        } = other.column_group_context;
410
411        let offset = self.column_group_context.next_group_id;
412        for (idx, group_id) in mapping {
413            self.column_group_context
414                .mapping
415                .insert(begin + idx, offset + group_id);
416        }
417        for (group_id, mut group) in groups {
418            group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
419            if let Some(col) = &mut group.non_nullable_column {
420                *col += begin;
421            }
422            self.column_group_context
423                .groups
424                .insert(offset + group_id, group);
425        }
426        self.column_group_context.next_group_id += next_group_id;
427
428        // we assume that the clause is contained in the outer-level context
429        Ok(())
430    }
431}
432
433impl BindContext {
434    pub fn new() -> Self {
435        Self::default()
436    }
437}