risingwave_frontend/binder/
bind_context.rs1use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
19
20use parse_display::Display;
21use risingwave_common::catalog::Field;
22use risingwave_common::types::DataType;
23use risingwave_sqlparser::ast::{TableAlias, WindowSpec};
24
25use crate::binder::Relation;
26use crate::error::{ErrorCode, Result};
27use crate::expr::ExprImpl;
28
29type LiteResult<T> = std::result::Result<T, ErrorCode>;
30
31use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
32
33#[derive(Debug, Clone)]
34pub struct ColumnBinding {
35 pub table_name: String,
36 pub index: usize,
37 pub is_hidden: bool,
38 pub field: Field,
39}
40
41impl ColumnBinding {
42 pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
43 ColumnBinding {
44 table_name,
45 index,
46 is_hidden,
47 field,
48 }
49 }
50}
51
52#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
53#[display(style = "TITLE CASE")]
54pub enum Clause {
55 Where,
56 Values,
57 GroupBy,
58 JoinOn,
59 Having,
60 Filter,
61 From,
62 GeneratedColumn,
63 Insert,
64}
65
66pub struct LateralBindContext {
69 pub is_visible: bool,
70 pub context: BindContext,
71}
72
73#[derive(Debug, Clone)]
74pub enum BindingCteState {
75 Bound {
77 query: BoundQuery,
78 },
79
80 ChangeLog {
81 table: Relation,
82 },
83}
84
85#[derive(Clone, Debug)]
86pub struct BindingCte {
87 pub share_id: ShareId,
88 pub state: BindingCteState,
89 pub alias: TableAlias,
90}
91
92#[derive(Default, Debug, Clone)]
93pub struct BindContext {
94 pub columns: Vec<ColumnBinding>,
96 pub indices_of: HashMap<String, Vec<usize>>,
98 pub range_of: HashMap<String, (usize, usize)>,
100 pub clause: Option<Clause>,
102 pub column_group_context: ColumnGroupContext,
104 pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
107 pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
109 pub disable_security_invoker: bool,
111 pub named_windows: HashMap<String, WindowSpec>,
113 pub sql_udf_arguments: Option<HashMap<String, ExprImpl>>,
116}
117
118#[derive(Default, Debug, Clone)]
120pub struct ColumnGroupContext {
121 pub mapping: HashMap<usize, u32>,
123 pub groups: BTreeMap<u32, ColumnGroup>,
126
127 next_group_id: u32,
128}
129
130#[derive(Default, Debug, Clone)]
133pub struct ColumnGroup {
134 pub indices: HashSet<usize>,
136 pub non_nullable_column: Option<usize>,
140
141 pub column_name: Option<String>,
142}
143
144impl BindContext {
145 pub fn get_column_binding_index(
146 &self,
147 table_name: &Option<String>,
148 column_name: &String,
149 ) -> LiteResult<usize> {
150 match &self.get_column_binding_indices(table_name, column_name)?[..] {
151 [] => unreachable!(),
152 [idx] => Ok(*idx),
153 _ => Err(ErrorCode::InternalError(format!(
154 "Ambiguous column name: {}",
155 column_name
156 ))),
157 }
158 }
159
160 pub fn get_column_binding_indices(
164 &self,
165 table_name: &Option<String>,
166 column_name: &String,
167 ) -> LiteResult<Vec<usize>> {
168 match table_name {
169 Some(table_name) => {
170 if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
171 let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
172 format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
173 self.get_indices_with_group_id(group_id, column_name)
174 } else {
175 Ok(vec![
176 self.get_index_with_table_name(column_name, table_name)?,
177 ])
178 }
179 }
180 None => self.get_unqualified_indices(column_name),
181 }
182 }
183
184 fn get_indices_with_group_id(
185 &self,
186 group_id: u32,
187 column_name: &String,
188 ) -> LiteResult<Vec<usize>> {
189 let group = self.column_group_context.groups.get(&group_id).unwrap();
190 if let Some(name) = &group.column_name {
191 debug_assert_eq!(name, column_name);
192 }
193 if let Some(non_nullable) = &group.non_nullable_column {
194 Ok(vec![*non_nullable])
195 } else {
196 let mut indices: Vec<_> = group.indices.iter().copied().collect();
198 indices.sort(); Ok(indices)
200 }
201 }
202
203 pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
204 let columns = self
205 .indices_of
206 .get(column_name)
207 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
208 if columns.len() > 1 {
209 if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
212 let group = self.column_group_context.groups.get(group_id).unwrap();
213 if columns.iter().all(|idx| group.indices.contains(idx)) {
214 if let Some(non_nullable) = &group.non_nullable_column {
215 return Ok(vec![*non_nullable]);
216 } else {
217 return Ok(columns.clone());
219 }
220 }
221 }
222 Err(ErrorCode::InternalError(format!(
223 "Ambiguous column name: {}",
224 column_name
225 )))
226 } else {
227 Ok(columns.clone())
228 }
229 }
230
231 pub fn add_natural_columns(
234 &mut self,
235 left: usize,
236 right: usize,
237 non_nullable_column: Option<usize>,
238 ) {
239 match (
240 self.column_group_context.mapping.get(&left).copied(),
241 self.column_group_context.mapping.get(&right).copied(),
242 ) {
243 (None, None) => {
244 let group_id = self.column_group_context.next_group_id;
245 self.column_group_context.next_group_id += 1;
246
247 let group = ColumnGroup {
248 indices: HashSet::from([left, right]),
249 non_nullable_column,
250 column_name: Some(self.columns[left].field.name.clone()),
251 };
252 self.column_group_context.groups.insert(group_id, group);
253 self.column_group_context.mapping.insert(left, group_id);
254 self.column_group_context.mapping.insert(right, group_id);
255 }
256 (Some(group_id), None) => {
257 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
258 group.indices.insert(right);
259 if group.non_nullable_column.is_none() {
260 group.non_nullable_column = non_nullable_column;
261 }
262 self.column_group_context.mapping.insert(right, group_id);
263 }
264 (None, Some(group_id)) => {
265 let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
266 group.indices.insert(left);
267 if group.non_nullable_column.is_none() {
268 group.non_nullable_column = non_nullable_column;
269 }
270 self.column_group_context.mapping.insert(left, group_id);
271 }
272 (Some(l_group_id), Some(r_group_id)) => {
273 if r_group_id == l_group_id {
274 return;
275 }
276
277 let r_group = self
278 .column_group_context
279 .groups
280 .remove(&r_group_id)
281 .unwrap();
282 let l_group = self
283 .column_group_context
284 .groups
285 .get_mut(&l_group_id)
286 .unwrap();
287
288 for idx in &r_group.indices {
289 *self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
290 l_group.indices.insert(*idx);
291 }
292 if l_group.non_nullable_column.is_none() {
293 l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
294 non_nullable_column
295 } else {
296 r_group.non_nullable_column
297 };
298 }
299 }
300 }
301 }
302
303 fn get_index_with_table_name(
304 &self,
305 column_name: &String,
306 table_name: &String,
307 ) -> LiteResult<usize> {
308 let column_indexes = self
309 .indices_of
310 .get(column_name)
311 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
312 match column_indexes
313 .iter()
314 .find(|column_index| self.columns[**column_index].table_name == *table_name)
315 {
316 Some(column_index) => Ok(*column_index),
317 None => Err(ErrorCode::ItemNotFound(format!(
318 "missing FROM-clause entry for table \"{}\"",
319 table_name
320 ))),
321 }
322 }
323
324 pub fn merge_context(&mut self, other: Self) -> Result<()> {
327 let begin = self.columns.len();
328 self.columns.extend(other.columns.into_iter().map(|mut c| {
329 c.index += begin;
330 c
331 }));
332 for (k, v) in other.indices_of {
333 let entry = self.indices_of.entry(k).or_default();
334 entry.extend(v.into_iter().map(|x| x + begin));
335 }
336 for (k, (x, y)) in other.range_of {
337 match self.range_of.entry(k) {
338 Entry::Occupied(e) => {
339 return Err(ErrorCode::InternalError(format!(
340 "Duplicated table name while merging adjacent contexts: {}",
341 e.key()
342 ))
343 .into());
344 }
345 Entry::Vacant(entry) => {
346 entry.insert((begin + x, begin + y));
347 }
348 }
349 }
350 let ColumnGroupContext {
353 mapping,
354 groups,
355 next_group_id,
356 } = other.column_group_context;
357
358 let offset = self.column_group_context.next_group_id;
359 for (idx, group_id) in mapping {
360 self.column_group_context
361 .mapping
362 .insert(begin + idx, offset + group_id);
363 }
364 for (group_id, mut group) in groups {
365 group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
366 if let Some(col) = &mut group.non_nullable_column {
367 *col += begin;
368 }
369 self.column_group_context
370 .groups
371 .insert(offset + group_id, group);
372 }
373 self.column_group_context.next_group_id += next_group_id;
374
375 Ok(())
377 }
378}
379
380impl BindContext {
381 pub fn new() -> Self {
382 Self::default()
383 }
384}