1use std::collections::BTreeSet;
16use std::marker::PhantomData;
17
18use risingwave_common::types::Datum;
19use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
20use risingwave_common_estimate_size::EstimateSize;
21use risingwave_common_estimate_size::collections::EstimatedVecDeque;
22use risingwave_expr::window_function::{
23 StateEvictHint, StateKey, StatePos, WindowFuncCall, WindowState, WindowStateSnapshot,
24};
25use risingwave_expr::{ExprError, Result};
26use risingwave_pb::window_function::window_state_snapshot::FunctionState;
27use risingwave_pb::window_function::{DenseRankState, RankState as RankStateProto, RowNumberState};
28use smallvec::SmallVec;
29
30use self::private::RankFuncCount;
31
32mod private {
33 use super::*;
34
35 pub trait RankFuncCount: Default + EstimateSize {
36 fn count(&mut self, curr_key: StateKey) -> i64;
38
39 fn to_proto_state(&self) -> FunctionState;
41
42 fn from_proto_state(state: FunctionState) -> Result<Self>
45 where
46 Self: Sized;
47 }
48}
49
50#[derive(Default, EstimateSize)]
51pub(super) struct RowNumber {
52 prev_rank: i64,
53}
54
55impl RankFuncCount for RowNumber {
56 fn count(&mut self, _curr_key: StateKey) -> i64 {
57 let curr_rank = self.prev_rank + 1;
58 self.prev_rank = curr_rank;
59 curr_rank
60 }
61
62 fn to_proto_state(&self) -> FunctionState {
63 FunctionState::RowNumberState(RowNumberState {
64 prev_rank: self.prev_rank,
65 })
66 }
67
68 fn from_proto_state(state: FunctionState) -> Result<Self> {
69 match state {
70 FunctionState::RowNumberState(s) => Ok(Self {
71 prev_rank: s.prev_rank,
72 }),
73 other => Err(ExprError::Internal(anyhow::anyhow!(
74 "expected RowNumberState, got {other:?}"
75 ))),
76 }
77 }
78}
79
80#[derive(EstimateSize)]
81pub(super) struct Rank {
82 prev_order_key: Option<MemcmpEncoded>,
83 prev_rank: i64,
84 prev_pos_in_peer_group: i64,
87}
88
89impl Default for Rank {
90 fn default() -> Self {
91 Self {
92 prev_order_key: None,
93 prev_rank: 0,
94 prev_pos_in_peer_group: 1, }
96 }
97}
98
99impl RankFuncCount for Rank {
100 fn count(&mut self, curr_key: StateKey) -> i64 {
101 let (curr_rank, curr_pos_in_group) = if let Some(prev_order_key) =
102 self.prev_order_key.as_ref()
103 && prev_order_key == &curr_key.order_key
104 {
105 (self.prev_rank, self.prev_pos_in_peer_group + 1)
107 } else {
108 (self.prev_rank + self.prev_pos_in_peer_group, 1)
110 };
111 self.prev_order_key = Some(curr_key.order_key);
112 self.prev_rank = curr_rank;
113 self.prev_pos_in_peer_group = curr_pos_in_group;
114 curr_rank
115 }
116
117 fn to_proto_state(&self) -> FunctionState {
118 FunctionState::RankState(RankStateProto {
119 prev_order_key: self.prev_order_key.as_ref().map(|k| k.to_vec()),
120 prev_rank: self.prev_rank,
121 prev_pos_in_peer_group: self.prev_pos_in_peer_group,
122 })
123 }
124
125 fn from_proto_state(state: FunctionState) -> Result<Self> {
126 match state {
127 FunctionState::RankState(s) => Ok(Self {
128 prev_order_key: s.prev_order_key.map(Into::into),
129 prev_rank: s.prev_rank,
130 prev_pos_in_peer_group: s.prev_pos_in_peer_group,
131 }),
132 other => Err(ExprError::Internal(anyhow::anyhow!(
133 "expected RankState, got {other:?}"
134 ))),
135 }
136 }
137}
138
139#[derive(Default, EstimateSize)]
140pub(super) struct DenseRank {
141 prev_order_key: Option<MemcmpEncoded>,
142 prev_rank: i64,
143}
144
145impl RankFuncCount for DenseRank {
146 fn count(&mut self, curr_key: StateKey) -> i64 {
147 let curr_rank = if let Some(prev_order_key) = self.prev_order_key.as_ref()
148 && prev_order_key == &curr_key.order_key
149 {
150 self.prev_rank
152 } else {
153 self.prev_rank + 1
155 };
156 self.prev_order_key = Some(curr_key.order_key);
157 self.prev_rank = curr_rank;
158 curr_rank
159 }
160
161 fn to_proto_state(&self) -> FunctionState {
162 FunctionState::DenseRankState(DenseRankState {
163 prev_order_key: self.prev_order_key.as_ref().map(|k| k.to_vec()),
164 prev_rank: self.prev_rank,
165 })
166 }
167
168 fn from_proto_state(state: FunctionState) -> Result<Self> {
169 match state {
170 FunctionState::DenseRankState(s) => Ok(Self {
171 prev_order_key: s.prev_order_key.map(Into::into),
172 prev_rank: s.prev_rank,
173 }),
174 other => Err(ExprError::Internal(anyhow::anyhow!(
175 "expected DenseRankState, got {other:?}"
176 ))),
177 }
178 }
179}
180
181#[derive(EstimateSize)]
183pub(super) struct RankState<RF: RankFuncCount> {
184 first_key: Option<StateKey>,
186 buffer: EstimatedVecDeque<StateKey>,
188 func_state: RF,
190 persistence_enabled: bool,
192 last_output_key: Option<StateKey>,
194 recover_skip_until: Option<StateKey>,
196 _phantom: PhantomData<RF>,
197}
198
199impl<RF: RankFuncCount> RankState<RF> {
200 pub fn new(_call: &WindowFuncCall) -> Self {
201 Self {
202 first_key: None,
203 buffer: Default::default(),
204 func_state: Default::default(),
205 persistence_enabled: false,
206 last_output_key: None,
207 recover_skip_until: None,
208 _phantom: PhantomData,
209 }
210 }
211
212 fn slide_inner(&mut self) -> (i64, StateEvictHint) {
213 let curr_key = self
214 .buffer
215 .pop_front()
216 .expect("should not slide forward when the current window is not ready");
217 let rank = self.func_state.count(curr_key.clone());
218
219 self.last_output_key = Some(curr_key.clone());
221
222 let evict_hint = if self.persistence_enabled {
223 let mut evict_set = BTreeSet::new();
225 evict_set.insert(curr_key);
226 StateEvictHint::CanEvict(evict_set)
227 } else {
228 StateEvictHint::CannotEvict(
231 self.first_key
232 .clone()
233 .expect("should have appended some rows"),
234 )
235 };
236 (rank, evict_hint)
237 }
238
239 fn slide_no_output_inner(&mut self) -> StateEvictHint {
240 let curr_key = self
241 .buffer
242 .pop_front()
243 .expect("should not slide forward when the current window is not ready");
244
245 let should_skip = self
247 .recover_skip_until
248 .as_ref()
249 .is_some_and(|skip_key| &curr_key <= skip_key);
250
251 if !should_skip {
252 self.func_state.count(curr_key.clone());
254 }
255
256 if self
258 .recover_skip_until
259 .as_ref()
260 .is_some_and(|skip_key| &curr_key == skip_key)
261 {
262 self.recover_skip_until = None;
263 }
264
265 self.last_output_key = Some(curr_key.clone());
267
268 if self.persistence_enabled {
269 let mut evict_set = BTreeSet::new();
270 evict_set.insert(curr_key);
271 StateEvictHint::CanEvict(evict_set)
272 } else {
273 StateEvictHint::CannotEvict(
274 self.first_key
275 .clone()
276 .expect("should have appended some rows"),
277 )
278 }
279 }
280}
281
282impl<RF: RankFuncCount> WindowState for RankState<RF> {
283 fn append(&mut self, key: StateKey, _args: SmallVec<[Datum; 2]>) {
284 if self.first_key.is_none() {
285 self.first_key = Some(key.clone());
286 }
287 self.buffer.push_back(key);
288 }
289
290 fn curr_window(&self) -> StatePos<'_> {
291 let curr_key = self.buffer.front();
292 StatePos {
293 key: curr_key,
294 is_ready: curr_key.is_some(),
295 }
296 }
297
298 fn slide(&mut self) -> Result<(Datum, StateEvictHint)> {
299 let (rank, evict_hint) = self.slide_inner();
300 Ok((Some(rank.into()), evict_hint))
301 }
302
303 fn slide_no_output(&mut self) -> Result<StateEvictHint> {
304 let evict_hint = self.slide_no_output_inner();
305 Ok(evict_hint)
306 }
307
308 fn enable_persistence(&mut self) {
309 self.persistence_enabled = true;
310 }
311
312 fn snapshot(&self) -> Option<WindowStateSnapshot> {
313 if !self.persistence_enabled {
314 return None;
315 }
316 Some(WindowStateSnapshot {
317 last_output_key: self.last_output_key.clone(),
318 function_state: self.func_state.to_proto_state(),
319 })
320 }
321
322 fn restore(&mut self, snapshot: WindowStateSnapshot) -> Result<()> {
323 self.func_state = RF::from_proto_state(snapshot.function_state)?;
324 self.recover_skip_until = snapshot.last_output_key;
327 Ok(())
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use risingwave_common::row::OwnedRow;
334 use risingwave_common::types::{DataType, ScalarImpl};
335 use risingwave_common::util::memcmp_encoding;
336 use risingwave_common::util::sort_util::OrderType;
337 use risingwave_expr::aggregate::AggArgs;
338 use risingwave_expr::window_function::{Frame, FrameBound, WindowFuncKind};
339
340 use super::*;
341
342 fn create_state_key(order: i64, pk: i64) -> StateKey {
343 StateKey {
344 order_key: memcmp_encoding::encode_value(
345 Some(ScalarImpl::from(order)),
346 OrderType::ascending(),
347 )
348 .unwrap(),
349 pk: OwnedRow::new(vec![Some(pk.into())]).into(),
350 }
351 }
352
353 #[test]
354 #[should_panic(expected = "should not slide forward when the current window is not ready")]
355 fn test_rank_state_bad_use() {
356 let call = WindowFuncCall {
357 kind: WindowFuncKind::RowNumber,
358 return_type: DataType::Int64,
359 args: AggArgs::default(),
360 ignore_nulls: false,
361 frame: Frame::rows(
362 FrameBound::UnboundedPreceding,
363 FrameBound::UnboundedFollowing,
364 ),
365 };
366 let mut state = RankState::<RowNumber>::new(&call);
367 assert!(state.curr_window().key.is_none());
368 assert!(!state.curr_window().is_ready);
369 _ = state.slide()
370 }
371
372 #[test]
373 fn test_row_number_state() {
374 let call = WindowFuncCall {
375 kind: WindowFuncKind::RowNumber,
376 return_type: DataType::Int64,
377 args: AggArgs::default(),
378 ignore_nulls: false,
379 frame: Frame::rows(
380 FrameBound::UnboundedPreceding,
381 FrameBound::UnboundedFollowing,
382 ),
383 };
384 let mut state = RankState::<RowNumber>::new(&call);
385 assert!(state.curr_window().key.is_none());
386 assert!(!state.curr_window().is_ready);
387 state.append(create_state_key(1, 100), SmallVec::new());
388 assert_eq!(state.curr_window().key.unwrap(), &create_state_key(1, 100));
389 assert!(state.curr_window().is_ready);
390 let (output, evict_hint) = state.slide().unwrap();
391 assert_eq!(output.unwrap(), 1i64.into());
392 match evict_hint {
393 StateEvictHint::CannotEvict(state_key) => {
394 assert_eq!(state_key, create_state_key(1, 100));
395 }
396 _ => unreachable!(),
397 }
398 assert!(!state.curr_window().is_ready);
399 state.append(create_state_key(2, 103), SmallVec::new());
400 state.append(create_state_key(2, 102), SmallVec::new());
401 assert_eq!(state.curr_window().key.unwrap(), &create_state_key(2, 103));
402 let (output, evict_hint) = state.slide().unwrap();
403 assert_eq!(output.unwrap(), 2i64.into());
404 match evict_hint {
405 StateEvictHint::CannotEvict(state_key) => {
406 assert_eq!(state_key, create_state_key(1, 100));
407 }
408 _ => unreachable!(),
409 }
410 assert_eq!(state.curr_window().key.unwrap(), &create_state_key(2, 102));
411 let (output, _) = state.slide().unwrap();
412 assert_eq!(output.unwrap(), 3i64.into());
413 }
414
415 #[test]
416 fn test_rank_state() {
417 let call = WindowFuncCall {
418 kind: WindowFuncKind::Rank,
419 return_type: DataType::Int64,
420 args: AggArgs::default(),
421 ignore_nulls: false,
422 frame: Frame::rows(
423 FrameBound::UnboundedPreceding,
424 FrameBound::UnboundedFollowing,
425 ),
426 };
427 let mut state = RankState::<Rank>::new(&call);
428 assert!(state.curr_window().key.is_none());
429 assert!(!state.curr_window().is_ready);
430 state.append(create_state_key(1, 100), SmallVec::new());
431 state.append(create_state_key(2, 103), SmallVec::new());
432 state.append(create_state_key(2, 102), SmallVec::new());
433 state.append(create_state_key(3, 106), SmallVec::new());
434 state.append(create_state_key(3, 105), SmallVec::new());
435 state.append(create_state_key(3, 104), SmallVec::new());
436 state.append(create_state_key(8, 108), SmallVec::new());
437
438 let mut outputs = vec![];
439 while state.curr_window().is_ready {
440 outputs.push(state.slide().unwrap().0)
441 }
442
443 assert_eq!(
444 outputs,
445 vec![
446 Some(1i64.into()),
447 Some(2i64.into()),
448 Some(2i64.into()),
449 Some(4i64.into()),
450 Some(4i64.into()),
451 Some(4i64.into()),
452 Some(7i64.into())
453 ]
454 );
455 }
456
457 #[test]
458 fn test_dense_rank_state() {
459 let call = WindowFuncCall {
460 kind: WindowFuncKind::DenseRank,
461 return_type: DataType::Int64,
462 args: AggArgs::default(),
463 ignore_nulls: false,
464 frame: Frame::rows(
465 FrameBound::UnboundedPreceding,
466 FrameBound::UnboundedFollowing,
467 ),
468 };
469 let mut state = RankState::<DenseRank>::new(&call);
470 assert!(state.curr_window().key.is_none());
471 assert!(!state.curr_window().is_ready);
472 state.append(create_state_key(1, 100), SmallVec::new());
473 state.append(create_state_key(2, 103), SmallVec::new());
474 state.append(create_state_key(2, 102), SmallVec::new());
475 state.append(create_state_key(3, 106), SmallVec::new());
476 state.append(create_state_key(3, 105), SmallVec::new());
477 state.append(create_state_key(3, 104), SmallVec::new());
478 state.append(create_state_key(8, 108), SmallVec::new());
479
480 let mut outputs = vec![];
481 while state.curr_window().is_ready {
482 outputs.push(state.slide().unwrap().0)
483 }
484
485 assert_eq!(
486 outputs,
487 vec![
488 Some(1i64.into()),
489 Some(2i64.into()),
490 Some(2i64.into()),
491 Some(3i64.into()),
492 Some(3i64.into()),
493 Some(3i64.into()),
494 Some(4i64.into())
495 ]
496 );
497 }
498
499 fn create_call(kind: WindowFuncKind) -> WindowFuncCall {
500 WindowFuncCall {
501 kind,
502 return_type: DataType::Int64,
503 args: AggArgs::default(),
504 ignore_nulls: false,
505 frame: Frame::rows(
506 FrameBound::UnboundedPreceding,
507 FrameBound::UnboundedFollowing,
508 ),
509 }
510 }
511
512 #[test]
513 fn test_row_number_snapshot_restore_roundtrip() {
514 let call = create_call(WindowFuncKind::RowNumber);
515 let mut state = RankState::<RowNumber>::new(&call);
516 state.enable_persistence();
517
518 state.append(create_state_key(1, 100), SmallVec::new());
520 state.append(create_state_key(2, 101), SmallVec::new());
521 state.append(create_state_key(3, 102), SmallVec::new());
522
523 let (output1, _) = state.slide().unwrap();
525 assert_eq!(output1.unwrap(), 1i64.into());
526 let (output2, _) = state.slide().unwrap();
527 assert_eq!(output2.unwrap(), 2i64.into());
528
529 let snapshot = state.snapshot().unwrap();
531 assert!(snapshot.last_output_key.is_some());
532 assert_eq!(
533 snapshot.last_output_key.as_ref().unwrap(),
534 &create_state_key(2, 101)
535 );
536
537 let mut new_state = RankState::<RowNumber>::new(&call);
539 new_state.enable_persistence();
540 new_state.restore(snapshot).unwrap();
541
542 new_state.append(create_state_key(3, 102), SmallVec::new());
544 new_state.append(create_state_key(4, 103), SmallVec::new());
545
546 let (output3, _) = new_state.slide().unwrap();
548 assert_eq!(output3.unwrap(), 3i64.into());
549 let (output4, _) = new_state.slide().unwrap();
550 assert_eq!(output4.unwrap(), 4i64.into());
551 }
552
553 #[test]
554 fn test_rank_snapshot_restore_roundtrip() {
555 let call = create_call(WindowFuncKind::Rank);
556 let mut state = RankState::<Rank>::new(&call);
557 state.enable_persistence();
558
559 state.append(create_state_key(1, 100), SmallVec::new());
561 state.append(create_state_key(2, 101), SmallVec::new());
562 state.append(create_state_key(2, 102), SmallVec::new()); let (output1, _) = state.slide().unwrap();
566 assert_eq!(output1.unwrap(), 1i64.into());
567 let (output2, _) = state.slide().unwrap();
568 assert_eq!(output2.unwrap(), 2i64.into()); let (output3, _) = state.slide().unwrap();
570 assert_eq!(output3.unwrap(), 2i64.into()); let snapshot = state.snapshot().unwrap();
574
575 let mut new_state = RankState::<Rank>::new(&call);
577 new_state.enable_persistence();
578 new_state.restore(snapshot).unwrap();
579
580 new_state.append(create_state_key(3, 103), SmallVec::new());
582
583 let (output4, _) = new_state.slide().unwrap();
585 assert_eq!(output4.unwrap(), 4i64.into());
586 }
587
588 #[test]
589 fn test_dense_rank_snapshot_restore_roundtrip() {
590 let call = create_call(WindowFuncKind::DenseRank);
591 let mut state = RankState::<DenseRank>::new(&call);
592 state.enable_persistence();
593
594 state.append(create_state_key(1, 100), SmallVec::new());
596 state.append(create_state_key(2, 101), SmallVec::new());
597 state.append(create_state_key(2, 102), SmallVec::new()); let (output1, _) = state.slide().unwrap();
601 assert_eq!(output1.unwrap(), 1i64.into());
602 let (output2, _) = state.slide().unwrap();
603 assert_eq!(output2.unwrap(), 2i64.into());
604 let (output3, _) = state.slide().unwrap();
605 assert_eq!(output3.unwrap(), 2i64.into()); let snapshot = state.snapshot().unwrap();
609
610 let mut new_state = RankState::<DenseRank>::new(&call);
612 new_state.enable_persistence();
613 new_state.restore(snapshot).unwrap();
614
615 new_state.append(create_state_key(3, 103), SmallVec::new());
617
618 let (output4, _) = new_state.slide().unwrap();
620 assert_eq!(output4.unwrap(), 3i64.into());
621 }
622
623 #[test]
624 fn test_recovery_skip_logic() {
625 let call = create_call(WindowFuncKind::RowNumber);
627 let mut state = RankState::<RowNumber>::new(&call);
628 state.enable_persistence();
629
630 state.append(create_state_key(1, 100), SmallVec::new());
632 state.append(create_state_key(2, 101), SmallVec::new());
633 state.append(create_state_key(3, 102), SmallVec::new());
634
635 let (_, _) = state.slide().unwrap();
636 let (_, _) = state.slide().unwrap();
637
638 let snapshot = state.snapshot().unwrap();
640 assert_eq!(
641 snapshot.last_output_key.as_ref().unwrap(),
642 &create_state_key(2, 101)
643 );
644
645 let mut new_state = RankState::<RowNumber>::new(&call);
647 new_state.enable_persistence();
648 new_state.restore(snapshot).unwrap();
649
650 new_state.append(create_state_key(1, 100), SmallVec::new());
653 new_state.append(create_state_key(2, 101), SmallVec::new());
654 new_state.append(create_state_key(3, 102), SmallVec::new());
655 new_state.append(create_state_key(4, 103), SmallVec::new());
656
657 let _ = new_state.slide_no_output().unwrap();
659 let _ = new_state.slide_no_output().unwrap();
660
661 let (output3, _) = new_state.slide().unwrap();
663 assert_eq!(output3.unwrap(), 3i64.into());
664 let (output4, _) = new_state.slide().unwrap();
665 assert_eq!(output4.unwrap(), 4i64.into());
666 }
667
668 #[test]
669 fn test_eviction_hint_with_persistence() {
670 let call = create_call(WindowFuncKind::RowNumber);
671
672 let mut state_no_persist = RankState::<RowNumber>::new(&call);
674 state_no_persist.append(create_state_key(1, 100), SmallVec::new());
675 let (_, evict_hint) = state_no_persist.slide().unwrap();
676 match evict_hint {
677 StateEvictHint::CannotEvict(_) => {}
678 StateEvictHint::CanEvict(_) => panic!("expected CannotEvict without persistence"),
679 }
680
681 let mut state_persist = RankState::<RowNumber>::new(&call);
683 state_persist.enable_persistence();
684 state_persist.append(create_state_key(1, 100), SmallVec::new());
685 let (_, evict_hint) = state_persist.slide().unwrap();
686 match evict_hint {
687 StateEvictHint::CanEvict(keys) => {
688 assert_eq!(keys.len(), 1);
689 assert!(keys.contains(&create_state_key(1, 100)));
690 }
691 StateEvictHint::CannotEvict(_) => panic!("expected CanEvict with persistence"),
692 }
693 }
694
695 #[test]
696 fn test_snapshot_returns_none_without_persistence() {
697 let call = create_call(WindowFuncKind::RowNumber);
698 let mut state = RankState::<RowNumber>::new(&call);
699 state.append(create_state_key(1, 100), SmallVec::new());
701 let (_, _) = state.slide().unwrap();
702
703 assert!(state.snapshot().is_none());
705 }
706
707 #[test]
708 fn test_restore_from_empty_starts_fresh() {
709 let call = create_call(WindowFuncKind::RowNumber);
711 let mut state = RankState::<RowNumber>::new(&call);
712 state.enable_persistence();
713
714 let snapshot = WindowStateSnapshot {
716 last_output_key: None,
717 function_state: RowNumber::default().to_proto_state(),
718 };
719
720 state.restore(snapshot).unwrap();
721
722 state.append(create_state_key(1, 100), SmallVec::new());
724 let (output, _) = state.slide().unwrap();
725 assert_eq!(output.unwrap(), 1i64.into());
726 }
727
728 #[test]
729 fn test_wrong_function_state_type_is_rejected() {
730 let call = create_call(WindowFuncKind::RowNumber);
732 let mut state = RankState::<RowNumber>::new(&call);
733 state.enable_persistence();
734
735 let snapshot = WindowStateSnapshot {
737 last_output_key: None,
738 function_state: Rank::default().to_proto_state(),
739 };
740
741 assert!(state.restore(snapshot).is_err());
742 }
743}