risingwave_expr_impl/aggregate/
approx_percentile.rs1use std::collections::BTreeMap;
16use std::mem::size_of;
17use std::ops::Range;
18
19use bytes::{Buf, Bytes};
20use risingwave_common::array::*;
21use risingwave_common::bail;
22use risingwave_common::row::Row;
23use risingwave_common::types::*;
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
26use risingwave_expr::{Result, build_aggregate};
27
28#[build_aggregate("approx_percentile(float8) -> float8", state = "bytea")]
33fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
34 let quantile = agg.direct_args[0]
35 .literal()
36 .map(|x| (*x.as_float64()).into())
37 .unwrap();
38 let relative_error: f64 = agg.direct_args[1]
39 .literal()
40 .map(|x| (*x.as_float64()).into())
41 .unwrap();
42 if relative_error <= 0.0 || relative_error >= 1.0 {
43 bail!(
44 "relative_error must be in the range (0, 1), got {}",
45 relative_error
46 )
47 }
48 let base = (1.0 + relative_error) / (1.0 - relative_error);
49 Ok(Box::new(ApproxPercentile { quantile, base }))
50}
51
52#[allow(dead_code)]
53pub struct ApproxPercentile {
54 quantile: f64,
55 base: f64,
56}
57
58type BucketCount = u64;
59type BucketId = i32;
60type Count = u64;
61
62#[derive(Debug, Default)]
63struct State {
64 count: BucketCount,
65 pos_buckets: BTreeMap<BucketId, Count>,
66 zeros: Count,
67 neg_buckets: BTreeMap<BucketId, Count>,
68}
69
70impl EstimateSize for State {
71 fn estimated_heap_size(&self) -> usize {
72 let count_size = size_of::<BucketCount>();
73 let pos_buckets_size =
74 self.pos_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
75 let zero_bucket_size = size_of::<Count>();
76 let neg_buckets_size =
77 self.neg_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
78 count_size + pos_buckets_size + zero_bucket_size + neg_buckets_size
79 }
80}
81
82impl AggStateDyn for State {}
83
84impl ApproxPercentile {
85 fn add_datum(&self, state: &mut State, op: Op, datum: DatumRef<'_>) {
86 if let Some(value) = datum {
87 let prim_value = value.into_float64().into_inner();
88 let (non_neg, abs_value) = if prim_value < 0.0 {
89 (false, -prim_value)
90 } else {
91 (true, prim_value)
92 };
93 let bucket_id = abs_value.log(self.base).ceil() as BucketId;
94 match op {
95 Op::Delete | Op::UpdateDelete => {
96 if abs_value == 0.0 {
97 state.zeros -= 1;
98 } else if non_neg {
99 let count = state.pos_buckets.entry(bucket_id).or_insert(0);
100 *count -= 1;
101 if *count == 0 {
102 state.pos_buckets.remove(&bucket_id);
103 }
104 } else {
105 let count = state.neg_buckets.entry(bucket_id).or_insert(0);
106 *count -= 1;
107 if *count == 0 {
108 state.neg_buckets.remove(&bucket_id);
109 }
110 }
111 state.count -= 1;
112 }
113 Op::Insert | Op::UpdateInsert => {
114 if abs_value == 0.0 {
115 state.zeros += 1;
116 } else if non_neg {
117 let count = state.pos_buckets.entry(bucket_id).or_insert(0);
118 *count += 1;
119 } else {
120 let count = state.neg_buckets.entry(bucket_id).or_insert(0);
121 *count += 1;
122 }
123 state.count += 1;
124 }
125 }
126 };
127 }
128}
129
130#[async_trait::async_trait]
131impl AggregateFunction for ApproxPercentile {
132 fn return_type(&self) -> DataType {
133 DataType::Float64
134 }
135
136 fn create_state(&self) -> Result<AggregateState> {
137 Ok(AggregateState::Any(Box::<State>::default()))
138 }
139
140 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
141 let state: &mut State = state.downcast_mut();
142 for (op, row) in input.rows() {
143 let datum = row.datum_at(0);
144 self.add_datum(state, op, datum);
145 }
146 Ok(())
147 }
148
149 async fn update_range(
150 &self,
151 state: &mut AggregateState,
152 input: &StreamChunk,
153 range: Range<usize>,
154 ) -> Result<()> {
155 let state = state.downcast_mut();
156 for (op, row) in input.rows_in(range) {
157 self.add_datum(state, op, row.datum_at(0));
158 }
159 Ok(())
160 }
161
162 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
165 let state = state.downcast_ref::<State>();
166 let quantile_count =
167 ((state.count.saturating_sub(1)) as f64 * self.quantile).floor() as u64;
168 let mut acc_count = 0;
169 for (bucket_id, count) in state.neg_buckets.iter().rev() {
170 acc_count += count;
171 if acc_count > quantile_count {
172 let approx_percentile = -2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
174 let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
175 return Ok(Datum::from(approx_percentile));
176 }
177 }
178 acc_count += state.zeros;
179 if acc_count > quantile_count {
180 return Ok(Datum::from(ScalarImpl::Float64(0.0.into())));
181 }
182 for (bucket_id, count) in &state.pos_buckets {
183 acc_count += count;
184 if acc_count > quantile_count {
185 let approx_percentile = 2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
187 let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
188 return Ok(Datum::from(approx_percentile));
189 }
190 }
191 return Ok(None);
192 }
193
194 fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
195 let state = state.downcast_ref::<State>();
196 let mut encoded_state = Vec::with_capacity(state.estimated_heap_size());
197 encoded_state.extend_from_slice(&state.count.to_be_bytes());
198 encoded_state.extend_from_slice(&state.zeros.to_be_bytes());
199 let neg_buckets_size = state.neg_buckets.len() as u64;
200 encoded_state.extend_from_slice(&neg_buckets_size.to_be_bytes());
201 for (bucket_id, count) in &state.neg_buckets {
202 encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
203 encoded_state.extend_from_slice(&count.to_be_bytes());
204 }
205 for (bucket_id, count) in &state.pos_buckets {
206 encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
207 encoded_state.extend_from_slice(&count.to_be_bytes());
208 }
209 let encoded_scalar = ScalarImpl::Bytea(encoded_state.into());
210 Ok(Datum::from(encoded_scalar))
211 }
212
213 fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
214 let mut state = State::default();
215 let Some(scalar_state) = datum else {
216 return Ok(AggregateState::Any(Box::new(state)));
217 };
218 let encoded_state: Box<[u8]> = scalar_state.into_bytea();
219 let mut buf = Bytes::from(encoded_state);
220 state.count = buf.get_u64();
221 state.zeros = buf.get_u64();
222 let neg_buckets_size = buf.get_u64();
223 for _ in 0..neg_buckets_size {
224 let bucket_id = buf.get_i32();
225 let count = buf.get_u64();
226 state.neg_buckets.insert(bucket_id, count);
227 }
228 while !buf.is_empty() {
229 let bucket_id = buf.get_i32();
230 let count = buf.get_u64();
231 state.pos_buckets.insert(bucket_id, count);
232 }
233 Ok(AggregateState::Any(Box::new(state)))
234 }
235}