risingwave_stream/executor/
filter.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::{Array, ArrayImpl, Op};
16use risingwave_common::bitmap::BitmapBuilder;
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_expr::expr::NonStrictExpression;
19
20use crate::executor::prelude::*;
21
22/// `FilterExecutor` filters data with the `expr`. The `expr` takes a chunk of data,
23/// and returns a boolean array on whether each item should be retained. And then,
24/// `FilterExecutor` will insert, delete or update element into next executor according
25/// to the result of the expression.
26pub struct FilterExecutor {
27    _ctx: ActorContextRef,
28    input: Executor,
29
30    /// Expression of the current filter, note that the filter must always have the same output for
31    /// the same input.
32    expr: NonStrictExpression,
33}
34
35impl FilterExecutor {
36    pub fn new(ctx: ActorContextRef, input: Executor, expr: NonStrictExpression) -> Self {
37        Self {
38            _ctx: ctx,
39            input,
40            expr,
41        }
42    }
43
44    pub(super) fn filter(
45        chunk: StreamChunk,
46        filter: Arc<ArrayImpl>,
47    ) -> StreamExecutorResult<Option<StreamChunk>> {
48        let (data_chunk, ops) = chunk.into_parts();
49
50        let (columns, vis) = data_chunk.into_parts();
51
52        let n = ops.len();
53
54        // TODO: Can we update ops and visibility inplace?
55        let mut new_ops = Vec::with_capacity(n);
56        let mut new_visibility = BitmapBuilder::with_capacity(n);
57        let mut last_res = false;
58
59        assert_eq!(vis.len(), n);
60
61        let ArrayImpl::Bool(bool_array) = &*filter else {
62            panic!("unmatched type: filter expr returns a non-null array");
63        };
64        for (&op, res) in ops.iter().zip_eq_fast(bool_array.iter()) {
65            // SAFETY: ops.len() == pred_output.len() == visibility.len()
66            let res = res.unwrap_or(false);
67            match op {
68                Op::Insert | Op::Delete => {
69                    new_ops.push(op);
70                    if res {
71                        new_visibility.append(true);
72                    } else {
73                        new_visibility.append(false);
74                    }
75                }
76                Op::UpdateDelete => {
77                    last_res = res;
78                }
79                Op::UpdateInsert => match (last_res, res) {
80                    (true, false) => {
81                        new_ops.push(Op::Delete);
82                        new_ops.push(Op::UpdateInsert);
83                        new_visibility.append(true);
84                        new_visibility.append(false);
85                    }
86                    (false, true) => {
87                        new_ops.push(Op::UpdateDelete);
88                        new_ops.push(Op::Insert);
89                        new_visibility.append(false);
90                        new_visibility.append(true);
91                    }
92                    (true, true) => {
93                        new_ops.push(Op::UpdateDelete);
94                        new_ops.push(Op::UpdateInsert);
95                        new_visibility.append(true);
96                        new_visibility.append(true);
97                    }
98                    (false, false) => {
99                        new_ops.push(Op::UpdateDelete);
100                        new_ops.push(Op::UpdateInsert);
101                        new_visibility.append(false);
102                        new_visibility.append(false);
103                    }
104                },
105            }
106        }
107
108        let new_visibility = new_visibility.finish();
109
110        Ok(if new_visibility.count_ones() > 0 {
111            let new_chunk = StreamChunk::with_visibility(new_ops, columns, new_visibility);
112            Some(new_chunk)
113        } else {
114            None
115        })
116    }
117}
118
119impl Debug for FilterExecutor {
120    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("FilterExecutor")
122            .field("expr", &self.expr)
123            .finish()
124    }
125}
126
127impl Execute for FilterExecutor {
128    fn execute(self: Box<Self>) -> BoxedMessageStream {
129        self.execute_inner().boxed()
130    }
131}
132
133impl FilterExecutor {
134    #[try_stream(ok = Message, error = StreamExecutorError)]
135    async fn execute_inner(self) {
136        let input = self.input.execute();
137        #[for_await]
138        for msg in input {
139            let msg = msg?;
140            match msg {
141                Message::Watermark(w) => yield Message::Watermark(w),
142                Message::Chunk(chunk) => {
143                    let chunk = chunk.compact();
144
145                    let pred_output = self.expr.eval_infallible(chunk.data_chunk()).await;
146
147                    match Self::filter(chunk, pred_output)? {
148                        Some(new_chunk) => yield Message::Chunk(new_chunk),
149                        None => continue,
150                    }
151                }
152                m => yield m,
153            }
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
161    use risingwave_common::catalog::Field;
162
163    use super::super::test_utils::MockSource;
164    use super::super::test_utils::expr::build_from_pretty;
165    use super::super::*;
166    use super::*;
167
168    #[tokio::test]
169    async fn test_filter() {
170        let chunk1 = StreamChunk::from_pretty(
171            " I I
172            + 1 4
173            + 5 2
174            + 6 6
175            - 7 5",
176        );
177        let chunk2 = StreamChunk::from_pretty(
178            "  I I
179            U- 5 3  // true -> true
180            U+ 7 5  // expect UpdateDelete, UpdateInsert
181            U- 5 3  // true -> false
182            U+ 3 5  // expect Delete
183            U- 3 5  // false -> true
184            U+ 5 3  // expect Insert
185            U- 3 5  // false -> false
186            U+ 4 6  // expect nothing",
187        );
188        let schema = Schema {
189            fields: vec![
190                Field::unnamed(DataType::Int64),
191                Field::unnamed(DataType::Int64),
192            ],
193        };
194        let pk_indices = PkIndices::new();
195        let source = MockSource::with_chunks(vec![chunk1, chunk2])
196            .into_executor(schema.clone(), pk_indices.clone());
197
198        let test_expr = build_from_pretty("(greater_than:boolean $0:int8 $1:int8)");
199
200        let mut filter = FilterExecutor::new(ActorContext::for_test(123), source, test_expr)
201            .boxed()
202            .execute();
203
204        let chunk = filter.next().await.unwrap().unwrap().into_chunk().unwrap();
205        assert_eq!(
206            chunk,
207            StreamChunk::from_pretty(
208                " I I
209                + 1 4 D
210                + 5 2
211                + 6 6 D
212                - 7 5",
213            )
214        );
215
216        let chunk = filter.next().await.unwrap().unwrap().into_chunk().unwrap();
217        assert_eq!(
218            chunk,
219            StreamChunk::from_pretty(
220                "  I I
221                U- 5 3
222                U+ 7 5
223                -  5 3
224                U+ 3 5 D
225                U- 3 5 D
226                +  5 3
227                U- 3 5 D
228                U+ 4 6 D",
229            )
230        );
231
232        assert!(filter.next().await.unwrap().unwrap().is_stop());
233    }
234}