risingwave_expr_impl/aggregate/
percentile_disc.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16
17use risingwave_common::array::*;
18use risingwave_common::row::Row;
19use risingwave_common::types::*;
20use risingwave_common_estimate_size::EstimateSize;
21use risingwave_expr::aggregate::{
22    AggCall, AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFunction,
23};
24use risingwave_expr::{Result, build_aggregate};
25
26/// Computes the discrete percentile, the first value within the ordered set of aggregated argument
27/// values whose position in the ordering equals or exceeds the specified fraction. The aggregated
28/// argument must be of a sortable type.
29///
30/// ```slt
31/// statement ok
32/// create table t(x int, y bigint, z real, w double, v varchar);
33///
34/// statement ok
35/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000');
36///
37/// query R
38/// select percentile_disc(0) within group (order by x) from t;
39/// ----
40/// 1
41///
42/// query R
43/// select percentile_disc(0.33) within group (order by y) from t;
44/// ----
45/// 10
46///
47/// query R
48/// select percentile_disc(0.34) within group (order by z) from t;
49/// ----
50/// 200
51///
52/// query R
53/// select percentile_disc(0.67) within group (order by w) from t
54/// ----
55/// 3000
56///
57/// query R
58/// select percentile_disc(1) within group (order by v) from t;
59/// ----
60/// 30000
61///
62/// query R
63/// select percentile_disc(NULL) within group (order by w) from t;
64/// ----
65/// NULL
66///
67/// statement ok
68/// drop table t;
69/// ```
70#[build_aggregate("percentile_disc(any) -> any")]
71fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
72    let fractions = agg.direct_args[0]
73        .literal()
74        .map(|x| (*x.as_float64()).into());
75    Ok(Box::new(PercentileDisc::new(
76        fractions,
77        agg.return_type.clone(),
78    )))
79}
80
81#[derive(Clone)]
82pub struct PercentileDisc {
83    fractions: Option<f64>,
84    return_type: DataType,
85}
86
87#[derive(Debug, Default)]
88struct State(Vec<ScalarImpl>);
89
90impl EstimateSize for State {
91    fn estimated_heap_size(&self) -> usize {
92        std::mem::size_of_val(self.0.as_slice())
93    }
94}
95
96impl AggStateDyn for State {}
97
98impl PercentileDisc {
99    pub fn new(fractions: Option<f64>, return_type: DataType) -> Self {
100        Self {
101            fractions,
102            return_type,
103        }
104    }
105
106    fn add_datum(&self, state: &mut State, datum_ref: DatumRef<'_>) {
107        if let Some(datum) = datum_ref.to_owned_datum() {
108            state.0.push(datum);
109        }
110    }
111}
112
113#[async_trait::async_trait]
114impl AggregateFunction for PercentileDisc {
115    fn return_type(&self) -> DataType {
116        self.return_type.clone()
117    }
118
119    fn create_state(&self) -> Result<AggregateState> {
120        Ok(AggregateState::Any(Box::<State>::default()))
121    }
122
123    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
124        let state = state.downcast_mut();
125        for (_, row) in input.rows() {
126            self.add_datum(state, row.datum_at(0));
127        }
128        Ok(())
129    }
130
131    async fn update_range(
132        &self,
133        state: &mut AggregateState,
134        input: &StreamChunk,
135        range: Range<usize>,
136    ) -> Result<()> {
137        let state = state.downcast_mut();
138        for (_, row) in input.rows_in(range) {
139            self.add_datum(state, row.datum_at(0));
140        }
141        Ok(())
142    }
143
144    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
145        let state = &state.downcast_ref::<State>().0;
146        Ok(
147            if let Some(fractions) = self.fractions
148                && !state.is_empty()
149            {
150                let idx = if fractions == 0.0 {
151                    0
152                } else {
153                    f64::ceil(fractions * state.len() as f64) as usize - 1
154                };
155                Some(state[idx].clone())
156            } else {
157                None
158            },
159        )
160    }
161}