risingwave_expr_impl/aggregate/
approx_percentile.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::mem::size_of;
17use std::ops::Range;
18
19use bytes::{Buf, Bytes};
20use risingwave_common::array::*;
21use risingwave_common::bail;
22use risingwave_common::row::Row;
23use risingwave_common::types::*;
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
26use risingwave_expr::{Result, build_aggregate};
27
28/// TODO(kwannoel): for single phase agg, we can actually support `UDDSketch`.
29/// For two phase agg, we still use `DDSketch`.
30/// Then we also need to store the `relative_error` of the sketch, so we can report it
31/// in an internal table, if it changes.
32#[build_aggregate("approx_percentile(float8) -> float8", state = "bytea")]
33fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
34    let quantile = agg.direct_args[0]
35        .literal()
36        .map(|x| (*x.as_float64()).into())
37        .unwrap();
38    let relative_error: f64 = agg.direct_args[1]
39        .literal()
40        .map(|x| (*x.as_float64()).into())
41        .unwrap();
42    if relative_error <= 0.0 || relative_error >= 1.0 {
43        bail!(
44            "relative_error must be in the range (0, 1), got {}",
45            relative_error
46        )
47    }
48    let base = (1.0 + relative_error) / (1.0 - relative_error);
49    Ok(Box::new(ApproxPercentile { quantile, base }))
50}
51
52pub struct ApproxPercentile {
53    quantile: f64,
54    base: f64,
55}
56
57type BucketCount = u64;
58type BucketId = i32;
59type Count = u64;
60
61#[derive(Debug, Default)]
62struct State {
63    count: BucketCount,
64    pos_buckets: BTreeMap<BucketId, Count>,
65    zeros: Count,
66    neg_buckets: BTreeMap<BucketId, Count>,
67}
68
69impl EstimateSize for State {
70    fn estimated_heap_size(&self) -> usize {
71        let count_size = size_of::<BucketCount>();
72        let pos_buckets_size =
73            self.pos_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
74        let zero_bucket_size = size_of::<Count>();
75        let neg_buckets_size =
76            self.neg_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
77        count_size + pos_buckets_size + zero_bucket_size + neg_buckets_size
78    }
79}
80
81impl AggStateDyn for State {}
82
83impl ApproxPercentile {
84    fn add_datum(&self, state: &mut State, op: Op, datum: DatumRef<'_>) {
85        if let Some(value) = datum {
86            let prim_value = value.into_float64().into_inner();
87            let (non_neg, abs_value) = if prim_value < 0.0 {
88                (false, -prim_value)
89            } else {
90                (true, prim_value)
91            };
92            let bucket_id = abs_value.log(self.base).ceil() as BucketId;
93            match op {
94                Op::Delete | Op::UpdateDelete => {
95                    if abs_value == 0.0 {
96                        state.zeros -= 1;
97                    } else if non_neg {
98                        let count = state.pos_buckets.entry(bucket_id).or_insert(0);
99                        *count -= 1;
100                        if *count == 0 {
101                            state.pos_buckets.remove(&bucket_id);
102                        }
103                    } else {
104                        let count = state.neg_buckets.entry(bucket_id).or_insert(0);
105                        *count -= 1;
106                        if *count == 0 {
107                            state.neg_buckets.remove(&bucket_id);
108                        }
109                    }
110                    state.count -= 1;
111                }
112                Op::Insert | Op::UpdateInsert => {
113                    if abs_value == 0.0 {
114                        state.zeros += 1;
115                    } else if non_neg {
116                        let count = state.pos_buckets.entry(bucket_id).or_insert(0);
117                        *count += 1;
118                    } else {
119                        let count = state.neg_buckets.entry(bucket_id).or_insert(0);
120                        *count += 1;
121                    }
122                    state.count += 1;
123                }
124            }
125        };
126    }
127}
128
129#[async_trait::async_trait]
130impl AggregateFunction for ApproxPercentile {
131    fn return_type(&self) -> DataType {
132        DataType::Float64
133    }
134
135    fn create_state(&self) -> Result<AggregateState> {
136        Ok(AggregateState::Any(Box::<State>::default()))
137    }
138
139    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
140        let state: &mut State = state.downcast_mut();
141        for (op, row) in input.rows() {
142            let datum = row.datum_at(0);
143            self.add_datum(state, op, datum);
144        }
145        Ok(())
146    }
147
148    async fn update_range(
149        &self,
150        state: &mut AggregateState,
151        input: &StreamChunk,
152        range: Range<usize>,
153    ) -> Result<()> {
154        let state = state.downcast_mut();
155        for (op, row) in input.rows_in(range) {
156            self.add_datum(state, op, row.datum_at(0));
157        }
158        Ok(())
159    }
160
161    // TODO(kwannoel): Instead of iterating over all buckets, we can maintain the
162    // approximate quantile bucket on the fly.
163    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
164        let state = state.downcast_ref::<State>();
165        let quantile_count =
166            ((state.count.saturating_sub(1)) as f64 * self.quantile).floor() as u64;
167        let mut acc_count = 0;
168        for (bucket_id, count) in state.neg_buckets.iter().rev() {
169            acc_count += count;
170            if acc_count > quantile_count {
171                // approx value = -2 * y^i / (y + 1)
172                let approx_percentile = -2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
173                let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
174                return Ok(Datum::from(approx_percentile));
175            }
176        }
177        acc_count += state.zeros;
178        if acc_count > quantile_count {
179            return Ok(Datum::from(ScalarImpl::Float64(0.0.into())));
180        }
181        for (bucket_id, count) in &state.pos_buckets {
182            acc_count += count;
183            if acc_count > quantile_count {
184                // approx value = 2 * y^i / (y + 1)
185                let approx_percentile = 2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
186                let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
187                return Ok(Datum::from(approx_percentile));
188            }
189        }
190        return Ok(None);
191    }
192
193    fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
194        let state = state.downcast_ref::<State>();
195        let mut encoded_state = Vec::with_capacity(state.estimated_heap_size());
196        encoded_state.extend_from_slice(&state.count.to_be_bytes());
197        encoded_state.extend_from_slice(&state.zeros.to_be_bytes());
198        let neg_buckets_size = state.neg_buckets.len() as u64;
199        encoded_state.extend_from_slice(&neg_buckets_size.to_be_bytes());
200        for (bucket_id, count) in &state.neg_buckets {
201            encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
202            encoded_state.extend_from_slice(&count.to_be_bytes());
203        }
204        for (bucket_id, count) in &state.pos_buckets {
205            encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
206            encoded_state.extend_from_slice(&count.to_be_bytes());
207        }
208        let encoded_scalar = ScalarImpl::Bytea(encoded_state.into());
209        Ok(Datum::from(encoded_scalar))
210    }
211
212    fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
213        let mut state = State::default();
214        let Some(scalar_state) = datum else {
215            return Ok(AggregateState::Any(Box::new(state)));
216        };
217        let encoded_state: Box<[u8]> = scalar_state.into_bytea();
218        let mut buf = Bytes::from(encoded_state);
219        state.count = buf.get_u64();
220        state.zeros = buf.get_u64();
221        let neg_buckets_size = buf.get_u64();
222        for _ in 0..neg_buckets_size {
223            let bucket_id = buf.get_i32();
224            let count = buf.get_u64();
225            state.neg_buckets.insert(bucket_id, count);
226        }
227        while !buf.is_empty() {
228            let bucket_id = buf.get_i32();
229            let count = buf.get_u64();
230            state.pos_buckets.insert(bucket_id, count);
231        }
232        Ok(AggregateState::Any(Box::new(state)))
233    }
234}