risingwave_expr_impl/aggregate/approx_count_distinct/
mod.rs1use std::collections::hash_map::DefaultHasher;
16use std::fmt::Debug;
17use std::hash::{Hash, Hasher};
18use std::ops::Range;
19
20use risingwave_common::array::{Op, StreamChunk};
21use risingwave_common::bail;
22use risingwave_common::row::Row;
23use risingwave_common::types::*;
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
26use risingwave_expr::{ExprError, Result, build_aggregate};
27
28use self::append_only::AppendOnlyBucket;
29use self::updatable::UpdatableBucket;
30
31mod append_only;
32mod updatable;
33
34const INDEX_BITS: u8 = 16; const NUM_OF_REGISTERS: usize = 1 << INDEX_BITS; const COUNT_BITS: u8 = 64 - INDEX_BITS; const LOG_COUNT_BITS: u8 = 6;
38
39const BIAS_CORRECTION: f64 = 0.7213 / (1. + (1.079 / NUM_OF_REGISTERS as f64));
42
43#[build_aggregate("approx_count_distinct(*) -> int8", state = "int8")]
45fn build_updatable(_agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
46 Ok(Box::new(UpdatableApproxCountDistinct))
47}
48
49#[build_aggregate("approx_count_distinct(*) -> int8", state = "int8[]", append_only)]
51fn build_append_only(_agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
52 Ok(Box::new(AppendOnlyApproxCountDistinct))
53}
54
55struct UpdatableApproxCountDistinct;
56
57#[async_trait::async_trait]
58impl AggregateFunction for UpdatableApproxCountDistinct {
59 fn return_type(&self) -> DataType {
60 DataType::Int64
61 }
62
63 fn create_state(&self) -> Result<AggregateState> {
64 Ok(AggregateState::Any(Box::<UpdatableRegisters>::default()))
65 }
66
67 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
68 let state = state.downcast_mut::<UpdatableRegisters>();
69 for (op, row) in input.rows() {
70 let retract = matches!(op, Op::Delete | Op::UpdateDelete);
71 if let Some(scalar) = row.datum_at(0) {
72 state.update(scalar, retract)?;
73 }
74 }
75 Ok(())
76 }
77
78 async fn update_range(
79 &self,
80 state: &mut AggregateState,
81 input: &StreamChunk,
82 range: Range<usize>,
83 ) -> Result<()> {
84 let state = state.downcast_mut::<UpdatableRegisters>();
85 for (op, row) in input.rows_in(range) {
86 let retract = matches!(op, Op::Delete | Op::UpdateDelete);
87 if let Some(scalar) = row.datum_at(0) {
88 state.update(scalar, retract)?;
89 }
90 }
91 Ok(())
92 }
93
94 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
95 let state = state.downcast_ref::<UpdatableRegisters>();
96 Ok(Some(state.calculate_result().into()))
97 }
98
99 fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
100 let state = state.downcast_ref::<UpdatableRegisters>();
101 Ok(Some(ScalarImpl::Int64(state.calculate_result())))
103 }
104
105 fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
106 let Some(ScalarImpl::Int64(initial_count)) = datum else {
108 return Err(ExprError::InvalidState("expect int8".into()));
109 };
110 Ok(AggregateState::Any(Box::new(UpdatableRegisters {
111 initial_count,
112 ..UpdatableRegisters::default()
113 })))
114 }
115}
116
117struct AppendOnlyApproxCountDistinct;
118
119#[async_trait::async_trait]
120impl AggregateFunction for AppendOnlyApproxCountDistinct {
121 fn return_type(&self) -> DataType {
122 DataType::Int64
123 }
124
125 fn create_state(&self) -> Result<AggregateState> {
126 Ok(AggregateState::Any(Box::<AppendOnlyRegisters>::default()))
127 }
128
129 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
130 let state = state.downcast_mut::<AppendOnlyRegisters>();
131 for (op, row) in input.rows() {
132 let retract = matches!(op, Op::Delete | Op::UpdateDelete);
133 if let Some(scalar) = row.datum_at(0) {
134 state.update(scalar, retract)?;
135 }
136 }
137 Ok(())
138 }
139
140 async fn update_range(
141 &self,
142 state: &mut AggregateState,
143 input: &StreamChunk,
144 range: Range<usize>,
145 ) -> Result<()> {
146 let state = state.downcast_mut::<AppendOnlyRegisters>();
147 for (op, row) in input.rows_in(range) {
148 let retract = matches!(op, Op::Delete | Op::UpdateDelete);
149 if let Some(scalar) = row.datum_at(0) {
150 state.update(scalar, retract)?;
151 }
152 }
153 Ok(())
154 }
155
156 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
157 let state = state.downcast_ref::<AppendOnlyRegisters>();
158 Ok(Some(state.calculate_result().into()))
159 }
160
161 fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
162 let reg = state.downcast_ref::<AppendOnlyRegisters>();
163
164 let buckets = ®.registers[..];
165 let result_len = (buckets.len() * LOG_COUNT_BITS as usize - 1) / (i64::BITS as usize) + 1;
166 let mut result = vec![0u64; result_len];
167 for (i, bucket_val) in buckets.iter().enumerate() {
168 let (start_idx, begin_bit, post_end_bit) = pos_in_serialized(i);
169 result[start_idx] |= (buckets[i].0 as u64) << begin_bit;
170 if post_end_bit > i64::BITS {
171 result[start_idx + 1] |= (bucket_val.0 as u64) >> (i64::BITS - begin_bit as u32);
172 }
173 }
174 Ok(Some(ScalarImpl::List(ListValue::from_iter(
175 result.into_iter().map(|v| v as i64),
176 ))))
177 }
178
179 fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
180 let scalar = datum.unwrap();
181 let list = scalar.as_list();
182 let bucket_num = list.len() * i64::BITS as usize / LOG_COUNT_BITS as usize;
183 let registers = (0..bucket_num)
184 .map(|i| {
185 let (start_idx, begin_bit, post_end_bit) = pos_in_serialized(i);
186 let val = list.get(start_idx).unwrap().unwrap().into_int64() as u64;
187 let v = if post_end_bit <= i64::BITS {
188 val << (i64::BITS - post_end_bit) >> (i64::BITS - LOG_COUNT_BITS as u32)
189 } else {
190 (val >> begin_bit)
191 + (((list.get(start_idx + 1).unwrap().unwrap().into_int64() as u64)
192 & ((1 << (post_end_bit - i64::BITS)) - 1))
193 << (i64::BITS - begin_bit as u32))
194 };
195 AppendOnlyBucket(v as u8)
196 })
197 .collect();
198 Ok(AggregateState::Any(Box::new(AppendOnlyRegisters {
199 registers,
200 initial_count: 0,
201 })))
202 }
203}
204
205#[derive(Debug, Clone)]
216struct Registers<B: Bucket> {
217 registers: Box<[B]>,
218 initial_count: i64,
221}
222
223type UpdatableRegisters = Registers<UpdatableBucket>;
224type AppendOnlyRegisters = Registers<AppendOnlyBucket>;
225
226trait Bucket: Debug + Default + Clone + EstimateSize + Send + Sync + 'static {
227 fn update(&mut self, index: u8, retract: bool) -> Result<()>;
230
231 fn max(&self) -> u8;
233}
234
235impl<B: Bucket> AggStateDyn for Registers<B> {}
236
237impl<B: Bucket> Default for Registers<B> {
238 fn default() -> Self {
239 Self {
240 registers: (0..NUM_OF_REGISTERS).map(|_| B::default()).collect(),
241 initial_count: 0,
242 }
243 }
244}
245
246impl<B: Bucket> Registers<B> {
247 fn update(&mut self, scalar_ref: ScalarRefImpl<'_>, retract: bool) -> Result<()> {
250 let hash = self.get_hash(scalar_ref);
251
252 let index = (hash as usize) & (NUM_OF_REGISTERS - 1); let count = self.count_hash(hash);
254
255 self.registers[index].update(count, retract)?;
256 Ok(())
257 }
258
259 fn get_hash(&self, scalar: ScalarRefImpl<'_>) -> u64 {
262 let mut hasher = DefaultHasher::new();
263 scalar.hash(&mut hasher);
264 hasher.finish()
265 }
266
267 fn count_hash(&self, mut hash: u64) -> u8 {
269 hash >>= INDEX_BITS; hash |= 1 << COUNT_BITS; (hash.trailing_zeros() + 1) as u8
273 }
274
275 fn calculate_result(&self) -> i64 {
277 let m = NUM_OF_REGISTERS as f64;
278 let mut mean = 0.0;
279
280 for bucket in &*self.registers {
282 let count = bucket.max();
283 mean += 1.0 / ((1 << count) as f64);
284 }
285
286 let raw_estimate = BIAS_CORRECTION * m * m / mean;
287
288 let answer = if raw_estimate <= 2.5 * m {
291 let mut zero_registers: f64 = 0.0;
292 for i in &*self.registers {
293 if i.max() == 0 {
294 zero_registers += 1.0;
295 }
296 }
297
298 if zero_registers == 0.0 {
299 raw_estimate
300 } else {
301 m * (m.log2() - (zero_registers.log2()))
302 }
303 } else {
304 raw_estimate
305 };
306
307 self.initial_count + answer as i64
308 }
309}
310
311impl<B: Bucket> From<Registers<B>> for i64 {
312 fn from(reg: Registers<B>) -> Self {
313 reg.calculate_result()
314 }
315}
316
317impl<B: Bucket> EstimateSize for Registers<B> {
318 fn estimated_heap_size(&self) -> usize {
319 self.registers.len() * std::mem::size_of::<B>()
320 }
321}
322
323fn pos_in_serialized(bucket_idx: usize) -> (usize, usize, u32) {
324 let start_idx = bucket_idx * LOG_COUNT_BITS as usize / i64::BITS as usize;
326 let begin_bit = bucket_idx * LOG_COUNT_BITS as usize % i64::BITS as usize;
327 let post_end_bit = begin_bit as u32 + LOG_COUNT_BITS as u32;
328 (start_idx, begin_bit, post_end_bit)
329}
330
331#[cfg(test)]
332mod tests {
333 use futures_util::FutureExt;
334 use risingwave_common::array::{Array, DataChunk, I32Array, StreamChunk};
335 use risingwave_expr::aggregate::{AggCall, build_append_only};
336
337 #[test]
338 fn test() {
339 let approx_count_distinct = build_append_only(&AggCall::from_pretty(
340 "(approx_count_distinct:int8 $0:int4)",
341 ))
342 .unwrap();
343
344 for range in [0..20000, 20000..30000, 30000..35000] {
345 let col = I32Array::from_iter(range.clone()).into_ref();
346 let input = StreamChunk::from(DataChunk::new(vec![col], range.len()));
347 let mut state = approx_count_distinct.create_state().unwrap();
348 approx_count_distinct
349 .update(&mut state, &input)
350 .now_or_never()
351 .unwrap()
352 .unwrap();
353 let count = approx_count_distinct
354 .get_result(&state)
355 .now_or_never()
356 .unwrap()
357 .unwrap()
358 .unwrap()
359 .into_int64() as usize;
360 let actual = range.len();
361 let expected_range = actual as f32 * 0.5..actual as f32 * 1.5;
364 if !expected_range.contains(&(count as f32)) {
365 panic!("approximate count {} not in {:?}", count, expected_range);
366 }
367 }
368 }
369}