risingwave_expr_impl/aggregate/
approx_percentile.rs1use std::collections::BTreeMap;
16use std::mem::size_of;
17use std::ops::Range;
18
19use bytes::{Buf, Bytes};
20use risingwave_common::array::*;
21use risingwave_common::bail;
22use risingwave_common::row::Row;
23use risingwave_common::types::*;
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
26use risingwave_expr::{Result, build_aggregate};
27
28#[build_aggregate("approx_percentile(float8) -> float8", state = "bytea")]
33fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
34 let quantile = agg.direct_args[0]
35 .literal()
36 .map(|x| (*x.as_float64()).into())
37 .unwrap();
38 let relative_error: f64 = agg.direct_args[1]
39 .literal()
40 .map(|x| (*x.as_float64()).into())
41 .unwrap();
42 if relative_error <= 0.0 || relative_error >= 1.0 {
43 bail!(
44 "relative_error must be in the range (0, 1), got {}",
45 relative_error
46 )
47 }
48 let base = (1.0 + relative_error) / (1.0 - relative_error);
49 Ok(Box::new(ApproxPercentile { quantile, base }))
50}
51
52pub struct ApproxPercentile {
53 quantile: f64,
54 base: f64,
55}
56
57type BucketCount = u64;
58type BucketId = i32;
59type Count = u64;
60
61#[derive(Debug, Default)]
62struct State {
63 count: BucketCount,
64 pos_buckets: BTreeMap<BucketId, Count>,
65 zeros: Count,
66 neg_buckets: BTreeMap<BucketId, Count>,
67}
68
69impl EstimateSize for State {
70 fn estimated_heap_size(&self) -> usize {
71 let count_size = size_of::<BucketCount>();
72 let pos_buckets_size =
73 self.pos_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
74 let zero_bucket_size = size_of::<Count>();
75 let neg_buckets_size =
76 self.neg_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
77 count_size + pos_buckets_size + zero_bucket_size + neg_buckets_size
78 }
79}
80
81impl AggStateDyn for State {}
82
83impl ApproxPercentile {
84 fn add_datum(&self, state: &mut State, op: Op, datum: DatumRef<'_>) {
85 if let Some(value) = datum {
86 let prim_value = value.into_float64().into_inner();
87 let (non_neg, abs_value) = if prim_value < 0.0 {
88 (false, -prim_value)
89 } else {
90 (true, prim_value)
91 };
92 let bucket_id = abs_value.log(self.base).ceil() as BucketId;
93 match op {
94 Op::Delete | Op::UpdateDelete => {
95 if abs_value == 0.0 {
96 state.zeros -= 1;
97 } else if non_neg {
98 let count = state.pos_buckets.entry(bucket_id).or_insert(0);
99 *count -= 1;
100 if *count == 0 {
101 state.pos_buckets.remove(&bucket_id);
102 }
103 } else {
104 let count = state.neg_buckets.entry(bucket_id).or_insert(0);
105 *count -= 1;
106 if *count == 0 {
107 state.neg_buckets.remove(&bucket_id);
108 }
109 }
110 state.count -= 1;
111 }
112 Op::Insert | Op::UpdateInsert => {
113 if abs_value == 0.0 {
114 state.zeros += 1;
115 } else if non_neg {
116 let count = state.pos_buckets.entry(bucket_id).or_insert(0);
117 *count += 1;
118 } else {
119 let count = state.neg_buckets.entry(bucket_id).or_insert(0);
120 *count += 1;
121 }
122 state.count += 1;
123 }
124 }
125 };
126 }
127}
128
129#[async_trait::async_trait]
130impl AggregateFunction for ApproxPercentile {
131 fn return_type(&self) -> DataType {
132 DataType::Float64
133 }
134
135 fn create_state(&self) -> Result<AggregateState> {
136 Ok(AggregateState::Any(Box::<State>::default()))
137 }
138
139 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
140 let state: &mut State = state.downcast_mut();
141 for (op, row) in input.rows() {
142 let datum = row.datum_at(0);
143 self.add_datum(state, op, datum);
144 }
145 Ok(())
146 }
147
148 async fn update_range(
149 &self,
150 state: &mut AggregateState,
151 input: &StreamChunk,
152 range: Range<usize>,
153 ) -> Result<()> {
154 let state = state.downcast_mut();
155 for (op, row) in input.rows_in(range) {
156 self.add_datum(state, op, row.datum_at(0));
157 }
158 Ok(())
159 }
160
161 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
164 let state = state.downcast_ref::<State>();
165 let quantile_count =
166 ((state.count.saturating_sub(1)) as f64 * self.quantile).floor() as u64;
167 let mut acc_count = 0;
168 for (bucket_id, count) in state.neg_buckets.iter().rev() {
169 acc_count += count;
170 if acc_count > quantile_count {
171 let approx_percentile = -2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
173 let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
174 return Ok(Datum::from(approx_percentile));
175 }
176 }
177 acc_count += state.zeros;
178 if acc_count > quantile_count {
179 return Ok(Datum::from(ScalarImpl::Float64(0.0.into())));
180 }
181 for (bucket_id, count) in &state.pos_buckets {
182 acc_count += count;
183 if acc_count > quantile_count {
184 let approx_percentile = 2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
186 let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
187 return Ok(Datum::from(approx_percentile));
188 }
189 }
190 return Ok(None);
191 }
192
193 fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
194 let state = state.downcast_ref::<State>();
195 let mut encoded_state = Vec::with_capacity(state.estimated_heap_size());
196 encoded_state.extend_from_slice(&state.count.to_be_bytes());
197 encoded_state.extend_from_slice(&state.zeros.to_be_bytes());
198 let neg_buckets_size = state.neg_buckets.len() as u64;
199 encoded_state.extend_from_slice(&neg_buckets_size.to_be_bytes());
200 for (bucket_id, count) in &state.neg_buckets {
201 encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
202 encoded_state.extend_from_slice(&count.to_be_bytes());
203 }
204 for (bucket_id, count) in &state.pos_buckets {
205 encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
206 encoded_state.extend_from_slice(&count.to_be_bytes());
207 }
208 let encoded_scalar = ScalarImpl::Bytea(encoded_state.into());
209 Ok(Datum::from(encoded_scalar))
210 }
211
212 fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
213 let mut state = State::default();
214 let Some(scalar_state) = datum else {
215 return Ok(AggregateState::Any(Box::new(state)));
216 };
217 let encoded_state: Box<[u8]> = scalar_state.into_bytea();
218 let mut buf = Bytes::from(encoded_state);
219 state.count = buf.get_u64();
220 state.zeros = buf.get_u64();
221 let neg_buckets_size = buf.get_u64();
222 for _ in 0..neg_buckets_size {
223 let bucket_id = buf.get_i32();
224 let count = buf.get_u64();
225 state.neg_buckets.insert(bucket_id, count);
226 }
227 while !buf.is_empty() {
228 let bucket_id = buf.get_i32();
229 let count = buf.get_u64();
230 state.pos_buckets.insert(bucket_id, count);
231 }
232 Ok(AggregateState::Any(Box::new(state)))
233 }
234}