risingwave_frontend/optimizer/plan_visitor/
cardinality_visitor.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::ops::{Mul, Sub};
17
18use risingwave_pb::plan_common::JoinType;
19
20use super::{DefaultBehavior, DefaultValue, PlanVisitor};
21use crate::optimizer::plan_node::generic::TopNLimit;
22use crate::optimizer::plan_node::{
23    self, PlanNode, PlanTreeNode, PlanTreeNodeBinary, PlanTreeNodeUnary,
24};
25use crate::optimizer::plan_visitor::PlanRef;
26use crate::optimizer::property::Cardinality;
27
28/// A visitor that computes the cardinality of a plan node.
29pub struct CardinalityVisitor;
30
31impl CardinalityVisitor {
32    /// Used for `Filter` and `Scan` with predicate.
33    fn visit_predicate(
34        input: &dyn PlanNode,
35        input_card: Cardinality,
36        eq_set: HashSet<usize>,
37    ) -> Cardinality {
38        // TODO: there could be more unique keys than the stream key after we support it.
39        let unique_keys: Vec<HashSet<_>> = input
40            .stream_key()
41            .into_iter()
42            .map(|s| s.iter().copied().collect())
43            .collect();
44
45        if unique_keys
46            .iter()
47            .any(|unique_key| eq_set.is_superset(unique_key))
48        {
49            input_card.min(0..=1)
50        } else {
51            input_card.min(0..)
52        }
53    }
54}
55
56impl PlanVisitor for CardinalityVisitor {
57    type Result = Cardinality;
58
59    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
60
61    fn default_behavior() -> Self::DefaultBehavior {
62        // returns unknown cardinality for default behavior, which is always correct
63        DefaultValue
64    }
65
66    fn visit_logical_values(&mut self, plan: &plan_node::LogicalValues) -> Cardinality {
67        plan.rows().len().into()
68    }
69
70    fn visit_logical_share(&mut self, plan: &plan_node::LogicalShare) -> Cardinality {
71        self.visit(plan.input())
72    }
73
74    fn visit_logical_dedup(&mut self, plan: &plan_node::LogicalDedup) -> Cardinality {
75        let input = self.visit(plan.input());
76        if plan.dedup_cols().is_empty() {
77            input.min(1)
78        } else {
79            input
80        }
81    }
82
83    fn visit_logical_over_window(&mut self, plan: &super::LogicalOverWindow) -> Self::Result {
84        self.visit(plan.input())
85    }
86
87    fn visit_logical_agg(&mut self, plan: &plan_node::LogicalAgg) -> Cardinality {
88        let input = self.visit(plan.input());
89
90        if plan.group_key().is_empty() {
91            input.min(1)
92        } else {
93            input.min(1..)
94        }
95    }
96
97    fn visit_logical_limit(&mut self, plan: &plan_node::LogicalLimit) -> Cardinality {
98        self.visit(plan.input()).min(plan.limit() as usize)
99    }
100
101    fn visit_logical_max_one_row(&mut self, plan: &plan_node::LogicalMaxOneRow) -> Cardinality {
102        self.visit(plan.input()).min(1)
103    }
104
105    fn visit_logical_project(&mut self, plan: &plan_node::LogicalProject) -> Cardinality {
106        self.visit(plan.input())
107    }
108
109    fn visit_logical_top_n(&mut self, plan: &plan_node::LogicalTopN) -> Cardinality {
110        let input = self.visit(plan.input());
111
112        let each_group = match plan.limit_attr() {
113            TopNLimit::Simple(limit) => input.sub(plan.offset() as usize).min(limit as usize),
114            TopNLimit::WithTies(limit) => {
115                assert_eq!(plan.offset(), 0, "ties with offset is not supported yet");
116                input.min((limit as usize)..)
117            }
118        };
119
120        if plan.group_key().is_empty() {
121            each_group
122        } else {
123            let group_number = input.min(1..);
124            each_group
125                .mul(group_number)
126                // the output cardinality will never be more than the input, thus `.min(input)`
127                .min(input)
128        }
129    }
130
131    fn visit_logical_filter(&mut self, plan: &plan_node::LogicalFilter) -> Cardinality {
132        let eq_set = plan
133            .predicate()
134            .collect_input_refs(plan.input().schema().len())
135            .ones()
136            .collect();
137        Self::visit_predicate(&*plan.input(), self.visit(plan.input()), eq_set)
138    }
139
140    fn visit_logical_scan(&mut self, plan: &plan_node::LogicalScan) -> Cardinality {
141        let eq_set = plan
142            .predicate()
143            .collect_input_refs(plan.table_desc().columns.len())
144            .ones()
145            .collect();
146        Self::visit_predicate(plan, plan.table_cardinality(), eq_set)
147    }
148
149    fn visit_logical_union(&mut self, plan: &plan_node::LogicalUnion) -> Cardinality {
150        let all = plan
151            .inputs()
152            .into_iter()
153            .map(|input| self.visit(input))
154            .fold(Cardinality::unknown(), std::ops::Add::add);
155
156        if plan.all() { all } else { all.min(1..) }
157    }
158
159    fn visit_logical_join(&mut self, plan: &plan_node::LogicalJoin) -> Cardinality {
160        let left = self.visit(plan.left());
161        let right = self.visit(plan.right());
162
163        match plan.join_type() {
164            JoinType::Unspecified => unreachable!(),
165
166            // For each row from one side, we match `0..=(right.hi)` rows from the other side.
167            JoinType::Inner => left.mul(right.min(0..)),
168
169            // For each row from one side, we match `1..=max(right.hi, 1)` rows from the other side,
170            // since we can at least match a `NULL` row.
171            JoinType::LeftOuter => left.mul(right.max(1).min(1..)),
172            JoinType::RightOuter => right.mul(left.max(1).min(1..)),
173
174            // For each row in the result set, it must belong to the given side.
175            JoinType::LeftSemi | JoinType::LeftAnti => left.min(0..),
176            JoinType::RightSemi | JoinType::RightAnti => right.min(0..),
177
178            // TODO: refine the cardinality of full outer join
179            JoinType::FullOuter => Cardinality::unknown(),
180
181            // For each row from one side, we match `0..=1` rows from the other side.
182            JoinType::AsofInner => left.mul(right.min(0..=1)),
183            // For each row from left side, we match exactly 1 row from the right side or a `NULL` row`.
184            JoinType::AsofLeftOuter => left,
185        }
186    }
187
188    fn visit_logical_now(&mut self, plan: &plan_node::LogicalNow) -> Cardinality {
189        if plan.max_one_row() {
190            1.into()
191        } else {
192            Cardinality::unknown()
193        }
194    }
195
196    fn visit_logical_expand(&mut self, plan: &plan_node::LogicalExpand) -> Cardinality {
197        self.visit(plan.input()) * plan.column_subsets().len()
198    }
199}
200
201#[easy_ext::ext(LogicalCardinalityExt)]
202pub impl PlanRef {
203    /// Returns `true` if the plan node is guaranteed to yield at most one row.
204    fn max_one_row(&self) -> bool {
205        CardinalityVisitor.visit(self.clone()).is_at_most(1)
206    }
207
208    /// Returns the number of rows the plan node is guaranteed to yield, if known exactly.
209    fn row_count(&self) -> Option<usize> {
210        CardinalityVisitor.visit(self.clone()).get_exact()
211    }
212}