risingwave_expr_impl/aggregate/
percentile_disc.rs1use std::ops::Range;
16
17use risingwave_common::array::*;
18use risingwave_common::row::Row;
19use risingwave_common::types::*;
20use risingwave_common_estimate_size::EstimateSize;
21use risingwave_expr::aggregate::{
22 AggCall, AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFunction,
23};
24use risingwave_expr::{Result, build_aggregate};
25
26#[build_aggregate("percentile_disc(any) -> any")]
71fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
72 let fractions = agg.direct_args[0]
73 .literal()
74 .map(|x| (*x.as_float64()).into());
75 Ok(Box::new(PercentileDisc::new(
76 fractions,
77 agg.return_type.clone(),
78 )))
79}
80
81#[derive(Clone)]
82pub struct PercentileDisc {
83 fractions: Option<f64>,
84 return_type: DataType,
85}
86
87#[derive(Debug, Default)]
88struct State(Vec<ScalarImpl>);
89
90impl EstimateSize for State {
91 fn estimated_heap_size(&self) -> usize {
92 std::mem::size_of_val(self.0.as_slice())
93 }
94}
95
96impl AggStateDyn for State {}
97
98impl PercentileDisc {
99 pub fn new(fractions: Option<f64>, return_type: DataType) -> Self {
100 Self {
101 fractions,
102 return_type,
103 }
104 }
105
106 fn add_datum(&self, state: &mut State, datum_ref: DatumRef<'_>) {
107 if let Some(datum) = datum_ref.to_owned_datum() {
108 state.0.push(datum);
109 }
110 }
111}
112
113#[async_trait::async_trait]
114impl AggregateFunction for PercentileDisc {
115 fn return_type(&self) -> DataType {
116 self.return_type.clone()
117 }
118
119 fn create_state(&self) -> Result<AggregateState> {
120 Ok(AggregateState::Any(Box::<State>::default()))
121 }
122
123 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
124 let state = state.downcast_mut();
125 for (_, row) in input.rows() {
126 self.add_datum(state, row.datum_at(0));
127 }
128 Ok(())
129 }
130
131 async fn update_range(
132 &self,
133 state: &mut AggregateState,
134 input: &StreamChunk,
135 range: Range<usize>,
136 ) -> Result<()> {
137 let state = state.downcast_mut();
138 for (_, row) in input.rows_in(range) {
139 self.add_datum(state, row.datum_at(0));
140 }
141 Ok(())
142 }
143
144 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
145 let state = &state.downcast_ref::<State>().0;
146 Ok(
147 if let Some(fractions) = self.fractions
148 && !state.is_empty()
149 {
150 let idx = if fractions == 0.0 {
151 0
152 } else {
153 f64::ceil(fractions * state.len() as f64) as usize - 1
154 };
155 Some(state[idx].clone())
156 } else {
157 None
158 },
159 )
160 }
161}