risingwave_batch_executors/executor/aggregation/
distinct.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::ops::Range;
17
18use risingwave_common::array::StreamChunk;
19use risingwave_common::bitmap::BitmapBuilder;
20use risingwave_common::row::{OwnedRow, Row};
21use risingwave_common::types::{DataType, Datum};
22use risingwave_common_estimate_size::EstimateSize;
23use risingwave_expr::Result;
24use risingwave_expr::aggregate::{
25    AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFunction,
26};
27
28/// `Distinct` is a wrapper of `Aggregator` that only keeps distinct rows.
29pub struct Distinct {
30    inner: BoxedAggregateFunction,
31}
32
33/// The intermediate state for distinct aggregation.
34#[derive(Debug)]
35struct State {
36    /// Inner aggregate function state.
37    inner: AggregateState,
38    /// The set of distinct rows.
39    exists: HashSet<OwnedRow>, // TODO: optimize for small rows
40    exists_estimated_heap_size: usize,
41}
42
43impl EstimateSize for State {
44    fn estimated_heap_size(&self) -> usize {
45        self.inner.estimated_size()
46            + self.exists.capacity() * std::mem::size_of::<OwnedRow>()
47            + self.exists_estimated_heap_size
48    }
49}
50
51impl AggStateDyn for State {}
52
53impl Distinct {
54    pub fn new(inner: BoxedAggregateFunction) -> Self {
55        Self { inner }
56    }
57}
58
59#[async_trait::async_trait]
60impl AggregateFunction for Distinct {
61    fn return_type(&self) -> DataType {
62        self.inner.return_type()
63    }
64
65    fn create_state(&self) -> Result<AggregateState> {
66        Ok(AggregateState::Any(Box::new(State {
67            inner: self.inner.create_state()?,
68            exists: HashSet::new(),
69            exists_estimated_heap_size: 0,
70        })))
71    }
72
73    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
74        self.update_range(state, input, 0..input.capacity()).await
75    }
76
77    async fn update_range(
78        &self,
79        state: &mut AggregateState,
80        input: &StreamChunk,
81        range: Range<usize>,
82    ) -> Result<()> {
83        let state = state.downcast_mut::<State>();
84
85        let mut bitmap_builder = BitmapBuilder::with_capacity(input.capacity());
86        bitmap_builder.append_bitmap(input.data_chunk().visibility());
87        for row_id in range.clone() {
88            let (row_ref, vis) = input.data_chunk().row_at(row_id);
89            let row = row_ref.to_owned_row();
90            let row_size = row.estimated_heap_size();
91            let b = vis && state.exists.insert(row);
92            if b {
93                state.exists_estimated_heap_size += row_size;
94            }
95            bitmap_builder.set(row_id, b);
96        }
97        let input = input.clone_with_vis(bitmap_builder.finish());
98        self.inner
99            .update_range(&mut state.inner, &input, range)
100            .await
101    }
102
103    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
104        let state = state.downcast_ref::<State>();
105        self.inner.get_result(&state.inner).await
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use futures_util::FutureExt;
112    use risingwave_common::array::StreamChunk;
113    use risingwave_common::test_prelude::StreamChunkTestExt;
114    use risingwave_common::types::{Datum, Decimal};
115    use risingwave_expr::aggregate::AggCall;
116
117    use super::super::build;
118
119    #[test]
120    fn distinct_sum_int32() {
121        let input = StreamChunk::from_pretty(
122            " i
123            + 1
124            + 1
125            + 3",
126        );
127        test_agg("(sum:int8 $0:int4 distinct)", input, Some(4i64.into()));
128    }
129
130    #[test]
131    fn distinct_sum_int64() {
132        let input = StreamChunk::from_pretty(
133            " I
134            + 1
135            + 1
136            + 3",
137        );
138        test_agg(
139            "(sum:decimal $0:int8 distinct)",
140            input,
141            Some(Decimal::from(4).into()),
142        );
143    }
144
145    #[test]
146    fn distinct_min_float32() {
147        let input = StreamChunk::from_pretty(
148            " f
149            + 1.0
150            + 2.0
151            + 3.0",
152        );
153        test_agg(
154            "(min:float4 $0:float4 distinct)",
155            input,
156            Some(1.0f32.into()),
157        );
158    }
159
160    #[test]
161    fn distinct_min_char() {
162        let input = StreamChunk::from_pretty(
163            " T
164            + b
165            + aa",
166        );
167        test_agg(
168            "(min:varchar $0:varchar distinct)",
169            input,
170            Some("aa".into()),
171        );
172    }
173
174    #[test]
175    fn distinct_max_char() {
176        let input = StreamChunk::from_pretty(
177            " T
178            + b
179            + aa",
180        );
181        test_agg("(max:varchar $0:varchar distinct)", input, Some("b".into()));
182    }
183
184    #[test]
185    fn distinct_count_int32() {
186        let input = StreamChunk::from_pretty(
187            " i
188            + 1
189            + 1
190            + 3",
191        );
192        test_agg("(count:int8 $0:int4 distinct)", input, Some(2i64.into()));
193
194        let input = StreamChunk::from_pretty("i");
195        test_agg("(count:int8 $0:int4 distinct)", input, Some(0i64.into()));
196
197        let input = StreamChunk::from_pretty(
198            " i
199            + .",
200        );
201        test_agg("(count:int8 $0:int4 distinct)", input, Some(0i64.into()));
202    }
203
204    fn test_agg(pretty: &str, input: StreamChunk, expected: Datum) {
205        let agg = build(&AggCall::from_pretty(pretty)).unwrap();
206        let mut state = agg.create_state().unwrap();
207        agg.update(&mut state, &input)
208            .now_or_never()
209            .unwrap()
210            .unwrap();
211        let actual = agg.get_result(&state).now_or_never().unwrap().unwrap();
212        assert_eq!(actual, expected);
213    }
214}