risingwave_frontend/optimizer/plan_visitor/
input_ref_validator.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use paste::paste;
16use risingwave_common::catalog::{Field, Schema};
17
18use super::{BatchPlanVisitor, DefaultBehavior, LogicalPlanVisitor, Merge, StreamPlanVisitor};
19use crate::expr::ExprVisitor;
20use crate::optimizer::plan_node::generic::GenericPlanRef;
21use crate::optimizer::plan_node::{ConventionMarker, Explain, PlanRef, PlanTreeNodeUnary};
22use crate::optimizer::plan_visitor::PlanVisitor;
23
24struct ExprVis<'a> {
25    schema: &'a Schema,
26    string: Option<String>,
27}
28
29impl ExprVisitor for ExprVis<'_> {
30    fn visit_input_ref(&mut self, input_ref: &crate::expr::InputRef) {
31        if !input_ref
32            .data_type
33            .equals_datatype(&self.schema[input_ref.index].data_type)
34        {
35            self.string.replace(format!(
36                "InputRef#{} has type {}, but its type is {} in the input schema",
37                input_ref.index, input_ref.data_type, self.schema[input_ref.index].data_type
38            ));
39        }
40    }
41}
42
43/// Validates that input references are consistent with the input schema.
44///
45/// Use `InputRefValidator::validate` as an assertion.
46#[derive(Debug, Clone, Default)]
47pub struct InputRefValidator;
48
49impl InputRefValidator {
50    #[track_caller]
51    pub fn validate<C: ConventionMarker>(mut self, plan: PlanRef<C>)
52    where
53        Self: PlanVisitor<C, Result = Option<String>>,
54    {
55        if let Some(err) = self.visit(plan.clone()) {
56            panic!(
57                "Input references are inconsistent with the input schema: {}, plan:\n{}",
58                err,
59                plan.explain_to_string()
60            );
61        }
62    }
63}
64
65macro_rules! visit_filter {
66    ($($convention:ident),*) => {
67        $(
68            paste! {
69                fn [<visit_ $convention _filter>](&mut self, plan: &crate::optimizer::plan_node:: [<$convention:camel Filter>]) -> Option<String> {
70                    let input = plan.input();
71                    let mut vis = ExprVis {
72                        schema: input.schema(),
73                        string: None,
74                    };
75                    plan.predicate().visit_expr(&mut vis);
76                    vis.string.or_else(|| {
77                        self.[<visit_$convention>](input)
78                    })
79                }
80            }
81        )*
82    };
83}
84
85macro_rules! visit_project {
86    ($($convention:ident),*) => {
87        $(
88            paste! {
89                fn [<visit_ $convention _project>](&mut self, plan: &crate::optimizer::plan_node:: [<$convention:camel Project>]) -> Option<String> {
90                    let input = plan.input();
91                    let mut vis = ExprVis {
92                        schema: input.schema(),
93                        string: None,
94                    };
95                    for expr in plan.exprs() {
96                        vis.visit_expr(expr);
97                        if vis.string.is_some() {
98                            return vis.string;
99                        }
100                    }
101                    self.[<visit_$convention>](input)
102                }
103            }
104        )*
105    };
106}
107
108impl StreamPlanVisitor for InputRefValidator {
109    type Result = Option<String>;
110
111    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
112
113    visit_filter!(stream);
114
115    visit_project!(stream);
116
117    fn default_behavior() -> Self::DefaultBehavior {
118        Merge(|a: Option<String>, b| a.or(b))
119    }
120}
121
122impl BatchPlanVisitor for InputRefValidator {
123    type Result = Option<String>;
124
125    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
126
127    visit_filter!(batch);
128
129    visit_project!(batch);
130
131    fn default_behavior() -> Self::DefaultBehavior {
132        Merge(|a: Option<String>, b| a.or(b))
133    }
134}
135
136impl LogicalPlanVisitor for InputRefValidator {
137    type Result = Option<String>;
138
139    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
140
141    visit_filter!(logical);
142
143    visit_project!(logical);
144
145    fn default_behavior() -> Self::DefaultBehavior {
146        Merge(|a: Option<String>, b| a.or(b))
147    }
148
149    fn visit_logical_scan(
150        &mut self,
151        plan: &crate::optimizer::plan_node::LogicalScan,
152    ) -> Option<String> {
153        let fields = plan
154            .table()
155            .columns
156            .iter()
157            .map(|col| Field::from_with_table_name_prefix(col, plan.table_name()))
158            .collect();
159        let input_schema = Schema { fields };
160        let mut vis = ExprVis {
161            schema: &input_schema,
162            string: None,
163        };
164        plan.predicate().visit_expr(&mut vis);
165        vis.string
166    }
167
168    // TODO: add more checks
169}