risingwave_frontend/binder/
bind_context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use either::Either;
21use parse_display::Display;
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::types::DataType;
24use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
25
26use crate::binder::Relation;
27use crate::error::{ErrorCode, Result};
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use super::BoundSetExpr;
32use super::statement::RewriteExprsRecursive;
33use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
34
35#[derive(Debug, Clone)]
36pub struct ColumnBinding {
37    pub table_name: String,
38    pub index: usize,
39    pub is_hidden: bool,
40    pub field: Field,
41}
42
43impl ColumnBinding {
44    pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
45        ColumnBinding {
46            table_name,
47            index,
48            is_hidden,
49            field,
50        }
51    }
52}
53
54#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
55#[display(style = "TITLE CASE")]
56pub enum Clause {
57    Where,
58    Values,
59    GroupBy,
60    JoinOn,
61    Having,
62    Filter,
63    From,
64    GeneratedColumn,
65    Insert,
66}
67
68/// A `BindContext` that is only visible if the `LATERAL` keyword
69/// is provided.
70pub struct LateralBindContext {
71    pub is_visible: bool,
72    pub context: BindContext,
73}
74
75/// For recursive CTE, we may need to store it in `cte_to_relation` first,
76/// and then bind it *step by step*.
77///
78/// note: the below sql example is to illustrate when we get the
79/// corresponding binding state when handling a recursive CTE like this.
80///
81/// ```sql
82/// WITH RECURSIVE t(n) AS (
83/// # -------------^ => Init
84///     VALUES (1)
85///   UNION ALL
86///     SELECT n + 1 FROM t WHERE n < 100
87/// # --------------------^ => BaseResolved (after binding the base term, this relation will be bound to `Relation::BackCteRef`)
88/// )
89/// SELECT sum(n) FROM t;
90/// # -----------------^ => Bound (we know exactly what the entire `RecursiveUnion` looks like, and this relation will be bound to `Relation::Share`)
91/// ```
92#[derive(Default, Debug, Clone)]
93pub enum BindingCteState {
94    /// We know nothing about the CTE before resolving the body.
95    #[default]
96    Init,
97    /// We know the schema form after the base term resolved.
98    BaseResolved {
99        base: BoundSetExpr,
100    },
101    /// We get the whole bound result of the (recursive) CTE.
102    Bound {
103        query: Either<BoundQuery, RecursiveUnion>,
104    },
105
106    ChangeLog {
107        table: Relation,
108    },
109}
110
111/// the entire `RecursiveUnion` represents a *bound* recursive cte.
112/// reference: <https://github.com/risingwavelabs/risingwave/pull/15522/files#r1524367781>
113#[derive(Debug, Clone)]
114pub struct RecursiveUnion {
115    /// currently this *must* be true,
116    /// otherwise binding will fail.
117    #[allow(dead_code)]
118    pub all: bool,
119    /// lhs part of the `UNION ALL` operator
120    pub base: Box<BoundSetExpr>,
121    /// rhs part of the `UNION ALL` operator
122    pub recursive: Box<BoundSetExpr>,
123    /// the aligned schema for this union
124    /// will be the *same* schema as recursive's
125    /// this is just for a better readability
126    pub schema: Schema,
127}
128
129impl RewriteExprsRecursive for RecursiveUnion {
130    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
131        // rewrite `base` and `recursive` separately
132        self.base.rewrite_exprs_recursive(rewriter);
133        self.recursive.rewrite_exprs_recursive(rewriter);
134    }
135}
136
137#[derive(Clone, Debug)]
138pub struct BindingCte {
139    pub share_id: ShareId,
140    pub state: BindingCteState,
141    pub alias: TableAlias,
142}
143
144#[derive(Default, Debug, Clone)]
145pub struct BindContext {
146    // Columns of all tables.
147    pub columns: Vec<ColumnBinding>,
148    // Mapping column name to indices in `columns`.
149    pub indices_of: HashMap<String, Vec<usize>>,
150    // Mapping table name to [begin, end) of its columns.
151    pub range_of: HashMap<String, (usize, usize)>,
152    // `clause` identifies in what clause we are binding.
153    pub clause: Option<Clause>,
154    // The `BindContext`'s data on its column groups
155    pub column_group_context: ColumnGroupContext,
156    /// Map the cte's name to its binding state.
157    /// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
158    pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
159    /// Current lambda functions's arguments
160    pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
161    /// Whether the security invoker is set, currently only used for views.
162    pub disable_security_invoker: bool,
163    /// Named window definitions from the `WINDOW` clause
164    pub named_windows: HashMap<String, WindowSpec>,
165}
166
167/// Holds the context for the `BindContext`'s `ColumnGroup`s.
168#[derive(Default, Debug, Clone)]
169pub struct ColumnGroupContext {
170    // Maps naturally-joined/USING columns to their column group id
171    pub mapping: HashMap<usize, u32>,
172    // Maps column group ids to their column group data
173    // We use a BTreeMap to ensure that iteration over the groups is ordered.
174    pub groups: BTreeMap<u32, ColumnGroup>,
175
176    next_group_id: u32,
177}
178
179/// When binding a natural join or a join with USING, a `ColumnGroup` contains the columns with the
180/// same name.
181#[derive(Default, Debug, Clone)]
182pub struct ColumnGroup {
183    /// Indices of the columns in the group
184    pub indices: HashSet<usize>,
185    /// A non-nullable column is never NULL.
186    /// If `None`, ambiguous references to the column name will be resolved to a `COALESCE(col1,
187    /// col2, ..., coln)` over each column in the group
188    pub non_nullable_column: Option<usize>,
189
190    pub column_name: Option<String>,
191}
192
193impl BindContext {
194    pub fn get_column_binding_index(
195        &self,
196        table_name: &Option<String>,
197        column_name: &String,
198    ) -> LiteResult<usize> {
199        match &self.get_column_binding_indices(table_name, column_name)?[..] {
200            [] => unreachable!(),
201            [idx] => Ok(*idx),
202            _ => Err(ErrorCode::InternalError(format!(
203                "Ambiguous column name: {}",
204                column_name
205            ))),
206        }
207    }
208
209    /// If return Vec has len > 1, it means we have an unqualified reference to a column which has
210    /// been naturally joined upon, wherein none of the columns are min-nullable. This will be
211    /// handled in downstream as a `COALESCE` expression
212    pub fn get_column_binding_indices(
213        &self,
214        table_name: &Option<String>,
215        column_name: &String,
216    ) -> LiteResult<Vec<usize>> {
217        match table_name {
218            Some(table_name) => {
219                if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
220                    let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
221                        format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
222                    self.get_indices_with_group_id(group_id, column_name)
223                } else {
224                    Ok(vec![
225                        self.get_index_with_table_name(column_name, table_name)?,
226                    ])
227                }
228            }
229            None => self.get_unqualified_indices(column_name),
230        }
231    }
232
233    fn get_indices_with_group_id(
234        &self,
235        group_id: u32,
236        column_name: &String,
237    ) -> LiteResult<Vec<usize>> {
238        let group = self.column_group_context.groups.get(&group_id).unwrap();
239        if let Some(name) = &group.column_name {
240            debug_assert_eq!(name, column_name);
241        }
242        if let Some(non_nullable) = &group.non_nullable_column {
243            Ok(vec![*non_nullable])
244        } else {
245            // These will be converted to a `COALESCE(col1, col2, ..., coln)`
246            let mut indices: Vec<_> = group.indices.iter().copied().collect();
247            indices.sort(); // ensure a deterministic result
248            Ok(indices)
249        }
250    }
251
252    pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
253        let columns = self
254            .indices_of
255            .get(column_name)
256            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
257        if columns.len() > 1 {
258            // If there is some group containing the columns and the ambiguous columns are all in
259            // the group
260            if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
261                let group = self.column_group_context.groups.get(group_id).unwrap();
262                if columns.iter().all(|idx| group.indices.contains(idx)) {
263                    if let Some(non_nullable) = &group.non_nullable_column {
264                        return Ok(vec![*non_nullable]);
265                    } else {
266                        // These will be converted to a `COALESCE(col1, col2, ..., coln)`
267                        return Ok(columns.to_vec());
268                    }
269                }
270            }
271            Err(ErrorCode::InternalError(format!(
272                "Ambiguous column name: {}",
273                column_name
274            )))
275        } else {
276            Ok(columns.to_vec())
277        }
278    }
279
280    /// Identifies two columns as being in the same group. Additionally, possibly provides one of
281    /// the columns as being `non_nullable`
282    pub fn add_natural_columns(
283        &mut self,
284        left: usize,
285        right: usize,
286        non_nullable_column: Option<usize>,
287    ) {
288        match (
289            self.column_group_context.mapping.get(&left).copied(),
290            self.column_group_context.mapping.get(&right).copied(),
291        ) {
292            (None, None) => {
293                let group_id = self.column_group_context.next_group_id;
294                self.column_group_context.next_group_id += 1;
295
296                let group = ColumnGroup {
297                    indices: HashSet::from([left, right]),
298                    non_nullable_column,
299                    column_name: Some(self.columns[left].field.name.clone()),
300                };
301                self.column_group_context.groups.insert(group_id, group);
302                self.column_group_context.mapping.insert(left, group_id);
303                self.column_group_context.mapping.insert(right, group_id);
304            }
305            (Some(group_id), None) => {
306                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
307                group.indices.insert(right);
308                if group.non_nullable_column.is_none() {
309                    group.non_nullable_column = non_nullable_column;
310                }
311                self.column_group_context.mapping.insert(right, group_id);
312            }
313            (None, Some(group_id)) => {
314                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
315                group.indices.insert(left);
316                if group.non_nullable_column.is_none() {
317                    group.non_nullable_column = non_nullable_column;
318                }
319                self.column_group_context.mapping.insert(left, group_id);
320            }
321            (Some(l_group_id), Some(r_group_id)) => {
322                if r_group_id == l_group_id {
323                    return;
324                }
325
326                let r_group = self
327                    .column_group_context
328                    .groups
329                    .remove(&r_group_id)
330                    .unwrap();
331                let l_group = self
332                    .column_group_context
333                    .groups
334                    .get_mut(&l_group_id)
335                    .unwrap();
336
337                for idx in &r_group.indices {
338                    *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
339                    l_group.indices.insert(*idx);
340                }
341                if l_group.non_nullable_column.is_none() {
342                    l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
343                        non_nullable_column
344                    } else {
345                        r_group.non_nullable_column
346                    };
347                }
348            }
349        }
350    }
351
352    fn get_index_with_table_name(
353        &self,
354        column_name: &String,
355        table_name: &String,
356    ) -> LiteResult<usize> {
357        let column_indexes = self
358            .indices_of
359            .get(column_name)
360            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
361        match column_indexes
362            .iter()
363            .find(|column_index| self.columns[**column_index].table_name == *table_name)
364        {
365            Some(column_index) => Ok(*column_index),
366            None => Err(ErrorCode::ItemNotFound(format!(
367                "missing FROM-clause entry for table \"{}\"",
368                table_name
369            ))),
370        }
371    }
372
373    /// Merges two `BindContext`s which are adjacent. For instance, the `BindContext` of two
374    /// adjacent cross-joined tables.
375    pub fn merge_context(&mut self, other: Self) -> Result<()> {
376        let begin = self.columns.len();
377        self.columns.extend(other.columns.into_iter().map(|mut c| {
378            c.index += begin;
379            c
380        }));
381        for (k, v) in other.indices_of {
382            let entry = self.indices_of.entry(k).or_default();
383            entry.extend(v.into_iter().map(|x| x + begin));
384        }
385        for (k, (x, y)) in other.range_of {
386            match self.range_of.entry(k) {
387                Entry::Occupied(e) => {
388                    return Err(ErrorCode::InternalError(format!(
389                        "Duplicated table name while merging adjacent contexts: {}",
390                        e.key()
391                    ))
392                    .into());
393                }
394                Entry::Vacant(entry) => {
395                    entry.insert((begin + x, begin + y));
396                }
397            }
398        }
399        // To merge the column_group_contexts, we just need to offset RHS
400        // with the next_group_id of LHS.
401        let ColumnGroupContext {
402            mapping,
403            groups,
404            next_group_id,
405        } = other.column_group_context;
406
407        let offset = self.column_group_context.next_group_id;
408        for (idx, group_id) in mapping {
409            self.column_group_context
410                .mapping
411                .insert(begin + idx, offset + group_id);
412        }
413        for (group_id, mut group) in groups {
414            group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
415            if let Some(col) = &mut group.non_nullable_column {
416                *col += begin;
417            }
418            self.column_group_context
419                .groups
420                .insert(offset + group_id, group);
421        }
422        self.column_group_context.next_group_id += next_group_id;
423
424        // we assume that the clause is contained in the outer-level context
425        Ok(())
426    }
427}
428
429impl BindContext {
430    pub fn new() -> Self {
431        Self::default()
432    }
433}