risingwave_frontend/binder/
bind_context.rs1use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use either::Either;
21use parse_display::Display;
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::types::DataType;
24use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
25
26use crate::binder::Relation;
27use crate::error::{ErrorCode, Result};
28use crate::expr::ExprImpl;
29
30type LiteResult<T> = std::result::Result<T, ErrorCode>;
31
32use super::BoundSetExpr;
33use super::statement::RewriteExprsRecursive;
34use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
35
36#[derive(Debug, Clone)]
37pub struct ColumnBinding {
38 pub table_name: String,
39 pub index: usize,
40 pub is_hidden: bool,
41 pub field: Field,
42}
43
44impl ColumnBinding {
45 pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
46 ColumnBinding {
47 table_name,
48 index,
49 is_hidden,
50 field,
51 }
52 }
53}
54
55#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
56#[display(style = "TITLE CASE")]
57pub enum Clause {
58 Where,
59 Values,
60 GroupBy,
61 JoinOn,
62 Having,
63 Filter,
64 From,
65 GeneratedColumn,
66 Insert,
67}
68
69pub struct LateralBindContext {
72 pub is_visible: bool,
73 pub context: BindContext,
74}
75
76#[derive(Default, Debug, Clone)]
94pub enum BindingCteState {
95 #[default]
97 Init,
98 BaseResolved {
100 base: BoundSetExpr,
101 },
102 Bound {
104 query: Either<BoundQuery, RecursiveUnion>,
105 },
106
107 ChangeLog {
108 table: Relation,
109 },
110}
111
112#[derive(Debug, Clone)]
115pub struct RecursiveUnion {
116 #[allow(dead_code)]
119 pub all: bool,
120 pub base: Box<BoundSetExpr>,
122 pub recursive: Box<BoundSetExpr>,
124 pub schema: Schema,
128}
129
130impl RewriteExprsRecursive for RecursiveUnion {
131 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
132 self.base.rewrite_exprs_recursive(rewriter);
134 self.recursive.rewrite_exprs_recursive(rewriter);
135 }
136}
137
138#[derive(Clone, Debug)]
139pub struct BindingCte {
140 pub share_id: ShareId,
141 pub state: BindingCteState,
142 pub alias: TableAlias,
143}
144
145#[derive(Default, Debug, Clone)]
146pub struct BindContext {
147 pub columns: Vec<ColumnBinding>,
149 pub indices_of: HashMap<String, Vec<usize>>,
151 pub range_of: HashMap<String, (usize, usize)>,
153 pub clause: Option<Clause>,
155 pub column_group_context: ColumnGroupContext,
157 pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
160 pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
162 pub disable_security_invoker: bool,
164 pub named_windows: HashMap<String, WindowSpec>,
166 pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
169}
170
171#[derive(Default, Debug, Clone)]
173pub struct ColumnGroupContext {
174 pub mapping: HashMap<usize, u32>,
176 pub groups: BTreeMap<u32, ColumnGroup>,
179
180 next_group_id: u32,
181}
182
183#[derive(Default, Debug, Clone)]
186pub struct ColumnGroup {
187 pub indices: HashSet<usize>,
189 pub non_nullable_column: Option<usize>,
193
194 pub column_name: Option<String>,
195}
196
197impl BindContext {
198 pub fn get_column_binding_index(
199 &self,
200 table_name: &Option<String>,
201 column_name: &String,
202 ) -> LiteResult<usize> {
203 match &self.get_column_binding_indices(table_name, column_name)?[..] {
204 [] => unreachable!(),
205 [idx] => Ok(*idx),
206 _ => Err(ErrorCode::InternalError(format!(
207 "Ambiguous column name: {}",
208 column_name
209 ))),
210 }
211 }
212
213 pub fn get_column_binding_indices(
217 &self,
218 table_name: &Option<String>,
219 column_name: &String,
220 ) -> LiteResult<Vec<usize>> {
221 match table_name {
222 Some(table_name) => {
223 if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
224 let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
225 format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
226 self.get_indices_with_group_id(group_id, column_name)
227 } else {
228 Ok(vec![
229 self.get_index_with_table_name(column_name, table_name)?,
230 ])
231 }
232 }
233 None => self.get_unqualified_indices(column_name),
234 }
235 }
236
237 fn get_indices_with_group_id(
238 &self,
239 group_id: u32,
240 column_name: &String,
241 ) -> LiteResult<Vec<usize>> {
242 let group = self.column_group_context.groups.get(&group_id).unwrap();
243 if let Some(name) = &group.column_name {
244 debug_assert_eq!(name, column_name);
245 }
246 if let Some(non_nullable) = &group.non_nullable_column {
247 Ok(vec![*non_nullable])
248 } else {
249 let mut indices: Vec<_> = group.indices.iter().copied().collect();
251 indices.sort(); Ok(indices)
253 }
254 }
255
256 pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
257 let columns = self
258 .indices_of
259 .get(column_name)
260 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
261 if columns.len() > 1 {
262 if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
265 let group = self.column_group_context.groups.get(group_id).unwrap();
266 if columns.iter().all(|idx| group.indices.contains(idx)) {
267 if let Some(non_nullable) = &group.non_nullable_column {
268 return Ok(vec![*non_nullable]);
269 } else {
270 return Ok(columns.to_vec());
272 }
273 }
274 }
275 Err(ErrorCode::InternalError(format!(
276 "Ambiguous column name: {}",
277 column_name
278 )))
279 } else {
280 Ok(columns.to_vec())
281 }
282 }
283
284 pub fn add_natural_columns(
287 &mut self,
288 left: usize,
289 right: usize,
290 non_nullable_column: Option<usize>,
291 ) {
292 match (
293 self.column_group_context.mapping.get(&left).copied(),
294 self.column_group_context.mapping.get(&right).copied(),
295 ) {
296 (None, None) => {
297 let group_id = self.column_group_context.next_group_id;
298 self.column_group_context.next_group_id += 1;
299
300 let group = ColumnGroup {
301 indices: HashSet::from([left, right]),
302 non_nullable_column,
303 column_name: Some(self.columns[left].field.name.clone()),
304 };
305 self.column_group_context.groups.insert(group_id, group);
306 self.column_group_context.mapping.insert(left, group_id);
307 self.column_group_context.mapping.insert(right, group_id);
308 }
309 (Some(group_id), None) => {
310 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
311 group.indices.insert(right);
312 if group.non_nullable_column.is_none() {
313 group.non_nullable_column = non_nullable_column;
314 }
315 self.column_group_context.mapping.insert(right, group_id);
316 }
317 (None, Some(group_id)) => {
318 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
319 group.indices.insert(left);
320 if group.non_nullable_column.is_none() {
321 group.non_nullable_column = non_nullable_column;
322 }
323 self.column_group_context.mapping.insert(left, group_id);
324 }
325 (Some(l_group_id), Some(r_group_id)) => {
326 if r_group_id == l_group_id {
327 return;
328 }
329
330 let r_group = self
331 .column_group_context
332 .groups
333 .remove(&r_group_id)
334 .unwrap();
335 let l_group = self
336 .column_group_context
337 .groups
338 .get_mut(&l_group_id)
339 .unwrap();
340
341 for idx in &r_group.indices {
342 *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
343 l_group.indices.insert(*idx);
344 }
345 if l_group.non_nullable_column.is_none() {
346 l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
347 non_nullable_column
348 } else {
349 r_group.non_nullable_column
350 };
351 }
352 }
353 }
354 }
355
356 fn get_index_with_table_name(
357 &self,
358 column_name: &String,
359 table_name: &String,
360 ) -> LiteResult<usize> {
361 let column_indexes = self
362 .indices_of
363 .get(column_name)
364 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
365 match column_indexes
366 .iter()
367 .find(|column_index| self.columns[**column_index].table_name == *table_name)
368 {
369 Some(column_index) => Ok(*column_index),
370 None => Err(ErrorCode::ItemNotFound(format!(
371 "missing FROM-clause entry for table \"{}\"",
372 table_name
373 ))),
374 }
375 }
376
377 pub fn merge_context(&mut self, other: Self) -> Result<()> {
380 let begin = self.columns.len();
381 self.columns.extend(other.columns.into_iter().map(|mut c| {
382 c.index += begin;
383 c
384 }));
385 for (k, v) in other.indices_of {
386 let entry = self.indices_of.entry(k).or_default();
387 entry.extend(v.into_iter().map(|x| x + begin));
388 }
389 for (k, (x, y)) in other.range_of {
390 match self.range_of.entry(k) {
391 Entry::Occupied(e) => {
392 return Err(ErrorCode::InternalError(format!(
393 "Duplicated table name while merging adjacent contexts: {}",
394 e.key()
395 ))
396 .into());
397 }
398 Entry::Vacant(entry) => {
399 entry.insert((begin + x, begin + y));
400 }
401 }
402 }
403 let ColumnGroupContext {
406 mapping,
407 groups,
408 next_group_id,
409 } = other.column_group_context;
410
411 let offset = self.column_group_context.next_group_id;
412 for (idx, group_id) in mapping {
413 self.column_group_context
414 .mapping
415 .insert(begin + idx, offset + group_id);
416 }
417 for (group_id, mut group) in groups {
418 group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
419 if let Some(col) = &mut group.non_nullable_column {
420 *col += begin;
421 }
422 self.column_group_context
423 .groups
424 .insert(offset + group_id, group);
425 }
426 self.column_group_context.next_group_id += next_group_id;
427
428 Ok(())
430 }
431}
432
433impl BindContext {
434 pub fn new() -> Self {
435 Self::default()
436 }
437}