risingwave_expr_impl/aggregate/approx_count_distinct/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::DefaultHasher;
16use std::fmt::Debug;
17use std::hash::{Hash, Hasher};
18use std::ops::Range;
19
20use risingwave_common::array::{Op, StreamChunk};
21use risingwave_common::bail;
22use risingwave_common::row::Row;
23use risingwave_common::types::*;
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
26use risingwave_expr::{ExprError, Result, build_aggregate};
27
28use self::append_only::AppendOnlyBucket;
29use self::updatable::UpdatableBucket;
30
31mod append_only;
32mod updatable;
33
34const INDEX_BITS: u8 = 16; // number of bits used for finding the index of each 64-bit hash
35const NUM_OF_REGISTERS: usize = 1 << INDEX_BITS; // number of indices available
36const COUNT_BITS: u8 = 64 - INDEX_BITS; // number of non-index bits in each 64-bit hash
37const LOG_COUNT_BITS: u8 = 6;
38
39// Approximation for bias correction for 16384 registers. See "HyperLogLog: the analysis of a
40// near-optimal cardinality estimation algorithm" by Philippe Flajolet et al.
41const BIAS_CORRECTION: f64 = 0.7213 / (1. + (1.079 / NUM_OF_REGISTERS as f64));
42
43/// Count the approximate number of unique non-null values.
44#[build_aggregate("approx_count_distinct(*) -> int8", state = "int8")]
45fn build_updatable(_agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
46    Ok(Box::new(UpdatableApproxCountDistinct))
47}
48
49/// Count the approximate number of unique non-null values.
50#[build_aggregate("approx_count_distinct(*) -> int8", state = "int8[]", append_only)]
51fn build_append_only(_agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
52    Ok(Box::new(AppendOnlyApproxCountDistinct))
53}
54
55struct UpdatableApproxCountDistinct;
56
57#[async_trait::async_trait]
58impl AggregateFunction for UpdatableApproxCountDistinct {
59    fn return_type(&self) -> DataType {
60        DataType::Int64
61    }
62
63    fn create_state(&self) -> Result<AggregateState> {
64        Ok(AggregateState::Any(Box::<UpdatableRegisters>::default()))
65    }
66
67    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
68        let state = state.downcast_mut::<UpdatableRegisters>();
69        for (op, row) in input.rows() {
70            let retract = matches!(op, Op::Delete | Op::UpdateDelete);
71            if let Some(scalar) = row.datum_at(0) {
72                state.update(scalar, retract)?;
73            }
74        }
75        Ok(())
76    }
77
78    async fn update_range(
79        &self,
80        state: &mut AggregateState,
81        input: &StreamChunk,
82        range: Range<usize>,
83    ) -> Result<()> {
84        let state = state.downcast_mut::<UpdatableRegisters>();
85        for (op, row) in input.rows_in(range) {
86            let retract = matches!(op, Op::Delete | Op::UpdateDelete);
87            if let Some(scalar) = row.datum_at(0) {
88                state.update(scalar, retract)?;
89            }
90        }
91        Ok(())
92    }
93
94    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
95        let state = state.downcast_ref::<UpdatableRegisters>();
96        Ok(Some(state.calculate_result().into()))
97    }
98
99    fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
100        let state = state.downcast_ref::<UpdatableRegisters>();
101        // FIXME: store state of updatable registers properly
102        Ok(Some(ScalarImpl::Int64(state.calculate_result())))
103    }
104
105    fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
106        // FIXME: restore state of updatable registers properly
107        let Some(ScalarImpl::Int64(initial_count)) = datum else {
108            return Err(ExprError::InvalidState("expect int8".into()));
109        };
110        Ok(AggregateState::Any(Box::new(UpdatableRegisters {
111            initial_count,
112            ..UpdatableRegisters::default()
113        })))
114    }
115}
116
117struct AppendOnlyApproxCountDistinct;
118
119#[async_trait::async_trait]
120impl AggregateFunction for AppendOnlyApproxCountDistinct {
121    fn return_type(&self) -> DataType {
122        DataType::Int64
123    }
124
125    fn create_state(&self) -> Result<AggregateState> {
126        Ok(AggregateState::Any(Box::<AppendOnlyRegisters>::default()))
127    }
128
129    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
130        let state = state.downcast_mut::<AppendOnlyRegisters>();
131        for (op, row) in input.rows() {
132            let retract = matches!(op, Op::Delete | Op::UpdateDelete);
133            if let Some(scalar) = row.datum_at(0) {
134                state.update(scalar, retract)?;
135            }
136        }
137        Ok(())
138    }
139
140    async fn update_range(
141        &self,
142        state: &mut AggregateState,
143        input: &StreamChunk,
144        range: Range<usize>,
145    ) -> Result<()> {
146        let state = state.downcast_mut::<AppendOnlyRegisters>();
147        for (op, row) in input.rows_in(range) {
148            let retract = matches!(op, Op::Delete | Op::UpdateDelete);
149            if let Some(scalar) = row.datum_at(0) {
150                state.update(scalar, retract)?;
151            }
152        }
153        Ok(())
154    }
155
156    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
157        let state = state.downcast_ref::<AppendOnlyRegisters>();
158        Ok(Some(state.calculate_result().into()))
159    }
160
161    fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
162        let reg = state.downcast_ref::<AppendOnlyRegisters>();
163
164        let buckets = &reg.registers[..];
165        let result_len = (buckets.len() * LOG_COUNT_BITS as usize - 1) / (i64::BITS as usize) + 1;
166        let mut result = vec![0u64; result_len];
167        for (i, bucket_val) in buckets.iter().enumerate() {
168            let (start_idx, begin_bit, post_end_bit) = pos_in_serialized(i);
169            result[start_idx] |= (buckets[i].0 as u64) << begin_bit;
170            if post_end_bit > i64::BITS {
171                result[start_idx + 1] |= (bucket_val.0 as u64) >> (i64::BITS - begin_bit as u32);
172            }
173        }
174        Ok(Some(ScalarImpl::List(ListValue::from_iter(
175            result.into_iter().map(|v| v as i64),
176        ))))
177    }
178
179    fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
180        let scalar = datum.unwrap();
181        let list = scalar.as_list();
182        let bucket_num = list.len() * i64::BITS as usize / LOG_COUNT_BITS as usize;
183        let registers = (0..bucket_num)
184            .map(|i| {
185                let (start_idx, begin_bit, post_end_bit) = pos_in_serialized(i);
186                let val = list.get(start_idx).unwrap().unwrap().into_int64() as u64;
187                let v = if post_end_bit <= i64::BITS {
188                    val << (i64::BITS - post_end_bit) >> (i64::BITS - LOG_COUNT_BITS as u32)
189                } else {
190                    (val >> begin_bit)
191                        + (((list.get(start_idx + 1).unwrap().unwrap().into_int64() as u64)
192                            & ((1 << (post_end_bit - i64::BITS)) - 1))
193                            << (i64::BITS - begin_bit as u32))
194                };
195                AppendOnlyBucket(v as u8)
196            })
197            .collect();
198        Ok(AggregateState::Any(Box::new(AppendOnlyRegisters {
199            registers,
200            initial_count: 0,
201        })))
202    }
203}
204
205/// Approximates the count of non-null rows using a modified version of the `HyperLogLog` algorithm.
206/// Each `Bucket` stores a count of how many hash values have x trailing zeroes for all x from 1-64.
207/// This allows the algorithm to support insertion and deletion, but uses up more memory and limits
208/// the number of rows that can be counted.
209///
210/// This can count up to a total of 2^64 unduplicated rows.
211///
212/// The estimation error for `HyperLogLog` is 1.04/sqrt(num of registers). With 2^16 registers this
213/// is ~1/256, or about 0.4%. The memory usage for the default choice of parameters is about
214/// (1024 + 24) bits * 2^16 buckets, which is about 8.58 MB.
215#[derive(Debug, Clone)]
216struct Registers<B: Bucket> {
217    registers: Box<[B]>,
218    // FIXME: Currently we only store the count result (i64) as the state of updatable register.
219    // This is not correct, because the state should be the registers themselves.
220    initial_count: i64,
221}
222
223type UpdatableRegisters = Registers<UpdatableBucket>;
224type AppendOnlyRegisters = Registers<AppendOnlyBucket>;
225
226trait Bucket: Debug + Default + Clone + EstimateSize + Send + Sync + 'static {
227    /// Increments or decrements the bucket at `index` depending on the state of `retract`.
228    /// Returns an Error if `index` is invalid or if inserting will cause an overflow in the bucket.
229    fn update(&mut self, index: u8, retract: bool) -> Result<()>;
230
231    /// Gets the number of the maximum bucket which has a count greater than zero.
232    fn max(&self) -> u8;
233}
234
235impl<B: Bucket> AggStateDyn for Registers<B> {}
236
237impl<B: Bucket> Default for Registers<B> {
238    fn default() -> Self {
239        Self {
240            registers: (0..NUM_OF_REGISTERS).map(|_| B::default()).collect(),
241            initial_count: 0,
242        }
243    }
244}
245
246impl<B: Bucket> Registers<B> {
247    /// Adds the count of the datum's hash into the register, if it is greater than the existing
248    /// count at the register
249    fn update(&mut self, scalar_ref: ScalarRefImpl<'_>, retract: bool) -> Result<()> {
250        let hash = self.get_hash(scalar_ref);
251
252        let index = (hash as usize) & (NUM_OF_REGISTERS - 1); // Index is based on last few bits
253        let count = self.count_hash(hash);
254
255        self.registers[index].update(count, retract)?;
256        Ok(())
257    }
258
259    /// Calculate the hash of the `scalar` using Rust's default hasher
260    /// Perhaps a different hash like Murmur2 could be used instead for optimization?
261    fn get_hash(&self, scalar: ScalarRefImpl<'_>) -> u64 {
262        let mut hasher = DefaultHasher::new();
263        scalar.hash(&mut hasher);
264        hasher.finish()
265    }
266
267    /// Counts the number of trailing zeroes plus 1 in the non-index bits of the hash
268    fn count_hash(&self, mut hash: u64) -> u8 {
269        hash >>= INDEX_BITS; // Ignore bits used as index for the hash
270        hash |= 1 << COUNT_BITS; // To allow hash to terminate if it is all 0s
271
272        (hash.trailing_zeros() + 1) as u8
273    }
274
275    /// Calculates the bias-corrected harmonic mean of the registers to get the approximate count
276    fn calculate_result(&self) -> i64 {
277        let m = NUM_OF_REGISTERS as f64;
278        let mut mean = 0.0;
279
280        // Get harmonic mean of all the counts in results
281        for bucket in &*self.registers {
282            let count = bucket.max();
283            mean += 1.0 / ((1 << count) as f64);
284        }
285
286        let raw_estimate = BIAS_CORRECTION * m * m / mean;
287
288        // If raw_estimate is not much bigger than m and some registers have value 0, set answer to
289        // m * log(m/V) where V is the number of registers with value 0
290        let answer = if raw_estimate <= 2.5 * m {
291            let mut zero_registers: f64 = 0.0;
292            for i in &*self.registers {
293                if i.max() == 0 {
294                    zero_registers += 1.0;
295                }
296            }
297
298            if zero_registers == 0.0 {
299                raw_estimate
300            } else {
301                m * (m.log2() - (zero_registers.log2()))
302            }
303        } else {
304            raw_estimate
305        };
306
307        self.initial_count + answer as i64
308    }
309}
310
311impl<B: Bucket> From<Registers<B>> for i64 {
312    fn from(reg: Registers<B>) -> Self {
313        reg.calculate_result()
314    }
315}
316
317impl<B: Bucket> EstimateSize for Registers<B> {
318    fn estimated_heap_size(&self) -> usize {
319        self.registers.len() * std::mem::size_of::<B>()
320    }
321}
322
323fn pos_in_serialized(bucket_idx: usize) -> (usize, usize, u32) {
324    // rust compiler will optimize for us
325    let start_idx = bucket_idx * LOG_COUNT_BITS as usize / i64::BITS as usize;
326    let begin_bit = bucket_idx * LOG_COUNT_BITS as usize % i64::BITS as usize;
327    let post_end_bit = begin_bit as u32 + LOG_COUNT_BITS as u32;
328    (start_idx, begin_bit, post_end_bit)
329}
330
331#[cfg(test)]
332mod tests {
333    use futures_util::FutureExt;
334    use risingwave_common::array::{Array, DataChunk, I32Array, StreamChunk};
335    use risingwave_expr::aggregate::{AggCall, build_append_only};
336
337    #[test]
338    fn test() {
339        let approx_count_distinct = build_append_only(&AggCall::from_pretty(
340            "(approx_count_distinct:int8 $0:int4)",
341        ))
342        .unwrap();
343
344        for range in [0..20000, 20000..30000, 30000..35000] {
345            let col = I32Array::from_iter(range.clone()).into_ref();
346            let input = StreamChunk::from(DataChunk::new(vec![col], range.len()));
347            let mut state = approx_count_distinct.create_state().unwrap();
348            approx_count_distinct
349                .update(&mut state, &input)
350                .now_or_never()
351                .unwrap()
352                .unwrap();
353            let count = approx_count_distinct
354                .get_result(&state)
355                .now_or_never()
356                .unwrap()
357                .unwrap()
358                .unwrap()
359                .into_int64() as usize;
360            let actual = range.len();
361            // FIXME: the error is too large?
362            // assert!((actual as f32 * 0.9..actual as f32 * 1.1).contains(&(count as f32)));
363            let expected_range = actual as f32 * 0.5..actual as f32 * 1.5;
364            if !expected_range.contains(&(count as f32)) {
365                panic!("approximate count {} not in {:?}", count, expected_range);
366            }
367        }
368    }
369}