1use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use parse_display::Display;
21use risingwave_common::catalog::Field;
22use risingwave_common::types::DataType;
23use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
24
25use crate::binder::Relation;
26use crate::error::{ErrorCode, Result};
27use crate::expr::ExprImpl;
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
32
33#[derive(Debug, Clone)]
34pub struct ColumnBinding {
35 pub table_name: String,
36 pub schema_name: Option<String>,
37 pub table_alias: Option<String>,
39 pub index: usize,
40 pub is_hidden: bool,
41 pub field: Field,
42}
43
44impl ColumnBinding {
45 pub fn new(
46 table_name: String,
47 schema_name: Option<String>,
48 table_alias: Option<String>,
49 index: usize,
50 is_hidden: bool,
51 field: Field,
52 ) -> Self {
53 ColumnBinding {
54 table_name,
55 schema_name,
56 table_alias,
57 index,
58 is_hidden,
59 field,
60 }
61 }
62}
63
64#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
65#[display(style = "TITLE CASE")]
66pub enum Clause {
67 Where,
68 Values,
69 GroupBy,
70 JoinOn,
71 Having,
72 Filter,
73 From,
74 GeneratedColumn,
75 Insert,
76}
77
78pub struct LateralBindContext {
81 pub is_visible: bool,
82 pub context: BindContext,
83}
84
85#[derive(Debug, Clone)]
86pub enum BindingCteState {
87 Bound {
89 query: BoundQuery,
90 },
91
92 ChangeLog {
93 table: Relation,
94 },
95}
96
97#[derive(Clone, Debug)]
98pub struct BindingCte {
99 pub share_id: ShareId,
100 pub state: BindingCteState,
101 pub alias: TableAlias,
102}
103
104#[derive(Default, Debug, Clone)]
105pub struct BindContext {
106 pub columns: Vec<ColumnBinding>,
108 pub indices_of: HashMap<String, Vec<usize>>,
110 pub range_of: HashMap<(Option<String>, String), (usize, usize)>,
112 pub clause: Option<Clause>,
114 pub column_group_context: ColumnGroupContext,
116 pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
119 pub cte_relation_names: HashSet<String>,
121 pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
123 pub disable_security_invoker: bool,
125 pub named_windows: HashMap<String, WindowSpec>,
127 pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
130}
131
132#[derive(Default, Debug, Clone)]
134pub struct ColumnGroupContext {
135 pub mapping: HashMap<usize, u32>,
137 pub groups: BTreeMap<u32, ColumnGroup>,
140
141 next_group_id: u32,
142}
143
144#[derive(Default, Debug, Clone)]
147pub struct ColumnGroup {
148 pub indices: HashSet<usize>,
150 pub non_nullable_column: Option<usize>,
154
155 pub column_name: Option<String>,
156}
157
158impl BindContext {
159 pub fn get_column_binding_index(
160 &self,
161 schema_name: &Option<String>,
162 table_name: &Option<String>,
163 column_name: &String,
164 ) -> LiteResult<usize> {
165 match &self.get_column_binding_indices(schema_name, table_name, column_name)?[..] {
166 [] => unreachable!(),
167 [idx] => Ok(*idx),
168 _ => Err(ErrorCode::InternalError(format!(
169 "Ambiguous column name: {}",
170 column_name
171 ))),
172 }
173 }
174
175 pub fn get_column_binding_indices(
179 &self,
180 schema_name: &Option<String>,
181 table_name: &Option<String>,
182 column_name: &String,
183 ) -> LiteResult<Vec<usize>> {
184 match table_name {
185 Some(table_name) => {
186 if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
187 let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
188 format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
189 self.get_indices_with_group_id(group_id, column_name)
190 } else {
191 Ok(vec![self.get_index_with_table_name(
192 column_name,
193 table_name,
194 schema_name,
195 )?])
196 }
197 }
198 None => self.get_unqualified_indices(column_name),
199 }
200 }
201
202 pub fn get_table_alias(
203 &self,
204 schema_name: &String,
205 table_name: &String,
206 column_name: &String,
207 ) -> LiteResult<Option<usize>> {
208 let column_indexes = self
209 .indices_of
210 .get(column_name)
211 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
212 for index in column_indexes {
213 let column = &self.columns[*index];
214 if let (Some(schema), Some(table_alias)) = (&column.schema_name, &column.table_alias)
215 && schema == schema_name
216 && table_alias == table_name
217 {
218 return Ok(Some(*index));
219 }
220 }
221 Ok(None)
222 }
223
224 fn get_indices_with_group_id(
225 &self,
226 group_id: u32,
227 column_name: &String,
228 ) -> LiteResult<Vec<usize>> {
229 let group = self.column_group_context.groups.get(&group_id).unwrap();
230 if let Some(name) = &group.column_name {
231 debug_assert_eq!(name, column_name);
232 }
233 if let Some(non_nullable) = &group.non_nullable_column {
234 Ok(vec![*non_nullable])
235 } else {
236 let mut indices: Vec<_> = group.indices.iter().copied().collect();
238 indices.sort(); Ok(indices)
240 }
241 }
242
243 pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
244 let columns = self
245 .indices_of
246 .get(column_name)
247 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
248 if columns.len() > 1 {
249 if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
252 let group = self.column_group_context.groups.get(group_id).unwrap();
253 if columns.iter().all(|idx| group.indices.contains(idx)) {
254 if let Some(non_nullable) = &group.non_nullable_column {
255 return Ok(vec![*non_nullable]);
256 } else {
257 return Ok(columns.clone());
259 }
260 }
261 }
262 Err(ErrorCode::InternalError(format!(
263 "Ambiguous column name: {}",
264 column_name
265 )))
266 } else {
267 Ok(columns.clone())
268 }
269 }
270
271 pub fn check_catalog_name(&self, exposed_relation_name: &str) -> Result<()> {
272 if self.cte_relation_names.contains(exposed_relation_name) {
273 return Err(ErrorCode::DuplicateRelationName(format!(
274 "table name \"{}\" specified more than once",
275 exposed_relation_name
276 ))
277 .into());
278 }
279 Ok(())
280 }
281
282 pub fn add_cte_name(&mut self, exposed_relation_name: String) {
283 self.cte_relation_names.insert(exposed_relation_name);
284 }
285
286 fn has_relation_name(&self, relation_name: &str) -> bool {
287 self.range_of
288 .keys()
289 .any(|(_, existing_name)| existing_name == relation_name)
290 }
291
292 pub fn check_relation_name_conflict(&self, relation_name: &str) -> Result<()> {
293 if self.has_relation_name(relation_name) {
294 return Err(ErrorCode::DuplicateRelationName(format!(
295 "table name \"{}\" specified more than once",
296 relation_name
297 ))
298 .into());
299 }
300 Ok(())
301 }
302
303 fn check_cte_relation_name_conflict(&self, other: &Self) -> Result<()> {
304 for cte_relation_name in &self.cte_relation_names {
307 other.check_relation_name_conflict(cte_relation_name)?;
308 }
309 for cte_relation_name in &other.cte_relation_names {
312 self.check_relation_name_conflict(cte_relation_name)?;
313 }
314 Ok(())
315 }
316
317 pub fn add_natural_columns(
320 &mut self,
321 left: usize,
322 right: usize,
323 non_nullable_column: Option<usize>,
324 ) {
325 match (
326 self.column_group_context.mapping.get(&left).copied(),
327 self.column_group_context.mapping.get(&right).copied(),
328 ) {
329 (None, None) => {
330 let group_id = self.column_group_context.next_group_id;
331 self.column_group_context.next_group_id += 1;
332
333 let group = ColumnGroup {
334 indices: HashSet::from([left, right]),
335 non_nullable_column,
336 column_name: Some(self.columns[left].field.name.clone()),
337 };
338 self.column_group_context.groups.insert(group_id, group);
339 self.column_group_context.mapping.insert(left, group_id);
340 self.column_group_context.mapping.insert(right, group_id);
341 }
342 (Some(group_id), None) => {
343 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
344 group.indices.insert(right);
345 if group.non_nullable_column.is_none() {
346 group.non_nullable_column = non_nullable_column;
347 }
348 self.column_group_context.mapping.insert(right, group_id);
349 }
350 (None, Some(group_id)) => {
351 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
352 group.indices.insert(left);
353 if group.non_nullable_column.is_none() {
354 group.non_nullable_column = non_nullable_column;
355 }
356 self.column_group_context.mapping.insert(left, group_id);
357 }
358 (Some(l_group_id), Some(r_group_id)) => {
359 if r_group_id == l_group_id {
360 return;
361 }
362
363 let r_group = self
364 .column_group_context
365 .groups
366 .remove(&r_group_id)
367 .unwrap();
368 let l_group = self
369 .column_group_context
370 .groups
371 .get_mut(&l_group_id)
372 .unwrap();
373
374 for idx in &r_group.indices {
375 *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
376 l_group.indices.insert(*idx);
377 }
378 if l_group.non_nullable_column.is_none() {
379 l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
380 non_nullable_column
381 } else {
382 r_group.non_nullable_column
383 };
384 }
385 }
386 }
387 }
388
389 fn get_index_with_table_name(
396 &self,
397 column_name: &String,
398 table_name: &String,
399 schema_name: &Option<String>,
400 ) -> LiteResult<usize> {
401 let chosen_range = self.resolve_relation_range(table_name, schema_name)?;
402 self.resolve_column_in_range(column_name, table_name, chosen_range)
403 }
404
405 fn resolve_relation_range(
409 &self,
410 table_name: &String,
411 schema_name: &Option<String>,
412 ) -> LiteResult<(usize, usize)> {
413 let mut alias_ranges = Vec::new();
414 let mut base_ranges = Vec::new();
415
416 for ((_, _), range) in self.range_of.iter().filter(|((rel_schema, rel_name), _)| {
417 rel_name == table_name
418 && schema_name
419 .as_deref()
420 .is_none_or(|s| rel_schema.as_deref() == Some(s))
421 }) {
422 match self.columns[range.0].table_alias {
423 Some(_) => alias_ranges.push(*range),
424 None => base_ranges.push(*range),
425 }
426 }
427
428 let chosen_ranges = if alias_ranges.is_empty() {
429 base_ranges
430 } else {
431 alias_ranges
432 };
433
434 match chosen_ranges.as_slice() {
435 [] => Err(ErrorCode::ItemNotFound(format!(
436 "missing FROM-clause entry for table \"{}\"",
437 table_name
438 ))),
439 [range] => Ok(*range),
440 _ => Err(ErrorCode::InvalidReference(format!(
441 "table reference \"{}\" is ambiguous",
442 table_name
443 ))),
444 }
445 }
446
447 fn resolve_column_in_range(
449 &self,
450 column_name: &String,
451 table_name: &String,
452 range: (usize, usize),
453 ) -> LiteResult<usize> {
454 let idxs = self
455 .indices_of
456 .get(column_name)
457 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
458
459 let matched: Vec<_> = idxs
460 .iter()
461 .copied()
462 .filter(|index| (range.0..range.1).contains(index))
463 .collect();
464
465 match matched.as_slice() {
466 [] => Err(ErrorCode::ItemNotFound(format!(
467 "missing FROM-clause entry for table \"{}\"",
468 table_name
469 ))),
470 [column_index] => Ok(*column_index),
471 _ => Err(ErrorCode::InvalidReference(format!(
472 "column reference \"{}\" is ambiguous",
473 column_name
474 ))),
475 }
476 }
477
478 pub fn merge_context(&mut self, other: Self) -> Result<()> {
481 self.check_cte_relation_name_conflict(&other)?;
482
483 let begin = self.columns.len();
484 self.columns.extend(other.columns.into_iter().map(|mut c| {
485 c.index += begin;
486 c
487 }));
488 for (k, v) in other.indices_of {
489 let entry = self.indices_of.entry(k).or_default();
490 entry.extend(v.into_iter().map(|x| x + begin));
491 }
492 for (k, (x, y)) in other.range_of {
493 match self.range_of.entry(k) {
494 Entry::Occupied(e) => {
495 return Err(ErrorCode::InternalError(format!(
496 "Duplicated table name while merging adjacent contexts: {}",
497 e.key().1
498 ))
499 .into());
500 }
501 Entry::Vacant(entry) => {
502 entry.insert((begin + x, begin + y));
503 }
504 }
505 }
506 self.cte_relation_names.extend(other.cte_relation_names);
507 let ColumnGroupContext {
510 mapping,
511 groups,
512 next_group_id,
513 } = other.column_group_context;
514
515 let offset = self.column_group_context.next_group_id;
516 for (idx, group_id) in mapping {
517 self.column_group_context
518 .mapping
519 .insert(begin + idx, offset + group_id);
520 }
521 for (group_id, mut group) in groups {
522 group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
523 if let Some(col) = &mut group.non_nullable_column {
524 *col += begin;
525 }
526 self.column_group_context
527 .groups
528 .insert(offset + group_id, group);
529 }
530 self.column_group_context.next_group_id += next_group_id;
531
532 Ok(())
534 }
535}
536
537impl BindContext {
538 pub fn new() -> Self {
539 Self::default()
540 }
541}