risingwave_stream/executor/
filter.rs1use risingwave_common::array::{Array, ArrayImpl, Op};
16use risingwave_common::bitmap::BitmapBuilder;
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_expr::expr::NonStrictExpression;
19
20use crate::executor::prelude::*;
21
22pub struct FilterExecutorInner<const UPSERT: bool> {
38    _ctx: ActorContextRef,
39    input: Executor,
40
41    expr: NonStrictExpression,
44}
45
46pub type FilterExecutor = FilterExecutorInner<false>;
47pub type UpsertFilterExecutor = FilterExecutorInner<true>;
48
49impl<const UPSERT: bool> FilterExecutorInner<UPSERT> {
50    pub fn new(ctx: ActorContextRef, input: Executor, expr: NonStrictExpression) -> Self {
51        Self {
52            _ctx: ctx,
53            input,
54            expr,
55        }
56    }
57
58    pub(super) fn filter(
59        chunk: StreamChunk,
60        filter: Arc<ArrayImpl>,
61    ) -> StreamExecutorResult<Option<StreamChunk>> {
62        let (data_chunk, ops) = chunk.into_parts();
63
64        let (columns, vis) = data_chunk.into_parts();
65
66        let n = ops.len();
67
68        let mut new_ops = Vec::with_capacity(n);
70        let mut new_visibility = BitmapBuilder::with_capacity(n);
71        let mut last_res = false;
72
73        assert_eq!(vis.len(), n);
74
75        let ArrayImpl::Bool(bool_array) = &*filter else {
76            panic!("unmatched type: filter expr returns a non-null array");
77        };
78        for (&op, res) in ops.iter().zip_eq_fast(bool_array.iter()) {
79            let res = res.unwrap_or(false);
81
82            if UPSERT {
83                match op {
84                    Op::Insert | Op::UpdateInsert => {
85                        if res {
86                            new_ops.push(Op::Insert);
90                            new_visibility.append(true);
91                        } else {
92                            new_ops.push(Op::Delete);
96                            new_visibility.append(true);
97                        }
98                    }
99                    Op::Delete | Op::UpdateDelete => {
100                        new_ops.push(Op::Delete);
104                        new_visibility.append(true);
105                    }
106                }
107            } else {
108                match op {
109                    Op::Insert | Op::Delete => {
110                        new_ops.push(op);
111                        new_visibility.append(res);
112                    }
113                    Op::UpdateDelete => {
114                        last_res = res;
115                    }
116                    Op::UpdateInsert => match (last_res, res) {
117                        (true, false) => {
118                            new_ops.push(Op::Delete);
119                            new_ops.push(Op::UpdateInsert);
120                            new_visibility.append(true);
121                            new_visibility.append(false);
122                        }
123                        (false, true) => {
124                            new_ops.push(Op::UpdateDelete);
125                            new_ops.push(Op::Insert);
126                            new_visibility.append(false);
127                            new_visibility.append(true);
128                        }
129                        (true, true) => {
130                            new_ops.push(Op::UpdateDelete);
131                            new_ops.push(Op::UpdateInsert);
132                            new_visibility.append(true);
133                            new_visibility.append(true);
134                        }
135                        (false, false) => {
136                            new_ops.push(Op::UpdateDelete);
137                            new_ops.push(Op::UpdateInsert);
138                            new_visibility.append(false);
139                            new_visibility.append(false);
140                        }
141                    },
142                }
143            }
144        }
145
146        let new_visibility = new_visibility.finish();
147
148        Ok(if new_visibility.count_ones() > 0 {
149            let new_chunk = StreamChunk::with_visibility(new_ops, columns, new_visibility);
150            Some(new_chunk)
151        } else {
152            None
153        })
154    }
155}
156
157impl<const UPSERT: bool> Debug for FilterExecutorInner<UPSERT> {
158    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("FilterExecutor")
160            .field("expr", &self.expr)
161            .field("upsert", &UPSERT)
162            .finish()
163    }
164}
165
166impl<const UPSERT: bool> Execute for FilterExecutorInner<UPSERT> {
167    fn execute(self: Box<Self>) -> BoxedMessageStream {
168        self.execute_inner().boxed()
169    }
170}
171
172impl<const UPSERT: bool> FilterExecutorInner<UPSERT> {
173    #[try_stream(ok = Message, error = StreamExecutorError)]
174    async fn execute_inner(self) {
175        let input = self.input.execute();
176        #[for_await]
177        for msg in input {
178            let msg = msg?;
179            match msg {
180                Message::Watermark(w) => yield Message::Watermark(w),
181                Message::Chunk(chunk) => {
182                    let chunk = chunk.compact_vis();
183
184                    let pred_output = self.expr.eval_infallible(chunk.data_chunk()).await;
185
186                    match Self::filter(chunk, pred_output)? {
187                        Some(new_chunk) => yield Message::Chunk(new_chunk),
188                        None => continue,
189                    }
190                }
191                m => yield m,
192            }
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
200    use risingwave_common::catalog::Field;
201
202    use super::super::test_utils::MockSource;
203    use super::super::test_utils::expr::build_from_pretty;
204    use super::super::*;
205    use super::*;
206
207    #[tokio::test]
208    async fn test_filter() {
209        let chunk1 = StreamChunk::from_pretty(
210            " I I
211            + 1 4
212            + 5 2
213            + 6 6
214            - 7 5",
215        );
216        let chunk2 = StreamChunk::from_pretty(
217            "  I I
218            U- 5 3  // true -> true
219            U+ 7 5  // expect UpdateDelete, UpdateInsert
220            U- 5 3  // true -> false
221            U+ 3 5  // expect Delete
222            U- 3 5  // false -> true
223            U+ 5 3  // expect Insert
224            U- 3 5  // false -> false
225            U+ 4 6  // expect nothing",
226        );
227        let schema = Schema {
228            fields: vec![
229                Field::unnamed(DataType::Int64),
230                Field::unnamed(DataType::Int64),
231            ],
232        };
233        let pk_indices = PkIndices::new();
234        let source = MockSource::with_chunks(vec![chunk1, chunk2])
235            .into_executor(schema.clone(), pk_indices.clone());
236
237        let test_expr = build_from_pretty("(greater_than:boolean $0:int8 $1:int8)");
238
239        let mut filter = FilterExecutor::new(ActorContext::for_test(123), source, test_expr)
240            .boxed()
241            .execute();
242
243        let chunk = filter.next().await.unwrap().unwrap().into_chunk().unwrap();
244        assert_eq!(
245            chunk,
246            StreamChunk::from_pretty(
247                " I I
248                + 1 4 D
249                + 5 2
250                + 6 6 D
251                - 7 5",
252            )
253        );
254
255        let chunk = filter.next().await.unwrap().unwrap().into_chunk().unwrap();
256        assert_eq!(
257            chunk,
258            StreamChunk::from_pretty(
259                "  I I
260                U- 5 3
261                U+ 7 5
262                -  5 3
263                U+ 3 5 D
264                U- 3 5 D
265                +  5 3
266                U- 3 5 D
267                U+ 4 6 D",
268            )
269        );
270
271        assert!(filter.next().await.unwrap().unwrap().is_stop());
272    }
273
274    #[tokio::test]
275    async fn test_upsert_filter() {
276        let chunk = StreamChunk::from_pretty(
277            " I  I
278            + 10 14
279            + 20 5
280            + 10 7
281            + 20 16
282            + 20 18
283            - 10 .
284            - 20 .
285            - 30 .
286            ",
287        );
288        let schema = Schema {
289            fields: vec![
290                Field::unnamed(DataType::Int64),
291                Field::unnamed(DataType::Int64),
292            ],
293        };
294        let pk_indices = vec![0];
295        let source =
296            MockSource::with_chunks(vec![chunk]).into_executor(schema.clone(), pk_indices.clone());
297        let test_expr = build_from_pretty("(greater_than:boolean $1:int8 10:int8)");
298        let mut filter = UpsertFilterExecutor::new(ActorContext::for_test(123), source, test_expr)
299            .boxed()
300            .execute();
301        let chunk = filter.next().await.unwrap().unwrap().into_chunk().unwrap();
302        assert_eq!(
303            chunk,
304            StreamChunk::from_pretty(
305                " I  I
306                + 10 14
307                - 20 5
308                - 10 7
309                + 20 16
310                + 20 18
311                - 10 .
312                - 20 .
313                - 30 .
314                ",
315            )
316        );
317    }
318}