risingwave_expr_impl/aggregate/bit_or.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::marker::PhantomData;
16use std::ops::BitOr;
17
18use risingwave_common::array::I64Array;
19use risingwave_common::types::{ListRef, ListValue};
20use risingwave_expr::aggregate;
21
22use super::bit_and::Bits;
23
24/// Computes the bitwise OR of all non-null input values.
25///
26/// # Example
27///
28/// ```slt
29/// statement ok
30/// create table t (a int2, b int4, c int8);
31///
32/// query III
33/// select bit_or(a), bit_or(b), bit_or(c) from t;
34/// ----
35/// NULL NULL NULL
36///
37/// statement ok
38/// insert into t values
39/// (1, 1, 1),
40/// (2, 2, 2),
41/// (null, null, null);
42///
43/// query III
44/// select bit_or(a), bit_or(b), bit_or(c) from t;
45/// ----
46/// 3 3 3
47///
48/// statement ok
49/// drop table t;
50/// ```
51#[aggregate("bit_or(*int) -> auto")]
52fn bit_or_append_only<T>(state: T, input: T) -> T
53where
54 T: BitOr<Output = T>,
55{
56 state.bitor(input)
57}
58
59/// Computes the bitwise OR of all non-null input values.
60///
61/// # Example
62///
63/// ```slt
64/// statement ok
65/// create table t (a int2, b int4, c int8);
66///
67/// statement ok
68/// create materialized view mv as
69/// select bit_or(a) a, bit_or(b) b, bit_or(c) c from t;
70///
71/// query III
72/// select * from mv;
73/// ----
74/// NULL NULL NULL
75///
76/// statement ok
77/// insert into t values
78/// (6, 6, 6),
79/// (3, 3, 3),
80/// (null, null, null);
81///
82/// query III
83/// select * from mv;
84/// ----
85/// 7 7 7
86///
87/// statement ok
88/// delete from t where a = 3;
89///
90/// query III
91/// select * from mv;
92/// ----
93/// 6 6 6
94///
95/// statement ok
96/// drop materialized view mv;
97///
98/// statement ok
99/// drop table t;
100/// ```
101#[derive(Debug, Default, Clone)]
102struct BitOrUpdatable<T> {
103 _phantom: PhantomData<T>,
104}
105
106#[aggregate("bit_or(int2) -> int2", state = "int8[]", generic = "i16")]
107#[aggregate("bit_or(int4) -> int4", state = "int8[]", generic = "i32")]
108#[aggregate("bit_or(int8) -> int8", state = "int8[]", generic = "i64")]
109impl<T: Bits> BitOrUpdatable<T> {
110 // state is the number of 1s for each bit.
111
112 fn create_state(&self) -> ListValue {
113 ListValue::new(I64Array::from_iter(std::iter::repeat_n(0, T::BITS)).into())
114 }
115
116 fn accumulate(&self, mut state: ListValue, input: T) -> ListValue {
117 let counts = state.as_i64_mut_slice().expect("invalid state");
118 for (i, count) in counts.iter_mut().enumerate() {
119 if input.get_bit(i) {
120 *count += 1;
121 }
122 }
123 state
124 }
125
126 fn retract(&self, mut state: ListValue, input: T) -> ListValue {
127 let counts = state.as_i64_mut_slice().expect("invalid state");
128 for (i, count) in counts.iter_mut().enumerate() {
129 if input.get_bit(i) {
130 *count -= 1;
131 }
132 }
133 state
134 }
135
136 fn finalize(&self, state: ListRef<'_>) -> T {
137 let counts = state.as_i64_slice().expect("invalid state");
138 let mut result = T::default();
139 for (i, count) in counts.iter().enumerate() {
140 if *count != 0 {
141 result.set_bit(i);
142 }
143 }
144 result
145 }
146}