risingwave_frontend/binder/
bind_context.rs1use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use either::Either;
21use parse_display::Display;
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::types::DataType;
24use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
25
26use crate::binder::Relation;
27use crate::error::{ErrorCode, Result};
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use super::BoundSetExpr;
32use super::statement::RewriteExprsRecursive;
33use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
34
35#[derive(Debug, Clone)]
36pub struct ColumnBinding {
37 pub table_name: String,
38 pub index: usize,
39 pub is_hidden: bool,
40 pub field: Field,
41}
42
43impl ColumnBinding {
44 pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
45 ColumnBinding {
46 table_name,
47 index,
48 is_hidden,
49 field,
50 }
51 }
52}
53
54#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
55#[display(style = "TITLE CASE")]
56pub enum Clause {
57 Where,
58 Values,
59 GroupBy,
60 JoinOn,
61 Having,
62 Filter,
63 From,
64 GeneratedColumn,
65 Insert,
66}
67
68pub struct LateralBindContext {
71 pub is_visible: bool,
72 pub context: BindContext,
73}
74
75#[derive(Default, Debug, Clone)]
93pub enum BindingCteState {
94 #[default]
96 Init,
97 BaseResolved {
99 base: BoundSetExpr,
100 },
101 Bound {
103 query: Either<BoundQuery, RecursiveUnion>,
104 },
105
106 ChangeLog {
107 table: Relation,
108 },
109}
110
111#[derive(Debug, Clone)]
114pub struct RecursiveUnion {
115 #[allow(dead_code)]
118 pub all: bool,
119 pub base: Box<BoundSetExpr>,
121 pub recursive: Box<BoundSetExpr>,
123 pub schema: Schema,
127}
128
129impl RewriteExprsRecursive for RecursiveUnion {
130 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
131 self.base.rewrite_exprs_recursive(rewriter);
133 self.recursive.rewrite_exprs_recursive(rewriter);
134 }
135}
136
137#[derive(Clone, Debug)]
138pub struct BindingCte {
139 pub share_id: ShareId,
140 pub state: BindingCteState,
141 pub alias: TableAlias,
142}
143
144#[derive(Default, Debug, Clone)]
145pub struct BindContext {
146 pub columns: Vec<ColumnBinding>,
148 pub indices_of: HashMap<String, Vec<usize>>,
150 pub range_of: HashMap<String, (usize, usize)>,
152 pub clause: Option<Clause>,
154 pub column_group_context: ColumnGroupContext,
156 pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
159 pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
161 pub disable_security_invoker: bool,
163 pub named_windows: HashMap<String, WindowSpec>,
165}
166
167#[derive(Default, Debug, Clone)]
169pub struct ColumnGroupContext {
170 pub mapping: HashMap<usize, u32>,
172 pub groups: BTreeMap<u32, ColumnGroup>,
175
176 next_group_id: u32,
177}
178
179#[derive(Default, Debug, Clone)]
182pub struct ColumnGroup {
183 pub indices: HashSet<usize>,
185 pub non_nullable_column: Option<usize>,
189
190 pub column_name: Option<String>,
191}
192
193impl BindContext {
194 pub fn get_column_binding_index(
195 &self,
196 table_name: &Option<String>,
197 column_name: &String,
198 ) -> LiteResult<usize> {
199 match &self.get_column_binding_indices(table_name, column_name)?[..] {
200 [] => unreachable!(),
201 [idx] => Ok(*idx),
202 _ => Err(ErrorCode::InternalError(format!(
203 "Ambiguous column name: {}",
204 column_name
205 ))),
206 }
207 }
208
209 pub fn get_column_binding_indices(
213 &self,
214 table_name: &Option<String>,
215 column_name: &String,
216 ) -> LiteResult<Vec<usize>> {
217 match table_name {
218 Some(table_name) => {
219 if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
220 let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
221 format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
222 self.get_indices_with_group_id(group_id, column_name)
223 } else {
224 Ok(vec![
225 self.get_index_with_table_name(column_name, table_name)?,
226 ])
227 }
228 }
229 None => self.get_unqualified_indices(column_name),
230 }
231 }
232
233 fn get_indices_with_group_id(
234 &self,
235 group_id: u32,
236 column_name: &String,
237 ) -> LiteResult<Vec<usize>> {
238 let group = self.column_group_context.groups.get(&group_id).unwrap();
239 if let Some(name) = &group.column_name {
240 debug_assert_eq!(name, column_name);
241 }
242 if let Some(non_nullable) = &group.non_nullable_column {
243 Ok(vec![*non_nullable])
244 } else {
245 let mut indices: Vec<_> = group.indices.iter().copied().collect();
247 indices.sort(); Ok(indices)
249 }
250 }
251
252 pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
253 let columns = self
254 .indices_of
255 .get(column_name)
256 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
257 if columns.len() > 1 {
258 if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
261 let group = self.column_group_context.groups.get(group_id).unwrap();
262 if columns.iter().all(|idx| group.indices.contains(idx)) {
263 if let Some(non_nullable) = &group.non_nullable_column {
264 return Ok(vec![*non_nullable]);
265 } else {
266 return Ok(columns.to_vec());
268 }
269 }
270 }
271 Err(ErrorCode::InternalError(format!(
272 "Ambiguous column name: {}",
273 column_name
274 )))
275 } else {
276 Ok(columns.to_vec())
277 }
278 }
279
280 pub fn add_natural_columns(
283 &mut self,
284 left: usize,
285 right: usize,
286 non_nullable_column: Option<usize>,
287 ) {
288 match (
289 self.column_group_context.mapping.get(&left).copied(),
290 self.column_group_context.mapping.get(&right).copied(),
291 ) {
292 (None, None) => {
293 let group_id = self.column_group_context.next_group_id;
294 self.column_group_context.next_group_id += 1;
295
296 let group = ColumnGroup {
297 indices: HashSet::from([left, right]),
298 non_nullable_column,
299 column_name: Some(self.columns[left].field.name.clone()),
300 };
301 self.column_group_context.groups.insert(group_id, group);
302 self.column_group_context.mapping.insert(left, group_id);
303 self.column_group_context.mapping.insert(right, group_id);
304 }
305 (Some(group_id), None) => {
306 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
307 group.indices.insert(right);
308 if group.non_nullable_column.is_none() {
309 group.non_nullable_column = non_nullable_column;
310 }
311 self.column_group_context.mapping.insert(right, group_id);
312 }
313 (None, Some(group_id)) => {
314 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
315 group.indices.insert(left);
316 if group.non_nullable_column.is_none() {
317 group.non_nullable_column = non_nullable_column;
318 }
319 self.column_group_context.mapping.insert(left, group_id);
320 }
321 (Some(l_group_id), Some(r_group_id)) => {
322 if r_group_id == l_group_id {
323 return;
324 }
325
326 let r_group = self
327 .column_group_context
328 .groups
329 .remove(&r_group_id)
330 .unwrap();
331 let l_group = self
332 .column_group_context
333 .groups
334 .get_mut(&l_group_id)
335 .unwrap();
336
337 for idx in &r_group.indices {
338 *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
339 l_group.indices.insert(*idx);
340 }
341 if l_group.non_nullable_column.is_none() {
342 l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
343 non_nullable_column
344 } else {
345 r_group.non_nullable_column
346 };
347 }
348 }
349 }
350 }
351
352 fn get_index_with_table_name(
353 &self,
354 column_name: &String,
355 table_name: &String,
356 ) -> LiteResult<usize> {
357 let column_indexes = self
358 .indices_of
359 .get(column_name)
360 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
361 match column_indexes
362 .iter()
363 .find(|column_index| self.columns[**column_index].table_name == *table_name)
364 {
365 Some(column_index) => Ok(*column_index),
366 None => Err(ErrorCode::ItemNotFound(format!(
367 "missing FROM-clause entry for table \"{}\"",
368 table_name
369 ))),
370 }
371 }
372
373 pub fn merge_context(&mut self, other: Self) -> Result<()> {
376 let begin = self.columns.len();
377 self.columns.extend(other.columns.into_iter().map(|mut c| {
378 c.index += begin;
379 c
380 }));
381 for (k, v) in other.indices_of {
382 let entry = self.indices_of.entry(k).or_default();
383 entry.extend(v.into_iter().map(|x| x + begin));
384 }
385 for (k, (x, y)) in other.range_of {
386 match self.range_of.entry(k) {
387 Entry::Occupied(e) => {
388 return Err(ErrorCode::InternalError(format!(
389 "Duplicated table name while merging adjacent contexts: {}",
390 e.key()
391 ))
392 .into());
393 }
394 Entry::Vacant(entry) => {
395 entry.insert((begin + x, begin + y));
396 }
397 }
398 }
399 let ColumnGroupContext {
402 mapping,
403 groups,
404 next_group_id,
405 } = other.column_group_context;
406
407 let offset = self.column_group_context.next_group_id;
408 for (idx, group_id) in mapping {
409 self.column_group_context
410 .mapping
411 .insert(begin + idx, offset + group_id);
412 }
413 for (group_id, mut group) in groups {
414 group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
415 if let Some(col) = &mut group.non_nullable_column {
416 *col += begin;
417 }
418 self.column_group_context
419 .groups
420 .insert(offset + group_id, group);
421 }
422 self.column_group_context.next_group_id += next_group_id;
423
424 Ok(())
426 }
427}
428
429impl BindContext {
430 pub fn new() -> Self {
431 Self::default()
432 }
433}