risingwave_expr_impl/aggregate/
percentile_cont.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16
17use risingwave_common::array::*;
18use risingwave_common::row::Row;
19use risingwave_common::types::*;
20use risingwave_common_estimate_size::EstimateSize;
21use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
22use risingwave_expr::{Result, build_aggregate};
23
24/// Computes the continuous percentile, a value corresponding to the specified fraction within the
25/// ordered set of aggregated argument values. This will interpolate between adjacent input items if
26/// needed.
27///
28/// ```slt
29/// statement ok
30/// create table t(x int, y bigint, z real, w double, v varchar);
31///
32/// statement ok
33/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000');
34///
35/// query R
36/// select percentile_cont(0.45) within group (order by x desc) from t;
37/// ----
38/// 2.1
39///
40/// query R
41/// select percentile_cont(0.45) within group (order by y desc) from t;
42/// ----
43/// 21
44///
45/// query R
46/// select percentile_cont(0.45) within group (order by z desc) from t;
47/// ----
48/// 210
49///
50/// query R
51/// select percentile_cont(0.45) within group (order by w desc) from t;
52/// ----
53/// 2100
54///
55/// query R
56/// select percentile_cont(NULL) within group (order by w desc) from t;
57/// ----
58/// NULL
59///
60/// statement ok
61/// drop table t;
62/// ```
63#[build_aggregate("percentile_cont(float8) -> float8")]
64fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
65    let fraction = agg.direct_args[0]
66        .literal()
67        .map(|x| (*x.as_float64()).into());
68    Ok(Box::new(PercentileCont { fraction }))
69}
70
71pub struct PercentileCont {
72    fraction: Option<f64>,
73}
74
75#[derive(Debug, Default, EstimateSize)]
76struct State(Vec<f64>);
77
78impl AggStateDyn for State {}
79
80impl PercentileCont {
81    fn add_datum(&self, state: &mut State, datum_ref: DatumRef<'_>) {
82        if let Some(datum) = datum_ref.to_owned_datum() {
83            state.0.push((*datum.as_float64()).into());
84        }
85    }
86}
87
88#[async_trait::async_trait]
89impl AggregateFunction for PercentileCont {
90    fn return_type(&self) -> DataType {
91        DataType::Float64
92    }
93
94    fn create_state(&self) -> Result<AggregateState> {
95        Ok(AggregateState::Any(Box::<State>::default()))
96    }
97
98    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
99        let state = state.downcast_mut();
100        for (_, row) in input.rows() {
101            self.add_datum(state, row.datum_at(0));
102        }
103        Ok(())
104    }
105
106    async fn update_range(
107        &self,
108        state: &mut AggregateState,
109        input: &StreamChunk,
110        range: Range<usize>,
111    ) -> Result<()> {
112        let state = state.downcast_mut();
113        for (_, row) in input.rows_in(range) {
114            self.add_datum(state, row.datum_at(0));
115        }
116        Ok(())
117    }
118
119    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
120        let state = &state.downcast_ref::<State>().0;
121        Ok(
122            if let Some(fraction) = self.fraction
123                && !state.is_empty()
124            {
125                let rn = fraction * (state.len() - 1) as f64;
126                let crn = f64::ceil(rn);
127                let frn = f64::floor(rn);
128                let result = if crn == frn {
129                    state[crn as usize]
130                } else {
131                    (crn - rn) * state[frn as usize] + (rn - frn) * state[crn as usize]
132                };
133                Some(result.into())
134            } else {
135                None
136            },
137        )
138    }
139}