risingwave_expr_impl/aggregate/
bit_and.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::marker::PhantomData;
16use std::ops::BitAnd;
17
18use risingwave_common::array::I64Array;
19use risingwave_common::types::{ListRef, ListValue};
20use risingwave_expr::aggregate;
21
22/// Computes the bitwise AND of all non-null input values.
23///
24/// # Example
25///
26/// ```slt
27/// statement ok
28/// create table t (a int2, b int4, c int8);
29///
30/// query III
31/// select bit_and(a), bit_and(b), bit_and(c) from t;
32/// ----
33/// NULL NULL NULL
34///
35/// statement ok
36/// insert into t values
37///    (6, 6, 6),
38///    (3, 3, 3),
39///    (null, null, null);
40///
41/// query III
42/// select bit_and(a), bit_and(b), bit_and(c) from t;
43/// ----
44/// 2 2 2
45///
46/// statement ok
47/// drop table t;
48/// ```
49// XXX: state = "ref" is required so that
50// for the first non-null value, the state is set to that value.
51#[aggregate("bit_and(*int) -> auto", state = "ref")]
52fn bit_and_append_only<T>(state: T, input: T) -> T
53where
54    T: BitAnd<Output = T>,
55{
56    state.bitand(input)
57}
58
59/// Computes the bitwise AND of all non-null input values.
60///
61/// # Example
62///
63/// ```slt
64/// statement ok
65/// create table t (a int2, b int4, c int8);
66///
67/// statement ok
68/// create materialized view mv as
69/// select bit_and(a) a, bit_and(b) b, bit_and(c) c from t;
70///
71/// query III
72/// select * from mv;
73/// ----
74/// NULL NULL NULL
75///
76/// statement ok
77/// insert into t values
78///    (6, 6, 6),
79///    (3, 3, 3),
80///    (null, null, null);
81///
82/// query III
83/// select * from mv;
84/// ----
85/// 2 2 2
86///
87/// statement ok
88/// delete from t where a = 3;
89///
90/// query III
91/// select * from mv;
92/// ----
93/// 6 6 6
94///
95/// statement ok
96/// drop materialized view mv;
97///
98/// statement ok
99/// drop table t;
100/// ```
101#[derive(Debug, Default, Clone)]
102struct BitAndUpdatable<T> {
103    _phantom: PhantomData<T>,
104}
105
106#[aggregate("bit_and(int2) -> int2", state = "int8[]", generic = "i16")]
107#[aggregate("bit_and(int4) -> int4", state = "int8[]", generic = "i32")]
108#[aggregate("bit_and(int8) -> int8", state = "int8[]", generic = "i64")]
109impl<T: Bits> BitAndUpdatable<T> {
110    // state is the number of 0s for each bit.
111
112    fn create_state(&self) -> ListValue {
113        ListValue::new(I64Array::from_iter(std::iter::repeat_n(0, T::BITS)).into())
114    }
115
116    fn accumulate(&self, mut state: ListValue, input: T) -> ListValue {
117        let counts = state.as_i64_mut_slice().expect("invalid state");
118        for (i, count) in counts.iter_mut().enumerate() {
119            if !input.get_bit(i) {
120                *count += 1;
121            }
122        }
123        state
124    }
125
126    fn retract(&self, mut state: ListValue, input: T) -> ListValue {
127        let counts = state.as_i64_mut_slice().expect("invalid state");
128        for (i, count) in counts.iter_mut().enumerate() {
129            if !input.get_bit(i) {
130                *count -= 1;
131            }
132        }
133        state
134    }
135
136    fn finalize(&self, state: ListRef<'_>) -> T {
137        let counts = state.as_i64_slice().expect("invalid state");
138        let mut result = T::default();
139        for (i, count) in counts.iter().enumerate() {
140            if *count == 0 {
141                result.set_bit(i);
142            }
143        }
144        result
145    }
146}
147
148pub trait Bits: Default {
149    const BITS: usize;
150    fn get_bit(&self, i: usize) -> bool;
151    fn set_bit(&mut self, i: usize);
152}
153
154impl Bits for i16 {
155    const BITS: usize = 16;
156
157    fn get_bit(&self, i: usize) -> bool {
158        (*self >> i) & 1 == 1
159    }
160
161    fn set_bit(&mut self, i: usize) {
162        *self |= 1 << i;
163    }
164}
165
166impl Bits for i32 {
167    const BITS: usize = 32;
168
169    fn get_bit(&self, i: usize) -> bool {
170        (*self >> i) & 1 == 1
171    }
172
173    fn set_bit(&mut self, i: usize) {
174        *self |= 1 << i;
175    }
176}
177
178impl Bits for i64 {
179    const BITS: usize = 64;
180
181    fn get_bit(&self, i: usize) -> bool {
182        (*self >> i) & 1 == 1
183    }
184
185    fn set_bit(&mut self, i: usize) {
186        *self |= 1 << i;
187    }
188}