risingwave_expr_impl/window_function/
range_utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16
17use smallvec::{SmallVec, smallvec};
18
19/// Calculate range (A - B), the result might be the union of two ranges when B is totally included
20/// in the A.
21pub(super) fn range_except(a: Range<usize>, b: Range<usize>) -> (Range<usize>, Range<usize>) {
22    #[allow(clippy::if_same_then_else)] // for better readability
23    if a.is_empty() {
24        (0..0, 0..0)
25    } else if b.is_empty() {
26        (a, 0..0)
27    } else if a.end <= b.start || b.end <= a.start {
28        // a: [   )
29        // b:        [   )
30        // or
31        // a:        [   )
32        // b: [   )
33        (a, 0..0)
34    } else if b.start <= a.start && a.end <= b.end {
35        // a:  [   )
36        // b: [       )
37        (0..0, 0..0)
38    } else if a.start < b.start && b.end < a.end {
39        // a: [       )
40        // b:   [   )
41        (a.start..b.start, b.end..a.end)
42    } else if a.end <= b.end {
43        // a: [   )
44        // b:   [   )
45        (a.start..b.start, 0..0)
46    } else if b.start <= a.start {
47        // a:   [   )
48        // b: [   )
49        (b.end..a.end, 0..0)
50    } else {
51        unreachable!()
52    }
53}
54
55/// Calculate the difference of two ranges A and B, return (removed ranges, added ranges).
56/// Note this is quite different from [`range_except`].
57#[allow(clippy::type_complexity)] // looks complex but it's not
58pub(super) fn range_diff(
59    a: Range<usize>,
60    b: Range<usize>,
61) -> (SmallVec<[Range<usize>; 2]>, SmallVec<[Range<usize>; 2]>) {
62    if a.start == b.start {
63        match a.end.cmp(&b.end) {
64            std::cmp::Ordering::Equal => {
65                // a: [   )
66                // b: [   )
67                (smallvec![], smallvec![])
68            }
69            std::cmp::Ordering::Less => {
70                // a: [   )
71                // b: [     )
72                (smallvec![], smallvec![a.end..b.end])
73            }
74            std::cmp::Ordering::Greater => {
75                // a: [     )
76                // b: [   )
77                (smallvec![b.end..a.end], smallvec![])
78            }
79        }
80    } else if a.end == b.end {
81        debug_assert!(a.start != b.start);
82        if a.start < b.start {
83            // a: [     )
84            // b:   [   )
85            (smallvec![a.start..b.start], smallvec![])
86        } else {
87            // a:   [   )
88            // b: [     )
89            (smallvec![], smallvec![b.start..a.start])
90        }
91    } else {
92        debug_assert!(a.start != b.start && a.end != b.end);
93        if a.end <= b.start || b.end <= a.start {
94            // a: [   )
95            // b:     [  [   )
96            // or
97            // a:       [   )
98            // b: [   ) )
99            (smallvec![a], smallvec![b])
100        } else if b.start < a.start && a.end < b.end {
101            // a:  [   )
102            // b: [       )
103            (smallvec![], smallvec![b.start..a.start, a.end..b.end])
104        } else if a.start < b.start && b.end < a.end {
105            // a: [       )
106            // b:   [   )
107            (smallvec![a.start..b.start, b.end..a.end], smallvec![])
108        } else if a.end < b.end {
109            // a: [   )
110            // b:   [   )
111            (smallvec![a.start..b.start], smallvec![a.end..b.end])
112        } else {
113            // a:   [   )
114            // b: [   )
115            (smallvec![b.end..a.end], smallvec![b.start..a.start])
116        }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use std::collections::HashSet;
123
124    use super::*;
125
126    #[test]
127    fn test_range_except() {
128        fn test(a: Range<usize>, b: Range<usize>, expected: impl IntoIterator<Item = usize>) {
129            let (l, r) = range_except(a, b);
130            let set = l.into_iter().chain(r).collect::<HashSet<_>>();
131            assert_eq!(set, expected.into_iter().collect())
132        }
133
134        test(0..0, 0..0, []);
135        test(0..1, 0..1, []);
136        test(0..1, 0..2, []);
137        test(1..2, 0..2, []);
138        test(0..2, 0..1, [1]);
139        test(0..2, 1..2, [0]);
140        test(0..5, 2..3, [0, 1, 3, 4]);
141        test(2..5, 1..3, [3, 4]);
142        test(2..5, 4..5, [2, 3]);
143    }
144
145    #[test]
146    fn test_range_diff() {
147        fn test(
148            a: Range<usize>,
149            b: Range<usize>,
150            expected_removed: impl IntoIterator<Item = usize>,
151            expected_added: impl IntoIterator<Item = usize>,
152        ) {
153            let (removed, added) = range_diff(a, b);
154            let removed_set = removed.into_iter().flatten().collect::<HashSet<_>>();
155            let added_set = added.into_iter().flatten().collect::<HashSet<_>>();
156            let expected_removed_set = expected_removed.into_iter().collect::<HashSet<_>>();
157            let expected_added_set = expected_added.into_iter().collect::<HashSet<_>>();
158            assert_eq!(removed_set, expected_removed_set);
159            assert_eq!(added_set, expected_added_set);
160        }
161
162        test(0..0, 0..0, [], []);
163        test(0..1, 0..1, [], []);
164        test(0..1, 0..2, [], [1]);
165        test(0..2, 0..1, [1], []);
166        test(0..2, 1..2, [0], []);
167        test(1..2, 0..2, [], [0]);
168        test(0..1, 1..2, [0], [1]);
169        test(0..1, 2..3, [0], [2]);
170        test(1..2, 0..1, [1], [0]);
171        test(2..3, 0..1, [2], [0]);
172        test(0..3, 1..2, [0, 2], []);
173        test(1..2, 0..3, [], [0, 2]);
174        test(0..3, 2..4, [0, 1], [3]);
175        test(2..4, 0..3, [3], [0, 1]);
176    }
177}