risingwave_expr_impl/aggregate/
mode.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16
17use risingwave_common::array::*;
18use risingwave_common::row::Row;
19use risingwave_common::types::*;
20use risingwave_common_estimate_size::EstimateSize;
21use risingwave_expr::aggregate::{
22    AggCall, AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFunction,
23};
24use risingwave_expr::{Result, build_aggregate};
25
26#[build_aggregate("mode(any) -> any")]
27fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
28    Ok(Box::new(Mode {
29        return_type: agg.return_type.clone(),
30    }))
31}
32
33/// Computes the mode, the most frequent value of the aggregated argument (arbitrarily choosing the
34/// first one if there are multiple equally-frequent values). The aggregated argument must be of a
35/// sortable type.
36///
37/// ```slt
38/// query I
39/// select mode() within group (order by unnest) from unnest(array[1]);
40/// ----
41/// 1
42///
43/// query I
44/// select mode() within group (order by unnest) from unnest(array[1,2,2,3,3,4,4,4]);
45/// ----
46/// 4
47///
48/// query R
49/// select mode() within group (order by unnest) from unnest(array[0.1,0.2,0.2,0.4,0.4,0.3,0.3,0.4]);
50/// ----
51/// 0.4
52///
53/// query R
54/// select mode() within group (order by unnest) from unnest(array[1,2,2,3,3,4,4,4,3]);
55/// ----
56/// 3
57///
58/// query T
59/// select mode() within group (order by unnest) from unnest(array['1','2','2','3','3','4','4','4','3']);
60/// ----
61/// 3
62///
63/// query I
64/// select mode() within group (order by unnest) from unnest(array[]::int[]);
65/// ----
66/// NULL
67/// ```
68struct Mode {
69    return_type: DataType,
70}
71
72#[derive(Debug, Clone, EstimateSize, Default)]
73struct State {
74    cur_mode: Datum,
75    cur_mode_freq: usize,
76    cur_item: Datum,
77    cur_item_freq: usize,
78}
79
80impl AggStateDyn for State {}
81
82impl State {
83    fn add_datum(&mut self, datum_ref: DatumRef<'_>) {
84        let datum = datum_ref.to_owned_datum();
85        if datum.is_some() && self.cur_item == datum {
86            self.cur_item_freq += 1;
87        } else if datum.is_some() {
88            self.cur_item = datum;
89            self.cur_item_freq = 1;
90        }
91        if self.cur_item_freq > self.cur_mode_freq {
92            self.cur_mode.clone_from(&self.cur_item);
93            self.cur_mode_freq = self.cur_item_freq;
94        }
95    }
96}
97
98#[async_trait::async_trait]
99impl AggregateFunction for Mode {
100    fn return_type(&self) -> DataType {
101        self.return_type.clone()
102    }
103
104    fn create_state(&self) -> Result<AggregateState> {
105        Ok(AggregateState::Any(Box::<State>::default()))
106    }
107
108    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
109        let state = state.downcast_mut::<State>();
110        for (_, row) in input.rows() {
111            state.add_datum(row.datum_at(0));
112        }
113        Ok(())
114    }
115
116    async fn update_range(
117        &self,
118        state: &mut AggregateState,
119        input: &StreamChunk,
120        range: Range<usize>,
121    ) -> Result<()> {
122        let state = state.downcast_mut::<State>();
123        for (_, row) in input.rows_in(range) {
124            state.add_datum(row.datum_at(0));
125        }
126        Ok(())
127    }
128
129    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
130        let state = state.downcast_ref::<State>();
131        Ok(state.cur_mode.clone())
132    }
133}