use std::collections::hash_map::Entry;
use std::ops::Deref;
use either::Either;
use itertools::{EitherOrBoth, Itertools};
use risingwave_common::bail;
use risingwave_common::catalog::{Field, TableId, DEFAULT_SCHEMA_NAME};
use risingwave_sqlparser::ast::{
AsOf, Expr as ParserExpr, FunctionArg, FunctionArgExpr, Ident, ObjectName, TableAlias,
TableFactor,
};
use thiserror::Error;
use thiserror_ext::AsReport;
use super::bind_context::ColumnBinding;
use super::statement::RewriteExprsRecursive;
use crate::binder::bind_context::{BindingCte, BindingCteState};
use crate::binder::Binder;
use crate::error::{ErrorCode, Result, RwError};
use crate::expr::{ExprImpl, InputRef};
mod cte_ref;
mod join;
mod share;
mod subquery;
mod table_function;
mod table_or_source;
mod watermark;
mod window_table_function;
pub use cte_ref::BoundBackCteRef;
pub use join::BoundJoin;
pub use share::{BoundShare, BoundShareInput};
pub use subquery::BoundSubquery;
pub use table_or_source::{BoundBaseTable, BoundSource, BoundSystemTable};
pub use watermark::BoundWatermark;
pub use window_table_function::{BoundWindowTableFunction, WindowTableFunctionKind};
use crate::expr::{CorrelatedId, Depth};
#[derive(Debug, Clone)]
pub enum Relation {
Source(Box<BoundSource>),
BaseTable(Box<BoundBaseTable>),
SystemTable(Box<BoundSystemTable>),
Subquery(Box<BoundSubquery>),
Join(Box<BoundJoin>),
Apply(Box<BoundJoin>),
WindowTableFunction(Box<BoundWindowTableFunction>),
TableFunction {
expr: ExprImpl,
with_ordinality: bool,
},
Watermark(Box<BoundWatermark>),
Share(Box<BoundShare>),
BackCteRef(Box<BoundBackCteRef>),
}
impl RewriteExprsRecursive for Relation {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
match self {
Relation::Subquery(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Join(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Apply(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::WindowTableFunction(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Watermark(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Share(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::TableFunction { expr: inner, .. } => {
*inner = rewriter.rewrite_expr(inner.take())
}
Relation::BackCteRef(inner) => inner.rewrite_exprs_recursive(rewriter),
_ => {}
}
}
}
impl Relation {
pub fn is_correlated(&self, depth: Depth) -> bool {
match self {
Relation::Subquery(subquery) => subquery.query.is_correlated(depth),
Relation::Join(join) | Relation::Apply(join) => {
join.cond.has_correlated_input_ref_by_depth(depth)
|| join.left.is_correlated(depth)
|| join.right.is_correlated(depth)
}
_ => false,
}
}
pub fn collect_correlated_indices_by_depth_and_assign_id(
&mut self,
depth: Depth,
correlated_id: CorrelatedId,
) -> Vec<usize> {
match self {
Relation::Subquery(subquery) => subquery
.query
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
Relation::Join(join) | Relation::Apply(join) => {
let mut correlated_indices = vec![];
correlated_indices.extend(
join.cond
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
);
correlated_indices.extend(
join.left
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
);
correlated_indices.extend(
join.right
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
);
correlated_indices
}
Relation::TableFunction {
expr: table_function,
with_ordinality: _,
} => table_function
.collect_correlated_indices_by_depth_and_assign_id(depth + 1, correlated_id),
Relation::Share(share) => match &mut share.input {
BoundShareInput::Query(query) => match query {
Either::Left(query) => query
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
Either::Right(_) => vec![],
},
BoundShareInput::ChangeLog(change_log) => change_log
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
},
_ => vec![],
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum ResolveQualifiedNameErrorKind {
QualifiedNameTooLong,
NotCurrentDatabase,
}
#[derive(Debug, Error)]
pub struct ResolveQualifiedNameError {
qualified: String,
kind: ResolveQualifiedNameErrorKind,
}
impl std::fmt::Display for ResolveQualifiedNameError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
ResolveQualifiedNameErrorKind::QualifiedNameTooLong => write!(
f,
"improper qualified name (too many dotted names): {}",
self.qualified
),
ResolveQualifiedNameErrorKind::NotCurrentDatabase => write!(
f,
"cross-database references are not implemented: \"{}\"",
self.qualified
),
}
}
}
impl ResolveQualifiedNameError {
pub fn new(qualified: String, kind: ResolveQualifiedNameErrorKind) -> Self {
Self { qualified, kind }
}
}
impl From<ResolveQualifiedNameError> for RwError {
fn from(e: ResolveQualifiedNameError) -> Self {
ErrorCode::InvalidInputSyntax(format!("{}", e.as_report())).into()
}
}
impl Binder {
pub fn resolve_schema_qualified_name(
db_name: &str,
name: ObjectName,
) -> std::result::Result<(Option<String>, String), ResolveQualifiedNameError> {
let formatted_name = name.to_string();
let mut identifiers = name.0;
if identifiers.len() > 3 {
return Err(ResolveQualifiedNameError::new(
formatted_name,
ResolveQualifiedNameErrorKind::QualifiedNameTooLong,
));
}
let name = identifiers.pop().unwrap().real_value();
let schema_name = identifiers.pop().map(|ident| ident.real_value());
let database_name = identifiers.pop().map(|ident| ident.real_value());
if let Some(database_name) = database_name
&& database_name != db_name
{
return Err(ResolveQualifiedNameError::new(
formatted_name,
ResolveQualifiedNameErrorKind::NotCurrentDatabase,
));
}
Ok((schema_name, name))
}
fn resolve_single_name(mut identifiers: Vec<Ident>, ident_desc: &str) -> Result<String> {
if identifiers.len() > 1 {
bail!("{} must contain 1 argument", ident_desc);
}
let name = identifiers.pop().unwrap().real_value();
Ok(name)
}
pub fn resolve_database_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "database name")
}
pub fn resolve_schema_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "schema name")
}
pub fn resolve_index_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "index name")
}
pub fn resolve_view_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "view name")
}
pub fn resolve_sink_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "sink name")
}
pub fn resolve_subscription_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "subscription name")
}
pub fn resolve_table_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "table name")
}
pub fn resolve_source_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "source name")
}
pub fn resolve_user_name(name: ObjectName) -> Result<String> {
Self::resolve_single_name(name.0, "user name")
}
pub(super) fn bind_table_to_context(
&mut self,
columns: impl IntoIterator<Item = (bool, Field)>, table_name: String,
alias: Option<TableAlias>,
) -> Result<()> {
let (table_name, column_aliases) = match alias {
None => (table_name, vec![]),
Some(TableAlias { name, columns }) => (name.real_value(), columns),
};
let num_col_aliases = column_aliases.len();
let begin = self.context.columns.len();
let mut alias_iter = column_aliases.into_iter().fuse();
let mut index = 0;
columns.into_iter().for_each(|(is_hidden, mut field)| {
let name = match is_hidden {
true => field.name.to_string(),
false => alias_iter
.next()
.map(|t| t.real_value())
.unwrap_or_else(|| field.name.to_string()),
};
field.name.clone_from(&name);
self.context.columns.push(ColumnBinding::new(
table_name.clone(),
begin + index,
is_hidden,
field,
));
self.context
.indices_of
.entry(name)
.or_default()
.push(self.context.columns.len() - 1);
index += 1;
});
let num_cols = index;
if num_cols < num_col_aliases {
return Err(ErrorCode::BindError(format!(
"table \"{table_name}\" has {num_cols} columns available but {num_col_aliases} column aliases specified",
))
.into());
}
match self.context.range_of.entry(table_name.clone()) {
Entry::Occupied(_) => Err(ErrorCode::InternalError(format!(
"Duplicated table name while binding table to context: {}",
table_name
))
.into()),
Entry::Vacant(entry) => {
entry.insert((begin, self.context.columns.len()));
Ok(())
}
}
}
pub fn bind_relation_by_name(
&mut self,
name: ObjectName,
alias: Option<TableAlias>,
as_of: Option<AsOf>,
) -> Result<Relation> {
let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;
if schema_name.is_none()
&& let Some(item) = self.context.cte_to_relation.get(&table_name)
{
if as_of.is_some() {
return Err(ErrorCode::BindError(
"Right table of a temporal join should not be a CTE. \
It should be a table, index, or materialized view"
.to_string(),
)
.into());
}
let BindingCte {
share_id,
state: cte_state,
alias: mut original_alias,
} = item.deref().borrow().clone();
debug_assert_eq!(original_alias.name.real_value(), table_name);
if let Some(from_alias) = alias {
original_alias.name = from_alias.name;
original_alias.columns = original_alias
.columns
.into_iter()
.zip_longest(from_alias.columns)
.map(EitherOrBoth::into_right)
.collect();
}
match cte_state {
BindingCteState::Init => {
Err(ErrorCode::BindError("Base term of recursive CTE not found, consider writing it to left side of the `UNION ALL` operator".to_string()).into())
}
BindingCteState::BaseResolved { base } => {
self.bind_table_to_context(
base.schema().fields.iter().map(|f| (false, f.clone())),
table_name.clone(),
Some(original_alias),
)?;
Ok(Relation::BackCteRef(Box::new(BoundBackCteRef { share_id, base })))
}
BindingCteState::Bound { query } => {
let input = BoundShareInput::Query(query);
self.bind_table_to_context(
input.fields()?,
table_name.clone(),
Some(original_alias),
)?;
Ok(Relation::Share(Box::new(BoundShare { share_id, input})))
}
BindingCteState::ChangeLog { table } => {
let input = BoundShareInput::ChangeLog(table);
self.bind_table_to_context(
input.fields()?,
table_name.clone(),
Some(original_alias),
)?;
Ok(Relation::Share(Box::new(BoundShare { share_id, input })))
},
}
} else {
self.bind_relation_by_name_inner(schema_name.as_deref(), &table_name, alias, as_of)
}
}
fn bind_relation_by_function_arg(
&mut self,
arg: Option<FunctionArg>,
err_msg: &str,
) -> Result<(Relation, ObjectName)> {
let Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))) = arg else {
return Err(ErrorCode::BindError(err_msg.to_string()).into());
};
let table_name = match expr {
ParserExpr::Identifier(ident) => Ok::<_, RwError>(ObjectName(vec![ident])),
ParserExpr::CompoundIdentifier(idents) => Ok(ObjectName(idents)),
_ => Err(ErrorCode::BindError(err_msg.to_string()).into()),
}?;
Ok((
self.bind_relation_by_name(table_name.clone(), None, None)?,
table_name,
))
}
fn bind_column_by_function_args(
&mut self,
arg: Option<FunctionArg>,
err_msg: &str,
) -> Result<Box<InputRef>> {
if let Some(time_col_arg) = arg
&& let Some(ExprImpl::InputRef(time_col)) =
self.bind_function_arg(time_col_arg)?.into_iter().next()
{
Ok(time_col)
} else {
Err(ErrorCode::BindError(err_msg.to_string()).into())
}
}
fn bind_internal_table(
&mut self,
args: Vec<FunctionArg>,
alias: Option<TableAlias>,
) -> Result<Relation> {
if args.is_empty() || args.len() > 2 {
return Err(ErrorCode::BindError(
"usage: rw_table(table_id[,schema_name])".to_string(),
)
.into());
}
let table_id: TableId = args[0]
.to_string()
.parse::<u32>()
.map_err(|err| {
RwError::from(ErrorCode::BindError(format!(
"invalid table id: {}",
err.as_report()
)))
})?
.into();
let schema = args
.get(1)
.map_or(DEFAULT_SCHEMA_NAME.to_string(), |arg| arg.to_string());
let table_name = self.catalog.get_table_name_by_id(table_id)?;
self.bind_relation_by_name_inner(Some(&schema), &table_name, alias, None)
}
pub(super) fn bind_table_factor(&mut self, table_factor: TableFactor) -> Result<Relation> {
match table_factor {
TableFactor::Table { name, alias, as_of } => {
self.bind_relation_by_name(name, alias, as_of)
}
TableFactor::TableFunction {
name,
alias,
args,
with_ordinality,
} => {
self.try_mark_lateral_as_visible();
let result = self.bind_table_function(name, alias, args, with_ordinality);
self.try_mark_lateral_as_invisible();
result
}
TableFactor::Derived {
lateral,
subquery,
alias,
} => {
if lateral {
self.try_mark_lateral_as_visible();
let bound_subquery = self.bind_subquery_relation(*subquery, alias, true)?;
self.try_mark_lateral_as_invisible();
Ok(Relation::Subquery(Box::new(bound_subquery)))
} else {
self.push_lateral_context();
let bound_subquery = self.bind_subquery_relation(*subquery, alias, false)?;
self.pop_and_merge_lateral_context()?;
Ok(Relation::Subquery(Box::new(bound_subquery)))
}
}
TableFactor::NestedJoin(table_with_joins) => {
self.push_lateral_context();
let bound_join = self.bind_table_with_joins(*table_with_joins)?;
self.pop_and_merge_lateral_context()?;
Ok(bound_join)
}
}
}
}