risingwave_frontend/binder/
bind_context.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use parse_display::Display;
21use risingwave_common::catalog::Field;
22use risingwave_common::types::DataType;
23use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
24
25use crate::binder::Relation;
26use crate::error::{ErrorCode, Result};
27use crate::expr::ExprImpl;
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
32
33#[derive(Debug, Clone)]
34pub struct ColumnBinding {
35    pub table_name: String,
36    pub index: usize,
37    pub is_hidden: bool,
38    pub field: Field,
39}
40
41impl ColumnBinding {
42    pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
43        ColumnBinding {
44            table_name,
45            index,
46            is_hidden,
47            field,
48        }
49    }
50}
51
52#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
53#[display(style = "TITLE CASE")]
54pub enum Clause {
55    Where,
56    Values,
57    GroupBy,
58    JoinOn,
59    Having,
60    Filter,
61    From,
62    GeneratedColumn,
63    Insert,
64}
65
66/// A `BindContext` that is only visible if the `LATERAL` keyword
67/// is provided.
68pub struct LateralBindContext {
69    pub is_visible: bool,
70    pub context: BindContext,
71}
72
73#[derive(Debug, Clone)]
74pub enum BindingCteState {
75    /// We get the whole bound result of the CTE.
76    Bound {
77        query: BoundQuery,
78    },
79
80    ChangeLog {
81        table: Relation,
82    },
83}
84
85#[derive(Clone, Debug)]
86pub struct BindingCte {
87    pub share_id: ShareId,
88    pub state: BindingCteState,
89    pub alias: TableAlias,
90}
91
92#[derive(Default, Debug, Clone)]
93pub struct BindContext {
94    // Columns of all tables.
95    pub columns: Vec<ColumnBinding>,
96    // Mapping column name to indices in `columns`.
97    pub indices_of: HashMap<String, Vec<usize>>,
98    // Mapping table name to [begin, end) of its columns.
99    pub range_of: HashMap<String, (usize, usize)>,
100    // `clause` identifies in what clause we are binding.
101    pub clause: Option<Clause>,
102    // The `BindContext`'s data on its column groups
103    pub column_group_context: ColumnGroupContext,
104    /// Map the cte's name to its binding state.
105    /// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
106    pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
107    /// Current lambda functions's arguments
108    pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
109    /// Whether the security invoker is set, currently only used for views.
110    pub disable_security_invoker: bool,
111    /// Named window definitions from the `WINDOW` clause
112    pub named_windows: HashMap<String, WindowSpec>,
113    /// Bound arguments for the current SQL UDF, if any.
114    // TODO: use enum for named or positional arguments
115    pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
116}
117
118/// Holds the context for the `BindContext`'s `ColumnGroup`s.
119#[derive(Default, Debug, Clone)]
120pub struct ColumnGroupContext {
121    // Maps naturally-joined/USING columns to their column group id
122    pub mapping: HashMap<usize, u32>,
123    // Maps column group ids to their column group data
124    // We use a BTreeMap to ensure that iteration over the groups is ordered.
125    pub groups: BTreeMap<u32, ColumnGroup>,
126
127    next_group_id: u32,
128}
129
130/// When binding a natural join or a join with USING, a `ColumnGroup` contains the columns with the
131/// same name.
132#[derive(Default, Debug, Clone)]
133pub struct ColumnGroup {
134    /// Indices of the columns in the group
135    pub indices: HashSet<usize>,
136    /// A non-nullable column is never NULL.
137    /// If `None`, ambiguous references to the column name will be resolved to a `COALESCE(col1,
138    /// col2, ..., coln)` over each column in the group
139    pub non_nullable_column: Option<usize>,
140
141    pub column_name: Option<String>,
142}
143
144impl BindContext {
145    pub fn get_column_binding_index(
146        &self,
147        table_name: &Option<String>,
148        column_name: &String,
149    ) -> LiteResult<usize> {
150        match &self.get_column_binding_indices(table_name, column_name)?[..] {
151            [] => unreachable!(),
152            [idx] => Ok(*idx),
153            _ => Err(ErrorCode::InternalError(format!(
154                "Ambiguous column name: {}",
155                column_name
156            ))),
157        }
158    }
159
160    /// If return Vec has len > 1, it means we have an unqualified reference to a column which has
161    /// been naturally joined upon, wherein none of the columns are min-nullable. This will be
162    /// handled in downstream as a `COALESCE` expression
163    pub fn get_column_binding_indices(
164        &self,
165        table_name: &Option<String>,
166        column_name: &String,
167    ) -> LiteResult<Vec<usize>> {
168        match table_name {
169            Some(table_name) => {
170                if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
171                    let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
172                        format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
173                    self.get_indices_with_group_id(group_id, column_name)
174                } else {
175                    Ok(vec![
176                        self.get_index_with_table_name(column_name, table_name)?,
177                    ])
178                }
179            }
180            None => self.get_unqualified_indices(column_name),
181        }
182    }
183
184    fn get_indices_with_group_id(
185        &self,
186        group_id: u32,
187        column_name: &String,
188    ) -> LiteResult<Vec<usize>> {
189        let group = self.column_group_context.groups.get(&group_id).unwrap();
190        if let Some(name) = &group.column_name {
191            debug_assert_eq!(name, column_name);
192        }
193        if let Some(non_nullable) = &group.non_nullable_column {
194            Ok(vec![*non_nullable])
195        } else {
196            // These will be converted to a `COALESCE(col1, col2, ..., coln)`
197            let mut indices: Vec<_> = group.indices.iter().copied().collect();
198            indices.sort(); // ensure a deterministic result
199            Ok(indices)
200        }
201    }
202
203    pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
204        let columns = self
205            .indices_of
206            .get(column_name)
207            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
208        if columns.len() > 1 {
209            // If there is some group containing the columns and the ambiguous columns are all in
210            // the group
211            if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
212                let group = self.column_group_context.groups.get(group_id).unwrap();
213                if columns.iter().all(|idx| group.indices.contains(idx)) {
214                    if let Some(non_nullable) = &group.non_nullable_column {
215                        return Ok(vec![*non_nullable]);
216                    } else {
217                        // These will be converted to a `COALESCE(col1, col2, ..., coln)`
218                        return Ok(columns.clone());
219                    }
220                }
221            }
222            Err(ErrorCode::InternalError(format!(
223                "Ambiguous column name: {}",
224                column_name
225            )))
226        } else {
227            Ok(columns.clone())
228        }
229    }
230
231    /// Identifies two columns as being in the same group. Additionally, possibly provides one of
232    /// the columns as being `non_nullable`
233    pub fn add_natural_columns(
234        &mut self,
235        left: usize,
236        right: usize,
237        non_nullable_column: Option<usize>,
238    ) {
239        match (
240            self.column_group_context.mapping.get(&left).copied(),
241            self.column_group_context.mapping.get(&right).copied(),
242        ) {
243            (None, None) => {
244                let group_id = self.column_group_context.next_group_id;
245                self.column_group_context.next_group_id += 1;
246
247                let group = ColumnGroup {
248                    indices: HashSet::from([left, right]),
249                    non_nullable_column,
250                    column_name: Some(self.columns[left].field.name.clone()),
251                };
252                self.column_group_context.groups.insert(group_id, group);
253                self.column_group_context.mapping.insert(left, group_id);
254                self.column_group_context.mapping.insert(right, group_id);
255            }
256            (Some(group_id), None) => {
257                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
258                group.indices.insert(right);
259                if group.non_nullable_column.is_none() {
260                    group.non_nullable_column = non_nullable_column;
261                }
262                self.column_group_context.mapping.insert(right, group_id);
263            }
264            (None, Some(group_id)) => {
265                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
266                group.indices.insert(left);
267                if group.non_nullable_column.is_none() {
268                    group.non_nullable_column = non_nullable_column;
269                }
270                self.column_group_context.mapping.insert(left, group_id);
271            }
272            (Some(l_group_id), Some(r_group_id)) => {
273                if r_group_id == l_group_id {
274                    return;
275                }
276
277                let r_group = self
278                    .column_group_context
279                    .groups
280                    .remove(&r_group_id)
281                    .unwrap();
282                let l_group = self
283                    .column_group_context
284                    .groups
285                    .get_mut(&l_group_id)
286                    .unwrap();
287
288                for idx in &r_group.indices {
289                    *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
290                    l_group.indices.insert(*idx);
291                }
292                if l_group.non_nullable_column.is_none() {
293                    l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
294                        non_nullable_column
295                    } else {
296                        r_group.non_nullable_column
297                    };
298                }
299            }
300        }
301    }
302
303    fn get_index_with_table_name(
304        &self,
305        column_name: &String,
306        table_name: &String,
307    ) -> LiteResult<usize> {
308        let column_indexes = self
309            .indices_of
310            .get(column_name)
311            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
312        match column_indexes
313            .iter()
314            .find(|column_index| self.columns[**column_index].table_name == *table_name)
315        {
316            Some(column_index) => Ok(*column_index),
317            None => Err(ErrorCode::ItemNotFound(format!(
318                "missing FROM-clause entry for table \"{}\"",
319                table_name
320            ))),
321        }
322    }
323
324    /// Merges two `BindContext`s which are adjacent. For instance, the `BindContext` of two
325    /// adjacent cross-joined tables.
326    pub fn merge_context(&mut self, other: Self) -> Result<()> {
327        let begin = self.columns.len();
328        self.columns.extend(other.columns.into_iter().map(|mut c| {
329            c.index += begin;
330            c
331        }));
332        for (k, v) in other.indices_of {
333            let entry = self.indices_of.entry(k).or_default();
334            entry.extend(v.into_iter().map(|x| x + begin));
335        }
336        for (k, (x, y)) in other.range_of {
337            match self.range_of.entry(k) {
338                Entry::Occupied(e) => {
339                    return Err(ErrorCode::InternalError(format!(
340                        "Duplicated table name while merging adjacent contexts: {}",
341                        e.key()
342                    ))
343                    .into());
344                }
345                Entry::Vacant(entry) => {
346                    entry.insert((begin + x, begin + y));
347                }
348            }
349        }
350        // To merge the column_group_contexts, we just need to offset RHS
351        // with the next_group_id of LHS.
352        let ColumnGroupContext {
353            mapping,
354            groups,
355            next_group_id,
356        } = other.column_group_context;
357
358        let offset = self.column_group_context.next_group_id;
359        for (idx, group_id) in mapping {
360            self.column_group_context
361                .mapping
362                .insert(begin + idx, offset + group_id);
363        }
364        for (group_id, mut group) in groups {
365            group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
366            if let Some(col) = &mut group.non_nullable_column {
367                *col += begin;
368            }
369            self.column_group_context
370                .groups
371                .insert(offset + group_id, group);
372        }
373        self.column_group_context.next_group_id += next_group_id;
374
375        // we assume that the clause is contained in the outer-level context
376        Ok(())
377    }
378}
379
380impl BindContext {
381    pub fn new() -> Self {
382        Self::default()
383    }
384}