risingwave_expr_impl/aggregate/approx_count_distinct/
updatable.rs1use super::*;
16
17#[derive(Clone, Default, Debug)]
18struct SparseCount {
19 inner: Vec<(u8, u64)>,
20}
21
22impl SparseCount {
23 fn new() -> Self {
24 Self {
25 inner: Vec::default(),
26 }
27 }
28
29 fn get(&self, k: u8) -> u64 {
30 for (key, count) in &self.inner {
31 if *key == k {
32 return *count;
33 }
34 if *key > k {
35 break;
36 }
37 }
38 0
39 }
40
41 fn add(&mut self, k: u8) -> bool {
42 let mut last = 0;
43 for (key, count) in &mut self.inner {
44 if *key == k {
45 *count += 1;
46 return true;
47 }
48 if *key > k {
49 break;
50 }
51 last += 1;
52 }
53 self.inner.insert(last, (k, 1));
54 false
55 }
56
57 fn subtract(&mut self, k: u8) -> bool {
58 for (i, (key, count)) in self.inner.iter_mut().enumerate() {
59 if *key == k {
60 *count -= 1;
61 if *count == 0 {
62 self.inner.remove(i);
64 }
65 return true;
66 }
67 if *key > k {
68 break;
69 }
70 }
71 false
72 }
73
74 fn is_empty(&self) -> bool {
75 self.inner.len() == 0
76 }
77
78 fn last_key(&self) -> u8 {
79 assert!(!self.is_empty());
80 self.inner.last().unwrap().0
81 }
82}
83
84impl EstimateSize for SparseCount {
85 fn estimated_heap_size(&self) -> usize {
86 self.inner.capacity() * std::mem::size_of::<(u8, u64)>()
87 }
88}
89
90#[derive(Clone, Debug, EstimateSize)]
91pub(super) struct UpdatableBucket<const DENSE_BITS: usize = 16> {
92 dense_counts: [u64; DENSE_BITS],
93 sparse_counts: SparseCount,
94}
95
96impl<const DENSE_BITS: usize> UpdatableBucket<DENSE_BITS> {
97 fn get_bucket(&self, index: u8) -> Result<u64> {
98 if index > 64 || index == 0 {
99 bail!("HyperLogLog: Invalid bucket index");
100 }
101
102 if index > DENSE_BITS as u8 {
103 Ok(self.sparse_counts.get(index))
104 } else {
105 Ok(self.dense_counts[index as usize - 1])
106 }
107 }
108}
109
110impl<const DENSE_BITS: usize> Default for UpdatableBucket<DENSE_BITS> {
111 fn default() -> Self {
112 Self {
113 dense_counts: [0u64; DENSE_BITS],
114 sparse_counts: SparseCount::new(),
115 }
116 }
117}
118
119impl<const DENSE_BITS: usize> Bucket for UpdatableBucket<DENSE_BITS> {
120 fn update(&mut self, index: u8, retract: bool) -> Result<()> {
121 if index > 64 || index == 0 {
122 bail!("HyperLogLog: Invalid bucket index");
123 }
124
125 let count = self.get_bucket(index)?;
126
127 if !retract {
128 if index > DENSE_BITS as u8 {
129 self.sparse_counts.add(index);
130 } else if index >= 1 {
131 if count == u64::MAX {
132 bail!(
133 "HyperLogLog: Count exceeds maximum bucket value.\
134 Your data stream may have too many repeated values or too large a\
135 cardinality for approx_count_distinct to handle (max: 2^64 - 1)"
136 );
137 }
138 self.dense_counts[index as usize - 1] = count + 1;
139 }
140 } else {
141 if index > DENSE_BITS as u8 {
144 self.sparse_counts.subtract(index);
145 } else if index >= 1 {
146 self.dense_counts[index as usize - 1] = count - 1;
147 }
148 }
149
150 Ok(())
151 }
152
153 fn max(&self) -> u8 {
154 if !self.sparse_counts.is_empty() {
155 return self.sparse_counts.last_key();
156 }
157 for i in (0..DENSE_BITS).rev() {
158 if self.dense_counts[i] > 0 {
159 return i as u8 + 1;
160 }
161 }
162 0
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_streaming_approx_count_distinct_insert_and_delete() {
172 test_streaming_approx_count_distinct_insert_and_delete_inner::<0>();
174 test_streaming_approx_count_distinct_insert_and_delete_inner::<4>();
176 }
177
178 fn test_streaming_approx_count_distinct_insert_and_delete_inner<const DENSE_BITS: usize>() {
179 let mut agg = Registers::<UpdatableBucket<DENSE_BITS>>::default();
180 assert_eq!(agg.calculate_result(), 0);
181
182 agg.update(1.into(), false).unwrap();
183 agg.update(2.into(), false).unwrap();
184 agg.update(3.into(), false).unwrap();
185 assert_eq!(agg.calculate_result(), 4);
186
187 agg.update(3.into(), false).unwrap();
188 assert_eq!(agg.calculate_result(), 4);
189
190 agg.update(3.into(), true).unwrap();
191 agg.update(3.into(), true).unwrap();
192 agg.update(1.into(), true).unwrap();
193 agg.update(2.into(), true).unwrap();
194 assert_eq!(agg.calculate_result(), 0);
195 }
196
197 #[test]
198 fn test_register_bucket_get_and_update() {
199 test_register_bucket_get_and_update_inner::<0>();
201 test_register_bucket_get_and_update_inner::<4>();
203 }
204
205 #[test]
211 fn test_error_ratio() {
212 let mut agg = Registers::<UpdatableBucket<16>>::default();
213 assert_eq!(agg.calculate_result(), 0);
214 let actual_ndv = 1000000;
215 for i in 0..1000000 {
216 for _ in 0..3 {
217 agg.update(i.into(), false).unwrap();
218 }
219 }
220
221 let estimation = agg.calculate_result();
222 let error_ratio = ((estimation - actual_ndv) as f64 / actual_ndv as f64).abs();
223 assert!(error_ratio < 0.01);
224 }
225
226 fn test_register_bucket_get_and_update_inner<const DENSE_BITS: usize>() {
227 let mut rb = UpdatableBucket::<DENSE_BITS>::default();
228
229 for i in 0..20 {
230 rb.update(i % 2 + 1, false).unwrap();
231 }
232 assert_eq!(rb.get_bucket(1).unwrap(), 10);
233 assert_eq!(rb.get_bucket(2).unwrap(), 10);
234
235 rb.update(1, true).unwrap();
236 assert_eq!(rb.get_bucket(1).unwrap(), 9);
237 assert_eq!(rb.get_bucket(2).unwrap(), 10);
238
239 rb.update(64, false).unwrap();
240 assert_eq!(rb.get_bucket(64).unwrap(), 1);
241 }
242
243 #[test]
244 fn test_register_bucket_invalid_register() {
245 let mut rb = UpdatableBucket::<0>::default();
246
247 assert!(rb.get_bucket(0).is_err());
248 assert!(rb.get_bucket(65).is_err());
249 assert!(rb.update(0, false).is_err());
250 assert!(rb.update(65, false).is_err());
251 }
252}