risingwave_frontend/utils/
connected_components.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::{BTreeMap, BTreeSet};

// TODO: could try to optimize using union-find algorithm
#[derive(Debug)]
pub(crate) struct ConnectedComponentLabeller {
    vertex_to_label: BTreeMap<usize, usize>,
    labels_to_vertices: BTreeMap<usize, BTreeSet<usize>>,
    labels_to_edges: BTreeMap<usize, BTreeSet<(usize, usize)>>,
}

impl ConnectedComponentLabeller {
    pub(crate) fn new(vertices: usize) -> Self {
        let mut vertex_to_label = BTreeMap::new();
        let mut labels_to_vertices = BTreeMap::new();
        let labels_to_edges = BTreeMap::new();
        for i in 0..vertices {
            vertex_to_label.insert(i, i);
            labels_to_vertices.insert(i, vec![i].into_iter().collect());
        }
        Self {
            vertex_to_label,
            labels_to_vertices,
            labels_to_edges,
        }
    }

    pub(crate) fn add_edge(&mut self, v1: usize, v2: usize) {
        let v1_label = *self.vertex_to_label.get(&v1).unwrap();
        let v2_label = *self.vertex_to_label.get(&v2).unwrap();

        let (new_label, old_label) = if v1_label < v2_label {
            (v1_label, v2_label)
        } else {
            // v1_label > v2_label
            (v2_label, v1_label)
        };

        {
            let edges = self.labels_to_edges.entry(new_label).or_default();

            let new_edge = if v1 < v2 { (v1, v2) } else { (v2, v1) };
            edges.insert(new_edge);
        }

        if v1_label == v2_label {
            return;
        }

        // Reassign to the smaller label
        let old_vertices = self.labels_to_vertices.remove(&old_label).unwrap();
        self.labels_to_vertices
            .get_mut(&new_label)
            .unwrap()
            .extend(old_vertices.iter());
        for v in old_vertices {
            self.vertex_to_label.insert(v, new_label);
        }
        if let Some(old_edges) = self.labels_to_edges.remove(&old_label) {
            let edges = self.labels_to_edges.entry(new_label).or_default();
            edges.extend(old_edges);
        }
    }

    pub(crate) fn into_edge_sets(self) -> Vec<BTreeSet<(usize, usize)>> {
        self.labels_to_edges.into_values().collect()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_connected_component_labeller() {
        // Graph:
        // 0-1-2  3-4-5  6
        // => 0-1-2-3-4-5  6
        // => 0-1-2-3-4-5-6
        let mut labeller = ConnectedComponentLabeller::new(7);
        labeller.add_edge(0, 1);
        labeller.add_edge(1, 2);

        labeller.add_edge(3, 4);
        labeller.add_edge(4, 5);

        assert_eq!(labeller.labels_to_vertices.len(), 3);

        labeller.add_edge(2, 3);

        assert_eq!(labeller.labels_to_vertices.len(), 2);

        labeller.add_edge(5, 6);

        assert_eq!(labeller.labels_to_vertices.len(), 1);
        assert_eq!(
            *labeller.labels_to_vertices.iter().next().unwrap().1,
            (0..7).collect::<BTreeSet<_>>()
        );
        assert_eq!(
            *labeller.labels_to_edges.iter().next().unwrap().1,
            vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]
                .into_iter()
                .collect::<BTreeSet<_>>()
        );
    }
}