risingwave_expr_impl/aggregate/
percentile_cont.rs1use std::ops::Range;
16
17use risingwave_common::array::*;
18use risingwave_common::row::Row;
19use risingwave_common::types::*;
20use risingwave_common_estimate_size::EstimateSize;
21use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
22use risingwave_expr::{Result, build_aggregate};
23
24#[build_aggregate("percentile_cont(float8) -> float8")]
64fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
65 let fraction = agg.direct_args[0]
66 .literal()
67 .map(|x| (*x.as_float64()).into());
68 Ok(Box::new(PercentileCont { fraction }))
69}
70
71pub struct PercentileCont {
72 fraction: Option<f64>,
73}
74
75#[derive(Debug, Default, EstimateSize)]
76struct State(Vec<f64>);
77
78impl AggStateDyn for State {}
79
80impl PercentileCont {
81 fn add_datum(&self, state: &mut State, datum_ref: DatumRef<'_>) {
82 if let Some(datum) = datum_ref.to_owned_datum() {
83 state.0.push((*datum.as_float64()).into());
84 }
85 }
86}
87
88#[async_trait::async_trait]
89impl AggregateFunction for PercentileCont {
90 fn return_type(&self) -> DataType {
91 DataType::Float64
92 }
93
94 fn create_state(&self) -> Result<AggregateState> {
95 Ok(AggregateState::Any(Box::<State>::default()))
96 }
97
98 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
99 let state = state.downcast_mut();
100 for (_, row) in input.rows() {
101 self.add_datum(state, row.datum_at(0));
102 }
103 Ok(())
104 }
105
106 async fn update_range(
107 &self,
108 state: &mut AggregateState,
109 input: &StreamChunk,
110 range: Range<usize>,
111 ) -> Result<()> {
112 let state = state.downcast_mut();
113 for (_, row) in input.rows_in(range) {
114 self.add_datum(state, row.datum_at(0));
115 }
116 Ok(())
117 }
118
119 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
120 let state = &state.downcast_ref::<State>().0;
121 Ok(
122 if let Some(fraction) = self.fraction
123 && !state.is_empty()
124 {
125 let rn = fraction * (state.len() - 1) as f64;
126 let crn = f64::ceil(rn);
127 let frn = f64::floor(rn);
128 let result = if crn == frn {
129 state[crn as usize]
130 } else {
131 (crn - rn) * state[frn as usize] + (rn - frn) * state[crn as usize]
132 };
133 Some(result.into())
134 } else {
135 None
136 },
137 )
138 }
139}