risingwave_frontend/binder/
bind_context.rs1use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use either::Either;
21use parse_display::Display;
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::types::DataType;
24use risingwave_sqlparser::ast::TableAlias;
25
26use crate::binder::Relation;
27use crate::error::{ErrorCode, Result};
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use super::BoundSetExpr;
32use super::statement::RewriteExprsRecursive;
33use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
34
35#[derive(Debug, Clone)]
36pub struct ColumnBinding {
37 pub table_name: String,
38 pub index: usize,
39 pub is_hidden: bool,
40 pub field: Field,
41}
42
43impl ColumnBinding {
44 pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
45 ColumnBinding {
46 table_name,
47 index,
48 is_hidden,
49 field,
50 }
51 }
52}
53
54#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
55#[display(style = "TITLE CASE")]
56pub enum Clause {
57 Where,
58 Values,
59 GroupBy,
60 JoinOn,
61 Having,
62 Filter,
63 From,
64 GeneratedColumn,
65 Insert,
66}
67
68pub struct LateralBindContext {
71 pub is_visible: bool,
72 pub context: BindContext,
73}
74
75#[derive(Default, Debug, Clone)]
93pub enum BindingCteState {
94 #[default]
96 Init,
97 BaseResolved {
99 base: BoundSetExpr,
100 },
101 Bound {
103 query: Either<BoundQuery, RecursiveUnion>,
104 },
105
106 ChangeLog {
107 table: Relation,
108 },
109}
110
111#[derive(Debug, Clone)]
114pub struct RecursiveUnion {
115 #[allow(dead_code)]
118 pub all: bool,
119 pub base: Box<BoundSetExpr>,
121 pub recursive: Box<BoundSetExpr>,
123 pub schema: Schema,
127}
128
129impl RewriteExprsRecursive for RecursiveUnion {
130 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
131 self.base.rewrite_exprs_recursive(rewriter);
133 self.recursive.rewrite_exprs_recursive(rewriter);
134 }
135}
136
137#[derive(Clone, Debug)]
138pub struct BindingCte {
139 pub share_id: ShareId,
140 pub state: BindingCteState,
141 pub alias: TableAlias,
142}
143
144#[derive(Default, Debug, Clone)]
145pub struct BindContext {
146 pub columns: Vec<ColumnBinding>,
148 pub indices_of: HashMap<String, Vec<usize>>,
150 pub range_of: HashMap<String, (usize, usize)>,
152 pub clause: Option<Clause>,
154 pub column_group_context: ColumnGroupContext,
156 pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
159 pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
161 pub disable_security_invoker: bool,
163}
164
165#[derive(Default, Debug, Clone)]
167pub struct ColumnGroupContext {
168 pub mapping: HashMap<usize, u32>,
170 pub groups: BTreeMap<u32, ColumnGroup>,
173
174 next_group_id: u32,
175}
176
177#[derive(Default, Debug, Clone)]
180pub struct ColumnGroup {
181 pub indices: HashSet<usize>,
183 pub non_nullable_column: Option<usize>,
187
188 pub column_name: Option<String>,
189}
190
191impl BindContext {
192 pub fn get_column_binding_index(
193 &self,
194 table_name: &Option<String>,
195 column_name: &String,
196 ) -> LiteResult<usize> {
197 match &self.get_column_binding_indices(table_name, column_name)?[..] {
198 [] => unreachable!(),
199 [idx] => Ok(*idx),
200 _ => Err(ErrorCode::InternalError(format!(
201 "Ambiguous column name: {}",
202 column_name
203 ))),
204 }
205 }
206
207 pub fn get_column_binding_indices(
211 &self,
212 table_name: &Option<String>,
213 column_name: &String,
214 ) -> LiteResult<Vec<usize>> {
215 match table_name {
216 Some(table_name) => {
217 if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
218 let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
219 format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
220 self.get_indices_with_group_id(group_id, column_name)
221 } else {
222 Ok(vec![
223 self.get_index_with_table_name(column_name, table_name)?,
224 ])
225 }
226 }
227 None => self.get_unqualified_indices(column_name),
228 }
229 }
230
231 fn get_indices_with_group_id(
232 &self,
233 group_id: u32,
234 column_name: &String,
235 ) -> LiteResult<Vec<usize>> {
236 let group = self.column_group_context.groups.get(&group_id).unwrap();
237 if let Some(name) = &group.column_name {
238 debug_assert_eq!(name, column_name);
239 }
240 if let Some(non_nullable) = &group.non_nullable_column {
241 Ok(vec![*non_nullable])
242 } else {
243 let mut indices: Vec<_> = group.indices.iter().copied().collect();
245 indices.sort(); Ok(indices)
247 }
248 }
249
250 pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
251 let columns = self
252 .indices_of
253 .get(column_name)
254 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
255 if columns.len() > 1 {
256 if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
259 let group = self.column_group_context.groups.get(group_id).unwrap();
260 if columns.iter().all(|idx| group.indices.contains(idx)) {
261 if let Some(non_nullable) = &group.non_nullable_column {
262 return Ok(vec![*non_nullable]);
263 } else {
264 return Ok(columns.to_vec());
266 }
267 }
268 }
269 Err(ErrorCode::InternalError(format!(
270 "Ambiguous column name: {}",
271 column_name
272 )))
273 } else {
274 Ok(columns.to_vec())
275 }
276 }
277
278 pub fn add_natural_columns(
281 &mut self,
282 left: usize,
283 right: usize,
284 non_nullable_column: Option<usize>,
285 ) {
286 match (
287 self.column_group_context.mapping.get(&left).copied(),
288 self.column_group_context.mapping.get(&right).copied(),
289 ) {
290 (None, None) => {
291 let group_id = self.column_group_context.next_group_id;
292 self.column_group_context.next_group_id += 1;
293
294 let group = ColumnGroup {
295 indices: HashSet::from([left, right]),
296 non_nullable_column,
297 column_name: Some(self.columns[left].field.name.clone()),
298 };
299 self.column_group_context.groups.insert(group_id, group);
300 self.column_group_context.mapping.insert(left, group_id);
301 self.column_group_context.mapping.insert(right, group_id);
302 }
303 (Some(group_id), None) => {
304 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
305 group.indices.insert(right);
306 if group.non_nullable_column.is_none() {
307 group.non_nullable_column = non_nullable_column;
308 }
309 self.column_group_context.mapping.insert(right, group_id);
310 }
311 (None, Some(group_id)) => {
312 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
313 group.indices.insert(left);
314 if group.non_nullable_column.is_none() {
315 group.non_nullable_column = non_nullable_column;
316 }
317 self.column_group_context.mapping.insert(left, group_id);
318 }
319 (Some(l_group_id), Some(r_group_id)) => {
320 if r_group_id == l_group_id {
321 return;
322 }
323
324 let r_group = self
325 .column_group_context
326 .groups
327 .remove(&r_group_id)
328 .unwrap();
329 let l_group = self
330 .column_group_context
331 .groups
332 .get_mut(&l_group_id)
333 .unwrap();
334
335 for idx in &r_group.indices {
336 *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
337 l_group.indices.insert(*idx);
338 }
339 if l_group.non_nullable_column.is_none() {
340 l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
341 non_nullable_column
342 } else {
343 r_group.non_nullable_column
344 };
345 }
346 }
347 }
348 }
349
350 fn get_index_with_table_name(
351 &self,
352 column_name: &String,
353 table_name: &String,
354 ) -> LiteResult<usize> {
355 let column_indexes = self
356 .indices_of
357 .get(column_name)
358 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
359 match column_indexes
360 .iter()
361 .find(|column_index| self.columns[**column_index].table_name == *table_name)
362 {
363 Some(column_index) => Ok(*column_index),
364 None => Err(ErrorCode::ItemNotFound(format!(
365 "missing FROM-clause entry for table \"{}\"",
366 table_name
367 ))),
368 }
369 }
370
371 pub fn merge_context(&mut self, other: Self) -> Result<()> {
374 let begin = self.columns.len();
375 self.columns.extend(other.columns.into_iter().map(|mut c| {
376 c.index += begin;
377 c
378 }));
379 for (k, v) in other.indices_of {
380 let entry = self.indices_of.entry(k).or_default();
381 entry.extend(v.into_iter().map(|x| x + begin));
382 }
383 for (k, (x, y)) in other.range_of {
384 match self.range_of.entry(k) {
385 Entry::Occupied(e) => {
386 return Err(ErrorCode::InternalError(format!(
387 "Duplicated table name while merging adjacent contexts: {}",
388 e.key()
389 ))
390 .into());
391 }
392 Entry::Vacant(entry) => {
393 entry.insert((begin + x, begin + y));
394 }
395 }
396 }
397 let ColumnGroupContext {
400 mapping,
401 groups,
402 next_group_id,
403 } = other.column_group_context;
404
405 let offset = self.column_group_context.next_group_id;
406 for (idx, group_id) in mapping {
407 self.column_group_context
408 .mapping
409 .insert(begin + idx, offset + group_id);
410 }
411 for (group_id, mut group) in groups {
412 group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
413 if let Some(col) = &mut group.non_nullable_column {
414 *col += begin;
415 }
416 self.column_group_context
417 .groups
418 .insert(offset + group_id, group);
419 }
420 self.column_group_context.next_group_id += next_group_id;
421
422 Ok(())
424 }
425}
426
427impl BindContext {
428 pub fn new() -> Self {
429 Self::default()
430 }
431}