risingwave_expr_impl/window_function/
range_utils.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16
17use smallvec::{SmallVec, smallvec};
18
19/// Calculate range (A - B), the result might be the union of two ranges when B is totally included
20/// in the A.
21pub(super) fn range_except(a: Range<usize>, b: Range<usize>) -> (Range<usize>, Range<usize>) {
22    if a.is_empty() {
23        (0..0, 0..0)
24    } else if b.is_empty() {
25        (a, 0..0)
26    } else if a.end <= b.start || b.end <= a.start {
27        // a: [   )
28        // b:        [   )
29        // or
30        // a:        [   )
31        // b: [   )
32        (a, 0..0)
33    } else if b.start <= a.start && a.end <= b.end {
34        // a:  [   )
35        // b: [       )
36        (0..0, 0..0)
37    } else if a.start < b.start && b.end < a.end {
38        // a: [       )
39        // b:   [   )
40        (a.start..b.start, b.end..a.end)
41    } else if a.end <= b.end {
42        // a: [   )
43        // b:   [   )
44        (a.start..b.start, 0..0)
45    } else if b.start <= a.start {
46        // a:   [   )
47        // b: [   )
48        (b.end..a.end, 0..0)
49    } else {
50        unreachable!()
51    }
52}
53
54/// Calculate the difference of two ranges A and B, return (removed ranges, added ranges).
55/// Note this is quite different from [`range_except`].
56#[expect(clippy::type_complexity)] // looks complex but it's not
57pub(super) fn range_diff(
58    a: Range<usize>,
59    b: Range<usize>,
60) -> (SmallVec<[Range<usize>; 2]>, SmallVec<[Range<usize>; 2]>) {
61    if a.start == b.start {
62        match a.end.cmp(&b.end) {
63            std::cmp::Ordering::Equal => {
64                // a: [   )
65                // b: [   )
66                (smallvec![], smallvec![])
67            }
68            std::cmp::Ordering::Less => {
69                // a: [   )
70                // b: [     )
71                (smallvec![], smallvec![a.end..b.end])
72            }
73            std::cmp::Ordering::Greater => {
74                // a: [     )
75                // b: [   )
76                (smallvec![b.end..a.end], smallvec![])
77            }
78        }
79    } else if a.end == b.end {
80        debug_assert!(a.start != b.start);
81        if a.start < b.start {
82            // a: [     )
83            // b:   [   )
84            (smallvec![a.start..b.start], smallvec![])
85        } else {
86            // a:   [   )
87            // b: [     )
88            (smallvec![], smallvec![b.start..a.start])
89        }
90    } else {
91        debug_assert!(a.start != b.start && a.end != b.end);
92        if a.end <= b.start || b.end <= a.start {
93            // a: [   )
94            // b:     [  [   )
95            // or
96            // a:       [   )
97            // b: [   ) )
98            (smallvec![a], smallvec![b])
99        } else if b.start < a.start && a.end < b.end {
100            // a:  [   )
101            // b: [       )
102            (smallvec![], smallvec![b.start..a.start, a.end..b.end])
103        } else if a.start < b.start && b.end < a.end {
104            // a: [       )
105            // b:   [   )
106            (smallvec![a.start..b.start, b.end..a.end], smallvec![])
107        } else if a.end < b.end {
108            // a: [   )
109            // b:   [   )
110            (smallvec![a.start..b.start], smallvec![a.end..b.end])
111        } else {
112            // a:   [   )
113            // b: [   )
114            (smallvec![b.end..a.end], smallvec![b.start..a.start])
115        }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use std::collections::HashSet;
122
123    use super::*;
124
125    #[test]
126    fn test_range_except() {
127        fn test(a: Range<usize>, b: Range<usize>, expected: impl IntoIterator<Item = usize>) {
128            let (l, r) = range_except(a, b);
129            let set = l.into_iter().chain(r).collect::<HashSet<_>>();
130            assert_eq!(set, expected.into_iter().collect())
131        }
132
133        test(0..0, 0..0, []);
134        test(0..1, 0..1, []);
135        test(0..1, 0..2, []);
136        test(1..2, 0..2, []);
137        test(0..2, 0..1, [1]);
138        test(0..2, 1..2, [0]);
139        test(0..5, 2..3, [0, 1, 3, 4]);
140        test(2..5, 1..3, [3, 4]);
141        test(2..5, 4..5, [2, 3]);
142    }
143
144    #[test]
145    fn test_range_diff() {
146        fn test(
147            a: Range<usize>,
148            b: Range<usize>,
149            expected_removed: impl IntoIterator<Item = usize>,
150            expected_added: impl IntoIterator<Item = usize>,
151        ) {
152            let (removed, added) = range_diff(a, b);
153            let removed_set = removed.into_iter().flatten().collect::<HashSet<_>>();
154            let added_set = added.into_iter().flatten().collect::<HashSet<_>>();
155            let expected_removed_set = expected_removed.into_iter().collect::<HashSet<_>>();
156            let expected_added_set = expected_added.into_iter().collect::<HashSet<_>>();
157            assert_eq!(removed_set, expected_removed_set);
158            assert_eq!(added_set, expected_added_set);
159        }
160
161        test(0..0, 0..0, [], []);
162        test(0..1, 0..1, [], []);
163        test(0..1, 0..2, [], [1]);
164        test(0..2, 0..1, [1], []);
165        test(0..2, 1..2, [0], []);
166        test(1..2, 0..2, [], [0]);
167        test(0..1, 1..2, [0], [1]);
168        test(0..1, 2..3, [0], [2]);
169        test(1..2, 0..1, [1], [0]);
170        test(2..3, 0..1, [2], [0]);
171        test(0..3, 1..2, [0, 2], []);
172        test(1..2, 0..3, [], [0, 2]);
173        test(0..3, 2..4, [0, 1], [3]);
174        test(2..4, 0..3, [3], [0, 1]);
175    }
176}