1use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use parse_display::Display;
21use risingwave_common::catalog::Field;
22use risingwave_common::types::DataType;
23use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
24
25use crate::binder::Relation;
26use crate::error::{ErrorCode, Result};
27use crate::expr::ExprImpl;
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
32
33#[derive(Debug, Clone)]
34pub struct ColumnBinding {
35 pub table_name: String,
36 pub schema_name: Option<String>,
37 pub table_alias: Option<String>,
39 pub index: usize,
40 pub is_hidden: bool,
41 pub field: Field,
42}
43
44impl ColumnBinding {
45 pub fn new(
46 table_name: String,
47 schema_name: Option<String>,
48 table_alias: Option<String>,
49 index: usize,
50 is_hidden: bool,
51 field: Field,
52 ) -> Self {
53 ColumnBinding {
54 table_name,
55 schema_name,
56 table_alias,
57 index,
58 is_hidden,
59 field,
60 }
61 }
62}
63
64#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
65#[display(style = "TITLE CASE")]
66pub enum Clause {
67 Where,
68 Values,
69 GroupBy,
70 JoinOn,
71 Having,
72 Filter,
73 From,
74 GeneratedColumn,
75 Insert,
76}
77
78pub struct LateralBindContext {
81 pub is_visible: bool,
82 pub context: BindContext,
83}
84
85#[derive(Debug, Clone)]
86pub enum BindingCteState {
87 Bound {
89 query: BoundQuery,
90 },
91
92 ChangeLog {
93 table: Relation,
94 },
95}
96
97#[derive(Clone, Debug)]
98pub struct BindingCte {
99 pub share_id: ShareId,
100 pub state: BindingCteState,
101 pub alias: TableAlias,
102}
103
104#[derive(Default, Debug, Clone)]
105pub struct BindContext {
106 pub columns: Vec<ColumnBinding>,
108 pub indices_of: HashMap<String, Vec<usize>>,
110 pub range_of: HashMap<(Option<String>, String), (usize, usize)>,
112 pub clause: Option<Clause>,
114 pub column_group_context: ColumnGroupContext,
116 pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
119 pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
121 pub disable_security_invoker: bool,
123 pub named_windows: HashMap<String, WindowSpec>,
125 pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
128}
129
130#[derive(Default, Debug, Clone)]
132pub struct ColumnGroupContext {
133 pub mapping: HashMap<usize, u32>,
135 pub groups: BTreeMap<u32, ColumnGroup>,
138
139 next_group_id: u32,
140}
141
142#[derive(Default, Debug, Clone)]
145pub struct ColumnGroup {
146 pub indices: HashSet<usize>,
148 pub non_nullable_column: Option<usize>,
152
153 pub column_name: Option<String>,
154}
155
156impl BindContext {
157 pub fn get_column_binding_index(
158 &self,
159 schema_name: &Option<String>,
160 table_name: &Option<String>,
161 column_name: &String,
162 ) -> LiteResult<usize> {
163 match &self.get_column_binding_indices(schema_name, table_name, column_name)?[..] {
164 [] => unreachable!(),
165 [idx] => Ok(*idx),
166 _ => Err(ErrorCode::InternalError(format!(
167 "Ambiguous column name: {}",
168 column_name
169 ))),
170 }
171 }
172
173 pub fn get_column_binding_indices(
177 &self,
178 schema_name: &Option<String>,
179 table_name: &Option<String>,
180 column_name: &String,
181 ) -> LiteResult<Vec<usize>> {
182 match table_name {
183 Some(table_name) => {
184 if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
185 let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
186 format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
187 self.get_indices_with_group_id(group_id, column_name)
188 } else {
189 Ok(vec![self.get_index_with_table_name(
190 column_name,
191 table_name,
192 schema_name,
193 )?])
194 }
195 }
196 None => self.get_unqualified_indices(column_name),
197 }
198 }
199
200 pub fn get_table_alias(
201 &self,
202 schema_name: &String,
203 table_name: &String,
204 column_name: &String,
205 ) -> LiteResult<Option<usize>> {
206 let column_indexes = self
207 .indices_of
208 .get(column_name)
209 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
210 for index in column_indexes {
211 let column = &self.columns[*index];
212 if let (Some(schema), Some(table_alias)) = (&column.schema_name, &column.table_alias)
213 && schema == schema_name
214 && table_alias == table_name
215 {
216 return Ok(Some(*index));
217 }
218 }
219 Ok(None)
220 }
221
222 fn get_indices_with_group_id(
223 &self,
224 group_id: u32,
225 column_name: &String,
226 ) -> LiteResult<Vec<usize>> {
227 let group = self.column_group_context.groups.get(&group_id).unwrap();
228 if let Some(name) = &group.column_name {
229 debug_assert_eq!(name, column_name);
230 }
231 if let Some(non_nullable) = &group.non_nullable_column {
232 Ok(vec![*non_nullable])
233 } else {
234 let mut indices: Vec<_> = group.indices.iter().copied().collect();
236 indices.sort(); Ok(indices)
238 }
239 }
240
241 pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
242 let columns = self
243 .indices_of
244 .get(column_name)
245 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
246 if columns.len() > 1 {
247 if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
250 let group = self.column_group_context.groups.get(group_id).unwrap();
251 if columns.iter().all(|idx| group.indices.contains(idx)) {
252 if let Some(non_nullable) = &group.non_nullable_column {
253 return Ok(vec![*non_nullable]);
254 } else {
255 return Ok(columns.clone());
257 }
258 }
259 }
260 Err(ErrorCode::InternalError(format!(
261 "Ambiguous column name: {}",
262 column_name
263 )))
264 } else {
265 Ok(columns.clone())
266 }
267 }
268
269 pub fn add_natural_columns(
272 &mut self,
273 left: usize,
274 right: usize,
275 non_nullable_column: Option<usize>,
276 ) {
277 match (
278 self.column_group_context.mapping.get(&left).copied(),
279 self.column_group_context.mapping.get(&right).copied(),
280 ) {
281 (None, None) => {
282 let group_id = self.column_group_context.next_group_id;
283 self.column_group_context.next_group_id += 1;
284
285 let group = ColumnGroup {
286 indices: HashSet::from([left, right]),
287 non_nullable_column,
288 column_name: Some(self.columns[left].field.name.clone()),
289 };
290 self.column_group_context.groups.insert(group_id, group);
291 self.column_group_context.mapping.insert(left, group_id);
292 self.column_group_context.mapping.insert(right, group_id);
293 }
294 (Some(group_id), None) => {
295 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
296 group.indices.insert(right);
297 if group.non_nullable_column.is_none() {
298 group.non_nullable_column = non_nullable_column;
299 }
300 self.column_group_context.mapping.insert(right, group_id);
301 }
302 (None, Some(group_id)) => {
303 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
304 group.indices.insert(left);
305 if group.non_nullable_column.is_none() {
306 group.non_nullable_column = non_nullable_column;
307 }
308 self.column_group_context.mapping.insert(left, group_id);
309 }
310 (Some(l_group_id), Some(r_group_id)) => {
311 if r_group_id == l_group_id {
312 return;
313 }
314
315 let r_group = self
316 .column_group_context
317 .groups
318 .remove(&r_group_id)
319 .unwrap();
320 let l_group = self
321 .column_group_context
322 .groups
323 .get_mut(&l_group_id)
324 .unwrap();
325
326 for idx in &r_group.indices {
327 *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
328 l_group.indices.insert(*idx);
329 }
330 if l_group.non_nullable_column.is_none() {
331 l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
332 non_nullable_column
333 } else {
334 r_group.non_nullable_column
335 };
336 }
337 }
338 }
339 }
340
341 fn get_index_with_table_name(
348 &self,
349 column_name: &String,
350 table_name: &String,
351 schema_name: &Option<String>,
352 ) -> LiteResult<usize> {
353 let chosen_range = self.resolve_relation_range(table_name, schema_name)?;
354 self.resolve_column_in_range(column_name, table_name, chosen_range)
355 }
356
357 fn resolve_relation_range(
361 &self,
362 table_name: &String,
363 schema_name: &Option<String>,
364 ) -> LiteResult<(usize, usize)> {
365 let mut alias_ranges = Vec::new();
366 let mut base_ranges = Vec::new();
367
368 for ((_, _), range) in self.range_of.iter().filter(|((rel_schema, rel_name), _)| {
369 rel_name == table_name
370 && schema_name
371 .as_deref()
372 .is_none_or(|s| rel_schema.as_deref() == Some(s))
373 }) {
374 match self.columns[range.0].table_alias {
375 Some(_) => alias_ranges.push(*range),
376 None => base_ranges.push(*range),
377 }
378 }
379
380 let chosen_ranges = if alias_ranges.is_empty() {
381 base_ranges
382 } else {
383 alias_ranges
384 };
385
386 match chosen_ranges.as_slice() {
387 [] => Err(ErrorCode::ItemNotFound(format!(
388 "missing FROM-clause entry for table \"{}\"",
389 table_name
390 ))),
391 [range] => Ok(*range),
392 _ => Err(ErrorCode::InvalidReference(format!(
393 "table reference \"{}\" is ambiguous",
394 table_name
395 ))),
396 }
397 }
398
399 fn resolve_column_in_range(
401 &self,
402 column_name: &String,
403 table_name: &String,
404 range: (usize, usize),
405 ) -> LiteResult<usize> {
406 let idxs = self
407 .indices_of
408 .get(column_name)
409 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
410
411 let matched: Vec<_> = idxs
412 .iter()
413 .copied()
414 .filter(|index| (range.0..range.1).contains(index))
415 .collect();
416
417 match matched.as_slice() {
418 [] => Err(ErrorCode::ItemNotFound(format!(
419 "missing FROM-clause entry for table \"{}\"",
420 table_name
421 ))),
422 [column_index] => Ok(*column_index),
423 _ => Err(ErrorCode::InvalidReference(format!(
424 "table reference \"{}\" is ambiguous",
425 table_name
426 ))),
427 }
428 }
429
430 pub fn merge_context(&mut self, other: Self) -> Result<()> {
433 let begin = self.columns.len();
434 self.columns.extend(other.columns.into_iter().map(|mut c| {
435 c.index += begin;
436 c
437 }));
438 for (k, v) in other.indices_of {
439 let entry = self.indices_of.entry(k).or_default();
440 entry.extend(v.into_iter().map(|x| x + begin));
441 }
442 for (k, (x, y)) in other.range_of {
443 match self.range_of.entry(k) {
444 Entry::Occupied(e) => {
445 return Err(ErrorCode::InternalError(format!(
446 "Duplicated table name while merging adjacent contexts: {}",
447 e.key().1
448 ))
449 .into());
450 }
451 Entry::Vacant(entry) => {
452 entry.insert((begin + x, begin + y));
453 }
454 }
455 }
456 let ColumnGroupContext {
459 mapping,
460 groups,
461 next_group_id,
462 } = other.column_group_context;
463
464 let offset = self.column_group_context.next_group_id;
465 for (idx, group_id) in mapping {
466 self.column_group_context
467 .mapping
468 .insert(begin + idx, offset + group_id);
469 }
470 for (group_id, mut group) in groups {
471 group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
472 if let Some(col) = &mut group.non_nullable_column {
473 *col += begin;
474 }
475 self.column_group_context
476 .groups
477 .insert(offset + group_id, group);
478 }
479 self.column_group_context.next_group_id += next_group_id;
480
481 Ok(())
483 }
484}
485
486impl BindContext {
487 pub fn new() -> Self {
488 Self::default()
489 }
490}