risingwave_expr_impl/aggregate/
approx_percentile.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::mem::size_of;
17use std::ops::Range;
18
19use bytes::{Buf, Bytes};
20use risingwave_common::array::*;
21use risingwave_common::bail;
22use risingwave_common::row::Row;
23use risingwave_common::types::*;
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
26use risingwave_expr::{Result, build_aggregate};
27
28/// TODO(kwannoel): for single phase agg, we can actually support `UDDSketch`.
29/// For two phase agg, we still use `DDSketch`.
30/// Then we also need to store the `relative_error` of the sketch, so we can report it
31/// in an internal table, if it changes.
32#[build_aggregate("approx_percentile(float8) -> float8", state = "bytea")]
33fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
34    let quantile = agg.direct_args[0]
35        .literal()
36        .map(|x| (*x.as_float64()).into())
37        .unwrap();
38    let relative_error: f64 = agg.direct_args[1]
39        .literal()
40        .map(|x| (*x.as_float64()).into())
41        .unwrap();
42    if relative_error <= 0.0 || relative_error >= 1.0 {
43        bail!(
44            "relative_error must be in the range (0, 1), got {}",
45            relative_error
46        )
47    }
48    let base = (1.0 + relative_error) / (1.0 - relative_error);
49    Ok(Box::new(ApproxPercentile { quantile, base }))
50}
51
52#[allow(dead_code)]
53pub struct ApproxPercentile {
54    quantile: f64,
55    base: f64,
56}
57
58type BucketCount = u64;
59type BucketId = i32;
60type Count = u64;
61
62#[derive(Debug, Default)]
63struct State {
64    count: BucketCount,
65    pos_buckets: BTreeMap<BucketId, Count>,
66    zeros: Count,
67    neg_buckets: BTreeMap<BucketId, Count>,
68}
69
70impl EstimateSize for State {
71    fn estimated_heap_size(&self) -> usize {
72        let count_size = size_of::<BucketCount>();
73        let pos_buckets_size =
74            self.pos_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
75        let zero_bucket_size = size_of::<Count>();
76        let neg_buckets_size =
77            self.neg_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
78        count_size + pos_buckets_size + zero_bucket_size + neg_buckets_size
79    }
80}
81
82impl AggStateDyn for State {}
83
84impl ApproxPercentile {
85    fn add_datum(&self, state: &mut State, op: Op, datum: DatumRef<'_>) {
86        if let Some(value) = datum {
87            let prim_value = value.into_float64().into_inner();
88            let (non_neg, abs_value) = if prim_value < 0.0 {
89                (false, -prim_value)
90            } else {
91                (true, prim_value)
92            };
93            let bucket_id = abs_value.log(self.base).ceil() as BucketId;
94            match op {
95                Op::Delete | Op::UpdateDelete => {
96                    if abs_value == 0.0 {
97                        state.zeros -= 1;
98                    } else if non_neg {
99                        let count = state.pos_buckets.entry(bucket_id).or_insert(0);
100                        *count -= 1;
101                        if *count == 0 {
102                            state.pos_buckets.remove(&bucket_id);
103                        }
104                    } else {
105                        let count = state.neg_buckets.entry(bucket_id).or_insert(0);
106                        *count -= 1;
107                        if *count == 0 {
108                            state.neg_buckets.remove(&bucket_id);
109                        }
110                    }
111                    state.count -= 1;
112                }
113                Op::Insert | Op::UpdateInsert => {
114                    if abs_value == 0.0 {
115                        state.zeros += 1;
116                    } else if non_neg {
117                        let count = state.pos_buckets.entry(bucket_id).or_insert(0);
118                        *count += 1;
119                    } else {
120                        let count = state.neg_buckets.entry(bucket_id).or_insert(0);
121                        *count += 1;
122                    }
123                    state.count += 1;
124                }
125            }
126        };
127    }
128}
129
130#[async_trait::async_trait]
131impl AggregateFunction for ApproxPercentile {
132    fn return_type(&self) -> DataType {
133        DataType::Float64
134    }
135
136    fn create_state(&self) -> Result<AggregateState> {
137        Ok(AggregateState::Any(Box::<State>::default()))
138    }
139
140    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
141        let state: &mut State = state.downcast_mut();
142        for (op, row) in input.rows() {
143            let datum = row.datum_at(0);
144            self.add_datum(state, op, datum);
145        }
146        Ok(())
147    }
148
149    async fn update_range(
150        &self,
151        state: &mut AggregateState,
152        input: &StreamChunk,
153        range: Range<usize>,
154    ) -> Result<()> {
155        let state = state.downcast_mut();
156        for (op, row) in input.rows_in(range) {
157            self.add_datum(state, op, row.datum_at(0));
158        }
159        Ok(())
160    }
161
162    // TODO(kwannoel): Instead of iterating over all buckets, we can maintain the
163    // approximate quantile bucket on the fly.
164    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
165        let state = state.downcast_ref::<State>();
166        let quantile_count =
167            ((state.count.saturating_sub(1)) as f64 * self.quantile).floor() as u64;
168        let mut acc_count = 0;
169        for (bucket_id, count) in state.neg_buckets.iter().rev() {
170            acc_count += count;
171            if acc_count > quantile_count {
172                // approx value = -2 * y^i / (y + 1)
173                let approx_percentile = -2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
174                let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
175                return Ok(Datum::from(approx_percentile));
176            }
177        }
178        acc_count += state.zeros;
179        if acc_count > quantile_count {
180            return Ok(Datum::from(ScalarImpl::Float64(0.0.into())));
181        }
182        for (bucket_id, count) in &state.pos_buckets {
183            acc_count += count;
184            if acc_count > quantile_count {
185                // approx value = 2 * y^i / (y + 1)
186                let approx_percentile = 2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
187                let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
188                return Ok(Datum::from(approx_percentile));
189            }
190        }
191        return Ok(None);
192    }
193
194    fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
195        let state = state.downcast_ref::<State>();
196        let mut encoded_state = Vec::with_capacity(state.estimated_heap_size());
197        encoded_state.extend_from_slice(&state.count.to_be_bytes());
198        encoded_state.extend_from_slice(&state.zeros.to_be_bytes());
199        let neg_buckets_size = state.neg_buckets.len() as u64;
200        encoded_state.extend_from_slice(&neg_buckets_size.to_be_bytes());
201        for (bucket_id, count) in &state.neg_buckets {
202            encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
203            encoded_state.extend_from_slice(&count.to_be_bytes());
204        }
205        for (bucket_id, count) in &state.pos_buckets {
206            encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
207            encoded_state.extend_from_slice(&count.to_be_bytes());
208        }
209        let encoded_scalar = ScalarImpl::Bytea(encoded_state.into());
210        Ok(Datum::from(encoded_scalar))
211    }
212
213    fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
214        let mut state = State::default();
215        let Some(scalar_state) = datum else {
216            return Ok(AggregateState::Any(Box::new(state)));
217        };
218        let encoded_state: Box<[u8]> = scalar_state.into_bytea();
219        let mut buf = Bytes::from(encoded_state);
220        state.count = buf.get_u64();
221        state.zeros = buf.get_u64();
222        let neg_buckets_size = buf.get_u64();
223        for _ in 0..neg_buckets_size {
224            let bucket_id = buf.get_i32();
225            let count = buf.get_u64();
226            state.neg_buckets.insert(bucket_id, count);
227        }
228        while !buf.is_empty() {
229            let bucket_id = buf.get_i32();
230            let count = buf.get_u64();
231            state.pos_buckets.insert(bucket_id, count);
232        }
233        Ok(AggregateState::Any(Box::new(state)))
234    }
235}