1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
1415use std::cell::RefCell;
16use std::collections::hash_map::Entry;
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::rc::Rc;
1920use either::Either;
21use parse_display::Display;
22use risingwave_common::catalog::{Field, Schema};
23use risingwave_common::types::DataType;
24use risingwave_sqlparser::ast::TableAlias;
2526use crate::binder::Relation;
27use crate::error::{ErrorCode, Result};
2829type LiteResult<T> = std::result::Result<T, ErrorCode>;
3031use super::BoundSetExpr;
32use super::statement::RewriteExprsRecursive;
33use crate::binder::{BoundQuery, COLUMN_GROUP_PREFIX, ShareId};
3435#[derive(Debug, Clone)]
36pub struct ColumnBinding {
37pub table_name: String,
38pub index: usize,
39pub is_hidden: bool,
40pub field: Field,
41}
4243impl ColumnBinding {
44pub fn new(table_name: String, index: usize, is_hidden: bool, field: Field) -> Self {
45 ColumnBinding {
46 table_name,
47 index,
48 is_hidden,
49 field,
50 }
51 }
52}
5354#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Display)]
55#[display(style = "TITLE CASE")]
56pub enum Clause {
57 Where,
58 Values,
59 GroupBy,
60 JoinOn,
61 Having,
62 Filter,
63 From,
64 GeneratedColumn,
65 Insert,
66}
6768/// A `BindContext` that is only visible if the `LATERAL` keyword
69/// is provided.
70pub struct LateralBindContext {
71pub is_visible: bool,
72pub context: BindContext,
73}
7475/// For recursive CTE, we may need to store it in `cte_to_relation` first,
76/// and then bind it *step by step*.
77///
78/// note: the below sql example is to illustrate when we get the
79/// corresponding binding state when handling a recursive CTE like this.
80///
81/// ```sql
82/// WITH RECURSIVE t(n) AS (
83/// # -------------^ => Init
84/// VALUES (1)
85/// UNION ALL
86/// SELECT n + 1 FROM t WHERE n < 100
87/// # --------------------^ => BaseResolved (after binding the base term, this relation will be bound to `Relation::BackCteRef`)
88/// )
89/// SELECT sum(n) FROM t;
90/// # -----------------^ => Bound (we know exactly what the entire `RecursiveUnion` looks like, and this relation will be bound to `Relation::Share`)
91/// ```
92#[derive(Default, Debug, Clone)]
93pub enum BindingCteState {
94/// We know nothing about the CTE before resolving the body.
95#[default]
96Init,
97/// We know the schema form after the base term resolved.
98BaseResolved {
99 base: BoundSetExpr,
100 },
101/// We get the whole bound result of the (recursive) CTE.
102Bound {
103 query: Either<BoundQuery, RecursiveUnion>,
104 },
105106 ChangeLog {
107 table: Relation,
108 },
109}
110111/// the entire `RecursiveUnion` represents a *bound* recursive cte.
112/// reference: <https://github.com/risingwavelabs/risingwave/pull/15522/files#r1524367781>
113#[derive(Debug, Clone)]
114pub struct RecursiveUnion {
115/// currently this *must* be true,
116 /// otherwise binding will fail.
117#[allow(dead_code)]
118pub all: bool,
119/// lhs part of the `UNION ALL` operator
120pub base: Box<BoundSetExpr>,
121/// rhs part of the `UNION ALL` operator
122pub recursive: Box<BoundSetExpr>,
123/// the aligned schema for this union
124 /// will be the *same* schema as recursive's
125 /// this is just for a better readability
126pub schema: Schema,
127}
128129impl RewriteExprsRecursive for RecursiveUnion {
130fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
131// rewrite `base` and `recursive` separately
132self.base.rewrite_exprs_recursive(rewriter);
133self.recursive.rewrite_exprs_recursive(rewriter);
134 }
135}
136137#[derive(Clone, Debug)]
138pub struct BindingCte {
139pub share_id: ShareId,
140pub state: BindingCteState,
141pub alias: TableAlias,
142}
143144#[derive(Default, Debug, Clone)]
145pub struct BindContext {
146// Columns of all tables.
147pub columns: Vec<ColumnBinding>,
148// Mapping column name to indices in `columns`.
149pub indices_of: HashMap<String, Vec<usize>>,
150// Mapping table name to [begin, end) of its columns.
151pub range_of: HashMap<String, (usize, usize)>,
152// `clause` identifies in what clause we are binding.
153pub clause: Option<Clause>,
154// The `BindContext`'s data on its column groups
155pub column_group_context: ColumnGroupContext,
156/// Map the cte's name to its binding state.
157 /// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
158pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
159/// Current lambda functions's arguments
160pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
161/// Whether the security invoker is set, currently only used for views.
162pub disable_security_invoker: bool,
163}
164165/// Holds the context for the `BindContext`'s `ColumnGroup`s.
166#[derive(Default, Debug, Clone)]
167pub struct ColumnGroupContext {
168// Maps naturally-joined/USING columns to their column group id
169pub mapping: HashMap<usize, u32>,
170// Maps column group ids to their column group data
171 // We use a BTreeMap to ensure that iteration over the groups is ordered.
172pub groups: BTreeMap<u32, ColumnGroup>,
173174 next_group_id: u32,
175}
176177/// When binding a natural join or a join with USING, a `ColumnGroup` contains the columns with the
178/// same name.
179#[derive(Default, Debug, Clone)]
180pub struct ColumnGroup {
181/// Indices of the columns in the group
182pub indices: HashSet<usize>,
183/// A non-nullable column is never NULL.
184 /// If `None`, ambiguous references to the column name will be resolved to a `COALESCE(col1,
185 /// col2, ..., coln)` over each column in the group
186pub non_nullable_column: Option<usize>,
187188pub column_name: Option<String>,
189}
190191impl BindContext {
192pub fn get_column_binding_index(
193&self,
194 table_name: &Option<String>,
195 column_name: &String,
196 ) -> LiteResult<usize> {
197match &self.get_column_binding_indices(table_name, column_name)?[..] {
198 [] => unreachable!(),
199 [idx] => Ok(*idx),
200_ => Err(ErrorCode::InternalError(format!(
201"Ambiguous column name: {}",
202 column_name
203 ))),
204 }
205 }
206207/// If return Vec has len > 1, it means we have an unqualified reference to a column which has
208 /// been naturally joined upon, wherein none of the columns are min-nullable. This will be
209 /// handled in downstream as a `COALESCE` expression
210pub fn get_column_binding_indices(
211&self,
212 table_name: &Option<String>,
213 column_name: &String,
214 ) -> LiteResult<Vec<usize>> {
215match table_name {
216Some(table_name) => {
217if let Some(group_id_str) = table_name.strip_prefix(COLUMN_GROUP_PREFIX) {
218let group_id = group_id_str.parse::<u32>().map_err(|_|ErrorCode::InternalError(
219format!("Could not parse {:?} as virtual table name `{COLUMN_GROUP_PREFIX}[group_id]`", table_name)))?;
220self.get_indices_with_group_id(group_id, column_name)
221 } else {
222Ok(vec![
223self.get_index_with_table_name(column_name, table_name)?,
224 ])
225 }
226 }
227None => self.get_unqualified_indices(column_name),
228 }
229 }
230231fn get_indices_with_group_id(
232&self,
233 group_id: u32,
234 column_name: &String,
235 ) -> LiteResult<Vec<usize>> {
236let group = self.column_group_context.groups.get(&group_id).unwrap();
237if let Some(name) = &group.column_name {
238debug_assert_eq!(name, column_name);
239 }
240if let Some(non_nullable) = &group.non_nullable_column {
241Ok(vec![*non_nullable])
242 } else {
243// These will be converted to a `COALESCE(col1, col2, ..., coln)`
244let mut indices: Vec<_> = group.indices.iter().copied().collect();
245 indices.sort(); // ensure a deterministic result
246Ok(indices)
247 }
248 }
249250pub fn get_unqualified_indices(&self, column_name: &String) -> LiteResult<Vec<usize>> {
251let columns = self
252.indices_of
253 .get(column_name)
254 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {column_name}")))?;
255if columns.len() > 1 {
256// If there is some group containing the columns and the ambiguous columns are all in
257 // the group
258if let Some(group_id) = self.column_group_context.mapping.get(&columns[0]) {
259let group = self.column_group_context.groups.get(group_id).unwrap();
260if columns.iter().all(|idx| group.indices.contains(idx)) {
261if let Some(non_nullable) = &group.non_nullable_column {
262return Ok(vec![*non_nullable]);
263 } else {
264// These will be converted to a `COALESCE(col1, col2, ..., coln)`
265return Ok(columns.to_vec());
266 }
267 }
268 }
269Err(ErrorCode::InternalError(format!(
270"Ambiguous column name: {}",
271 column_name
272 )))
273 } else {
274Ok(columns.to_vec())
275 }
276 }
277278/// Identifies two columns as being in the same group. Additionally, possibly provides one of
279 /// the columns as being `non_nullable`
280pub fn add_natural_columns(
281&mut self,
282 left: usize,
283 right: usize,
284 non_nullable_column: Option<usize>,
285 ) {
286match (
287self.column_group_context.mapping.get(&left).copied(),
288self.column_group_context.mapping.get(&right).copied(),
289 ) {
290 (None, None) => {
291let group_id = self.column_group_context.next_group_id;
292self.column_group_context.next_group_id += 1;
293294let group = ColumnGroup {
295 indices: HashSet::from([left, right]),
296 non_nullable_column,
297 column_name: Some(self.columns[left].field.name.clone()),
298 };
299self.column_group_context.groups.insert(group_id, group);
300self.column_group_context.mapping.insert(left, group_id);
301self.column_group_context.mapping.insert(right, group_id);
302 }
303 (Some(group_id), None) => {
304let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
305 group.indices.insert(right);
306if group.non_nullable_column.is_none() {
307 group.non_nullable_column = non_nullable_column;
308 }
309self.column_group_context.mapping.insert(right, group_id);
310 }
311 (None, Some(group_id)) => {
312let group = self.column_group_context.groups.get_mut(&group_id).unwrap();
313 group.indices.insert(left);
314if group.non_nullable_column.is_none() {
315 group.non_nullable_column = non_nullable_column;
316 }
317self.column_group_context.mapping.insert(left, group_id);
318 }
319 (Some(l_group_id), Some(r_group_id)) => {
320if r_group_id == l_group_id {
321return;
322 }
323324let r_group = self
325.column_group_context
326 .groups
327 .remove(&r_group_id)
328 .unwrap();
329let l_group = self
330.column_group_context
331 .groups
332 .get_mut(&l_group_id)
333 .unwrap();
334335for idx in &r_group.indices {
336*self.column_group_context.mapping.get_mut(idx).unwrap() = l_group_id;
337 l_group.indices.insert(*idx);
338 }
339if l_group.non_nullable_column.is_none() {
340 l_group.non_nullable_column = if r_group.non_nullable_column.is_none() {
341 non_nullable_column
342 } else {
343 r_group.non_nullable_column
344 };
345 }
346 }
347 }
348 }
349350fn get_index_with_table_name(
351&self,
352 column_name: &String,
353 table_name: &String,
354 ) -> LiteResult<usize> {
355let column_indexes = self
356.indices_of
357 .get(column_name)
358 .ok_or_else(|| ErrorCode::ItemNotFound(format!("Invalid column: {}", column_name)))?;
359match column_indexes
360 .iter()
361 .find(|column_index| self.columns[**column_index].table_name == *table_name)
362 {
363Some(column_index) => Ok(*column_index),
364None => Err(ErrorCode::ItemNotFound(format!(
365"missing FROM-clause entry for table \"{}\"",
366 table_name
367 ))),
368 }
369 }
370371/// Merges two `BindContext`s which are adjacent. For instance, the `BindContext` of two
372 /// adjacent cross-joined tables.
373pub fn merge_context(&mut self, other: Self) -> Result<()> {
374let begin = self.columns.len();
375self.columns.extend(other.columns.into_iter().map(|mut c| {
376 c.index += begin;
377 c
378 }));
379for (k, v) in other.indices_of {
380let entry = self.indices_of.entry(k).or_default();
381 entry.extend(v.into_iter().map(|x| x + begin));
382 }
383for (k, (x, y)) in other.range_of {
384match self.range_of.entry(k) {
385 Entry::Occupied(e) => {
386return Err(ErrorCode::InternalError(format!(
387"Duplicated table name while merging adjacent contexts: {}",
388 e.key()
389 ))
390 .into());
391 }
392 Entry::Vacant(entry) => {
393 entry.insert((begin + x, begin + y));
394 }
395 }
396 }
397// To merge the column_group_contexts, we just need to offset RHS
398 // with the next_group_id of LHS.
399let ColumnGroupContext {
400 mapping,
401 groups,
402 next_group_id,
403 } = other.column_group_context;
404405let offset = self.column_group_context.next_group_id;
406for (idx, group_id) in mapping {
407self.column_group_context
408 .mapping
409 .insert(begin + idx, offset + group_id);
410 }
411for (group_id, mut group) in groups {
412 group.indices = group.indices.into_iter().map(|idx| idx + begin).collect();
413if let Some(col) = &mut group.non_nullable_column {
414*col += begin;
415 }
416self.column_group_context
417 .groups
418 .insert(offset + group_id, group);
419 }
420self.column_group_context.next_group_id += next_group_id;
421422// we assume that the clause is contained in the outer-level context
423Ok(())
424 }
425}
426427impl BindContext {
428pub fn new() -> Self {
429Self::default()
430 }
431}