risingwave_expr_impl/window_function/
rank.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::marker::PhantomData;
16
17use risingwave_common::types::Datum;
18use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
19use risingwave_common_estimate_size::EstimateSize;
20use risingwave_common_estimate_size::collections::EstimatedVecDeque;
21use risingwave_expr::Result;
22use risingwave_expr::window_function::{
23    StateEvictHint, StateKey, StatePos, WindowFuncCall, WindowState,
24};
25use smallvec::SmallVec;
26
27use self::private::RankFuncCount;
28
29mod private {
30    use super::*;
31
32    pub trait RankFuncCount: Default + EstimateSize {
33        fn count(&mut self, curr_key: StateKey) -> i64;
34    }
35}
36
37#[derive(Default, EstimateSize)]
38pub(super) struct RowNumber {
39    prev_rank: i64,
40}
41
42impl RankFuncCount for RowNumber {
43    fn count(&mut self, _curr_key: StateKey) -> i64 {
44        let curr_rank = self.prev_rank + 1;
45        self.prev_rank = curr_rank;
46        curr_rank
47    }
48}
49
50#[derive(EstimateSize)]
51pub(super) struct Rank {
52    prev_order_key: Option<MemcmpEncoded>,
53    prev_rank: i64,
54    prev_pos_in_peer_group: i64,
55}
56
57impl Default for Rank {
58    fn default() -> Self {
59        Self {
60            prev_order_key: None,
61            prev_rank: 0,
62            prev_pos_in_peer_group: 1, // first position in the fake starting peer group
63        }
64    }
65}
66
67impl RankFuncCount for Rank {
68    fn count(&mut self, curr_key: StateKey) -> i64 {
69        let (curr_rank, curr_pos_in_group) = if let Some(prev_order_key) =
70            self.prev_order_key.as_ref()
71            && prev_order_key == &curr_key.order_key
72        {
73            // current key is in the same peer group as the previous one
74            (self.prev_rank, self.prev_pos_in_peer_group + 1)
75        } else {
76            // starting a new peer group
77            (self.prev_rank + self.prev_pos_in_peer_group, 1)
78        };
79        self.prev_order_key = Some(curr_key.order_key);
80        self.prev_rank = curr_rank;
81        self.prev_pos_in_peer_group = curr_pos_in_group;
82        curr_rank
83    }
84}
85
86#[derive(Default, EstimateSize)]
87pub(super) struct DenseRank {
88    prev_order_key: Option<MemcmpEncoded>,
89    prev_rank: i64,
90}
91
92impl RankFuncCount for DenseRank {
93    fn count(&mut self, curr_key: StateKey) -> i64 {
94        let curr_rank = if let Some(prev_order_key) = self.prev_order_key.as_ref()
95            && prev_order_key == &curr_key.order_key
96        {
97            // current key is in the same peer group as the previous one
98            self.prev_rank
99        } else {
100            // starting a new peer group
101            self.prev_rank + 1
102        };
103        self.prev_order_key = Some(curr_key.order_key);
104        self.prev_rank = curr_rank;
105        curr_rank
106    }
107}
108
109/// Generic state for rank window functions including `row_number`, `rank` and `dense_rank`.
110#[derive(EstimateSize)]
111pub(super) struct RankState<RF: RankFuncCount> {
112    /// First state key of the partition.
113    first_key: Option<StateKey>,
114    /// State keys that are waiting to be outputted.
115    buffer: EstimatedVecDeque<StateKey>,
116    /// Function-specific state.
117    func_state: RF,
118    _phantom: PhantomData<RF>,
119}
120
121impl<RF: RankFuncCount> RankState<RF> {
122    pub fn new(_call: &WindowFuncCall) -> Self {
123        Self {
124            first_key: None,
125            buffer: Default::default(),
126            func_state: Default::default(),
127            _phantom: PhantomData,
128        }
129    }
130
131    fn slide_inner(&mut self) -> (i64, StateEvictHint) {
132        let curr_key = self
133            .buffer
134            .pop_front()
135            .expect("should not slide forward when the current window is not ready");
136        let rank = self.func_state.count(curr_key);
137        // can't evict any state key in EOWC mode, because we can't recover from previous output now
138        let evict_hint = StateEvictHint::CannotEvict(
139            self.first_key
140                .clone()
141                .expect("should have appended some rows"),
142        );
143        (rank, evict_hint)
144    }
145}
146
147impl<RF: RankFuncCount> WindowState for RankState<RF> {
148    fn append(&mut self, key: StateKey, _args: SmallVec<[Datum; 2]>) {
149        if self.first_key.is_none() {
150            self.first_key = Some(key.clone());
151        }
152        self.buffer.push_back(key);
153    }
154
155    fn curr_window(&self) -> StatePos<'_> {
156        let curr_key = self.buffer.front();
157        StatePos {
158            key: curr_key,
159            is_ready: curr_key.is_some(),
160        }
161    }
162
163    fn slide(&mut self) -> Result<(Datum, StateEvictHint)> {
164        let (rank, evict_hint) = self.slide_inner();
165        Ok((Some(rank.into()), evict_hint))
166    }
167
168    fn slide_no_output(&mut self) -> Result<StateEvictHint> {
169        let (_rank, evict_hint) = self.slide_inner();
170        Ok(evict_hint)
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use risingwave_common::row::OwnedRow;
177    use risingwave_common::types::{DataType, ScalarImpl};
178    use risingwave_common::util::memcmp_encoding;
179    use risingwave_common::util::sort_util::OrderType;
180    use risingwave_expr::aggregate::AggArgs;
181    use risingwave_expr::window_function::{Frame, FrameBound, WindowFuncKind};
182
183    use super::*;
184
185    fn create_state_key(order: i64, pk: i64) -> StateKey {
186        StateKey {
187            order_key: memcmp_encoding::encode_value(
188                Some(ScalarImpl::from(order)),
189                OrderType::ascending(),
190            )
191            .unwrap(),
192            pk: OwnedRow::new(vec![Some(pk.into())]).into(),
193        }
194    }
195
196    #[test]
197    #[should_panic(expected = "should not slide forward when the current window is not ready")]
198    fn test_rank_state_bad_use() {
199        let call = WindowFuncCall {
200            kind: WindowFuncKind::RowNumber,
201            return_type: DataType::Int64,
202            args: AggArgs::default(),
203            ignore_nulls: false,
204            frame: Frame::rows(
205                FrameBound::UnboundedPreceding,
206                FrameBound::UnboundedFollowing,
207            ),
208        };
209        let mut state = RankState::<RowNumber>::new(&call);
210        assert!(state.curr_window().key.is_none());
211        assert!(!state.curr_window().is_ready);
212        _ = state.slide()
213    }
214
215    #[test]
216    fn test_row_number_state() {
217        let call = WindowFuncCall {
218            kind: WindowFuncKind::RowNumber,
219            return_type: DataType::Int64,
220            args: AggArgs::default(),
221            ignore_nulls: false,
222            frame: Frame::rows(
223                FrameBound::UnboundedPreceding,
224                FrameBound::UnboundedFollowing,
225            ),
226        };
227        let mut state = RankState::<RowNumber>::new(&call);
228        assert!(state.curr_window().key.is_none());
229        assert!(!state.curr_window().is_ready);
230        state.append(create_state_key(1, 100), SmallVec::new());
231        assert_eq!(state.curr_window().key.unwrap(), &create_state_key(1, 100));
232        assert!(state.curr_window().is_ready);
233        let (output, evict_hint) = state.slide().unwrap();
234        assert_eq!(output.unwrap(), 1i64.into());
235        match evict_hint {
236            StateEvictHint::CannotEvict(state_key) => {
237                assert_eq!(state_key, create_state_key(1, 100));
238            }
239            _ => unreachable!(),
240        }
241        assert!(!state.curr_window().is_ready);
242        state.append(create_state_key(2, 103), SmallVec::new());
243        state.append(create_state_key(2, 102), SmallVec::new());
244        assert_eq!(state.curr_window().key.unwrap(), &create_state_key(2, 103));
245        let (output, evict_hint) = state.slide().unwrap();
246        assert_eq!(output.unwrap(), 2i64.into());
247        match evict_hint {
248            StateEvictHint::CannotEvict(state_key) => {
249                assert_eq!(state_key, create_state_key(1, 100));
250            }
251            _ => unreachable!(),
252        }
253        assert_eq!(state.curr_window().key.unwrap(), &create_state_key(2, 102));
254        let (output, _) = state.slide().unwrap();
255        assert_eq!(output.unwrap(), 3i64.into());
256    }
257
258    #[test]
259    fn test_rank_state() {
260        let call = WindowFuncCall {
261            kind: WindowFuncKind::Rank,
262            return_type: DataType::Int64,
263            args: AggArgs::default(),
264            ignore_nulls: false,
265            frame: Frame::rows(
266                FrameBound::UnboundedPreceding,
267                FrameBound::UnboundedFollowing,
268            ),
269        };
270        let mut state = RankState::<Rank>::new(&call);
271        assert!(state.curr_window().key.is_none());
272        assert!(!state.curr_window().is_ready);
273        state.append(create_state_key(1, 100), SmallVec::new());
274        state.append(create_state_key(2, 103), SmallVec::new());
275        state.append(create_state_key(2, 102), SmallVec::new());
276        state.append(create_state_key(3, 106), SmallVec::new());
277        state.append(create_state_key(3, 105), SmallVec::new());
278        state.append(create_state_key(3, 104), SmallVec::new());
279        state.append(create_state_key(8, 108), SmallVec::new());
280
281        let mut outputs = vec![];
282        while state.curr_window().is_ready {
283            outputs.push(state.slide().unwrap().0)
284        }
285
286        assert_eq!(
287            outputs,
288            vec![
289                Some(1i64.into()),
290                Some(2i64.into()),
291                Some(2i64.into()),
292                Some(4i64.into()),
293                Some(4i64.into()),
294                Some(4i64.into()),
295                Some(7i64.into())
296            ]
297        );
298    }
299
300    #[test]
301    fn test_dense_rank_state() {
302        let call = WindowFuncCall {
303            kind: WindowFuncKind::DenseRank,
304            return_type: DataType::Int64,
305            args: AggArgs::default(),
306            ignore_nulls: false,
307            frame: Frame::rows(
308                FrameBound::UnboundedPreceding,
309                FrameBound::UnboundedFollowing,
310            ),
311        };
312        let mut state = RankState::<DenseRank>::new(&call);
313        assert!(state.curr_window().key.is_none());
314        assert!(!state.curr_window().is_ready);
315        state.append(create_state_key(1, 100), SmallVec::new());
316        state.append(create_state_key(2, 103), SmallVec::new());
317        state.append(create_state_key(2, 102), SmallVec::new());
318        state.append(create_state_key(3, 106), SmallVec::new());
319        state.append(create_state_key(3, 105), SmallVec::new());
320        state.append(create_state_key(3, 104), SmallVec::new());
321        state.append(create_state_key(8, 108), SmallVec::new());
322
323        let mut outputs = vec![];
324        while state.curr_window().is_ready {
325            outputs.push(state.slide().unwrap().0)
326        }
327
328        assert_eq!(
329            outputs,
330            vec![
331                Some(1i64.into()),
332                Some(2i64.into()),
333                Some(2i64.into()),
334                Some(3i64.into()),
335                Some(3i64.into()),
336                Some(3i64.into()),
337                Some(4i64.into())
338            ]
339        );
340    }
341}