Skip to main content

risingwave_frontend/binder/
bind_context.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use parse_display::Display;
21use risingwave_common::catalog::Field;
22use risingwave_common::types::DataType;
23use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
24
25use crate::binder::Relation;
26use crate::error::{ErrorCode, Result};
27use crate::expr::ExprImpl;
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
32
33#[derive(Debug, Clone)]
34pub struct ColumnBinding {
35    pub table_name: String,
36    pub schema_name: Option<String>,
37    /// if the table has table alias, `table_alias` store the original table name
38    pub table_alias: Option<String>,
39    pub index: usize,
40    pub is_hidden: bool,
41    pub field: Field,
42}
43
44impl ColumnBinding {
45    pub fn new(
46        table_name: String,
47        schema_name: Option<String>,
48        table_alias: Option<String>,
49        index: usize,
50        is_hidden: bool,
51        field: Field,
52    ) -> Self {
53        ColumnBinding {
54            table_name,
55            schema_name,
56            table_alias,
57            index,
58            is_hidden,
59            field,
60        }
61    }
62}
63
64#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
65#[display(style = "TITLE CASE")]
66pub enum Clause {
67    Where,
68    Values,
69    GroupBy,
70    JoinOn,
71    Having,
72    Filter,
73    From,
74    GeneratedColumn,
75    Insert,
76}
77
78/// A `BindContext` that is only visible if the `LATERAL` keyword
79/// is provided.
80pub struct LateralBindContext {
81    pub is_visible: bool,
82    pub context: BindContext,
83}
84
85#[derive(Debug, Clone)]
86pub enum BindingCteState {
87    /// We get the whole bound result of the CTE.
88    Bound {
89        query: BoundQuery,
90    },
91
92    ChangeLog {
93        table: Relation,
94    },
95}
96
97#[derive(Clone, Debug)]
98pub struct BindingCte {
99    pub share_id: ShareId,
100    pub state: BindingCteState,
101    pub alias: TableAlias,
102}
103
104#[derive(Default, Debug, Clone)]
105pub struct BindContext {
106    // Columns of all tables.
107    pub columns: Vec<ColumnBinding>,
108    // Mapping column name to indices in `columns`.
109    pub indices_of: HashMap<String, Vec<usize>>,
110    // Mapping (schema name, table name) to [begin, end) of its columns.
111    pub range_of: HashMap<(Option<String>, String), (usize, usize)>,
112    // `clause` identifies in what clause we are binding.
113    pub clause: Option<Clause>,
114    // The `BindContext`'s data on its column groups
115    pub column_group_context: ColumnGroupContext,
116    /// Map the cte's name to its binding state.
117    /// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
118    pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
119    /// Exposed relation names of CTE references in the current FROM scope.
120    pub cte_relation_names: HashSet<String>,
121    /// Current lambda functions's arguments
122    pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
123    /// Whether the security invoker is set, currently only used for views.
124    pub disable_security_invoker: bool,
125    /// Named window definitions from the `WINDOW` clause
126    pub named_windows: HashMap<String, WindowSpec>,
127    /// Bound arguments for the current SQL UDF, if any.
128    // TODO: use enum for named or positional arguments
129    pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
130}
131
132/// Holds the context for the `BindContext`'s `ColumnGroup`s.
133#[derive(Default, Debug, Clone)]
134pub struct ColumnGroupContext {
135    // Maps naturally-joined/USING columns to their column group id
136    pub mapping: HashMap<usize, u32>,
137    // Maps column group ids to their column group data
138    // We use a BTreeMap to ensure that iteration over the groups is ordered.
139    pub groups: BTreeMap<u32, ColumnGroup>,
140
141    next_group_id: u32,
142}
143
144/// When binding a natural join or a join with USING, a `ColumnGroup` contains the columns with the
145/// same name.
146#[derive(Default, Debug, Clone)]
147pub struct ColumnGroup {
148    /// Indices of the columns in the group
149    pub indices: HashSet<usize>,
150    /// A non-nullable column is never NULL.
151    /// If `None`, ambiguous references to the column name will be resolved to a `COALESCE(col1,
152    /// col2, ..., coln)` over each column in the group
153    pub non_nullable_column: Option<usize>,
154
155    pub column_name: Option<String>,
156}
157
158impl BindContext {
159    pub fn get_column_binding_index(
160        &self,
161        schema_name: &Option<String>,
162        table_name: &Option<String>,
163        column_name: &String,
164    ) -> LiteResult<usize> {
165        match &self.get_column_binding_indices(schema_name, table_name, column_name)?[..] {
166            [] => unreachable!(),
167            [idx] => Ok(*idx),
168            _ => Err(ErrorCode::InternalError(format!(
169                "Ambiguous column name: {}",
170                column_name
171            ))),
172        }
173    }
174
175    /// If return Vec has len > 1, it means we have an unqualified reference to a column which has
176    /// been naturally joined upon, wherein none of the columns are min-nullable. This will be
177    /// handled in downstream as a `COALESCE` expression
178    pub fn get_column_binding_indices(
179        &self,
180        schema_name: &Option<String>,
181        table_name: &Option<String>,
182        column_name: &String,
183    ) -> LiteResult<Vec<usize>> {
184        match table_name {
185            Some(table_name) => {
186                if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
187                    let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
188                        format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
189                    self.get_indices_with_group_id(group_id, column_name)
190                } else {
191                    Ok(vec![self.get_index_with_table_name(
192                        column_name,
193                        table_name,
194                        schema_name,
195                    )?])
196                }
197            }
198            None => self.get_unqualified_indices(column_name),
199        }
200    }
201
202    pub fn get_table_alias(
203        &self,
204        schema_name: &String,
205        table_name: &String,
206        column_name: &String,
207    ) -> LiteResult<Option<usize>> {
208        let column_indexes = self
209            .indices_of
210            .get(column_name)
211            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
212        for index in column_indexes {
213            let column = &self.columns[*index];
214            if let (Some(schema), Some(table_alias)) = (&column.schema_name, &column.table_alias)
215                && schema == schema_name
216                && table_alias == table_name
217            {
218                return Ok(Some(*index));
219            }
220        }
221        Ok(None)
222    }
223
224    fn get_indices_with_group_id(
225        &self,
226        group_id: u32,
227        column_name: &String,
228    ) -> LiteResult<Vec<usize>> {
229        let group = self.column_group_context.groups.get(&group_id).unwrap();
230        if let Some(name) = &group.column_name {
231            debug_assert_eq!(name, column_name);
232        }
233        if let Some(non_nullable) = &group.non_nullable_column {
234            Ok(vec![*non_nullable])
235        } else {
236            // These will be converted to a `COALESCE(col1, col2, ..., coln)`
237            let mut indices: Vec<_> = group.indices.iter().copied().collect();
238            indices.sort(); // ensure a deterministic result
239            Ok(indices)
240        }
241    }
242
243    pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
244        let columns = self
245            .indices_of
246            .get(column_name)
247            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
248        if columns.len() > 1 {
249            // If there is some group containing the columns and the ambiguous columns are all in
250            // the group
251            if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
252                let group = self.column_group_context.groups.get(group_id).unwrap();
253                if columns.iter().all(|idx| group.indices.contains(idx)) {
254                    if let Some(non_nullable) = &group.non_nullable_column {
255                        return Ok(vec![*non_nullable]);
256                    } else {
257                        // These will be converted to a `COALESCE(col1, col2, ..., coln)`
258                        return Ok(columns.clone());
259                    }
260                }
261            }
262            Err(ErrorCode::InternalError(format!(
263                "Ambiguous column name: {}",
264                column_name
265            )))
266        } else {
267            Ok(columns.clone())
268        }
269    }
270
271    pub fn check_catalog_name(&self, exposed_relation_name: &str) -> Result<()> {
272        if self.cte_relation_names.contains(exposed_relation_name) {
273            return Err(ErrorCode::DuplicateRelationName(format!(
274                "table name \"{}\" specified more than once",
275                exposed_relation_name
276            ))
277            .into());
278        }
279        Ok(())
280    }
281
282    pub fn add_cte_name(&mut self, exposed_relation_name: String) {
283        self.cte_relation_names.insert(exposed_relation_name);
284    }
285
286    fn has_relation_name(&self, relation_name: &str) -> bool {
287        self.range_of
288            .keys()
289            .any(|(_, existing_name)| existing_name == relation_name)
290    }
291
292    pub fn check_relation_name_conflict(&self, relation_name: &str) -> Result<()> {
293        if self.has_relation_name(relation_name) {
294            return Err(ErrorCode::DuplicateRelationName(format!(
295                "table name \"{}\" specified more than once",
296                relation_name
297            ))
298            .into());
299        }
300        Ok(())
301    }
302
303    fn check_cte_relation_name_conflict(&self, other: &Self) -> Result<()> {
304        // `self` may bind a CTE in a separate context, while `other` already has a catalog
305        // relation with the same exposed name.
306        for cte_relation_name in &self.cte_relation_names {
307            other.check_relation_name_conflict(cte_relation_name)?;
308        }
309        // `other` may bind a CTE in a separate context, while `self` already has a catalog
310        // relation with the same exposed name.
311        for cte_relation_name in &other.cte_relation_names {
312            self.check_relation_name_conflict(cte_relation_name)?;
313        }
314        Ok(())
315    }
316
317    /// Identifies two columns as being in the same group. Additionally, possibly provides one of
318    /// the columns as being `non_nullable`
319    pub fn add_natural_columns(
320        &mut self,
321        left: usize,
322        right: usize,
323        non_nullable_column: Option<usize>,
324    ) {
325        match (
326            self.column_group_context.mapping.get(&left).copied(),
327            self.column_group_context.mapping.get(&right).copied(),
328        ) {
329            (None, None) => {
330                let group_id = self.column_group_context.next_group_id;
331                self.column_group_context.next_group_id += 1;
332
333                let group = ColumnGroup {
334                    indices: HashSet::from([left, right]),
335                    non_nullable_column,
336                    column_name: Some(self.columns[left].field.name.clone()),
337                };
338                self.column_group_context.groups.insert(group_id, group);
339                self.column_group_context.mapping.insert(left, group_id);
340                self.column_group_context.mapping.insert(right, group_id);
341            }
342            (Some(group_id), None) => {
343                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
344                group.indices.insert(right);
345                if group.non_nullable_column.is_none() {
346                    group.non_nullable_column = non_nullable_column;
347                }
348                self.column_group_context.mapping.insert(right, group_id);
349            }
350            (None, Some(group_id)) => {
351                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
352                group.indices.insert(left);
353                if group.non_nullable_column.is_none() {
354                    group.non_nullable_column = non_nullable_column;
355                }
356                self.column_group_context.mapping.insert(left, group_id);
357            }
358            (Some(l_group_id), Some(r_group_id)) => {
359                if r_group_id == l_group_id {
360                    return;
361                }
362
363                let r_group = self
364                    .column_group_context
365                    .groups
366                    .remove(&r_group_id)
367                    .unwrap();
368                let l_group = self
369                    .column_group_context
370                    .groups
371                    .get_mut(&l_group_id)
372                    .unwrap();
373
374                for idx in &r_group.indices {
375                    *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
376                    l_group.indices.insert(*idx);
377                }
378                if l_group.non_nullable_column.is_none() {
379                    l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
380                        non_nullable_column
381                    } else {
382                        r_group.non_nullable_column
383                    };
384                }
385            }
386        }
387    }
388
389    /// Resolve a qualified column reference (`table.column` or `schema.table.column`)
390    /// to one concrete column index in the current bind scope.
391    ///
392    /// This function intentionally uses a two-step strategy:
393    /// 1. Resolve relation first (`resolve_relation_range`), including ambiguity checks.
394    /// 2. Resolve column inside that unique relation range (`resolve_column_in_range`).
395    fn get_index_with_table_name(
396        &self,
397        column_name: &String,
398        table_name: &String,
399        schema_name: &Option<String>,
400    ) -> LiteResult<usize> {
401        let chosen_range = self.resolve_relation_range(table_name, schema_name)?;
402        self.resolve_column_in_range(column_name, table_name, chosen_range)
403    }
404
405    /// Resolve the relation range for a `table_name` in the current bind scope.
406    ///
407    /// Alias matches are prioritized over base relation name matches.
408    fn resolve_relation_range(
409        &self,
410        table_name: &String,
411        schema_name: &Option<String>,
412    ) -> LiteResult<(usize, usize)> {
413        let mut alias_ranges = Vec::new();
414        let mut base_ranges = Vec::new();
415
416        for ((_, _), range) in self.range_of.iter().filter(|((rel_schema, rel_name), _)| {
417            rel_name == table_name
418                && schema_name
419                    .as_deref()
420                    .is_none_or(|s| rel_schema.as_deref() == Some(s))
421        }) {
422            match self.columns[range.0].table_alias {
423                Some(_) => alias_ranges.push(*range),
424                None => base_ranges.push(*range),
425            }
426        }
427
428        let chosen_ranges = if alias_ranges.is_empty() {
429            base_ranges
430        } else {
431            alias_ranges
432        };
433
434        match chosen_ranges.as_slice() {
435            [] => Err(ErrorCode::ItemNotFound(format!(
436                "missing FROM-clause entry for table \"{}\"",
437                table_name
438            ))),
439            [range] => Ok(*range),
440            _ => Err(ErrorCode::InvalidReference(format!(
441                "table reference \"{}\" is ambiguous",
442                table_name
443            ))),
444        }
445    }
446
447    /// Resolve `column_name` inside a pre-resolved relation range.
448    fn resolve_column_in_range(
449        &self,
450        column_name: &String,
451        table_name: &String,
452        range: (usize, usize),
453    ) -> LiteResult<usize> {
454        let idxs = self
455            .indices_of
456            .get(column_name)
457            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
458
459        let matched: Vec<_> = idxs
460            .iter()
461            .copied()
462            .filter(|index| (range.0..range.1).contains(index))
463            .collect();
464
465        match matched.as_slice() {
466            [] => Err(ErrorCode::ItemNotFound(format!(
467                "missing FROM-clause entry for table \"{}\"",
468                table_name
469            ))),
470            [column_index] => Ok(*column_index),
471            _ => Err(ErrorCode::InvalidReference(format!(
472                "column reference \"{}\" is ambiguous",
473                column_name
474            ))),
475        }
476    }
477
478    /// Merges two `BindContext`s which are adjacent. For instance, the `BindContext` of two
479    /// adjacent cross-joined tables.
480    pub fn merge_context(&mut self, other: Self) -> Result<()> {
481        self.check_cte_relation_name_conflict(&other)?;
482
483        let begin = self.columns.len();
484        self.columns.extend(other.columns.into_iter().map(|mut c| {
485            c.index += begin;
486            c
487        }));
488        for (k, v) in other.indices_of {
489            let entry = self.indices_of.entry(k).or_default();
490            entry.extend(v.into_iter().map(|x| x + begin));
491        }
492        for (k, (x, y)) in other.range_of {
493            match self.range_of.entry(k) {
494                Entry::Occupied(e) => {
495                    return Err(ErrorCode::InternalError(format!(
496                        "Duplicated table name while merging adjacent contexts: {}",
497                        e.key().1
498                    ))
499                    .into());
500                }
501                Entry::Vacant(entry) => {
502                    entry.insert((begin + x, begin + y));
503                }
504            }
505        }
506        self.cte_relation_names.extend(other.cte_relation_names);
507        // To merge the column_group_contexts, we just need to offset RHS
508        // with the next_group_id of LHS.
509        let ColumnGroupContext {
510            mapping,
511            groups,
512            next_group_id,
513        } = other.column_group_context;
514
515        let offset = self.column_group_context.next_group_id;
516        for (idx, group_id) in mapping {
517            self.column_group_context
518                .mapping
519                .insert(begin + idx, offset + group_id);
520        }
521        for (group_id, mut group) in groups {
522            group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
523            if let Some(col) = &mut group.non_nullable_column {
524                *col += begin;
525            }
526            self.column_group_context
527                .groups
528                .insert(offset + group_id, group);
529        }
530        self.column_group_context.next_group_id += next_group_id;
531
532        // we assume that the clause is contained in the outer-level context
533        Ok(())
534    }
535}
536
537impl BindContext {
538    pub fn new() -> Self {
539        Self::default()
540    }
541}