use std::collections::{BTreeMap, HashSet};
use std::fmt::{self, Debug};
use std::ops::Bound;
use std::rc::Rc;
use std::sync::LazyLock;
use fixedbitset::FixedBitSet;
use itertools::Itertools;
use risingwave_common::catalog::{Schema, TableDesc};
use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl};
use risingwave_common::util::scan_range::{is_full_range, ScanRange};
use crate::error::Result;
use crate::expr::{
collect_input_refs, column_self_eq_eliminate, factorization_expr, fold_boolean_constant,
push_down_not, to_conjunctions, try_get_bool_constant, ExprDisplay, ExprImpl, ExprMutator,
ExprRewriter, ExprType, ExprVisitor, FunctionCall, InequalityInputPair, InputRef,
};
use crate::utils::condition::cast_compare::{ResultForCmp, ResultForEq};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Condition {
pub conjunctions: Vec<ExprImpl>,
}
impl IntoIterator for Condition {
type IntoIter = std::vec::IntoIter<ExprImpl>;
type Item = ExprImpl;
fn into_iter(self) -> Self::IntoIter {
self.conjunctions.into_iter()
}
}
impl fmt::Display for Condition {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut conjunctions = self.conjunctions.iter();
if let Some(expr) = conjunctions.next() {
write!(f, "{:?}", expr)?;
}
if self.always_true() {
write!(f, "true")?;
} else {
for expr in conjunctions {
write!(f, " AND {:?}", expr)?;
}
}
Ok(())
}
}
impl Condition {
pub fn with_expr(expr: ExprImpl) -> Self {
let conjunctions = to_conjunctions(expr);
Self { conjunctions }.simplify()
}
pub fn true_cond() -> Self {
Self {
conjunctions: vec![],
}
}
pub fn false_cond() -> Self {
Self {
conjunctions: vec![ExprImpl::literal_bool(false)],
}
}
pub fn always_true(&self) -> bool {
self.conjunctions.is_empty()
}
pub fn always_false(&self) -> bool {
static FALSE: LazyLock<ExprImpl> = LazyLock::new(|| ExprImpl::literal_bool(false));
!self.conjunctions.is_empty() && self.conjunctions.iter().any(|e| *e == *FALSE)
}
pub fn as_expr_unless_true(&self) -> Option<ExprImpl> {
if self.always_true() {
None
} else {
Some(self.clone().into())
}
}
#[must_use]
pub fn and(self, other: Self) -> Self {
let mut ret = self;
ret.conjunctions.extend(other.conjunctions);
ret.simplify()
}
#[must_use]
pub fn or(self, other: Self) -> Self {
let or_expr = ExprImpl::FunctionCall(
FunctionCall::new_unchecked(
ExprType::Or,
vec![self.into(), other.into()],
DataType::Boolean,
)
.into(),
);
let ret = Self::with_expr(or_expr);
ret.simplify()
}
#[must_use]
pub fn split(self, left_col_num: usize, right_col_num: usize) -> (Self, Self, Self) {
let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
self.group_by::<_, 3>(|expr| {
let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
if input_bits.is_subset(&left_bit_map) {
0
} else if input_bits.is_subset(&right_bit_map) {
1
} else {
2
}
})
.into_iter()
.next_tuple()
.unwrap()
}
pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
collect_input_refs(input_col_num, &self.conjunctions)
}
#[must_use]
pub fn split_by_input_col_nums(
self,
input_col_nums: &[usize],
only_eq: bool,
) -> (BTreeMap<(usize, usize), Self>, Self) {
let mut bitmaps = Vec::with_capacity(input_col_nums.len());
let mut cols_seen = 0;
for cols in input_col_nums {
bitmaps.push(FixedBitSet::from_iter(cols_seen..cols_seen + cols));
cols_seen += cols;
}
let mut pairwise_conditions = BTreeMap::new();
let mut non_eq_join = vec![];
for expr in self.conjunctions {
let input_bits = expr.collect_input_refs(cols_seen);
let mut subset_indices = Vec::with_capacity(input_col_nums.len());
for (idx, bitmap) in bitmaps.iter().enumerate() {
if !input_bits.is_disjoint(bitmap) {
subset_indices.push(idx);
}
}
if subset_indices.len() != 2 || (only_eq && expr.as_eq_cond().is_none()) {
non_eq_join.push(expr);
} else {
let key = if subset_indices[0] < subset_indices[1] {
(subset_indices[0], subset_indices[1])
} else {
(subset_indices[1], subset_indices[0])
};
let e = pairwise_conditions
.entry(key)
.or_insert_with(Condition::true_cond);
e.conjunctions.push(expr);
}
}
(
pairwise_conditions,
Condition {
conjunctions: non_eq_join,
},
)
}
#[must_use]
pub fn split_eq_keys(
self,
left_col_num: usize,
right_col_num: usize,
) -> (Vec<(InputRef, InputRef, bool)>, Self) {
let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
let (mut eq_keys, mut others) = (vec![], vec![]);
self.conjunctions.into_iter().for_each(|expr| {
let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
others.push(expr)
} else if let Some(columns) = expr.as_eq_cond() {
eq_keys.push((columns.0, columns.1, false));
} else if let Some(columns) = expr.as_is_not_distinct_from_cond() {
eq_keys.push((columns.0, columns.1, true));
} else {
others.push(expr)
}
});
(
eq_keys,
Condition {
conjunctions: others,
},
)
}
pub(crate) fn extract_inequality_keys(
&self,
left_col_num: usize,
right_col_num: usize,
) -> Vec<(usize, InequalityInputPair)> {
let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
self.conjunctions
.iter()
.enumerate()
.filter_map(|(conjunction_idx, expr)| {
let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
None
} else {
expr.as_input_comparison_cond()
.map(|inequality_pair| (conjunction_idx, inequality_pair))
}
})
.collect_vec()
}
#[must_use]
pub fn split_disjoint(self, columns: &FixedBitSet) -> (Self, Self) {
self.group_by::<_, 2>(|expr| {
let input_bits = expr.collect_input_refs(columns.len());
input_bits.is_disjoint(columns) as usize
})
.into_iter()
.next_tuple()
.unwrap()
}
fn disjunctions_to_scan_ranges(
table_desc: Rc<TableDesc>,
max_split_range_gap: u64,
disjunctions: Vec<ExprImpl>,
) -> Result<Option<(Vec<ScanRange>, Self)>> {
let disjunctions_result: Result<Vec<(Vec<ScanRange>, Self)>> = disjunctions
.into_iter()
.map(|x| {
Condition {
conjunctions: to_conjunctions(x),
}
.split_to_scan_ranges(table_desc.clone(), max_split_range_gap)
})
.collect();
let disjunctions_result = disjunctions_result?;
let all_equal = disjunctions_result
.iter()
.all(|(scan_ranges, other_condition)| {
other_condition.always_true()
&& scan_ranges
.iter()
.all(|x| !x.eq_conds.is_empty() && is_full_range(&x.range))
});
if all_equal {
let scan_ranges = disjunctions_result
.into_iter()
.flat_map(|(scan_ranges, _)| scan_ranges)
.sorted_by(|a, b| a.eq_conds.len().cmp(&b.eq_conds.len()))
.collect_vec();
let mut non_overlap_scan_ranges: Vec<ScanRange> = vec![];
for s1 in &scan_ranges {
let overlap = non_overlap_scan_ranges.iter().any(|s2| {
#[allow(clippy::disallowed_methods)]
s1.eq_conds
.iter()
.zip(s2.eq_conds.iter())
.all(|(a, b)| a == b)
});
if !overlap {
non_overlap_scan_ranges.push(s1.clone());
}
}
Ok(Some((non_overlap_scan_ranges, Condition::true_cond())))
} else {
Ok(None)
}
}
pub fn split_to_scan_ranges(
self,
table_desc: Rc<TableDesc>,
max_split_range_gap: u64,
) -> Result<(Vec<ScanRange>, Self)> {
fn false_cond() -> (Vec<ScanRange>, Condition) {
(vec![], Condition::false_cond())
}
if self.conjunctions.len() == 1 {
if let Some(disjunctions) = self.conjunctions[0].as_or_disjunctions() {
if let Some((scan_ranges, other_condition)) = Self::disjunctions_to_scan_ranges(
table_desc,
max_split_range_gap,
disjunctions,
)? {
return Ok((scan_ranges, other_condition));
} else {
return Ok((vec![], self));
}
}
}
let mut groups = Self::classify_conjunctions_by_pk(self.conjunctions, &table_desc);
let mut scan_range = ScanRange::full_table_scan();
let mut other_conds = groups.pop().unwrap();
for i in 0..table_desc.order_column_indices().len() {
let group = std::mem::take(&mut groups[i]);
if group.is_empty() {
groups.push(other_conds);
return Ok((
if scan_range.is_full_table_scan() {
vec![]
} else {
vec![scan_range]
},
Self {
conjunctions: groups[i + 1..].concat(),
},
));
}
let Some((
lower_bound_conjunctions,
upper_bound_conjunctions,
eq_conds,
part_of_other_conds,
)) = Self::analyze_group(group)?
else {
return Ok(false_cond());
};
other_conds.extend(part_of_other_conds.into_iter());
let lower_bound = Self::merge_lower_bound_conjunctions(lower_bound_conjunctions);
let upper_bound = Self::merge_upper_bound_conjunctions(upper_bound_conjunctions);
if Self::is_invalid_range(&lower_bound, &upper_bound) {
return Ok(false_cond());
}
match eq_conds.len() {
1 => {
let eq_conds =
Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
if eq_conds.is_empty() {
return Ok(false_cond());
}
scan_range.eq_conds.extend(eq_conds.into_iter());
}
0 => {
scan_range.range = (lower_bound, upper_bound);
other_conds.extend(groups[i + 1..].iter().flatten().cloned());
break;
}
_ => {
let eq_conds =
Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
if eq_conds.is_empty() {
return Ok(false_cond());
}
other_conds.extend(groups[i + 1..].iter().flatten().cloned());
let scan_ranges = eq_conds
.into_iter()
.map(|lit| {
let mut scan_range = scan_range.clone();
scan_range.eq_conds.push(lit);
scan_range
})
.collect();
return Ok((
scan_ranges,
Self {
conjunctions: other_conds,
},
));
}
}
}
Ok((
if scan_range.is_full_table_scan() {
vec![]
} else if table_desc.columns[table_desc.order_column_indices()[0]]
.data_type
.is_int()
{
match scan_range.split_small_range(max_split_range_gap) {
Some(scan_ranges) => scan_ranges,
None => vec![scan_range],
}
} else {
vec![scan_range]
},
Self {
conjunctions: other_conds,
},
))
}
fn classify_conjunctions_by_pk(
conjunctions: Vec<ExprImpl>,
table_desc: &Rc<TableDesc>,
) -> Vec<Vec<ExprImpl>> {
let pk_column_ids = &table_desc.order_column_indices();
let pk_cols_num = pk_column_ids.len();
let cols_num = table_desc.columns.len();
let mut col_idx_to_pk_idx = vec![None; cols_num];
pk_column_ids.iter().enumerate().for_each(|(idx, pk_idx)| {
col_idx_to_pk_idx[*pk_idx] = Some(idx);
});
let mut groups = vec![vec![]; pk_cols_num + 1];
for (key, group) in &conjunctions.into_iter().chunk_by(|expr| {
let input_bits = expr.collect_input_refs(cols_num);
if input_bits.count_ones(..) == 1 {
let col_idx = input_bits.ones().next().unwrap();
col_idx_to_pk_idx[col_idx].unwrap_or(pk_cols_num)
} else {
pk_cols_num
}
}) {
groups[key].extend(group);
}
groups
}
#[allow(clippy::type_complexity)]
fn analyze_group(
group: Vec<ExprImpl>,
) -> Result<
Option<(
Vec<Bound<ScalarImpl>>,
Vec<Bound<ScalarImpl>>,
Vec<Option<ScalarImpl>>,
Vec<ExprImpl>,
)>,
> {
let mut lower_bound_conjunctions = vec![];
let mut upper_bound_conjunctions = vec![];
let mut eq_conds = vec![];
let mut other_conds = vec![];
for expr in group {
if let Some((input_ref, const_expr)) = expr.as_eq_const() {
let new_expr = if let Ok(expr) = const_expr
.clone()
.cast_implicit(input_ref.data_type.clone())
{
expr
} else {
match self::cast_compare::cast_compare_for_eq(const_expr, input_ref.data_type) {
Ok(ResultForEq::Success(expr)) => expr,
Ok(ResultForEq::NeverEqual) => {
return Ok(None);
}
Err(_) => {
other_conds.push(expr);
continue;
}
}
};
let Some(new_cond) = new_expr.fold_const()? else {
return Ok(None);
};
if Self::mutual_exclusive_with_eq_conds(&new_cond, &eq_conds) {
return Ok(None);
}
eq_conds = vec![Some(new_cond)];
} else if expr.as_is_null().is_some() {
if !eq_conds.is_empty() && eq_conds.into_iter().all(|l| l.is_some()) {
return Ok(None);
}
eq_conds = vec![None];
} else if let Some((input_ref, in_const_list)) = expr.as_in_const_list() {
let mut scalars = HashSet::new();
for const_expr in in_const_list {
let const_expr = const_expr
.cast_implicit(input_ref.data_type.clone())
.unwrap();
let value = const_expr.fold_const()?;
let Some(value) = value else {
continue;
};
scalars.insert(Some(value));
}
if scalars.is_empty() {
return Ok(None);
}
if !eq_conds.is_empty() {
scalars = scalars
.intersection(&HashSet::from_iter(eq_conds))
.cloned()
.collect();
if scalars.is_empty() {
return Ok(None);
}
}
eq_conds = scalars
.into_iter()
.sorted_by(DefaultOrd::default_cmp)
.collect();
} else if let Some((input_ref, op, const_expr)) = expr.as_comparison_const() {
let new_expr = if let Ok(expr) = const_expr
.clone()
.cast_implicit(input_ref.data_type.clone())
{
expr
} else {
match self::cast_compare::cast_compare_for_cmp(
const_expr,
input_ref.data_type,
op,
) {
Ok(ResultForCmp::Success(expr)) => expr,
_ => {
other_conds.push(expr);
continue;
}
}
};
let Some(value) = new_expr.fold_const()? else {
return Ok(None);
};
match op {
ExprType::LessThan => {
upper_bound_conjunctions.push(Bound::Excluded(value));
}
ExprType::LessThanOrEqual => {
upper_bound_conjunctions.push(Bound::Included(value));
}
ExprType::GreaterThan => {
lower_bound_conjunctions.push(Bound::Excluded(value));
}
ExprType::GreaterThanOrEqual => {
lower_bound_conjunctions.push(Bound::Included(value));
}
_ => unreachable!(),
}
} else {
other_conds.push(expr);
}
}
Ok(Some((
lower_bound_conjunctions,
upper_bound_conjunctions,
eq_conds,
other_conds,
)))
}
fn mutual_exclusive_with_eq_conds(
new_conds: &ScalarImpl,
eq_conds: &[Option<ScalarImpl>],
) -> bool {
!eq_conds.is_empty()
&& eq_conds.iter().all(|l| {
if let Some(l) = l {
l != new_conds
} else {
true
}
})
}
fn merge_lower_bound_conjunctions(lb: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
lb.into_iter()
.max_by(|a, b| {
match (a, b) {
(Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
(Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
(Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Less,
(Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Less,
(Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
(Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
(Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
(Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
other => other,
},
(Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
other => other,
},
}
})
.unwrap_or(Bound::Unbounded)
}
fn merge_upper_bound_conjunctions(ub: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
ub.into_iter()
.min_by(|a, b| {
match (a, b) {
(Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Less,
(Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Less,
(Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Greater,
(Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Greater,
(Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
(Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
(Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
(Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
other => other,
},
(Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
other => other,
},
}
})
.unwrap_or(Bound::Unbounded)
}
fn is_invalid_range(lower_bound: &Bound<ScalarImpl>, upper_bound: &Bound<ScalarImpl>) -> bool {
match (lower_bound, upper_bound) {
(Bound::Included(l), Bound::Included(u)) => l.default_cmp(u).is_gt(), (Bound::Included(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), (Bound::Excluded(l), Bound::Included(u)) => l.default_cmp(u).is_ge(), (Bound::Excluded(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), _ => false,
}
}
fn extract_eq_conds_within_range(
eq_conds: Vec<Option<ScalarImpl>>,
upper_bound: &Bound<ScalarImpl>,
lower_bound: &Bound<ScalarImpl>,
) -> Vec<Option<ScalarImpl>> {
if Self::is_invalid_range(lower_bound, upper_bound) {
return vec![];
}
let is_extract_null = upper_bound == &Bound::Unbounded && lower_bound == &Bound::Unbounded;
eq_conds
.into_iter()
.filter(|cond| {
if let Some(cond) = cond {
match lower_bound {
Bound::Included(val) => {
if cond.default_cmp(val).is_lt() {
return false;
}
}
Bound::Excluded(val) => {
if cond.default_cmp(val).is_le() {
return false;
}
}
Bound::Unbounded => {}
}
match upper_bound {
Bound::Included(val) => {
if cond.default_cmp(val).is_gt() {
return false;
}
}
Bound::Excluded(val) => {
if cond.default_cmp(val).is_ge() {
return false;
}
}
Bound::Unbounded => {}
}
true
} else {
is_extract_null
}
})
.collect()
}
#[must_use]
pub fn group_by<F, const N: usize>(self, f: F) -> [Self; N]
where
F: Fn(&ExprImpl) -> usize,
{
const EMPTY: Vec<ExprImpl> = vec![];
let mut groups = [EMPTY; N];
for (key, group) in &self.conjunctions.into_iter().chunk_by(|expr| {
let i = f(expr);
assert!(i < N);
i
}) {
groups[key].extend(group);
}
groups.map(|group| Condition {
conjunctions: group,
})
}
#[must_use]
pub fn rewrite_expr(self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
Self {
conjunctions: self
.conjunctions
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect(),
}
.simplify()
}
pub fn visit_expr<V: ExprVisitor + ?Sized>(&self, visitor: &mut V) {
self.conjunctions
.iter()
.for_each(|expr| visitor.visit_expr(expr));
}
pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) {
self.conjunctions
.iter_mut()
.for_each(|expr| mutator.visit_expr(expr))
}
fn simplify(self) -> Self {
let conjunctions: Vec<_> = self
.conjunctions
.into_iter()
.map(push_down_not)
.map(fold_boolean_constant)
.map(column_self_eq_eliminate)
.flat_map(to_conjunctions)
.collect();
let mut res: Vec<ExprImpl> = Vec::new();
let mut visited: HashSet<ExprImpl> = HashSet::new();
for expr in conjunctions {
if !expr.has_subquery() {
let results_of_factorization = factorization_expr(expr);
res.extend(
results_of_factorization
.clone()
.into_iter()
.filter(|expr| !visited.contains(expr)),
);
visited.extend(results_of_factorization);
} else {
res.push(expr);
}
}
res.retain(|expr| {
if let Some(v) = try_get_bool_constant(expr)
&& v
{
false
} else {
true
}
});
for expr in &mut res {
if let Some(v) = try_get_bool_constant(expr) {
if !v {
res.clear();
res.push(ExprImpl::literal_bool(false));
break;
}
}
}
Self { conjunctions: res }
}
}
pub struct ConditionDisplay<'a> {
pub condition: &'a Condition,
pub input_schema: &'a Schema,
}
impl ConditionDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.condition.always_true() {
write!(f, "true")
} else {
write!(
f,
"{}",
self.condition
.conjunctions
.iter()
.format_with(" AND ", |expr, f| {
f(&ExprDisplay {
expr,
input_schema: self.input_schema,
})
})
)
}
}
}
impl fmt::Display for ConditionDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt(f)
}
}
impl fmt::Debug for ConditionDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt(f)
}
}
mod cast_compare {
use risingwave_common::types::DataType;
use crate::expr::{Expr, ExprImpl, ExprType};
enum ShrinkResult {
OutUpperBound,
OutLowerBound,
InRange(ExprImpl),
}
pub enum ResultForEq {
Success(ExprImpl),
NeverEqual,
}
pub enum ResultForCmp {
Success(ExprImpl),
OutUpperBound,
OutLowerBound,
}
pub fn cast_compare_for_eq(const_expr: ExprImpl, target: DataType) -> Result<ResultForEq, ()> {
match (const_expr.return_type(), &target) {
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
ShrinkResult::InRange(expr) => Ok(ResultForEq::Success(expr)),
ShrinkResult::OutUpperBound | ShrinkResult::OutLowerBound => {
Ok(ResultForEq::NeverEqual)
}
},
_ => Err(()),
}
}
pub fn cast_compare_for_cmp(
const_expr: ExprImpl,
target: DataType,
_op: ExprType,
) -> Result<ResultForCmp, ()> {
match (const_expr.return_type(), &target) {
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
ShrinkResult::InRange(expr) => Ok(ResultForCmp::Success(expr)),
ShrinkResult::OutUpperBound => Ok(ResultForCmp::OutUpperBound),
ShrinkResult::OutLowerBound => Ok(ResultForCmp::OutLowerBound),
},
_ => Err(()),
}
}
fn shrink_integral(const_expr: ExprImpl, target: DataType) -> Result<ShrinkResult, ()> {
let (upper_bound, lower_bound) = match (const_expr.return_type(), &target) {
(DataType::Int64, DataType::Int32) => (i32::MAX as i64, i32::MIN as i64),
(DataType::Int64, DataType::Int16) | (DataType::Int32, DataType::Int16) => {
(i16::MAX as i64, i16::MIN as i64)
}
_ => unreachable!(),
};
match const_expr.fold_const().map_err(|_| ())? {
Some(scalar) => {
let value = scalar.as_integral();
if value > upper_bound {
Ok(ShrinkResult::OutUpperBound)
} else if value < lower_bound {
Ok(ShrinkResult::OutLowerBound)
} else {
Ok(ShrinkResult::InRange(
const_expr.cast_explicit(target).unwrap(),
))
}
}
None => Ok(ShrinkResult::InRange(
const_expr.cast_explicit(target).unwrap(),
)),
}
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use super::*;
#[test]
fn test_split() {
let left_col_num = 3;
let right_col_num = 2;
let ty = DataType::Int32;
let mut rng = rand::thread_rng();
let left: ExprImpl = FunctionCall::new(
ExprType::LessThanOrEqual,
vec![
InputRef::new(rng.gen_range(0..left_col_num), ty.clone()).into(),
InputRef::new(rng.gen_range(0..left_col_num), ty.clone()).into(),
],
)
.unwrap()
.into();
let right: ExprImpl = FunctionCall::new(
ExprType::LessThan,
vec![
InputRef::new(
rng.gen_range(left_col_num..left_col_num + right_col_num),
ty.clone(),
)
.into(),
InputRef::new(
rng.gen_range(left_col_num..left_col_num + right_col_num),
ty.clone(),
)
.into(),
],
)
.unwrap()
.into();
let other: ExprImpl = FunctionCall::new(
ExprType::GreaterThan,
vec![
InputRef::new(rng.gen_range(0..left_col_num), ty.clone()).into(),
InputRef::new(
rng.gen_range(left_col_num..left_col_num + right_col_num),
ty,
)
.into(),
],
)
.unwrap()
.into();
let cond = Condition::with_expr(other.clone())
.and(Condition::with_expr(right.clone()))
.and(Condition::with_expr(left.clone()));
let res = cond.split(left_col_num, right_col_num);
assert_eq!(res.0.conjunctions, vec![left]);
assert_eq!(res.1.conjunctions, vec![right]);
assert_eq!(res.2.conjunctions, vec![other]);
}
#[test]
fn test_self_eq_eliminate() {
let left_col_num = 3;
let right_col_num = 2;
let ty = DataType::Int32;
let mut rng = rand::thread_rng();
let x: ExprImpl = InputRef::new(rng.gen_range(0..left_col_num), ty.clone()).into();
let left: ExprImpl = FunctionCall::new(ExprType::Equal, vec![x.clone(), x.clone()])
.unwrap()
.into();
let right: ExprImpl = FunctionCall::new(
ExprType::LessThan,
vec![
InputRef::new(
rng.gen_range(left_col_num..left_col_num + right_col_num),
ty.clone(),
)
.into(),
InputRef::new(
rng.gen_range(left_col_num..left_col_num + right_col_num),
ty.clone(),
)
.into(),
],
)
.unwrap()
.into();
let other: ExprImpl = FunctionCall::new(
ExprType::GreaterThan,
vec![
InputRef::new(rng.gen_range(0..left_col_num), ty.clone()).into(),
InputRef::new(
rng.gen_range(left_col_num..left_col_num + right_col_num),
ty,
)
.into(),
],
)
.unwrap()
.into();
let cond = Condition::with_expr(other.clone())
.and(Condition::with_expr(right.clone()))
.and(Condition::with_expr(left.clone()));
let res = cond.split(left_col_num, right_col_num);
let left_res = FunctionCall::new(ExprType::IsNotNull, vec![x])
.unwrap()
.into();
assert_eq!(res.0.conjunctions, vec![left_res]);
assert_eq!(res.1.conjunctions, vec![right]);
assert_eq!(res.2.conjunctions, vec![other]);
}
}