risingwave_expr_impl/aggregate/approx_count_distinct/
updatable.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::*;
16
17#[derive(Clone, Default, Debug)]
18struct SparseCount {
19    inner: Vec<(u8, u64)>,
20}
21
22impl SparseCount {
23    fn new() -> Self {
24        Self {
25            inner: Vec::default(),
26        }
27    }
28
29    fn get(&self, k: u8) -> u64 {
30        for (key, count) in &self.inner {
31            if *key == k {
32                return *count;
33            }
34            if *key > k {
35                break;
36            }
37        }
38        0
39    }
40
41    fn add(&mut self, k: u8) -> bool {
42        let mut last = 0;
43        for (key, count) in &mut self.inner {
44            if *key == k {
45                *count += 1;
46                return true;
47            }
48            if *key > k {
49                break;
50            }
51            last += 1;
52        }
53        self.inner.insert(last, (k, 1));
54        false
55    }
56
57    fn subtract(&mut self, k: u8) -> bool {
58        for (i, (key, count)) in self.inner.iter_mut().enumerate() {
59            if *key == k {
60                *count -= 1;
61                if *count == 0 {
62                    // delete the count
63                    self.inner.remove(i);
64                }
65                return true;
66            }
67            if *key > k {
68                break;
69            }
70        }
71        false
72    }
73
74    fn is_empty(&self) -> bool {
75        self.inner.len() == 0
76    }
77
78    fn last_key(&self) -> u8 {
79        assert!(!self.is_empty());
80        self.inner.last().unwrap().0
81    }
82}
83
84impl EstimateSize for SparseCount {
85    fn estimated_heap_size(&self) -> usize {
86        self.inner.capacity() * std::mem::size_of::<(u8, u64)>()
87    }
88}
89
90#[derive(Clone, Debug, EstimateSize)]
91pub(super) struct UpdatableBucket<const DENSE_BITS: usize = 16> {
92    dense_counts: [u64; DENSE_BITS],
93    sparse_counts: SparseCount,
94}
95
96impl<const DENSE_BITS: usize> UpdatableBucket<DENSE_BITS> {
97    fn get_bucket(&self, index: u8) -> Result<u64> {
98        if index > 64 || index == 0 {
99            bail!("HyperLogLog: Invalid bucket index");
100        }
101
102        if index > DENSE_BITS as u8 {
103            Ok(self.sparse_counts.get(index))
104        } else {
105            Ok(self.dense_counts[index as usize - 1])
106        }
107    }
108}
109
110impl<const DENSE_BITS: usize> Default for UpdatableBucket<DENSE_BITS> {
111    fn default() -> Self {
112        Self {
113            dense_counts: [0u64; DENSE_BITS],
114            sparse_counts: SparseCount::new(),
115        }
116    }
117}
118
119impl<const DENSE_BITS: usize> Bucket for UpdatableBucket<DENSE_BITS> {
120    fn update(&mut self, index: u8, retract: bool) -> Result<()> {
121        if index > 64 || index == 0 {
122            bail!("HyperLogLog: Invalid bucket index");
123        }
124
125        let count = self.get_bucket(index)?;
126
127        if !retract {
128            if index > DENSE_BITS as u8 {
129                self.sparse_counts.add(index);
130            } else if index >= 1 {
131                if count == u64::MAX {
132                    bail!(
133                        "HyperLogLog: Count exceeds maximum bucket value.\
134                        Your data stream may have too many repeated values or too large a\
135                        cardinality for approx_count_distinct to handle (max: 2^64 - 1)"
136                    );
137                }
138                self.dense_counts[index as usize - 1] = count + 1;
139            }
140        } else {
141            // We don't have to worry about the user deleting nonexistent elements, so the counts
142            // can never go below 0.
143            if index > DENSE_BITS as u8 {
144                self.sparse_counts.subtract(index);
145            } else if index >= 1 {
146                self.dense_counts[index as usize - 1] = count - 1;
147            }
148        }
149
150        Ok(())
151    }
152
153    fn max(&self) -> u8 {
154        if !self.sparse_counts.is_empty() {
155            return self.sparse_counts.last_key();
156        }
157        for i in (0..DENSE_BITS).rev() {
158            if self.dense_counts[i] > 0 {
159                return i as u8 + 1;
160            }
161        }
162        0
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_streaming_approx_count_distinct_insert_and_delete() {
172        // sparse case
173        test_streaming_approx_count_distinct_insert_and_delete_inner::<0>();
174        // dense case
175        test_streaming_approx_count_distinct_insert_and_delete_inner::<4>();
176    }
177
178    fn test_streaming_approx_count_distinct_insert_and_delete_inner<const DENSE_BITS: usize>() {
179        let mut agg = Registers::<UpdatableBucket<DENSE_BITS>>::default();
180        assert_eq!(agg.calculate_result(), 0);
181
182        agg.update(1.into(), false).unwrap();
183        agg.update(2.into(), false).unwrap();
184        agg.update(3.into(), false).unwrap();
185        assert_eq!(agg.calculate_result(), 4);
186
187        agg.update(3.into(), false).unwrap();
188        assert_eq!(agg.calculate_result(), 4);
189
190        agg.update(3.into(), true).unwrap();
191        agg.update(3.into(), true).unwrap();
192        agg.update(1.into(), true).unwrap();
193        agg.update(2.into(), true).unwrap();
194        assert_eq!(agg.calculate_result(), 0);
195    }
196
197    #[test]
198    fn test_register_bucket_get_and_update() {
199        // sparse case
200        test_register_bucket_get_and_update_inner::<0>();
201        // dense case
202        test_register_bucket_get_and_update_inner::<4>();
203    }
204
205    /// In this test case, we use `1_000_000` distinct values to ensure there is enough samples.
206    /// Theoretically, we need at least 2.5 * m samples. (m is the registers size which is equal to
207    /// 2^`INDEX_BITS`, by default `INDEX_BITS` is 16) The error can be estimated as 1.04 /
208    /// sqrt(m) which is approximately equal to 0.004, So we use 0.01 to make sure we can bound the
209    /// error.
210    #[test]
211    fn test_error_ratio() {
212        let mut agg = Registers::<UpdatableBucket<16>>::default();
213        assert_eq!(agg.calculate_result(), 0);
214        let actual_ndv = 1000000;
215        for i in 0..1000000 {
216            for _ in 0..3 {
217                agg.update(i.into(), false).unwrap();
218            }
219        }
220
221        let estimation = agg.calculate_result();
222        let error_ratio = ((estimation - actual_ndv) as f64 / actual_ndv as f64).abs();
223        assert!(error_ratio < 0.01);
224    }
225
226    fn test_register_bucket_get_and_update_inner<const DENSE_BITS: usize>() {
227        let mut rb = UpdatableBucket::<DENSE_BITS>::default();
228
229        for i in 0..20 {
230            rb.update(i % 2 + 1, false).unwrap();
231        }
232        assert_eq!(rb.get_bucket(1).unwrap(), 10);
233        assert_eq!(rb.get_bucket(2).unwrap(), 10);
234
235        rb.update(1, true).unwrap();
236        assert_eq!(rb.get_bucket(1).unwrap(), 9);
237        assert_eq!(rb.get_bucket(2).unwrap(), 10);
238
239        rb.update(64, false).unwrap();
240        assert_eq!(rb.get_bucket(64).unwrap(), 1);
241    }
242
243    #[test]
244    fn test_register_bucket_invalid_register() {
245        let mut rb = UpdatableBucket::<0>::default();
246
247        assert!(rb.get_bucket(0).is_err());
248        assert!(rb.get_bucket(65).is_err());
249        assert!(rb.update(0, false).is_err());
250        assert!(rb.update(65, false).is_err());
251    }
252}