risingwave_frontend/utils/
connected_components.rs1use std::collections::{BTreeMap, BTreeSet};
16
17#[derive(Debug)]
19pub(crate) struct ConnectedComponentLabeller {
20 vertex_to_label: BTreeMap<usize, usize>,
21 labels_to_vertices: BTreeMap<usize, BTreeSet<usize>>,
22 labels_to_edges: BTreeMap<usize, BTreeSet<(usize, usize)>>,
23}
24
25impl ConnectedComponentLabeller {
26 pub(crate) fn new(vertices: usize) -> Self {
27 let mut vertex_to_label = BTreeMap::new();
28 let mut labels_to_vertices = BTreeMap::new();
29 let labels_to_edges = BTreeMap::new();
30 for i in 0..vertices {
31 vertex_to_label.insert(i, i);
32 labels_to_vertices.insert(i, vec![i].into_iter().collect());
33 }
34 Self {
35 vertex_to_label,
36 labels_to_vertices,
37 labels_to_edges,
38 }
39 }
40
41 pub(crate) fn add_edge(&mut self, v1: usize, v2: usize) {
42 let v1_label = *self.vertex_to_label.get(&v1).unwrap();
43 let v2_label = *self.vertex_to_label.get(&v2).unwrap();
44
45 let (new_label, old_label) = if v1_label < v2_label {
46 (v1_label, v2_label)
47 } else {
48 (v2_label, v1_label)
50 };
51
52 {
53 let edges = self.labels_to_edges.entry(new_label).or_default();
54
55 let new_edge = if v1 < v2 { (v1, v2) } else { (v2, v1) };
56 edges.insert(new_edge);
57 }
58
59 if v1_label == v2_label {
60 return;
61 }
62
63 let old_vertices = self.labels_to_vertices.remove(&old_label).unwrap();
65 self.labels_to_vertices
66 .get_mut(&new_label)
67 .unwrap()
68 .extend(old_vertices.iter());
69 for v in old_vertices {
70 self.vertex_to_label.insert(v, new_label);
71 }
72 if let Some(old_edges) = self.labels_to_edges.remove(&old_label) {
73 let edges = self.labels_to_edges.entry(new_label).or_default();
74 edges.extend(old_edges);
75 }
76 }
77
78 pub(crate) fn into_edge_sets(self) -> Vec<BTreeSet<(usize, usize)>> {
79 self.labels_to_edges.into_values().collect()
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn test_connected_component_labeller() {
89 let mut labeller = ConnectedComponentLabeller::new(7);
94 labeller.add_edge(0, 1);
95 labeller.add_edge(1, 2);
96
97 labeller.add_edge(3, 4);
98 labeller.add_edge(4, 5);
99
100 assert_eq!(labeller.labels_to_vertices.len(), 3);
101
102 labeller.add_edge(2, 3);
103
104 assert_eq!(labeller.labels_to_vertices.len(), 2);
105
106 labeller.add_edge(5, 6);
107
108 assert_eq!(labeller.labels_to_vertices.len(), 1);
109 assert_eq!(
110 *labeller.labels_to_vertices.iter().next().unwrap().1,
111 (0..7).collect::<BTreeSet<_>>()
112 );
113 assert_eq!(
114 *labeller.labels_to_edges.iter().next().unwrap().1,
115 vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]
116 .into_iter()
117 .collect::<BTreeSet<_>>()
118 );
119 }
120}