risingwave_frontend/binder/
bind_context.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use parse_display::Display;
21use risingwave_common::catalog::Field;
22use risingwave_common::types::DataType;
23use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
24
25use crate::binder::Relation;
26use crate::error::{ErrorCode, Result};
27use crate::expr::ExprImpl;
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
32
33#[derive(Debug, Clone)]
34pub struct ColumnBinding {
35    pub table_name: String,
36    pub schema_name: Option<String>,
37    /// if the table has table alias, `table_alias` store the original table name
38    pub table_alias: Option<String>,
39    pub index: usize,
40    pub is_hidden: bool,
41    pub field: Field,
42}
43
44impl ColumnBinding {
45    pub fn new(
46        table_name: String,
47        schema_name: Option<String>,
48        table_alias: Option<String>,
49        index: usize,
50        is_hidden: bool,
51        field: Field,
52    ) -> Self {
53        ColumnBinding {
54            table_name,
55            schema_name,
56            table_alias,
57            index,
58            is_hidden,
59            field,
60        }
61    }
62}
63
64#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
65#[display(style = "TITLE CASE")]
66pub enum Clause {
67    Where,
68    Values,
69    GroupBy,
70    JoinOn,
71    Having,
72    Filter,
73    From,
74    GeneratedColumn,
75    Insert,
76}
77
78/// A `BindContext` that is only visible if the `LATERAL` keyword
79/// is provided.
80pub struct LateralBindContext {
81    pub is_visible: bool,
82    pub context: BindContext,
83}
84
85#[derive(Debug, Clone)]
86pub enum BindingCteState {
87    /// We get the whole bound result of the CTE.
88    Bound {
89        query: BoundQuery,
90    },
91
92    ChangeLog {
93        table: Relation,
94    },
95}
96
97#[derive(Clone, Debug)]
98pub struct BindingCte {
99    pub share_id: ShareId,
100    pub state: BindingCteState,
101    pub alias: TableAlias,
102}
103
104#[derive(Default, Debug, Clone)]
105pub struct BindContext {
106    // Columns of all tables.
107    pub columns: Vec<ColumnBinding>,
108    // Mapping column name to indices in `columns`.
109    pub indices_of: HashMap<String, Vec<usize>>,
110    // Mapping (schema name, table name) to [begin, end) of its columns.
111    pub range_of: HashMap<(Option<String>, String), (usize, usize)>,
112    // `clause` identifies in what clause we are binding.
113    pub clause: Option<Clause>,
114    // The `BindContext`'s data on its column groups
115    pub column_group_context: ColumnGroupContext,
116    /// Map the cte's name to its binding state.
117    /// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
118    pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
119    /// Current lambda functions's arguments
120    pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
121    /// Whether the security invoker is set, currently only used for views.
122    pub disable_security_invoker: bool,
123    /// Named window definitions from the `WINDOW` clause
124    pub named_windows: HashMap<String, WindowSpec>,
125    /// Bound arguments for the current SQL UDF, if any.
126    // TODO: use enum for named or positional arguments
127    pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
128}
129
130/// Holds the context for the `BindContext`'s `ColumnGroup`s.
131#[derive(Default, Debug, Clone)]
132pub struct ColumnGroupContext {
133    // Maps naturally-joined/USING columns to their column group id
134    pub mapping: HashMap<usize, u32>,
135    // Maps column group ids to their column group data
136    // We use a BTreeMap to ensure that iteration over the groups is ordered.
137    pub groups: BTreeMap<u32, ColumnGroup>,
138
139    next_group_id: u32,
140}
141
142/// When binding a natural join or a join with USING, a `ColumnGroup` contains the columns with the
143/// same name.
144#[derive(Default, Debug, Clone)]
145pub struct ColumnGroup {
146    /// Indices of the columns in the group
147    pub indices: HashSet<usize>,
148    /// A non-nullable column is never NULL.
149    /// If `None`, ambiguous references to the column name will be resolved to a `COALESCE(col1,
150    /// col2, ..., coln)` over each column in the group
151    pub non_nullable_column: Option<usize>,
152
153    pub column_name: Option<String>,
154}
155
156impl BindContext {
157    pub fn get_column_binding_index(
158        &self,
159        schema_name: &Option<String>,
160        table_name: &Option<String>,
161        column_name: &String,
162    ) -> LiteResult<usize> {
163        match &self.get_column_binding_indices(schema_name, table_name, column_name)?[..] {
164            [] => unreachable!(),
165            [idx] => Ok(*idx),
166            _ => Err(ErrorCode::InternalError(format!(
167                "Ambiguous column name: {}",
168                column_name
169            ))),
170        }
171    }
172
173    /// If return Vec has len > 1, it means we have an unqualified reference to a column which has
174    /// been naturally joined upon, wherein none of the columns are min-nullable. This will be
175    /// handled in downstream as a `COALESCE` expression
176    pub fn get_column_binding_indices(
177        &self,
178        schema_name: &Option<String>,
179        table_name: &Option<String>,
180        column_name: &String,
181    ) -> LiteResult<Vec<usize>> {
182        match table_name {
183            Some(table_name) => {
184                if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
185                    let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
186                        format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
187                    self.get_indices_with_group_id(group_id, column_name)
188                } else {
189                    Ok(vec![self.get_index_with_table_name(
190                        column_name,
191                        table_name,
192                        schema_name,
193                    )?])
194                }
195            }
196            None => self.get_unqualified_indices(column_name),
197        }
198    }
199
200    pub fn get_table_alias(
201        &self,
202        schema_name: &String,
203        table_name: &String,
204        column_name: &String,
205    ) -> LiteResult<Option<usize>> {
206        let column_indexes = self
207            .indices_of
208            .get(column_name)
209            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
210        for index in column_indexes {
211            let column = &self.columns[*index];
212            if let (Some(schema), Some(table_alias)) = (&column.schema_name, &column.table_alias)
213                && schema == schema_name
214                && table_alias == table_name
215            {
216                return Ok(Some(*index));
217            }
218        }
219        Ok(None)
220    }
221
222    fn get_indices_with_group_id(
223        &self,
224        group_id: u32,
225        column_name: &String,
226    ) -> LiteResult<Vec<usize>> {
227        let group = self.column_group_context.groups.get(&group_id).unwrap();
228        if let Some(name) = &group.column_name {
229            debug_assert_eq!(name, column_name);
230        }
231        if let Some(non_nullable) = &group.non_nullable_column {
232            Ok(vec![*non_nullable])
233        } else {
234            // These will be converted to a `COALESCE(col1, col2, ..., coln)`
235            let mut indices: Vec<_> = group.indices.iter().copied().collect();
236            indices.sort(); // ensure a deterministic result
237            Ok(indices)
238        }
239    }
240
241    pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
242        let columns = self
243            .indices_of
244            .get(column_name)
245            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
246        if columns.len() > 1 {
247            // If there is some group containing the columns and the ambiguous columns are all in
248            // the group
249            if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
250                let group = self.column_group_context.groups.get(group_id).unwrap();
251                if columns.iter().all(|idx| group.indices.contains(idx)) {
252                    if let Some(non_nullable) = &group.non_nullable_column {
253                        return Ok(vec![*non_nullable]);
254                    } else {
255                        // These will be converted to a `COALESCE(col1, col2, ..., coln)`
256                        return Ok(columns.clone());
257                    }
258                }
259            }
260            Err(ErrorCode::InternalError(format!(
261                "Ambiguous column name: {}",
262                column_name
263            )))
264        } else {
265            Ok(columns.clone())
266        }
267    }
268
269    /// Identifies two columns as being in the same group. Additionally, possibly provides one of
270    /// the columns as being `non_nullable`
271    pub fn add_natural_columns(
272        &mut self,
273        left: usize,
274        right: usize,
275        non_nullable_column: Option<usize>,
276    ) {
277        match (
278            self.column_group_context.mapping.get(&left).copied(),
279            self.column_group_context.mapping.get(&right).copied(),
280        ) {
281            (None, None) => {
282                let group_id = self.column_group_context.next_group_id;
283                self.column_group_context.next_group_id += 1;
284
285                let group = ColumnGroup {
286                    indices: HashSet::from([left, right]),
287                    non_nullable_column,
288                    column_name: Some(self.columns[left].field.name.clone()),
289                };
290                self.column_group_context.groups.insert(group_id, group);
291                self.column_group_context.mapping.insert(left, group_id);
292                self.column_group_context.mapping.insert(right, group_id);
293            }
294            (Some(group_id), None) => {
295                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
296                group.indices.insert(right);
297                if group.non_nullable_column.is_none() {
298                    group.non_nullable_column = non_nullable_column;
299                }
300                self.column_group_context.mapping.insert(right, group_id);
301            }
302            (None, Some(group_id)) => {
303                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
304                group.indices.insert(left);
305                if group.non_nullable_column.is_none() {
306                    group.non_nullable_column = non_nullable_column;
307                }
308                self.column_group_context.mapping.insert(left, group_id);
309            }
310            (Some(l_group_id), Some(r_group_id)) => {
311                if r_group_id == l_group_id {
312                    return;
313                }
314
315                let r_group = self
316                    .column_group_context
317                    .groups
318                    .remove(&r_group_id)
319                    .unwrap();
320                let l_group = self
321                    .column_group_context
322                    .groups
323                    .get_mut(&l_group_id)
324                    .unwrap();
325
326                for idx in &r_group.indices {
327                    *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
328                    l_group.indices.insert(*idx);
329                }
330                if l_group.non_nullable_column.is_none() {
331                    l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
332                        non_nullable_column
333                    } else {
334                        r_group.non_nullable_column
335                    };
336                }
337            }
338        }
339    }
340
341    /// Resolve a qualified column reference (`table.column` or `schema.table.column`)
342    /// to one concrete column index in the current bind scope.
343    ///
344    /// This function intentionally uses a two-step strategy:
345    /// 1. Resolve relation first (`resolve_relation_range`), including ambiguity checks.
346    /// 2. Resolve column inside that unique relation range (`resolve_column_in_range`).
347    fn get_index_with_table_name(
348        &self,
349        column_name: &String,
350        table_name: &String,
351        schema_name: &Option<String>,
352    ) -> LiteResult<usize> {
353        let chosen_range = self.resolve_relation_range(table_name, schema_name)?;
354        self.resolve_column_in_range(column_name, table_name, chosen_range)
355    }
356
357    /// Resolve the relation range for a `table_name` in the current bind scope.
358    ///
359    /// Alias matches are prioritized over base relation name matches.
360    fn resolve_relation_range(
361        &self,
362        table_name: &String,
363        schema_name: &Option<String>,
364    ) -> LiteResult<(usize, usize)> {
365        let mut alias_ranges = Vec::new();
366        let mut base_ranges = Vec::new();
367
368        for ((_, _), range) in self.range_of.iter().filter(|((rel_schema, rel_name), _)| {
369            rel_name == table_name
370                && schema_name
371                    .as_deref()
372                    .is_none_or(|s| rel_schema.as_deref() == Some(s))
373        }) {
374            match self.columns[range.0].table_alias {
375                Some(_) => alias_ranges.push(*range),
376                None => base_ranges.push(*range),
377            }
378        }
379
380        let chosen_ranges = if alias_ranges.is_empty() {
381            base_ranges
382        } else {
383            alias_ranges
384        };
385
386        match chosen_ranges.as_slice() {
387            [] => Err(ErrorCode::ItemNotFound(format!(
388                "missing FROM-clause entry for table \"{}\"",
389                table_name
390            ))),
391            [range] => Ok(*range),
392            _ => Err(ErrorCode::InvalidReference(format!(
393                "table reference \"{}\" is ambiguous",
394                table_name
395            ))),
396        }
397    }
398
399    /// Resolve `column_name` inside a pre-resolved relation range.
400    fn resolve_column_in_range(
401        &self,
402        column_name: &String,
403        table_name: &String,
404        range: (usize, usize),
405    ) -> LiteResult<usize> {
406        let idxs = self
407            .indices_of
408            .get(column_name)
409            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
410
411        let matched: Vec<_> = idxs
412            .iter()
413            .copied()
414            .filter(|index| (range.0..range.1).contains(index))
415            .collect();
416
417        match matched.as_slice() {
418            [] => Err(ErrorCode::ItemNotFound(format!(
419                "missing FROM-clause entry for table \"{}\"",
420                table_name
421            ))),
422            [column_index] => Ok(*column_index),
423            _ => Err(ErrorCode::InvalidReference(format!(
424                "table reference \"{}\" is ambiguous",
425                table_name
426            ))),
427        }
428    }
429
430    /// Merges two `BindContext`s which are adjacent. For instance, the `BindContext` of two
431    /// adjacent cross-joined tables.
432    pub fn merge_context(&mut self, other: Self) -> Result<()> {
433        let begin = self.columns.len();
434        self.columns.extend(other.columns.into_iter().map(|mut c| {
435            c.index += begin;
436            c
437        }));
438        for (k, v) in other.indices_of {
439            let entry = self.indices_of.entry(k).or_default();
440            entry.extend(v.into_iter().map(|x| x + begin));
441        }
442        for (k, (x, y)) in other.range_of {
443            match self.range_of.entry(k) {
444                Entry::Occupied(e) => {
445                    return Err(ErrorCode::InternalError(format!(
446                        "Duplicated table name while merging adjacent contexts: {}",
447                        e.key().1
448                    ))
449                    .into());
450                }
451                Entry::Vacant(entry) => {
452                    entry.insert((begin + x, begin + y));
453                }
454            }
455        }
456        // To merge the column_group_contexts, we just need to offset RHS
457        // with the next_group_id of LHS.
458        let ColumnGroupContext {
459            mapping,
460            groups,
461            next_group_id,
462        } = other.column_group_context;
463
464        let offset = self.column_group_context.next_group_id;
465        for (idx, group_id) in mapping {
466            self.column_group_context
467                .mapping
468                .insert(begin + idx, offset + group_id);
469        }
470        for (group_id, mut group) in groups {
471            group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
472            if let Some(col) = &mut group.non_nullable_column {
473                *col += begin;
474            }
475            self.column_group_context
476                .groups
477                .insert(offset + group_id, group);
478        }
479        self.column_group_context.next_group_id += next_group_id;
480
481        // we assume that the clause is contained in the outer-level context
482        Ok(())
483    }
484}
485
486impl BindContext {
487    pub fn new() -> Self {
488        Self::default()
489    }
490}