risingwave_frontend/utils/
connected_components.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet};
16
17// TODO: could try to optimize using union-find algorithm
18#[derive(Debug)]
19pub(crate) struct ConnectedComponentLabeller {
20    vertex_to_label: BTreeMap<usize, usize>,
21    labels_to_vertices: BTreeMap<usize, BTreeSet<usize>>,
22    labels_to_edges: BTreeMap<usize, BTreeSet<(usize, usize)>>,
23}
24
25impl ConnectedComponentLabeller {
26    pub(crate) fn new(vertices: usize) -> Self {
27        let mut vertex_to_label = BTreeMap::new();
28        let mut labels_to_vertices = BTreeMap::new();
29        let labels_to_edges = BTreeMap::new();
30        for i in 0..vertices {
31            vertex_to_label.insert(i, i);
32            labels_to_vertices.insert(i, vec![i].into_iter().collect());
33        }
34        Self {
35            vertex_to_label,
36            labels_to_vertices,
37            labels_to_edges,
38        }
39    }
40
41    pub(crate) fn add_edge(&mut self, v1: usize, v2: usize) {
42        let v1_label = *self.vertex_to_label.get(&v1).unwrap();
43        let v2_label = *self.vertex_to_label.get(&v2).unwrap();
44
45        let (new_label, old_label) = if v1_label < v2_label {
46            (v1_label, v2_label)
47        } else {
48            // v1_label > v2_label
49            (v2_label, v1_label)
50        };
51
52        {
53            let edges = self.labels_to_edges.entry(new_label).or_default();
54
55            let new_edge = if v1 < v2 { (v1, v2) } else { (v2, v1) };
56            edges.insert(new_edge);
57        }
58
59        if v1_label == v2_label {
60            return;
61        }
62
63        // Reassign to the smaller label
64        let old_vertices = self.labels_to_vertices.remove(&old_label).unwrap();
65        self.labels_to_vertices
66            .get_mut(&new_label)
67            .unwrap()
68            .extend(old_vertices.iter());
69        for v in old_vertices {
70            self.vertex_to_label.insert(v, new_label);
71        }
72        if let Some(old_edges) = self.labels_to_edges.remove(&old_label) {
73            let edges = self.labels_to_edges.entry(new_label).or_default();
74            edges.extend(old_edges);
75        }
76    }
77
78    pub(crate) fn into_edge_sets(self) -> Vec<BTreeSet<(usize, usize)>> {
79        self.labels_to_edges.into_values().collect()
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn test_connected_component_labeller() {
89        // Graph:
90        // 0-1-2  3-4-5  6
91        // => 0-1-2-3-4-5  6
92        // => 0-1-2-3-4-5-6
93        let mut labeller = ConnectedComponentLabeller::new(7);
94        labeller.add_edge(0, 1);
95        labeller.add_edge(1, 2);
96
97        labeller.add_edge(3, 4);
98        labeller.add_edge(4, 5);
99
100        assert_eq!(labeller.labels_to_vertices.len(), 3);
101
102        labeller.add_edge(2, 3);
103
104        assert_eq!(labeller.labels_to_vertices.len(), 2);
105
106        labeller.add_edge(5, 6);
107
108        assert_eq!(labeller.labels_to_vertices.len(), 1);
109        assert_eq!(
110            *labeller.labels_to_vertices.iter().next().unwrap().1,
111            (0..7).collect::<BTreeSet<_>>()
112        );
113        assert_eq!(
114            *labeller.labels_to_edges.iter().next().unwrap().1,
115            vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]
116                .into_iter()
117                .collect::<BTreeSet<_>>()
118        );
119    }
120}