risingwave_frontend/expr/
utils.rsuse std::collections::VecDeque;
use fixedbitset::FixedBitSet;
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_pb::expr::expr_node::Type;
use super::now::RewriteNowToProcTime;
use super::{Expr, ExprImpl, ExprRewriter, ExprVisitor, FunctionCall, InputRef};
use crate::expr::ExprType;
fn split_expr_by(expr: ExprImpl, op: ExprType, rets: &mut Vec<ExprImpl>) {
match expr {
ExprImpl::FunctionCall(func_call) if func_call.func_type() == op => {
let (_, exprs, _) = func_call.decompose();
for expr in exprs {
split_expr_by(expr, op, rets);
}
}
_ => rets.push(expr),
}
}
pub(super) fn merge_expr_by_logical<I>(exprs: I, op: ExprType, identity_elem: ExprImpl) -> ExprImpl
where
I: IntoIterator<Item = ExprImpl>,
{
let mut exprs: VecDeque<_> = exprs.into_iter().map(|e| (0usize, e)).collect();
while exprs.len() > 1 {
let (level, lhs) = exprs.pop_front().unwrap();
let rhs_level = exprs.front().unwrap().0;
if level < rhs_level {
exprs.push_back((level, lhs));
} else {
let rhs = exprs.pop_front().unwrap().1;
let new_expr = FunctionCall::new(op, vec![lhs, rhs]).unwrap().into();
exprs.push_back((level + 1, new_expr));
}
}
exprs.pop_front().map(|(_, e)| e).unwrap_or(identity_elem)
}
pub fn to_conjunctions(expr: ExprImpl) -> Vec<ExprImpl> {
let mut rets = vec![];
split_expr_by(expr, ExprType::And, &mut rets);
rets
}
pub fn to_disjunctions(expr: ExprImpl) -> Vec<ExprImpl> {
let mut rets = vec![];
split_expr_by(expr, ExprType::Or, &mut rets);
rets
}
pub fn fold_boolean_constant(expr: ExprImpl) -> ExprImpl {
let mut rewriter = BooleanConstantFolding {};
rewriter.rewrite_expr(expr)
}
pub fn column_self_eq_eliminate(expr: ExprImpl) -> ExprImpl {
ColumnSelfEqualRewriter::rewrite(expr)
}
pub struct ColumnSelfEqualRewriter {}
impl ColumnSelfEqualRewriter {
fn extract_column(expr: ExprImpl, columns: &mut Vec<ExprImpl>) {
match expr.clone() {
ExprImpl::FunctionCall(func_call) => {
if Self::is_not_null(func_call.func_type()) {
return;
}
for sub_expr in func_call.inputs() {
Self::extract_column(sub_expr.clone(), columns);
}
}
ExprImpl::InputRef(_) => {
if !columns.contains(&expr) {
columns.push(expr);
}
}
_ => (),
}
}
fn is_not_null(func_type: ExprType) -> bool {
func_type == ExprType::IsNull
|| func_type == ExprType::IsNotNull
|| func_type == ExprType::IsTrue
|| func_type == ExprType::IsFalse
|| func_type == ExprType::IsNotTrue
|| func_type == ExprType::IsNotFalse
}
pub fn rewrite(expr: ExprImpl) -> ExprImpl {
let mut columns = vec![];
Self::extract_column(expr.clone(), &mut columns);
if columns.len() > 1 {
return expr;
}
let ExprImpl::FunctionCall(func_call) = expr.clone() else {
return expr;
};
if func_call.func_type() != ExprType::Equal || func_call.inputs().len() != 2 {
return expr;
}
assert_eq!(func_call.return_type(), DataType::Boolean);
let inputs = func_call.inputs();
let e1 = inputs[0].clone();
let e2 = inputs[1].clone();
if e1 == e2 {
if columns.is_empty() {
return ExprImpl::literal_bool(true);
}
let Ok(ret) = FunctionCall::new(ExprType::IsNotNull, vec![columns[0].clone()]) else {
return expr;
};
ret.into()
} else {
expr
}
}
}
struct BooleanConstantFolding {}
impl ExprRewriter for BooleanConstantFolding {
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
let (func_type, inputs, ret) = func_call.decompose();
let inputs: Vec<_> = inputs
.into_iter()
.map(|expr| self.rewrite_expr(expr))
.collect();
let bool_constant_values: Vec<Option<bool>> =
inputs.iter().map(try_get_bool_constant).collect();
let contains_bool_constant = bool_constant_values.iter().any(|x| x.is_some());
let prepare_binary_function_inputs = |mut inputs: Vec<ExprImpl>| -> (ExprImpl, ExprImpl) {
assert_eq!(inputs.len(), 2);
let rhs = inputs.pop().unwrap();
let lhs = inputs.pop().unwrap();
if bool_constant_values[0].is_some() {
(lhs, rhs)
} else {
(rhs, lhs)
}
};
match func_type {
Type::Not => {
let input = inputs.first().unwrap();
if let Some(v) = try_get_bool_constant(input) {
return ExprImpl::literal_bool(!v);
}
}
Type::IsFalse => {
let input = inputs.first().unwrap();
if input.is_null() {
return ExprImpl::literal_bool(false);
}
if let Some(v) = try_get_bool_constant(input) {
return ExprImpl::literal_bool(!v);
}
}
Type::IsTrue => {
let input = inputs.first().unwrap();
if input.is_null() {
return ExprImpl::literal_bool(false);
}
if let Some(v) = try_get_bool_constant(input) {
return ExprImpl::literal_bool(v);
}
}
Type::IsNull => {
let input = inputs.first().unwrap();
if input.is_null() {
return ExprImpl::literal_bool(true);
}
}
Type::IsNotTrue => {
let input = inputs.first().unwrap();
if input.is_null() {
return ExprImpl::literal_bool(true);
}
if let Some(v) = try_get_bool_constant(input) {
return ExprImpl::literal_bool(!v);
}
}
Type::IsNotFalse => {
let input = inputs.first().unwrap();
if input.is_null() {
return ExprImpl::literal_bool(true);
}
if let Some(v) = try_get_bool_constant(input) {
return ExprImpl::literal_bool(v);
}
}
Type::IsNotNull => {
let input = inputs.first().unwrap();
if let ExprImpl::Literal(lit) = input {
return ExprImpl::literal_bool(lit.get_data().is_some());
}
}
Type::And if contains_bool_constant => {
let (constant_lhs, rhs) = prepare_binary_function_inputs(inputs);
return boolean_constant_fold_and(constant_lhs, rhs);
}
Type::Or if contains_bool_constant => {
let (constant_lhs, rhs) = prepare_binary_function_inputs(inputs);
return boolean_constant_fold_or(constant_lhs, rhs);
}
_ => {}
}
FunctionCall::new_unchecked(func_type, inputs, ret).into()
}
}
pub fn try_get_bool_constant(expr: &ExprImpl) -> Option<bool> {
if let ExprImpl::Literal(l) = expr {
if let Some(ScalarImpl::Bool(v)) = l.get_data() {
return Some(*v);
}
}
None
}
fn boolean_constant_fold_and(constant_lhs: ExprImpl, rhs: ExprImpl) -> ExprImpl {
if try_get_bool_constant(&constant_lhs).unwrap() {
rhs
} else {
constant_lhs
}
}
fn boolean_constant_fold_or(constant_lhs: ExprImpl, rhs: ExprImpl) -> ExprImpl {
if try_get_bool_constant(&constant_lhs).unwrap() {
constant_lhs
} else {
rhs
}
}
pub fn push_down_not(expr: ExprImpl) -> ExprImpl {
let mut not_push_down = NotPushDown {};
not_push_down.rewrite_expr(expr)
}
struct NotPushDown {}
impl ExprRewriter for NotPushDown {
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
let (func_type, mut inputs, ret) = func_call.decompose();
if func_type != Type::Not {
let inputs = inputs
.into_iter()
.map(|expr| self.rewrite_expr(expr))
.collect();
FunctionCall::new_unchecked(func_type, inputs, ret).into()
} else {
assert_eq!(inputs.len(), 1);
let input = inputs.pop().unwrap();
let rewritten_not_expr = match input {
ExprImpl::FunctionCall(func) => {
let (func_type, mut inputs, ret) = func.decompose();
match func_type {
Type::Not => {
assert_eq!(inputs.len(), 1);
Ok(inputs.pop().unwrap())
}
Type::And => {
assert_eq!(inputs.len(), 2);
let rhs = inputs.pop().unwrap();
let lhs = inputs.pop().unwrap();
let rhs_not: ExprImpl =
FunctionCall::new(Type::Not, vec![rhs]).unwrap().into();
let lhs_not: ExprImpl =
FunctionCall::new(Type::Not, vec![lhs]).unwrap().into();
Ok(FunctionCall::new(Type::Or, vec![lhs_not, rhs_not])
.unwrap()
.into())
}
Type::Or => {
assert_eq!(inputs.len(), 2);
let rhs = inputs.pop().unwrap();
let lhs = inputs.pop().unwrap();
let rhs_not: ExprImpl =
FunctionCall::new(Type::Not, vec![rhs]).unwrap().into();
let lhs_not: ExprImpl =
FunctionCall::new(Type::Not, vec![lhs]).unwrap().into();
Ok(FunctionCall::new(Type::And, vec![lhs_not, rhs_not])
.unwrap()
.into())
}
_ => Err(FunctionCall::new_unchecked(func_type, inputs, ret).into()),
}
}
_ => Err(input),
};
match rewritten_not_expr {
Ok(res) => self.rewrite_expr(res),
Err(input) => FunctionCall::new(Type::Not, vec![self.rewrite_expr(input)])
.unwrap()
.into(),
}
}
}
}
pub fn factorization_expr(expr: ExprImpl) -> Vec<ExprImpl> {
let disjunctions: Vec<ExprImpl> = to_disjunctions(expr);
if disjunctions.len() == 1 {
return disjunctions;
}
let mut disjunctions: Vec<Vec<_>> = disjunctions
.into_iter()
.map(|x| to_conjunctions(x).into_iter().collect())
.collect();
let (last, remaining) = disjunctions.split_last_mut().unwrap();
let greatest_common_divider: Vec<_> = last
.extract_if(|factor| remaining.iter().all(|expr| expr.contains(factor)))
.collect();
for disjunction in remaining {
disjunction.retain(|factor| !greatest_common_divider.contains(factor));
}
let remaining = ExprImpl::or(disjunctions.into_iter().map(ExprImpl::and));
greatest_common_divider
.into_iter()
.chain(std::iter::once(remaining))
.map(fold_boolean_constant)
.collect()
}
macro_rules! assert_input_ref {
($expr:expr, $input_col_num:expr) => {
let _ = $expr.collect_input_refs($input_col_num);
};
}
pub(crate) use assert_input_ref;
#[derive(Clone)]
pub struct CollectInputRef {
input_bits: FixedBitSet,
}
impl ExprVisitor for CollectInputRef {
fn visit_input_ref(&mut self, expr: &InputRef) {
self.input_bits.insert(expr.index());
}
}
impl CollectInputRef {
pub fn new(initial_input_bits: FixedBitSet) -> Self {
CollectInputRef {
input_bits: initial_input_bits,
}
}
pub fn with_capacity(capacity: usize) -> Self {
CollectInputRef {
input_bits: FixedBitSet::with_capacity(capacity),
}
}
}
impl From<CollectInputRef> for FixedBitSet {
fn from(s: CollectInputRef) -> Self {
s.input_bits
}
}
impl Extend<usize> for CollectInputRef {
fn extend<T: IntoIterator<Item = usize>>(&mut self, iter: T) {
self.input_bits.extend(iter);
}
}
pub fn collect_input_refs<'a>(
input_col_num: usize,
exprs: impl IntoIterator<Item = &'a ExprImpl>,
) -> FixedBitSet {
let mut input_ref_collector = CollectInputRef::with_capacity(input_col_num);
for expr in exprs {
input_ref_collector.visit_expr(expr);
}
input_ref_collector.into()
}
#[derive(Clone, Default)]
pub struct CountNow {
count: usize,
}
impl CountNow {
pub fn count(&self) -> usize {
self.count
}
}
impl ExprVisitor for CountNow {
fn visit_now(&mut self, _: &super::Now) {
self.count += 1;
}
}
pub fn rewrite_now_to_proctime(expr: ExprImpl) -> ExprImpl {
let mut r = RewriteNowToProcTime;
r.rewrite_expr(expr)
}
#[cfg(test)]
mod tests {
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_pb::expr::expr_node::Type;
use super::{fold_boolean_constant, push_down_not};
use crate::expr::{ExprImpl, FunctionCall, InputRef};
#[test]
fn constant_boolean_folding_basic_and() {
let expr: ExprImpl = FunctionCall::new(
Type::And,
vec![
InputRef::new(0, DataType::Boolean).into(),
ExprImpl::literal_bool(true),
],
)
.unwrap()
.into();
let res = fold_boolean_constant(expr);
assert!(res.as_input_ref().is_some());
let res = res.as_input_ref().unwrap();
assert_eq!(res.index(), 0);
let expr: ExprImpl = FunctionCall::new(
Type::And,
vec![
InputRef::new(0, DataType::Boolean).into(),
ExprImpl::literal_bool(false),
],
)
.unwrap()
.into();
let res = fold_boolean_constant(expr);
assert!(res.as_literal().is_some());
let res = res.as_literal().unwrap();
assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(false)));
}
#[test]
fn constant_boolean_folding_basic_or() {
let expr: ExprImpl = FunctionCall::new(
Type::Or,
vec![
InputRef::new(0, DataType::Boolean).into(),
ExprImpl::literal_bool(true),
],
)
.unwrap()
.into();
let res = fold_boolean_constant(expr);
assert!(res.as_literal().is_some());
let res = res.as_literal().unwrap();
assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(true)));
let expr: ExprImpl = FunctionCall::new(
Type::Or,
vec![
InputRef::new(0, DataType::Boolean).into(),
ExprImpl::literal_bool(false),
],
)
.unwrap()
.into();
let res = fold_boolean_constant(expr);
assert!(res.as_input_ref().is_some());
let res = res.as_input_ref().unwrap();
assert_eq!(res.index(), 0);
}
#[test]
fn constant_boolean_folding_complex() {
let expr: ExprImpl = FunctionCall::new(
Type::And,
vec![
FunctionCall::new(
Type::And,
vec![ExprImpl::literal_bool(false), ExprImpl::literal_bool(true)],
)
.unwrap()
.into(),
FunctionCall::new(
Type::Or,
vec![
ExprImpl::literal_bool(true),
FunctionCall::new(
Type::Equal,
vec![ExprImpl::literal_int(1), ExprImpl::literal_int(2)],
)
.unwrap()
.into(),
],
)
.unwrap()
.into(),
],
)
.unwrap()
.into();
let res = fold_boolean_constant(expr);
assert!(res.as_literal().is_some());
let res = res.as_literal().unwrap();
assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(false)));
}
#[test]
fn not_push_down_test() {
let expr: ExprImpl = FunctionCall::new(
Type::Not,
vec![
FunctionCall::new(Type::Not, vec![InputRef::new(0, DataType::Boolean).into()])
.unwrap()
.into(),
],
)
.unwrap()
.into();
let res = push_down_not(expr);
assert!(res.as_input_ref().is_some());
let expr: ExprImpl = FunctionCall::new(
Type::Not,
vec![FunctionCall::new(
Type::And,
vec![
InputRef::new(0, DataType::Boolean).into(),
FunctionCall::new(Type::Not, vec![InputRef::new(1, DataType::Boolean).into()])
.unwrap()
.into(),
],
)
.unwrap()
.into()],
)
.unwrap()
.into();
let res = push_down_not(expr);
assert!(res.as_function_call().is_some());
let res = res.as_function_call().unwrap().clone();
let (func, lhs, rhs) = res.decompose_as_binary();
assert_eq!(func, Type::Or);
assert!(rhs.as_input_ref().is_some());
assert!(lhs.as_function_call().is_some());
let lhs = lhs.as_function_call().unwrap().clone();
let (func, input) = lhs.decompose_as_unary();
assert_eq!(func, Type::Not);
assert!(input.as_input_ref().is_some());
let expr: ExprImpl = FunctionCall::new(
Type::Not,
vec![FunctionCall::new(
Type::Or,
vec![
InputRef::new(0, DataType::Boolean).into(),
InputRef::new(1, DataType::Boolean).into(),
],
)
.unwrap()
.into()],
)
.unwrap()
.into();
let res = push_down_not(expr);
assert!(res.as_function_call().is_some());
let (func_type, lhs, rhs) = res
.as_function_call()
.unwrap()
.clone()
.decompose_as_binary();
assert_eq!(func_type, Type::And);
let (lhs_type, lhs_input) = lhs.as_function_call().unwrap().clone().decompose_as_unary();
assert_eq!(lhs_type, Type::Not);
assert!(lhs_input.as_input_ref().is_some());
let (rhs_type, rhs_input) = rhs.as_function_call().unwrap().clone().decompose_as_unary();
assert_eq!(rhs_type, Type::Not);
assert!(rhs_input.as_input_ref().is_some());
}
}