use std::iter::Peekable;
use itertools::Itertools;
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_pb::expr::expr_node::{PbType, RexNode};
use risingwave_pb::expr::ExprNode;
use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UserDefinedFunction;
use super::strict::Strict;
use super::wrapper::checked::Checked;
use super::wrapper::non_strict::NonStrict;
use super::wrapper::EvalErrorReport;
use super::NonStrictExpression;
use crate::expr::{
BoxedExpression, Expression, ExpressionBoxExt, InputRefExpression, LiteralExpression,
};
use crate::sig::FUNCTION_REGISTRY;
use crate::{bail, Result};
pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
let expr = ExprBuilder::new_strict().build(prost)?;
Ok(Strict::new(expr).boxed())
}
pub fn build_non_strict_from_prost(
prost: &ExprNode,
error_report: impl EvalErrorReport + 'static,
) -> Result<NonStrictExpression> {
ExprBuilder::new_non_strict(error_report)
.build(prost)
.map(NonStrictExpression)
}
struct ExprBuilder<R> {
error_report: Option<R>,
}
impl ExprBuilder<!> {
fn new_strict() -> Self {
Self { error_report: None }
}
}
impl<R> ExprBuilder<R>
where
R: EvalErrorReport + 'static,
{
fn new_non_strict(error_report: R) -> Self {
Self {
error_report: Some(error_report),
}
}
#[expect(clippy::let_and_return)]
fn wrap(&self, expr: impl Expression + 'static) -> BoxedExpression {
let checked = Checked(expr);
let may_non_strict = if let Some(error_report) = &self.error_report {
NonStrict::new(checked, error_report.clone()).boxed()
} else {
checked.boxed()
};
may_non_strict
}
fn build(&self, prost: &ExprNode) -> Result<BoxedExpression> {
let expr = self.build_inner(prost)?;
Ok(self.wrap(expr))
}
fn build_inner(&self, prost: &ExprNode) -> Result<BoxedExpression> {
use PbType as E;
let build_child = |prost: &'_ ExprNode| self.build(prost);
match prost.get_rex_node()? {
RexNode::InputRef(_) => InputRefExpression::build_boxed(prost, build_child),
RexNode::Constant(_) => LiteralExpression::build_boxed(prost, build_child),
RexNode::Udf(_) => UserDefinedFunction::build_boxed(prost, build_child),
RexNode::FuncCall(_) => match prost.function_type() {
E::All | E::Some => SomeAllExpression::build_boxed(prost, build_child),
_ => FuncCallBuilder::build_boxed(prost, build_child),
},
RexNode::Now(_) => unreachable!("now should not be built at backend"),
}
}
}
pub(crate) trait Build: Expression + Sized {
fn build(
prost: &ExprNode,
build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
) -> Result<Self>;
#[cfg(test)]
fn build_for_test(prost: &ExprNode) -> Result<Self> {
Self::build(prost, build_from_prost)
}
}
pub(crate) trait BuildBoxed: 'static {
fn build_boxed(
prost: &ExprNode,
build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
) -> Result<BoxedExpression>;
}
impl<E: Build + 'static> BuildBoxed for E {
fn build_boxed(
prost: &ExprNode,
build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
) -> Result<BoxedExpression> {
Self::build(prost, build_child).map(ExpressionBoxExt::boxed)
}
}
struct FuncCallBuilder;
impl BuildBoxed for FuncCallBuilder {
fn build_boxed(
prost: &ExprNode,
build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
) -> Result<BoxedExpression> {
let func_type = prost.function_type();
let ret_type = DataType::from(prost.get_return_type().unwrap());
let func_call = prost
.get_rex_node()?
.as_func_call()
.expect("not a func call");
let children = func_call
.get_children()
.iter()
.map(build_child)
.try_collect()?;
build_func(func_type, ret_type, children)
}
}
pub fn build_func(
func: PbType,
ret_type: DataType,
children: Vec<BoxedExpression>,
) -> Result<BoxedExpression> {
let args = children.iter().map(|c| c.return_type()).collect_vec();
let desc = FUNCTION_REGISTRY.get(func, &args, &ret_type)?;
desc.build_scalar(ret_type, children)
}
pub fn build_func_non_strict(
func: PbType,
ret_type: DataType,
children: Vec<BoxedExpression>,
error_report: impl EvalErrorReport + 'static,
) -> Result<NonStrictExpression> {
let expr = build_func(func, ret_type, children)?;
let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr));
Ok(wrapped)
}
pub(super) fn get_children_and_return_type(prost: &ExprNode) -> Result<(&[ExprNode], DataType)> {
let ret_type = DataType::from(prost.get_return_type().unwrap());
if let RexNode::FuncCall(func_call) = prost.get_rex_node().unwrap() {
Ok((func_call.get_children(), ret_type))
} else {
bail!("Expected RexNode::FuncCall");
}
}
pub fn build_from_pretty(s: impl AsRef<str>) -> BoxedExpression {
let tokens = lexer(s.as_ref());
Parser::new(tokens.into_iter()).parse_expression()
}
struct Parser<Iter: Iterator> {
tokens: Peekable<Iter>,
}
impl<Iter: Iterator<Item = Token>> Parser<Iter> {
fn new(tokens: Iter) -> Self {
Self {
tokens: tokens.peekable(),
}
}
fn parse_expression(&mut self) -> BoxedExpression {
match self.tokens.next().expect("Unexpected end of input") {
Token::Index(index) => {
assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
let ty = self.parse_type();
InputRefExpression::new(ty, index).boxed()
}
Token::LParen => {
let func = self.parse_function();
assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
let ty = self.parse_type();
let mut children = Vec::new();
while self.tokens.peek() != Some(&Token::RParen) {
children.push(self.parse_expression());
}
self.tokens.next(); build_func(func, ty, children).expect("Failed to build")
}
Token::Literal(value) => {
assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
let ty = self.parse_type();
let value = match value.as_str() {
"null" | "NULL" => None,
_ => Some(ScalarImpl::from_text(&value, &ty).expect_str("value", &value)),
};
LiteralExpression::new(ty, value).boxed()
}
_ => panic!("Unexpected token"),
}
}
fn parse_type(&mut self) -> DataType {
match self.tokens.next().expect("Unexpected end of input") {
Token::Literal(name) => name
.replace('_', " ")
.parse::<DataType>()
.expect_str("type", &name),
t => panic!("Expected a Literal, got {t:?}"),
}
}
fn parse_function(&mut self) -> PbType {
match self.tokens.next().expect("Unexpected end of input") {
Token::Literal(name) => {
PbType::from_str_name(&name.to_uppercase()).expect_str("function", &name)
}
t => panic!("Expected a Literal, got {t:?}"),
}
}
}
#[derive(Debug, PartialEq, Clone)]
pub(crate) enum Token {
LParen,
RParen,
Colon,
Index(usize),
Literal(String),
}
pub(crate) fn lexer(input: &str) -> Vec<Token> {
let mut tokens = Vec::new();
let mut chars = input.chars().peekable();
while let Some(c) = chars.next() {
let token = match c {
'(' => Token::LParen,
')' => Token::RParen,
':' => Token::Colon,
'$' => {
let mut number = String::new();
while let Some(c) = chars.peek()
&& c.is_ascii_digit()
{
number.push(chars.next().unwrap());
}
let index = number.parse::<usize>().expect("Invalid number");
Token::Index(index)
}
' ' | '\t' | '\r' | '\n' => continue,
_ => {
let mut literal = String::new();
literal.push(c);
while let Some(&c) = chars.peek()
&& !matches!(c, '(' | ')' | ':' | ' ' | '\t' | '\r' | '\n')
{
literal.push(chars.next().unwrap());
}
Token::Literal(literal)
}
};
tokens.push(token);
}
tokens
}
pub(crate) trait ExpectExt<T> {
fn expect_str(self, what: &str, s: &str) -> T;
}
impl<T> ExpectExt<T> for Option<T> {
#[track_caller]
fn expect_str(self, what: &str, s: &str) -> T {
match self {
Some(x) => x,
None => panic!("expect {what} in {s:?}"),
}
}
}
impl<T, E> ExpectExt<T> for std::result::Result<T, E> {
#[track_caller]
fn expect_str(self, what: &str, s: &str) -> T {
match self {
Ok(x) => x,
Err(_) => panic!("expect {what} in {s:?}"),
}
}
}