risingwave_expr_impl/aggregate/
bit_or.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::marker::PhantomData;
16use std::ops::BitOr;
17
18use risingwave_common::array::I64Array;
19use risingwave_common::types::{ListRef, ListValue};
20use risingwave_expr::aggregate;
21
22use super::bit_and::Bits;
23
24/// Computes the bitwise OR of all non-null input values.
25///
26/// # Example
27///
28/// ```slt
29/// statement ok
30/// create table t (a int2, b int4, c int8);
31///
32/// query III
33/// select bit_or(a), bit_or(b), bit_or(c) from t;
34/// ----
35/// NULL NULL NULL
36///
37/// statement ok
38/// insert into t values
39///    (1, 1, 1),
40///    (2, 2, 2),
41///    (null, null, null);
42///
43/// query III
44/// select bit_or(a), bit_or(b), bit_or(c) from t;
45/// ----
46/// 3 3 3
47///
48/// statement ok
49/// drop table t;
50/// ```
51#[aggregate("bit_or(*int) -> auto")]
52fn bit_or_append_only<T>(state: T, input: T) -> T
53where
54    T: BitOr<Output = T>,
55{
56    state.bitor(input)
57}
58
59/// Computes the bitwise OR of all non-null input values.
60///
61/// # Example
62///
63/// ```slt
64/// statement ok
65/// create table t (a int2, b int4, c int8);
66///
67/// statement ok
68/// create materialized view mv as
69/// select bit_or(a) a, bit_or(b) b, bit_or(c) c from t;
70///
71/// query III
72/// select * from mv;
73/// ----
74/// NULL NULL NULL
75///
76/// statement ok
77/// insert into t values
78///    (6, 6, 6),
79///    (3, 3, 3),
80///    (null, null, null);
81///
82/// query III
83/// select * from mv;
84/// ----
85/// 7 7 7
86///
87/// statement ok
88/// delete from t where a = 3;
89///
90/// query III
91/// select * from mv;
92/// ----
93/// 6 6 6
94///
95/// statement ok
96/// drop materialized view mv;
97///
98/// statement ok
99/// drop table t;
100/// ```
101#[derive(Debug, Default, Clone)]
102struct BitOrUpdatable<T> {
103    _phantom: PhantomData<T>,
104}
105
106#[aggregate("bit_or(int2) -> int2", state = "int8[]", generic = "i16")]
107#[aggregate("bit_or(int4) -> int4", state = "int8[]", generic = "i32")]
108#[aggregate("bit_or(int8) -> int8", state = "int8[]", generic = "i64")]
109impl<T: Bits> BitOrUpdatable<T> {
110    // state is the number of 1s for each bit.
111
112    fn create_state(&self) -> ListValue {
113        ListValue::new(I64Array::from_iter(std::iter::repeat_n(0, T::BITS)).into())
114    }
115
116    fn accumulate(&self, mut state: ListValue, input: T) -> ListValue {
117        let counts = state.as_i64_mut_slice().expect("invalid state");
118        for (i, count) in counts.iter_mut().enumerate() {
119            if input.get_bit(i) {
120                *count += 1;
121            }
122        }
123        state
124    }
125
126    fn retract(&self, mut state: ListValue, input: T) -> ListValue {
127        let counts = state.as_i64_mut_slice().expect("invalid state");
128        for (i, count) in counts.iter_mut().enumerate() {
129            if input.get_bit(i) {
130                *count -= 1;
131            }
132        }
133        state
134    }
135
136    fn finalize(&self, state: ListRef<'_>) -> T {
137        let counts = state.as_i64_slice().expect("invalid state");
138        let mut result = T::default();
139        for (i, count) in counts.iter().enumerate() {
140            if *count != 0 {
141                result.set_bit(i);
142            }
143        }
144        result
145    }
146}