risingwave_expr_impl/window_function/
rank.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeSet;
16use std::marker::PhantomData;
17
18use risingwave_common::types::Datum;
19use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
20use risingwave_common_estimate_size::EstimateSize;
21use risingwave_common_estimate_size::collections::EstimatedVecDeque;
22use risingwave_expr::window_function::{
23    StateEvictHint, StateKey, StatePos, WindowFuncCall, WindowState, WindowStateSnapshot,
24};
25use risingwave_expr::{ExprError, Result};
26use risingwave_pb::window_function::window_state_snapshot::FunctionState;
27use risingwave_pb::window_function::{DenseRankState, RankState as RankStateProto, RowNumberState};
28use smallvec::SmallVec;
29
30use self::private::RankFuncCount;
31
32mod private {
33    use super::*;
34
35    pub trait RankFuncCount: Default + EstimateSize {
36        /// Count and return the rank for the given key, updating internal state.
37        fn count(&mut self, curr_key: StateKey) -> i64;
38
39        /// Convert the function state to the proto oneof variant for persistence.
40        fn to_proto_state(&self) -> FunctionState;
41
42        /// Restore the function state from the proto oneof variant.
43        /// Returns an error if the variant does not match this function type.
44        fn from_proto_state(state: FunctionState) -> Result<Self>
45        where
46            Self: Sized;
47    }
48}
49
50#[derive(Default, EstimateSize)]
51pub(super) struct RowNumber {
52    prev_rank: i64,
53}
54
55impl RankFuncCount for RowNumber {
56    fn count(&mut self, _curr_key: StateKey) -> i64 {
57        let curr_rank = self.prev_rank + 1;
58        self.prev_rank = curr_rank;
59        curr_rank
60    }
61
62    fn to_proto_state(&self) -> FunctionState {
63        FunctionState::RowNumberState(RowNumberState {
64            prev_rank: self.prev_rank,
65        })
66    }
67
68    fn from_proto_state(state: FunctionState) -> Result<Self> {
69        match state {
70            FunctionState::RowNumberState(s) => Ok(Self {
71                prev_rank: s.prev_rank,
72            }),
73            other => Err(ExprError::Internal(anyhow::anyhow!(
74                "expected RowNumberState, got {other:?}"
75            ))),
76        }
77    }
78}
79
80#[derive(EstimateSize)]
81pub(super) struct Rank {
82    prev_order_key: Option<MemcmpEncoded>,
83    prev_rank: i64,
84    // 1-based position of the previously output row within its peer group.
85    // Used to advance the rank by the peer-group size when a new group starts.
86    prev_pos_in_peer_group: i64,
87}
88
89impl Default for Rank {
90    fn default() -> Self {
91        Self {
92            prev_order_key: None,
93            prev_rank: 0,
94            prev_pos_in_peer_group: 1, // first position in the fake starting peer group
95        }
96    }
97}
98
99impl RankFuncCount for Rank {
100    fn count(&mut self, curr_key: StateKey) -> i64 {
101        let (curr_rank, curr_pos_in_group) = if let Some(prev_order_key) =
102            self.prev_order_key.as_ref()
103            && prev_order_key == &curr_key.order_key
104        {
105            // current key is in the same peer group as the previous one
106            (self.prev_rank, self.prev_pos_in_peer_group + 1)
107        } else {
108            // starting a new peer group
109            (self.prev_rank + self.prev_pos_in_peer_group, 1)
110        };
111        self.prev_order_key = Some(curr_key.order_key);
112        self.prev_rank = curr_rank;
113        self.prev_pos_in_peer_group = curr_pos_in_group;
114        curr_rank
115    }
116
117    fn to_proto_state(&self) -> FunctionState {
118        FunctionState::RankState(RankStateProto {
119            prev_order_key: self.prev_order_key.as_ref().map(|k| k.to_vec()),
120            prev_rank: self.prev_rank,
121            prev_pos_in_peer_group: self.prev_pos_in_peer_group,
122        })
123    }
124
125    fn from_proto_state(state: FunctionState) -> Result<Self> {
126        match state {
127            FunctionState::RankState(s) => Ok(Self {
128                prev_order_key: s.prev_order_key.map(Into::into),
129                prev_rank: s.prev_rank,
130                prev_pos_in_peer_group: s.prev_pos_in_peer_group,
131            }),
132            other => Err(ExprError::Internal(anyhow::anyhow!(
133                "expected RankState, got {other:?}"
134            ))),
135        }
136    }
137}
138
139#[derive(Default, EstimateSize)]
140pub(super) struct DenseRank {
141    prev_order_key: Option<MemcmpEncoded>,
142    prev_rank: i64,
143}
144
145impl RankFuncCount for DenseRank {
146    fn count(&mut self, curr_key: StateKey) -> i64 {
147        let curr_rank = if let Some(prev_order_key) = self.prev_order_key.as_ref()
148            && prev_order_key == &curr_key.order_key
149        {
150            // current key is in the same peer group as the previous one
151            self.prev_rank
152        } else {
153            // starting a new peer group
154            self.prev_rank + 1
155        };
156        self.prev_order_key = Some(curr_key.order_key);
157        self.prev_rank = curr_rank;
158        curr_rank
159    }
160
161    fn to_proto_state(&self) -> FunctionState {
162        FunctionState::DenseRankState(DenseRankState {
163            prev_order_key: self.prev_order_key.as_ref().map(|k| k.to_vec()),
164            prev_rank: self.prev_rank,
165        })
166    }
167
168    fn from_proto_state(state: FunctionState) -> Result<Self> {
169        match state {
170            FunctionState::DenseRankState(s) => Ok(Self {
171                prev_order_key: s.prev_order_key.map(Into::into),
172                prev_rank: s.prev_rank,
173            }),
174            other => Err(ExprError::Internal(anyhow::anyhow!(
175                "expected DenseRankState, got {other:?}"
176            ))),
177        }
178    }
179}
180
181/// Generic state for rank window functions including `row_number`, `rank` and `dense_rank`.
182#[derive(EstimateSize)]
183pub(super) struct RankState<RF: RankFuncCount> {
184    /// First state key of the partition.
185    first_key: Option<StateKey>,
186    /// State keys that are waiting to be outputted.
187    buffer: EstimatedVecDeque<StateKey>,
188    /// Function-specific state.
189    func_state: RF,
190    /// Whether persistence is enabled for this state.
191    persistence_enabled: bool,
192    /// The key of the last output row (for snapshot).
193    last_output_key: Option<StateKey>,
194    /// During recovery, skip updating `func_state` for rows up to and including this key.
195    recover_skip_until: Option<StateKey>,
196    _phantom: PhantomData<RF>,
197}
198
199impl<RF: RankFuncCount> RankState<RF> {
200    pub fn new(_call: &WindowFuncCall) -> Self {
201        Self {
202            first_key: None,
203            buffer: Default::default(),
204            func_state: Default::default(),
205            persistence_enabled: false,
206            last_output_key: None,
207            recover_skip_until: None,
208            _phantom: PhantomData,
209        }
210    }
211
212    fn slide_inner(&mut self) -> (i64, StateEvictHint) {
213        let curr_key = self
214            .buffer
215            .pop_front()
216            .expect("should not slide forward when the current window is not ready");
217        let rank = self.func_state.count(curr_key.clone());
218
219        // Track last output key for persistence
220        self.last_output_key = Some(curr_key.clone());
221
222        let evict_hint = if self.persistence_enabled {
223            // When persistence is enabled, we can evict the just-output key
224            let mut evict_set = BTreeSet::new();
225            evict_set.insert(curr_key);
226            StateEvictHint::CanEvict(evict_set)
227        } else {
228            // Can't evict any state key in EOWC mode without persistence,
229            // because we can't recover from previous output
230            StateEvictHint::CannotEvict(
231                self.first_key
232                    .clone()
233                    .expect("should have appended some rows"),
234            )
235        };
236        (rank, evict_hint)
237    }
238
239    fn slide_no_output_inner(&mut self) -> StateEvictHint {
240        let curr_key = self
241            .buffer
242            .pop_front()
243            .expect("should not slide forward when the current window is not ready");
244
245        // Check if we should skip counting (during recovery)
246        let should_skip = self
247            .recover_skip_until
248            .as_ref()
249            .is_some_and(|skip_key| &curr_key <= skip_key);
250
251        if !should_skip {
252            // Normal counting
253            self.func_state.count(curr_key.clone());
254        }
255
256        // Clear recover_skip_until when we've processed the target key
257        if self
258            .recover_skip_until
259            .as_ref()
260            .is_some_and(|skip_key| &curr_key == skip_key)
261        {
262            self.recover_skip_until = None;
263        }
264
265        // Track last output key for persistence
266        self.last_output_key = Some(curr_key.clone());
267
268        if self.persistence_enabled {
269            let mut evict_set = BTreeSet::new();
270            evict_set.insert(curr_key);
271            StateEvictHint::CanEvict(evict_set)
272        } else {
273            StateEvictHint::CannotEvict(
274                self.first_key
275                    .clone()
276                    .expect("should have appended some rows"),
277            )
278        }
279    }
280}
281
282impl<RF: RankFuncCount> WindowState for RankState<RF> {
283    fn append(&mut self, key: StateKey, _args: SmallVec<[Datum; 2]>) {
284        if self.first_key.is_none() {
285            self.first_key = Some(key.clone());
286        }
287        self.buffer.push_back(key);
288    }
289
290    fn curr_window(&self) -> StatePos<'_> {
291        let curr_key = self.buffer.front();
292        StatePos {
293            key: curr_key,
294            is_ready: curr_key.is_some(),
295        }
296    }
297
298    fn slide(&mut self) -> Result<(Datum, StateEvictHint)> {
299        let (rank, evict_hint) = self.slide_inner();
300        Ok((Some(rank.into()), evict_hint))
301    }
302
303    fn slide_no_output(&mut self) -> Result<StateEvictHint> {
304        let evict_hint = self.slide_no_output_inner();
305        Ok(evict_hint)
306    }
307
308    fn enable_persistence(&mut self) {
309        self.persistence_enabled = true;
310    }
311
312    fn snapshot(&self) -> Option<WindowStateSnapshot> {
313        if !self.persistence_enabled {
314            return None;
315        }
316        Some(WindowStateSnapshot {
317            last_output_key: self.last_output_key.clone(),
318            function_state: self.func_state.to_proto_state(),
319        })
320    }
321
322    fn restore(&mut self, snapshot: WindowStateSnapshot) -> Result<()> {
323        self.func_state = RF::from_proto_state(snapshot.function_state)?;
324        // Set recover_skip_until so that slide_no_output skips counting for rows
325        // that have already been counted before the snapshot was taken.
326        self.recover_skip_until = snapshot.last_output_key;
327        Ok(())
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use risingwave_common::row::OwnedRow;
334    use risingwave_common::types::{DataType, ScalarImpl};
335    use risingwave_common::util::memcmp_encoding;
336    use risingwave_common::util::sort_util::OrderType;
337    use risingwave_expr::aggregate::AggArgs;
338    use risingwave_expr::window_function::{Frame, FrameBound, WindowFuncKind};
339
340    use super::*;
341
342    fn create_state_key(order: i64, pk: i64) -> StateKey {
343        StateKey {
344            order_key: memcmp_encoding::encode_value(
345                Some(ScalarImpl::from(order)),
346                OrderType::ascending(),
347            )
348            .unwrap(),
349            pk: OwnedRow::new(vec![Some(pk.into())]).into(),
350        }
351    }
352
353    #[test]
354    #[should_panic(expected = "should not slide forward when the current window is not ready")]
355    fn test_rank_state_bad_use() {
356        let call = WindowFuncCall {
357            kind: WindowFuncKind::RowNumber,
358            return_type: DataType::Int64,
359            args: AggArgs::default(),
360            ignore_nulls: false,
361            frame: Frame::rows(
362                FrameBound::UnboundedPreceding,
363                FrameBound::UnboundedFollowing,
364            ),
365        };
366        let mut state = RankState::<RowNumber>::new(&call);
367        assert!(state.curr_window().key.is_none());
368        assert!(!state.curr_window().is_ready);
369        _ = state.slide()
370    }
371
372    #[test]
373    fn test_row_number_state() {
374        let call = WindowFuncCall {
375            kind: WindowFuncKind::RowNumber,
376            return_type: DataType::Int64,
377            args: AggArgs::default(),
378            ignore_nulls: false,
379            frame: Frame::rows(
380                FrameBound::UnboundedPreceding,
381                FrameBound::UnboundedFollowing,
382            ),
383        };
384        let mut state = RankState::<RowNumber>::new(&call);
385        assert!(state.curr_window().key.is_none());
386        assert!(!state.curr_window().is_ready);
387        state.append(create_state_key(1, 100), SmallVec::new());
388        assert_eq!(state.curr_window().key.unwrap(), &create_state_key(1, 100));
389        assert!(state.curr_window().is_ready);
390        let (output, evict_hint) = state.slide().unwrap();
391        assert_eq!(output.unwrap(), 1i64.into());
392        match evict_hint {
393            StateEvictHint::CannotEvict(state_key) => {
394                assert_eq!(state_key, create_state_key(1, 100));
395            }
396            _ => unreachable!(),
397        }
398        assert!(!state.curr_window().is_ready);
399        state.append(create_state_key(2, 103), SmallVec::new());
400        state.append(create_state_key(2, 102), SmallVec::new());
401        assert_eq!(state.curr_window().key.unwrap(), &create_state_key(2, 103));
402        let (output, evict_hint) = state.slide().unwrap();
403        assert_eq!(output.unwrap(), 2i64.into());
404        match evict_hint {
405            StateEvictHint::CannotEvict(state_key) => {
406                assert_eq!(state_key, create_state_key(1, 100));
407            }
408            _ => unreachable!(),
409        }
410        assert_eq!(state.curr_window().key.unwrap(), &create_state_key(2, 102));
411        let (output, _) = state.slide().unwrap();
412        assert_eq!(output.unwrap(), 3i64.into());
413    }
414
415    #[test]
416    fn test_rank_state() {
417        let call = WindowFuncCall {
418            kind: WindowFuncKind::Rank,
419            return_type: DataType::Int64,
420            args: AggArgs::default(),
421            ignore_nulls: false,
422            frame: Frame::rows(
423                FrameBound::UnboundedPreceding,
424                FrameBound::UnboundedFollowing,
425            ),
426        };
427        let mut state = RankState::<Rank>::new(&call);
428        assert!(state.curr_window().key.is_none());
429        assert!(!state.curr_window().is_ready);
430        state.append(create_state_key(1, 100), SmallVec::new());
431        state.append(create_state_key(2, 103), SmallVec::new());
432        state.append(create_state_key(2, 102), SmallVec::new());
433        state.append(create_state_key(3, 106), SmallVec::new());
434        state.append(create_state_key(3, 105), SmallVec::new());
435        state.append(create_state_key(3, 104), SmallVec::new());
436        state.append(create_state_key(8, 108), SmallVec::new());
437
438        let mut outputs = vec![];
439        while state.curr_window().is_ready {
440            outputs.push(state.slide().unwrap().0)
441        }
442
443        assert_eq!(
444            outputs,
445            vec![
446                Some(1i64.into()),
447                Some(2i64.into()),
448                Some(2i64.into()),
449                Some(4i64.into()),
450                Some(4i64.into()),
451                Some(4i64.into()),
452                Some(7i64.into())
453            ]
454        );
455    }
456
457    #[test]
458    fn test_dense_rank_state() {
459        let call = WindowFuncCall {
460            kind: WindowFuncKind::DenseRank,
461            return_type: DataType::Int64,
462            args: AggArgs::default(),
463            ignore_nulls: false,
464            frame: Frame::rows(
465                FrameBound::UnboundedPreceding,
466                FrameBound::UnboundedFollowing,
467            ),
468        };
469        let mut state = RankState::<DenseRank>::new(&call);
470        assert!(state.curr_window().key.is_none());
471        assert!(!state.curr_window().is_ready);
472        state.append(create_state_key(1, 100), SmallVec::new());
473        state.append(create_state_key(2, 103), SmallVec::new());
474        state.append(create_state_key(2, 102), SmallVec::new());
475        state.append(create_state_key(3, 106), SmallVec::new());
476        state.append(create_state_key(3, 105), SmallVec::new());
477        state.append(create_state_key(3, 104), SmallVec::new());
478        state.append(create_state_key(8, 108), SmallVec::new());
479
480        let mut outputs = vec![];
481        while state.curr_window().is_ready {
482            outputs.push(state.slide().unwrap().0)
483        }
484
485        assert_eq!(
486            outputs,
487            vec![
488                Some(1i64.into()),
489                Some(2i64.into()),
490                Some(2i64.into()),
491                Some(3i64.into()),
492                Some(3i64.into()),
493                Some(3i64.into()),
494                Some(4i64.into())
495            ]
496        );
497    }
498
499    fn create_call(kind: WindowFuncKind) -> WindowFuncCall {
500        WindowFuncCall {
501            kind,
502            return_type: DataType::Int64,
503            args: AggArgs::default(),
504            ignore_nulls: false,
505            frame: Frame::rows(
506                FrameBound::UnboundedPreceding,
507                FrameBound::UnboundedFollowing,
508            ),
509        }
510    }
511
512    #[test]
513    fn test_row_number_snapshot_restore_roundtrip() {
514        let call = create_call(WindowFuncKind::RowNumber);
515        let mut state = RankState::<RowNumber>::new(&call);
516        state.enable_persistence();
517
518        // Process some rows
519        state.append(create_state_key(1, 100), SmallVec::new());
520        state.append(create_state_key(2, 101), SmallVec::new());
521        state.append(create_state_key(3, 102), SmallVec::new());
522
523        // Output first two rows
524        let (output1, _) = state.slide().unwrap();
525        assert_eq!(output1.unwrap(), 1i64.into());
526        let (output2, _) = state.slide().unwrap();
527        assert_eq!(output2.unwrap(), 2i64.into());
528
529        // Take snapshot
530        let snapshot = state.snapshot().unwrap();
531        assert!(snapshot.last_output_key.is_some());
532        assert_eq!(
533            snapshot.last_output_key.as_ref().unwrap(),
534            &create_state_key(2, 101)
535        );
536
537        // Create new state and restore
538        let mut new_state = RankState::<RowNumber>::new(&call);
539        new_state.enable_persistence();
540        new_state.restore(snapshot).unwrap();
541
542        // Continue from where we left off (row 3)
543        new_state.append(create_state_key(3, 102), SmallVec::new());
544        new_state.append(create_state_key(4, 103), SmallVec::new());
545
546        // Output should continue from rank 3
547        let (output3, _) = new_state.slide().unwrap();
548        assert_eq!(output3.unwrap(), 3i64.into());
549        let (output4, _) = new_state.slide().unwrap();
550        assert_eq!(output4.unwrap(), 4i64.into());
551    }
552
553    #[test]
554    fn test_rank_snapshot_restore_roundtrip() {
555        let call = create_call(WindowFuncKind::Rank);
556        let mut state = RankState::<Rank>::new(&call);
557        state.enable_persistence();
558
559        // Add rows with ties
560        state.append(create_state_key(1, 100), SmallVec::new());
561        state.append(create_state_key(2, 101), SmallVec::new());
562        state.append(create_state_key(2, 102), SmallVec::new()); // tie
563
564        // Output rows
565        let (output1, _) = state.slide().unwrap();
566        assert_eq!(output1.unwrap(), 1i64.into());
567        let (output2, _) = state.slide().unwrap();
568        assert_eq!(output2.unwrap(), 2i64.into()); // first in tie group
569        let (output3, _) = state.slide().unwrap();
570        assert_eq!(output3.unwrap(), 2i64.into()); // second in tie group
571
572        // Take snapshot
573        let snapshot = state.snapshot().unwrap();
574
575        // Create new state and restore
576        let mut new_state = RankState::<Rank>::new(&call);
577        new_state.enable_persistence();
578        new_state.restore(snapshot).unwrap();
579
580        // Add more rows
581        new_state.append(create_state_key(3, 103), SmallVec::new());
582
583        // Output should be rank 4 (since 2 items tied at rank 2)
584        let (output4, _) = new_state.slide().unwrap();
585        assert_eq!(output4.unwrap(), 4i64.into());
586    }
587
588    #[test]
589    fn test_dense_rank_snapshot_restore_roundtrip() {
590        let call = create_call(WindowFuncKind::DenseRank);
591        let mut state = RankState::<DenseRank>::new(&call);
592        state.enable_persistence();
593
594        // Add rows with ties
595        state.append(create_state_key(1, 100), SmallVec::new());
596        state.append(create_state_key(2, 101), SmallVec::new());
597        state.append(create_state_key(2, 102), SmallVec::new()); // tie
598
599        // Output rows
600        let (output1, _) = state.slide().unwrap();
601        assert_eq!(output1.unwrap(), 1i64.into());
602        let (output2, _) = state.slide().unwrap();
603        assert_eq!(output2.unwrap(), 2i64.into());
604        let (output3, _) = state.slide().unwrap();
605        assert_eq!(output3.unwrap(), 2i64.into()); // same rank due to tie
606
607        // Take snapshot
608        let snapshot = state.snapshot().unwrap();
609
610        // Create new state and restore
611        let mut new_state = RankState::<DenseRank>::new(&call);
612        new_state.enable_persistence();
613        new_state.restore(snapshot).unwrap();
614
615        // Add more rows
616        new_state.append(create_state_key(3, 103), SmallVec::new());
617
618        // Output should be rank 3 (dense rank increments by 1)
619        let (output4, _) = new_state.slide().unwrap();
620        assert_eq!(output4.unwrap(), 3i64.into());
621    }
622
623    #[test]
624    fn test_recovery_skip_logic() {
625        // Test that slide_no_output correctly skips counting for recovered rows
626        let call = create_call(WindowFuncKind::RowNumber);
627        let mut state = RankState::<RowNumber>::new(&call);
628        state.enable_persistence();
629
630        // Process initial rows
631        state.append(create_state_key(1, 100), SmallVec::new());
632        state.append(create_state_key(2, 101), SmallVec::new());
633        state.append(create_state_key(3, 102), SmallVec::new());
634
635        let (_, _) = state.slide().unwrap();
636        let (_, _) = state.slide().unwrap();
637
638        // Take snapshot after row 2
639        let snapshot = state.snapshot().unwrap();
640        assert_eq!(
641            snapshot.last_output_key.as_ref().unwrap(),
642            &create_state_key(2, 101)
643        );
644
645        // Create new state and restore
646        let mut new_state = RankState::<RowNumber>::new(&call);
647        new_state.enable_persistence();
648        new_state.restore(snapshot).unwrap();
649
650        // Simulate recovery: replay rows from state table
651        // First two rows should be skipped in slide_no_output
652        new_state.append(create_state_key(1, 100), SmallVec::new());
653        new_state.append(create_state_key(2, 101), SmallVec::new());
654        new_state.append(create_state_key(3, 102), SmallVec::new());
655        new_state.append(create_state_key(4, 103), SmallVec::new());
656
657        // Use slide_no_output for recovery (skips counting for rows <= snapshot key)
658        let _ = new_state.slide_no_output().unwrap();
659        let _ = new_state.slide_no_output().unwrap();
660
661        // Now output the remaining rows - should continue correctly
662        let (output3, _) = new_state.slide().unwrap();
663        assert_eq!(output3.unwrap(), 3i64.into());
664        let (output4, _) = new_state.slide().unwrap();
665        assert_eq!(output4.unwrap(), 4i64.into());
666    }
667
668    #[test]
669    fn test_eviction_hint_with_persistence() {
670        let call = create_call(WindowFuncKind::RowNumber);
671
672        // Test without persistence - should return CannotEvict
673        let mut state_no_persist = RankState::<RowNumber>::new(&call);
674        state_no_persist.append(create_state_key(1, 100), SmallVec::new());
675        let (_, evict_hint) = state_no_persist.slide().unwrap();
676        match evict_hint {
677            StateEvictHint::CannotEvict(_) => {}
678            StateEvictHint::CanEvict(_) => panic!("expected CannotEvict without persistence"),
679        }
680
681        // Test with persistence - should return CanEvict
682        let mut state_persist = RankState::<RowNumber>::new(&call);
683        state_persist.enable_persistence();
684        state_persist.append(create_state_key(1, 100), SmallVec::new());
685        let (_, evict_hint) = state_persist.slide().unwrap();
686        match evict_hint {
687            StateEvictHint::CanEvict(keys) => {
688                assert_eq!(keys.len(), 1);
689                assert!(keys.contains(&create_state_key(1, 100)));
690            }
691            StateEvictHint::CannotEvict(_) => panic!("expected CanEvict with persistence"),
692        }
693    }
694
695    #[test]
696    fn test_snapshot_returns_none_without_persistence() {
697        let call = create_call(WindowFuncKind::RowNumber);
698        let mut state = RankState::<RowNumber>::new(&call);
699        // Don't enable persistence
700        state.append(create_state_key(1, 100), SmallVec::new());
701        let (_, _) = state.slide().unwrap();
702
703        // Snapshot should return None
704        assert!(state.snapshot().is_none());
705    }
706
707    #[test]
708    fn test_restore_from_empty_starts_fresh() {
709        // Test that restoring with default state works correctly
710        let call = create_call(WindowFuncKind::RowNumber);
711        let mut state = RankState::<RowNumber>::new(&call);
712        state.enable_persistence();
713
714        // Create a snapshot with default func_state (prev_rank = 0)
715        let snapshot = WindowStateSnapshot {
716            last_output_key: None,
717            function_state: RowNumber::default().to_proto_state(),
718        };
719
720        state.restore(snapshot).unwrap();
721
722        // Add rows and verify output starts from 1
723        state.append(create_state_key(1, 100), SmallVec::new());
724        let (output, _) = state.slide().unwrap();
725        assert_eq!(output.unwrap(), 1i64.into());
726    }
727
728    #[test]
729    fn test_wrong_function_state_type_is_rejected() {
730        // Restoring a RowNumber state with a RankState payload should fail fast.
731        let call = create_call(WindowFuncKind::RowNumber);
732        let mut state = RankState::<RowNumber>::new(&call);
733        state.enable_persistence();
734
735        // Snapshot contains a RankState variant instead of RowNumberState.
736        let snapshot = WindowStateSnapshot {
737            last_output_key: None,
738            function_state: Rank::default().to_proto_state(),
739        };
740
741        assert!(state.restore(snapshot).is_err());
742    }
743}