risingwave_expr_impl/aggregate/bit_and.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::marker::PhantomData;
16use std::ops::BitAnd;
17
18use risingwave_common::array::I64Array;
19use risingwave_common::types::{ListRef, ListValue};
20use risingwave_expr::aggregate;
21
22/// Computes the bitwise AND of all non-null input values.
23///
24/// # Example
25///
26/// ```slt
27/// statement ok
28/// create table t (a int2, b int4, c int8);
29///
30/// query III
31/// select bit_and(a), bit_and(b), bit_and(c) from t;
32/// ----
33/// NULL NULL NULL
34///
35/// statement ok
36/// insert into t values
37/// (6, 6, 6),
38/// (3, 3, 3),
39/// (null, null, null);
40///
41/// query III
42/// select bit_and(a), bit_and(b), bit_and(c) from t;
43/// ----
44/// 2 2 2
45///
46/// statement ok
47/// drop table t;
48/// ```
49// XXX: state = "ref" is required so that
50// for the first non-null value, the state is set to that value.
51#[aggregate("bit_and(*int) -> auto", state = "ref")]
52fn bit_and_append_only<T>(state: T, input: T) -> T
53where
54 T: BitAnd<Output = T>,
55{
56 state.bitand(input)
57}
58
59/// Computes the bitwise AND of all non-null input values.
60///
61/// # Example
62///
63/// ```slt
64/// statement ok
65/// create table t (a int2, b int4, c int8);
66///
67/// statement ok
68/// create materialized view mv as
69/// select bit_and(a) a, bit_and(b) b, bit_and(c) c from t;
70///
71/// query III
72/// select * from mv;
73/// ----
74/// NULL NULL NULL
75///
76/// statement ok
77/// insert into t values
78/// (6, 6, 6),
79/// (3, 3, 3),
80/// (null, null, null);
81///
82/// query III
83/// select * from mv;
84/// ----
85/// 2 2 2
86///
87/// statement ok
88/// delete from t where a = 3;
89///
90/// query III
91/// select * from mv;
92/// ----
93/// 6 6 6
94///
95/// statement ok
96/// drop materialized view mv;
97///
98/// statement ok
99/// drop table t;
100/// ```
101#[derive(Debug, Default, Clone)]
102struct BitAndUpdatable<T> {
103 _phantom: PhantomData<T>,
104}
105
106#[aggregate("bit_and(int2) -> int2", state = "int8[]", generic = "i16")]
107#[aggregate("bit_and(int4) -> int4", state = "int8[]", generic = "i32")]
108#[aggregate("bit_and(int8) -> int8", state = "int8[]", generic = "i64")]
109impl<T: Bits> BitAndUpdatable<T> {
110 // state is the number of 0s for each bit.
111
112 fn create_state(&self) -> ListValue {
113 ListValue::new(I64Array::from_iter(std::iter::repeat_n(0, T::BITS)).into())
114 }
115
116 fn accumulate(&self, mut state: ListValue, input: T) -> ListValue {
117 let counts = state.as_i64_mut_slice().expect("invalid state");
118 for (i, count) in counts.iter_mut().enumerate() {
119 if !input.get_bit(i) {
120 *count += 1;
121 }
122 }
123 state
124 }
125
126 fn retract(&self, mut state: ListValue, input: T) -> ListValue {
127 let counts = state.as_i64_mut_slice().expect("invalid state");
128 for (i, count) in counts.iter_mut().enumerate() {
129 if !input.get_bit(i) {
130 *count -= 1;
131 }
132 }
133 state
134 }
135
136 fn finalize(&self, state: ListRef<'_>) -> T {
137 let counts = state.as_i64_slice().expect("invalid state");
138 let mut result = T::default();
139 for (i, count) in counts.iter().enumerate() {
140 if *count == 0 {
141 result.set_bit(i);
142 }
143 }
144 result
145 }
146}
147
148pub trait Bits: Default {
149 const BITS: usize;
150 fn get_bit(&self, i: usize) -> bool;
151 fn set_bit(&mut self, i: usize);
152}
153
154impl Bits for i16 {
155 const BITS: usize = 16;
156
157 fn get_bit(&self, i: usize) -> bool {
158 (*self >> i) & 1 == 1
159 }
160
161 fn set_bit(&mut self, i: usize) {
162 *self |= 1 << i;
163 }
164}
165
166impl Bits for i32 {
167 const BITS: usize = 32;
168
169 fn get_bit(&self, i: usize) -> bool {
170 (*self >> i) & 1 == 1
171 }
172
173 fn set_bit(&mut self, i: usize) {
174 *self |= 1 << i;
175 }
176}
177
178impl Bits for i64 {
179 const BITS: usize = 64;
180
181 fn get_bit(&self, i: usize) -> bool {
182 (*self >> i) & 1 == 1
183 }
184
185 fn set_bit(&mut self, i: usize) {
186 *self |= 1 << i;
187 }
188}