risingwave_frontend/optimizer/plan_visitor/
input_ref_validator.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use paste::paste;
16use risingwave_common::catalog::{Field, Schema};
17
18use super::{BatchPlanVisitor, DefaultBehavior, LogicalPlanVisitor, Merge, StreamPlanVisitor};
19use crate::expr::ExprVisitor;
20use crate::optimizer::plan_node::generic::GenericPlanRef;
21use crate::optimizer::plan_node::{ConventionMarker, Explain, PlanRef, PlanTreeNodeUnary};
22use crate::optimizer::plan_visitor::PlanVisitor;
23
24struct ExprVis<'a> {
25    schema: &'a Schema,
26    string: Option<String>,
27}
28
29impl ExprVisitor for ExprVis<'_> {
30    fn visit_input_ref(&mut self, input_ref: &crate::expr::InputRef) {
31        if input_ref.data_type != self.schema[input_ref.index].data_type {
32            self.string.replace(format!(
33                "InputRef#{} has type {}, but its type is {} in the input schema",
34                input_ref.index, input_ref.data_type, self.schema[input_ref.index].data_type
35            ));
36        }
37    }
38}
39
40/// Validates that input references are consistent with the input schema.
41///
42/// Use `InputRefValidator::validate` as an assertion.
43#[derive(Debug, Clone, Default)]
44pub struct InputRefValidator;
45
46impl InputRefValidator {
47    #[track_caller]
48    pub fn validate<C: ConventionMarker>(mut self, plan: PlanRef<C>)
49    where
50        Self: PlanVisitor<C, Result = Option<String>>,
51    {
52        if let Some(err) = self.visit(plan.clone()) {
53            panic!(
54                "Input references are inconsistent with the input schema: {}, plan:\n{}",
55                err,
56                plan.explain_to_string()
57            );
58        }
59    }
60}
61
62macro_rules! visit_filter {
63    ($($convention:ident),*) => {
64        $(
65            paste! {
66                fn [<visit_ $convention _filter>](&mut self, plan: &crate::optimizer::plan_node:: [<$convention:camel Filter>]) -> Option<String> {
67                    let input = plan.input();
68                    let mut vis = ExprVis {
69                        schema: input.schema(),
70                        string: None,
71                    };
72                    plan.predicate().visit_expr(&mut vis);
73                    vis.string.or_else(|| {
74                        self.[<visit_$convention>](input)
75                    })
76                }
77            }
78        )*
79    };
80}
81
82macro_rules! visit_project {
83    ($($convention:ident),*) => {
84        $(
85            paste! {
86                fn [<visit_ $convention _project>](&mut self, plan: &crate::optimizer::plan_node:: [<$convention:camel Project>]) -> Option<String> {
87                    let input = plan.input();
88                    let mut vis = ExprVis {
89                        schema: input.schema(),
90                        string: None,
91                    };
92                    for expr in plan.exprs() {
93                        vis.visit_expr(expr);
94                        if vis.string.is_some() {
95                            return vis.string;
96                        }
97                    }
98                    self.[<visit_$convention>](input)
99                }
100            }
101        )*
102    };
103}
104
105impl StreamPlanVisitor for InputRefValidator {
106    type Result = Option<String>;
107
108    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
109
110    visit_filter!(stream);
111
112    visit_project!(stream);
113
114    fn default_behavior() -> Self::DefaultBehavior {
115        Merge(|a: Option<String>, b| a.or(b))
116    }
117}
118
119impl BatchPlanVisitor for InputRefValidator {
120    type Result = Option<String>;
121
122    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
123
124    visit_filter!(batch);
125
126    visit_project!(batch);
127
128    fn default_behavior() -> Self::DefaultBehavior {
129        Merge(|a: Option<String>, b| a.or(b))
130    }
131}
132
133impl LogicalPlanVisitor for InputRefValidator {
134    type Result = Option<String>;
135
136    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
137
138    visit_filter!(logical);
139
140    visit_project!(logical);
141
142    fn default_behavior() -> Self::DefaultBehavior {
143        Merge(|a: Option<String>, b| a.or(b))
144    }
145
146    fn visit_logical_scan(
147        &mut self,
148        plan: &crate::optimizer::plan_node::LogicalScan,
149    ) -> Option<String> {
150        let fields = plan
151            .table()
152            .columns
153            .iter()
154            .map(|col| Field::from_with_table_name_prefix(col, plan.table_name()))
155            .collect();
156        let input_schema = Schema { fields };
157        let mut vis = ExprVis {
158            schema: &input_schema,
159            string: None,
160        };
161        plan.predicate().visit_expr(&mut vis);
162        vis.string
163    }
164
165    // TODO: add more checks
166}