risingwave_expr_impl/window_function/
rank.rs1use std::marker::PhantomData;
16
17use risingwave_common::types::Datum;
18use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
19use risingwave_common_estimate_size::EstimateSize;
20use risingwave_common_estimate_size::collections::EstimatedVecDeque;
21use risingwave_expr::Result;
22use risingwave_expr::window_function::{
23 StateEvictHint, StateKey, StatePos, WindowFuncCall, WindowState,
24};
25use smallvec::SmallVec;
26
27use self::private::RankFuncCount;
28
29mod private {
30 use super::*;
31
32 pub trait RankFuncCount: Default + EstimateSize {
33 fn count(&mut self, curr_key: StateKey) -> i64;
34 }
35}
36
37#[derive(Default, EstimateSize)]
38pub(super) struct RowNumber {
39 prev_rank: i64,
40}
41
42impl RankFuncCount for RowNumber {
43 fn count(&mut self, _curr_key: StateKey) -> i64 {
44 let curr_rank = self.prev_rank + 1;
45 self.prev_rank = curr_rank;
46 curr_rank
47 }
48}
49
50#[derive(EstimateSize)]
51pub(super) struct Rank {
52 prev_order_key: Option<MemcmpEncoded>,
53 prev_rank: i64,
54 prev_pos_in_peer_group: i64,
55}
56
57impl Default for Rank {
58 fn default() -> Self {
59 Self {
60 prev_order_key: None,
61 prev_rank: 0,
62 prev_pos_in_peer_group: 1, }
64 }
65}
66
67impl RankFuncCount for Rank {
68 fn count(&mut self, curr_key: StateKey) -> i64 {
69 let (curr_rank, curr_pos_in_group) = if let Some(prev_order_key) =
70 self.prev_order_key.as_ref()
71 && prev_order_key == &curr_key.order_key
72 {
73 (self.prev_rank, self.prev_pos_in_peer_group + 1)
75 } else {
76 (self.prev_rank + self.prev_pos_in_peer_group, 1)
78 };
79 self.prev_order_key = Some(curr_key.order_key);
80 self.prev_rank = curr_rank;
81 self.prev_pos_in_peer_group = curr_pos_in_group;
82 curr_rank
83 }
84}
85
86#[derive(Default, EstimateSize)]
87pub(super) struct DenseRank {
88 prev_order_key: Option<MemcmpEncoded>,
89 prev_rank: i64,
90}
91
92impl RankFuncCount for DenseRank {
93 fn count(&mut self, curr_key: StateKey) -> i64 {
94 let curr_rank = if let Some(prev_order_key) = self.prev_order_key.as_ref()
95 && prev_order_key == &curr_key.order_key
96 {
97 self.prev_rank
99 } else {
100 self.prev_rank + 1
102 };
103 self.prev_order_key = Some(curr_key.order_key);
104 self.prev_rank = curr_rank;
105 curr_rank
106 }
107}
108
109#[derive(EstimateSize)]
111pub(super) struct RankState<RF: RankFuncCount> {
112 first_key: Option<StateKey>,
114 buffer: EstimatedVecDeque<StateKey>,
116 func_state: RF,
118 _phantom: PhantomData<RF>,
119}
120
121impl<RF: RankFuncCount> RankState<RF> {
122 pub fn new(_call: &WindowFuncCall) -> Self {
123 Self {
124 first_key: None,
125 buffer: Default::default(),
126 func_state: Default::default(),
127 _phantom: PhantomData,
128 }
129 }
130
131 fn slide_inner(&mut self) -> (i64, StateEvictHint) {
132 let curr_key = self
133 .buffer
134 .pop_front()
135 .expect("should not slide forward when the current window is not ready");
136 let rank = self.func_state.count(curr_key);
137 let evict_hint = StateEvictHint::CannotEvict(
139 self.first_key
140 .clone()
141 .expect("should have appended some rows"),
142 );
143 (rank, evict_hint)
144 }
145}
146
147impl<RF: RankFuncCount> WindowState for RankState<RF> {
148 fn append(&mut self, key: StateKey, _args: SmallVec<[Datum; 2]>) {
149 if self.first_key.is_none() {
150 self.first_key = Some(key.clone());
151 }
152 self.buffer.push_back(key);
153 }
154
155 fn curr_window(&self) -> StatePos<'_> {
156 let curr_key = self.buffer.front();
157 StatePos {
158 key: curr_key,
159 is_ready: curr_key.is_some(),
160 }
161 }
162
163 fn slide(&mut self) -> Result<(Datum, StateEvictHint)> {
164 let (rank, evict_hint) = self.slide_inner();
165 Ok((Some(rank.into()), evict_hint))
166 }
167
168 fn slide_no_output(&mut self) -> Result<StateEvictHint> {
169 let (_rank, evict_hint) = self.slide_inner();
170 Ok(evict_hint)
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use risingwave_common::row::OwnedRow;
177 use risingwave_common::types::{DataType, ScalarImpl};
178 use risingwave_common::util::memcmp_encoding;
179 use risingwave_common::util::sort_util::OrderType;
180 use risingwave_expr::aggregate::AggArgs;
181 use risingwave_expr::window_function::{Frame, FrameBound, WindowFuncKind};
182
183 use super::*;
184
185 fn create_state_key(order: i64, pk: i64) -> StateKey {
186 StateKey {
187 order_key: memcmp_encoding::encode_value(
188 Some(ScalarImpl::from(order)),
189 OrderType::ascending(),
190 )
191 .unwrap(),
192 pk: OwnedRow::new(vec![Some(pk.into())]).into(),
193 }
194 }
195
196 #[test]
197 #[should_panic(expected = "should not slide forward when the current window is not ready")]
198 fn test_rank_state_bad_use() {
199 let call = WindowFuncCall {
200 kind: WindowFuncKind::RowNumber,
201 return_type: DataType::Int64,
202 args: AggArgs::default(),
203 ignore_nulls: false,
204 frame: Frame::rows(
205 FrameBound::UnboundedPreceding,
206 FrameBound::UnboundedFollowing,
207 ),
208 };
209 let mut state = RankState::<RowNumber>::new(&call);
210 assert!(state.curr_window().key.is_none());
211 assert!(!state.curr_window().is_ready);
212 _ = state.slide()
213 }
214
215 #[test]
216 fn test_row_number_state() {
217 let call = WindowFuncCall {
218 kind: WindowFuncKind::RowNumber,
219 return_type: DataType::Int64,
220 args: AggArgs::default(),
221 ignore_nulls: false,
222 frame: Frame::rows(
223 FrameBound::UnboundedPreceding,
224 FrameBound::UnboundedFollowing,
225 ),
226 };
227 let mut state = RankState::<RowNumber>::new(&call);
228 assert!(state.curr_window().key.is_none());
229 assert!(!state.curr_window().is_ready);
230 state.append(create_state_key(1, 100), SmallVec::new());
231 assert_eq!(state.curr_window().key.unwrap(), &create_state_key(1, 100));
232 assert!(state.curr_window().is_ready);
233 let (output, evict_hint) = state.slide().unwrap();
234 assert_eq!(output.unwrap(), 1i64.into());
235 match evict_hint {
236 StateEvictHint::CannotEvict(state_key) => {
237 assert_eq!(state_key, create_state_key(1, 100));
238 }
239 _ => unreachable!(),
240 }
241 assert!(!state.curr_window().is_ready);
242 state.append(create_state_key(2, 103), SmallVec::new());
243 state.append(create_state_key(2, 102), SmallVec::new());
244 assert_eq!(state.curr_window().key.unwrap(), &create_state_key(2, 103));
245 let (output, evict_hint) = state.slide().unwrap();
246 assert_eq!(output.unwrap(), 2i64.into());
247 match evict_hint {
248 StateEvictHint::CannotEvict(state_key) => {
249 assert_eq!(state_key, create_state_key(1, 100));
250 }
251 _ => unreachable!(),
252 }
253 assert_eq!(state.curr_window().key.unwrap(), &create_state_key(2, 102));
254 let (output, _) = state.slide().unwrap();
255 assert_eq!(output.unwrap(), 3i64.into());
256 }
257
258 #[test]
259 fn test_rank_state() {
260 let call = WindowFuncCall {
261 kind: WindowFuncKind::Rank,
262 return_type: DataType::Int64,
263 args: AggArgs::default(),
264 ignore_nulls: false,
265 frame: Frame::rows(
266 FrameBound::UnboundedPreceding,
267 FrameBound::UnboundedFollowing,
268 ),
269 };
270 let mut state = RankState::<Rank>::new(&call);
271 assert!(state.curr_window().key.is_none());
272 assert!(!state.curr_window().is_ready);
273 state.append(create_state_key(1, 100), SmallVec::new());
274 state.append(create_state_key(2, 103), SmallVec::new());
275 state.append(create_state_key(2, 102), SmallVec::new());
276 state.append(create_state_key(3, 106), SmallVec::new());
277 state.append(create_state_key(3, 105), SmallVec::new());
278 state.append(create_state_key(3, 104), SmallVec::new());
279 state.append(create_state_key(8, 108), SmallVec::new());
280
281 let mut outputs = vec![];
282 while state.curr_window().is_ready {
283 outputs.push(state.slide().unwrap().0)
284 }
285
286 assert_eq!(
287 outputs,
288 vec![
289 Some(1i64.into()),
290 Some(2i64.into()),
291 Some(2i64.into()),
292 Some(4i64.into()),
293 Some(4i64.into()),
294 Some(4i64.into()),
295 Some(7i64.into())
296 ]
297 );
298 }
299
300 #[test]
301 fn test_dense_rank_state() {
302 let call = WindowFuncCall {
303 kind: WindowFuncKind::DenseRank,
304 return_type: DataType::Int64,
305 args: AggArgs::default(),
306 ignore_nulls: false,
307 frame: Frame::rows(
308 FrameBound::UnboundedPreceding,
309 FrameBound::UnboundedFollowing,
310 ),
311 };
312 let mut state = RankState::<DenseRank>::new(&call);
313 assert!(state.curr_window().key.is_none());
314 assert!(!state.curr_window().is_ready);
315 state.append(create_state_key(1, 100), SmallVec::new());
316 state.append(create_state_key(2, 103), SmallVec::new());
317 state.append(create_state_key(2, 102), SmallVec::new());
318 state.append(create_state_key(3, 106), SmallVec::new());
319 state.append(create_state_key(3, 105), SmallVec::new());
320 state.append(create_state_key(3, 104), SmallVec::new());
321 state.append(create_state_key(8, 108), SmallVec::new());
322
323 let mut outputs = vec![];
324 while state.curr_window().is_ready {
325 outputs.push(state.slide().unwrap().0)
326 }
327
328 assert_eq!(
329 outputs,
330 vec![
331 Some(1i64.into()),
332 Some(2i64.into()),
333 Some(2i64.into()),
334 Some(3i64.into()),
335 Some(3i64.into()),
336 Some(3i64.into()),
337 Some(4i64.into())
338 ]
339 );
340 }
341}