risingwave_expr_impl/aggregate/
mode.rs1use std::ops::Range;
16
17use risingwave_common::array::*;
18use risingwave_common::row::Row;
19use risingwave_common::types::*;
20use risingwave_common_estimate_size::EstimateSize;
21use risingwave_expr::aggregate::{
22 AggCall, AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFunction,
23};
24use risingwave_expr::{Result, build_aggregate};
25
26#[build_aggregate("mode(any) -> any")]
27fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
28 Ok(Box::new(Mode {
29 return_type: agg.return_type.clone(),
30 }))
31}
32
33struct Mode {
69 return_type: DataType,
70}
71
72#[derive(Debug, Clone, EstimateSize, Default)]
73struct State {
74 cur_mode: Datum,
75 cur_mode_freq: usize,
76 cur_item: Datum,
77 cur_item_freq: usize,
78}
79
80impl AggStateDyn for State {}
81
82impl State {
83 fn add_datum(&mut self, datum_ref: DatumRef<'_>) {
84 let datum = datum_ref.to_owned_datum();
85 if datum.is_some() && self.cur_item == datum {
86 self.cur_item_freq += 1;
87 } else if datum.is_some() {
88 self.cur_item = datum;
89 self.cur_item_freq = 1;
90 }
91 if self.cur_item_freq > self.cur_mode_freq {
92 self.cur_mode.clone_from(&self.cur_item);
93 self.cur_mode_freq = self.cur_item_freq;
94 }
95 }
96}
97
98#[async_trait::async_trait]
99impl AggregateFunction for Mode {
100 fn return_type(&self) -> DataType {
101 self.return_type.clone()
102 }
103
104 fn create_state(&self) -> Result<AggregateState> {
105 Ok(AggregateState::Any(Box::<State>::default()))
106 }
107
108 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
109 let state = state.downcast_mut::<State>();
110 for (_, row) in input.rows() {
111 state.add_datum(row.datum_at(0));
112 }
113 Ok(())
114 }
115
116 async fn update_range(
117 &self,
118 state: &mut AggregateState,
119 input: &StreamChunk,
120 range: Range<usize>,
121 ) -> Result<()> {
122 let state = state.downcast_mut::<State>();
123 for (_, row) in input.rows_in(range) {
124 state.add_datum(row.datum_at(0));
125 }
126 Ok(())
127 }
128
129 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
130 let state = state.downcast_ref::<State>();
131 Ok(state.cur_mode.clone())
132 }
133}