risingwave_frontend/optimizer/plan_visitor/
apply_visitor.rs1use super::{DefaultBehavior, LogicalPlanVisitor, Merge};
16use crate::error::{ErrorCode, RwError};
17use crate::optimizer::plan_node::{LogicalApply, LogicalPlanRef as PlanRef, PlanTreeNodeBinary};
18use crate::optimizer::plan_visitor::PlanVisitor;
19
20pub struct HasMaxOneRowApply();
21
22impl LogicalPlanVisitor for HasMaxOneRowApply {
23 type Result = bool;
24
25 type DefaultBehavior = impl DefaultBehavior<Self::Result>;
26
27 fn default_behavior() -> Self::DefaultBehavior {
28 Merge(|a, b| a | b)
29 }
30
31 fn visit_logical_apply(&mut self, plan: &LogicalApply) -> bool {
32 plan.max_one_row() | self.visit(plan.left()) | self.visit(plan.right())
33 }
34}
35
36#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
37enum CheckResult {
38 #[default]
39 Ok,
40 CannotBeUnnested,
41 MoreThanOneRow,
42}
43
44impl From<CheckResult> for Result<(), RwError> {
45 fn from(val: CheckResult) -> Self {
46 let msg = match val {
47 CheckResult::Ok => return Ok(()),
48 CheckResult::CannotBeUnnested => "Subquery can not be unnested.",
49 CheckResult::MoreThanOneRow => "Scalar subquery might produce more than one row.",
50 };
51
52 Err(ErrorCode::InternalError(msg.to_owned()).into())
53 }
54}
55
56#[derive(Default)]
57pub struct CheckApplyElimination {
58 result: CheckResult,
59}
60
61impl LogicalPlanVisitor for CheckApplyElimination {
62 type Result = ();
63
64 type DefaultBehavior = impl DefaultBehavior<Self::Result>;
65
66 fn default_behavior() -> Self::DefaultBehavior {
67 Merge(std::cmp::max)
68 }
69
70 fn visit_logical_apply(&mut self, plan: &LogicalApply) {
71 if plan.right().as_logical_max_one_row().is_some() {
74 self.result = CheckResult::MoreThanOneRow;
75 } else {
76 self.result = CheckResult::CannotBeUnnested;
77 }
78 }
79}
80
81#[easy_ext::ext(PlanCheckApplyEliminationExt)]
82impl PlanRef {
83 pub fn check_apply_elimination(&self) -> Result<(), RwError> {
86 let mut visitor = CheckApplyElimination::default();
87 visitor.visit(self.clone());
88 visitor.result.into()
89 }
90}