risingwave_batch_executors/executor/aggregation/
distinct.rs1use std::collections::HashSet;
16use std::ops::Range;
17
18use risingwave_common::array::StreamChunk;
19use risingwave_common::bitmap::BitmapBuilder;
20use risingwave_common::row::{OwnedRow, Row};
21use risingwave_common::types::{DataType, Datum};
22use risingwave_common_estimate_size::EstimateSize;
23use risingwave_expr::Result;
24use risingwave_expr::aggregate::{
25 AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFunction,
26};
27
28pub struct Distinct {
30 inner: BoxedAggregateFunction,
31}
32
33#[derive(Debug)]
35struct State {
36 inner: AggregateState,
38 exists: HashSet<OwnedRow>, exists_estimated_heap_size: usize,
41}
42
43impl EstimateSize for State {
44 fn estimated_heap_size(&self) -> usize {
45 self.inner.estimated_size()
46 + self.exists.capacity() * std::mem::size_of::<OwnedRow>()
47 + self.exists_estimated_heap_size
48 }
49}
50
51impl AggStateDyn for State {}
52
53impl Distinct {
54 pub fn new(inner: BoxedAggregateFunction) -> Self {
55 Self { inner }
56 }
57}
58
59#[async_trait::async_trait]
60impl AggregateFunction for Distinct {
61 fn return_type(&self) -> DataType {
62 self.inner.return_type()
63 }
64
65 fn create_state(&self) -> Result<AggregateState> {
66 Ok(AggregateState::Any(Box::new(State {
67 inner: self.inner.create_state()?,
68 exists: HashSet::new(),
69 exists_estimated_heap_size: 0,
70 })))
71 }
72
73 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
74 self.update_range(state, input, 0..input.capacity()).await
75 }
76
77 async fn update_range(
78 &self,
79 state: &mut AggregateState,
80 input: &StreamChunk,
81 range: Range<usize>,
82 ) -> Result<()> {
83 let state = state.downcast_mut::<State>();
84
85 let mut bitmap_builder = BitmapBuilder::with_capacity(input.capacity());
86 bitmap_builder.append_bitmap(input.data_chunk().visibility());
87 for row_id in range.clone() {
88 let (row_ref, vis) = input.data_chunk().row_at(row_id);
89 let row = row_ref.to_owned_row();
90 let row_size = row.estimated_heap_size();
91 let b = vis && state.exists.insert(row);
92 if b {
93 state.exists_estimated_heap_size += row_size;
94 }
95 bitmap_builder.set(row_id, b);
96 }
97 let input = input.clone_with_vis(bitmap_builder.finish());
98 self.inner
99 .update_range(&mut state.inner, &input, range)
100 .await
101 }
102
103 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
104 let state = state.downcast_ref::<State>();
105 self.inner.get_result(&state.inner).await
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use futures_util::FutureExt;
112 use risingwave_common::array::StreamChunk;
113 use risingwave_common::test_prelude::StreamChunkTestExt;
114 use risingwave_common::types::{Datum, Decimal};
115 use risingwave_expr::aggregate::AggCall;
116
117 use super::super::build;
118
119 #[test]
120 fn distinct_sum_int32() {
121 let input = StreamChunk::from_pretty(
122 " i
123 + 1
124 + 1
125 + 3",
126 );
127 test_agg("(sum:int8 $0:int4 distinct)", input, Some(4i64.into()));
128 }
129
130 #[test]
131 fn distinct_sum_int64() {
132 let input = StreamChunk::from_pretty(
133 " I
134 + 1
135 + 1
136 + 3",
137 );
138 test_agg(
139 "(sum:decimal $0:int8 distinct)",
140 input,
141 Some(Decimal::from(4).into()),
142 );
143 }
144
145 #[test]
146 fn distinct_min_float32() {
147 let input = StreamChunk::from_pretty(
148 " f
149 + 1.0
150 + 2.0
151 + 3.0",
152 );
153 test_agg(
154 "(min:float4 $0:float4 distinct)",
155 input,
156 Some(1.0f32.into()),
157 );
158 }
159
160 #[test]
161 fn distinct_min_char() {
162 let input = StreamChunk::from_pretty(
163 " T
164 + b
165 + aa",
166 );
167 test_agg(
168 "(min:varchar $0:varchar distinct)",
169 input,
170 Some("aa".into()),
171 );
172 }
173
174 #[test]
175 fn distinct_max_char() {
176 let input = StreamChunk::from_pretty(
177 " T
178 + b
179 + aa",
180 );
181 test_agg("(max:varchar $0:varchar distinct)", input, Some("b".into()));
182 }
183
184 #[test]
185 fn distinct_count_int32() {
186 let input = StreamChunk::from_pretty(
187 " i
188 + 1
189 + 1
190 + 3",
191 );
192 test_agg("(count:int8 $0:int4 distinct)", input, Some(2i64.into()));
193
194 let input = StreamChunk::from_pretty("i");
195 test_agg("(count:int8 $0:int4 distinct)", input, Some(0i64.into()));
196
197 let input = StreamChunk::from_pretty(
198 " i
199 + .",
200 );
201 test_agg("(count:int8 $0:int4 distinct)", input, Some(0i64.into()));
202 }
203
204 fn test_agg(pretty: &str, input: StreamChunk, expected: Datum) {
205 let agg = build(&AggCall::from_pretty(pretty)).unwrap();
206 let mut state = agg.create_state().unwrap();
207 agg.update(&mut state, &input)
208 .now_or_never()
209 .unwrap()
210 .unwrap();
211 let actual = agg.get_result(&state).now_or_never().unwrap().unwrap();
212 assert_eq!(actual, expected);
213 }
214}