risingwave_frontend/binder/
bind_context.rs1use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use either::Either;
21use parse_display::Display;
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::types::DataType;
24use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
25
26use crate::binder::Relation;
27use crate::error::{ErrorCode, Result};
28use crate::expr::ExprImpl;
29
30type LiteResult<T> = std::result::Result<T, ErrorCode>;
31
32use super::BoundSetExpr;
33use super::statement::RewriteExprsRecursive;
34use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
35
36#[derive(Debug, Clone)]
37pub struct ColumnBinding {
38    pub table_name: String,
39    pub index: usize,
40    pub is_hidden: bool,
41    pub field: Field,
42}
43
44impl ColumnBinding {
45    pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
46        ColumnBinding {
47            table_name,
48            index,
49            is_hidden,
50            field,
51        }
52    }
53}
54
55#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
56#[display(style = "TITLE CASE")]
57pub enum Clause {
58    Where,
59    Values,
60    GroupBy,
61    JoinOn,
62    Having,
63    Filter,
64    From,
65    GeneratedColumn,
66    Insert,
67}
68
69pub struct LateralBindContext {
72    pub is_visible: bool,
73    pub context: BindContext,
74}
75
76#[derive(Default, Debug, Clone)]
94pub enum BindingCteState {
95    #[default]
97    Init,
98    BaseResolved {
100        base: BoundSetExpr,
101    },
102    Bound {
104        query: Either<BoundQuery, RecursiveUnion>,
105    },
106
107    ChangeLog {
108        table: Relation,
109    },
110}
111
112#[derive(Debug, Clone)]
115pub struct RecursiveUnion {
116    #[allow(dead_code)]
119    pub all: bool,
120    pub base: Box<BoundSetExpr>,
122    pub recursive: Box<BoundSetExpr>,
124    pub schema: Schema,
128}
129
130impl RewriteExprsRecursive for RecursiveUnion {
131    fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
132        self.base.rewrite_exprs_recursive(rewriter);
134        self.recursive.rewrite_exprs_recursive(rewriter);
135    }
136}
137
138#[derive(Clone, Debug)]
139pub struct BindingCte {
140    pub share_id: ShareId,
141    pub state: BindingCteState,
142    pub alias: TableAlias,
143}
144
145#[derive(Default, Debug, Clone)]
146pub struct BindContext {
147    pub columns: Vec<ColumnBinding>,
149    pub indices_of: HashMap<String, Vec<usize>>,
151    pub range_of: HashMap<String, (usize, usize)>,
153    pub clause: Option<Clause>,
155    pub column_group_context: ColumnGroupContext,
157    pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
160    pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
162    pub disable_security_invoker: bool,
164    pub named_windows: HashMap<String, WindowSpec>,
166    pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
169}
170
171#[derive(Default, Debug, Clone)]
173pub struct ColumnGroupContext {
174    pub mapping: HashMap<usize, u32>,
176    pub groups: BTreeMap<u32, ColumnGroup>,
179
180    next_group_id: u32,
181}
182
183#[derive(Default, Debug, Clone)]
186pub struct ColumnGroup {
187    pub indices: HashSet<usize>,
189    pub non_nullable_column: Option<usize>,
193
194    pub column_name: Option<String>,
195}
196
197impl BindContext {
198    pub fn get_column_binding_index(
199        &self,
200        table_name: &Option<String>,
201        column_name: &String,
202    ) -> LiteResult<usize> {
203        match &self.get_column_binding_indices(table_name, column_name)?[..] {
204            [] => unreachable!(),
205            [idx] => Ok(*idx),
206            _ => Err(ErrorCode::InternalError(format!(
207                "Ambiguous column name: {}",
208                column_name
209            ))),
210        }
211    }
212
213    pub fn get_column_binding_indices(
217        &self,
218        table_name: &Option<String>,
219        column_name: &String,
220    ) -> LiteResult<Vec<usize>> {
221        match table_name {
222            Some(table_name) => {
223                if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
224                    let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
225                        format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
226                    self.get_indices_with_group_id(group_id, column_name)
227                } else {
228                    Ok(vec![
229                        self.get_index_with_table_name(column_name, table_name)?,
230                    ])
231                }
232            }
233            None => self.get_unqualified_indices(column_name),
234        }
235    }
236
237    fn get_indices_with_group_id(
238        &self,
239        group_id: u32,
240        column_name: &String,
241    ) -> LiteResult<Vec<usize>> {
242        let group = self.column_group_context.groups.get(&group_id).unwrap();
243        if let Some(name) = &group.column_name {
244            debug_assert_eq!(name, column_name);
245        }
246        if let Some(non_nullable) = &group.non_nullable_column {
247            Ok(vec![*non_nullable])
248        } else {
249            let mut indices: Vec<_> = group.indices.iter().copied().collect();
251            indices.sort(); Ok(indices)
253        }
254    }
255
256    pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
257        let columns = self
258            .indices_of
259            .get(column_name)
260            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
261        if columns.len() > 1 {
262            if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
265                let group = self.column_group_context.groups.get(group_id).unwrap();
266                if columns.iter().all(|idx| group.indices.contains(idx)) {
267                    if let Some(non_nullable) = &group.non_nullable_column {
268                        return Ok(vec![*non_nullable]);
269                    } else {
270                        return Ok(columns.clone());
272                    }
273                }
274            }
275            Err(ErrorCode::InternalError(format!(
276                "Ambiguous column name: {}",
277                column_name
278            )))
279        } else {
280            Ok(columns.clone())
281        }
282    }
283
284    pub fn add_natural_columns(
287        &mut self,
288        left: usize,
289        right: usize,
290        non_nullable_column: Option<usize>,
291    ) {
292        match (
293            self.column_group_context.mapping.get(&left).copied(),
294            self.column_group_context.mapping.get(&right).copied(),
295        ) {
296            (None, None) => {
297                let group_id = self.column_group_context.next_group_id;
298                self.column_group_context.next_group_id += 1;
299
300                let group = ColumnGroup {
301                    indices: HashSet::from([left, right]),
302                    non_nullable_column,
303                    column_name: Some(self.columns[left].field.name.clone()),
304                };
305                self.column_group_context.groups.insert(group_id, group);
306                self.column_group_context.mapping.insert(left, group_id);
307                self.column_group_context.mapping.insert(right, group_id);
308            }
309            (Some(group_id), None) => {
310                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
311                group.indices.insert(right);
312                if group.non_nullable_column.is_none() {
313                    group.non_nullable_column = non_nullable_column;
314                }
315                self.column_group_context.mapping.insert(right, group_id);
316            }
317            (None, Some(group_id)) => {
318                let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
319                group.indices.insert(left);
320                if group.non_nullable_column.is_none() {
321                    group.non_nullable_column = non_nullable_column;
322                }
323                self.column_group_context.mapping.insert(left, group_id);
324            }
325            (Some(l_group_id), Some(r_group_id)) => {
326                if r_group_id == l_group_id {
327                    return;
328                }
329
330                let r_group = self
331                    .column_group_context
332                    .groups
333                    .remove(&r_group_id)
334                    .unwrap();
335                let l_group = self
336                    .column_group_context
337                    .groups
338                    .get_mut(&l_group_id)
339                    .unwrap();
340
341                for idx in &r_group.indices {
342                    *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
343                    l_group.indices.insert(*idx);
344                }
345                if l_group.non_nullable_column.is_none() {
346                    l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
347                        non_nullable_column
348                    } else {
349                        r_group.non_nullable_column
350                    };
351                }
352            }
353        }
354    }
355
356    fn get_index_with_table_name(
357        &self,
358        column_name: &String,
359        table_name: &String,
360    ) -> LiteResult<usize> {
361        let column_indexes = self
362            .indices_of
363            .get(column_name)
364            .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
365        match column_indexes
366            .iter()
367            .find(|column_index| self.columns[**column_index].table_name == *table_name)
368        {
369            Some(column_index) => Ok(*column_index),
370            None => Err(ErrorCode::ItemNotFound(format!(
371                "missing FROM-clause entry for table \"{}\"",
372                table_name
373            ))),
374        }
375    }
376
377    pub fn merge_context(&mut self, other: Self) -> Result<()> {
380        let begin = self.columns.len();
381        self.columns.extend(other.columns.into_iter().map(|mut c| {
382            c.index += begin;
383            c
384        }));
385        for (k, v) in other.indices_of {
386            let entry = self.indices_of.entry(k).or_default();
387            entry.extend(v.into_iter().map(|x| x + begin));
388        }
389        for (k, (x, y)) in other.range_of {
390            match self.range_of.entry(k) {
391                Entry::Occupied(e) => {
392                    return Err(ErrorCode::InternalError(format!(
393                        "Duplicated table name while merging adjacent contexts: {}",
394                        e.key()
395                    ))
396                    .into());
397                }
398                Entry::Vacant(entry) => {
399                    entry.insert((begin + x, begin + y));
400                }
401            }
402        }
403        let ColumnGroupContext {
406            mapping,
407            groups,
408            next_group_id,
409        } = other.column_group_context;
410
411        let offset = self.column_group_context.next_group_id;
412        for (idx, group_id) in mapping {
413            self.column_group_context
414                .mapping
415                .insert(begin + idx, offset + group_id);
416        }
417        for (group_id, mut group) in groups {
418            group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
419            if let Some(col) = &mut group.non_nullable_column {
420                *col += begin;
421            }
422            self.column_group_context
423                .groups
424                .insert(offset + group_id, group);
425        }
426        self.column_group_context.next_group_id += next_group_id;
427
428        Ok(())
430    }
431}
432
433impl BindContext {
434    pub fn new() -> Self {
435        Self::default()
436    }
437}