risingwave_stream/executor/aggregate/
distinct.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::marker::PhantomData;
17use std::sync::atomic::AtomicU64;
18
19use itertools::Itertools;
20use risingwave_common::array::{ArrayRef, Op};
21use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
22use risingwave_common::row::{self, CompactedRow, RowExt};
23use risingwave_common::util::iter_util::ZipEqFast;
24use risingwave_expr::aggregate::AggCall;
25
26use super::agg_group::GroupKey;
27use crate::cache::ManagedLruCache;
28use crate::common::metrics::MetricsInfo;
29use crate::executor::monitor::AggDistinctDedupMetrics;
30use crate::executor::prelude::*;
31
32type DedupCache = ManagedLruCache<CompactedRow, Box<[i64]>>;
33
34/// Deduplicater for one distinct column.
35struct ColumnDeduplicater<S: StateStore> {
36    cache: DedupCache,
37    metrics: AggDistinctDedupMetrics,
38    _phantom: PhantomData<S>,
39}
40
41impl<S: StateStore> ColumnDeduplicater<S> {
42    fn new(
43        distinct_col: usize,
44        dedup_table: &StateTable<S>,
45        watermark_sequence: Arc<AtomicU64>,
46        actor_ctx: &ActorContext,
47    ) -> Self {
48        let metrics_info = MetricsInfo::new(
49            actor_ctx.streaming_metrics.clone(),
50            dedup_table.table_id(),
51            actor_ctx.id,
52            format!("distinct dedup column {}", distinct_col),
53        );
54        let metrics = actor_ctx.streaming_metrics.new_agg_distinct_dedup_metrics(
55            dedup_table.table_id(),
56            actor_ctx.id,
57            actor_ctx.fragment_id,
58        );
59
60        Self {
61            cache: ManagedLruCache::unbounded(watermark_sequence, metrics_info),
62            metrics,
63            _phantom: PhantomData,
64        }
65    }
66
67    async fn dedup(
68        &mut self,
69        ops: &[Op],
70        column: &ArrayRef,
71        mut visibilities: Vec<&mut Bitmap>,
72        dedup_table: &mut StateTable<S>,
73        group_key: Option<&GroupKey>,
74    ) -> StreamExecutorResult<()> {
75        let n_calls = visibilities.len();
76
77        let mut prev_counts_map = HashMap::new(); // also serves as changeset
78
79        // inverted masks for visibilities, 1 means hidden, 0 means visible
80        let mut vis_masks_inv = (0..visibilities.len())
81            .map(|_| BitmapBuilder::zeroed(column.len()))
82            .collect_vec();
83        for (datum_idx, (op, datum)) in ops.iter().zip_eq_fast(column.iter()).enumerate() {
84            // skip if this item is hidden to all agg calls (this is likely to happen)
85            if !visibilities.iter().any(|vis| vis.is_set(datum_idx)) {
86                continue;
87            }
88
89            // get counts of the distinct key of all agg calls that distinct on this column
90            let row_prefix = group_key.map(GroupKey::table_row).chain(row::once(datum));
91            let table_pk = group_key.map(GroupKey::table_pk).chain(row::once(datum));
92            let cache_key =
93                CompactedRow::from(group_key.map(GroupKey::cache_key).chain(row::once(datum)));
94
95            self.metrics.agg_distinct_total_cache_count.inc();
96            // TODO(yuhao): avoid this `contains`.
97            // https://github.com/risingwavelabs/risingwave/issues/9233
98            let mut counts = if self.cache.contains(&cache_key) {
99                self.cache.get_mut(&cache_key).unwrap()
100            } else {
101                self.metrics.agg_distinct_cache_miss_count.inc();
102                // load from table into the cache
103                let counts = if let Some(counts_row) =
104                    dedup_table.get_row(&table_pk).await? as Option<OwnedRow>
105                {
106                    counts_row
107                        .iter()
108                        .map(|v| v.map_or(0, ScalarRefImpl::into_int64))
109                        .collect()
110                } else {
111                    // ensure there is a row in the dedup table for this distinct key
112                    dedup_table.insert(
113                        (&row_prefix).chain(row::repeat_n(Some(ScalarImpl::from(0i64)), n_calls)),
114                    );
115                    vec![0; n_calls].into_boxed_slice()
116                };
117                self.cache.put(cache_key.clone(), counts); // TODO(rc): can we avoid this clone?
118
119                self.cache.get_mut(&cache_key).unwrap()
120            };
121            debug_assert_eq!(counts.len(), visibilities.len());
122
123            // snapshot the counts as prev counts when first time seeing this distinct key
124            prev_counts_map
125                .entry(datum)
126                .or_insert_with(|| counts.to_owned());
127
128            match op {
129                Op::Insert | Op::UpdateInsert => {
130                    // iterate over vis of each distinct agg call, count up for visible datum
131                    for (i, vis) in visibilities.iter().enumerate() {
132                        if vis.is_set(datum_idx) {
133                            counts[i] += 1;
134                            if counts[i] > 1 {
135                                // duplicate, hide this one
136                                vis_masks_inv[i].set(datum_idx, true);
137                            }
138                        }
139                    }
140                }
141                Op::Delete | Op::UpdateDelete => {
142                    // iterate over vis of each distinct agg call, count down for visible datum
143                    for (i, vis) in visibilities.iter().enumerate() {
144                        if vis.is_set(datum_idx) {
145                            counts[i] -= 1;
146                            debug_assert!(counts[i] >= 0);
147                            if counts[i] > 0 {
148                                // still exists at least one duplicate, hide this one
149                                vis_masks_inv[i].set(datum_idx, true);
150                            }
151                        }
152                    }
153                }
154            }
155        }
156
157        // flush changes to dedup table
158        prev_counts_map
159            .into_iter()
160            .for_each(|(datum, prev_counts)| {
161                let row_prefix = group_key.map(GroupKey::table_row).chain(row::once(datum));
162                let cache_key =
163                    CompactedRow::from(group_key.map(GroupKey::cache_key).chain(row::once(datum)));
164                let new_counts = OwnedRow::new(
165                    self.cache
166                        .get(&cache_key)
167                        .expect("distinct key in `prev_counts_map` must also exist in `self.cache`")
168                        .iter()
169                        .map(|&v| Some(v.into()))
170                        .collect(),
171                );
172                let old_counts =
173                    OwnedRow::new(prev_counts.iter().map(|&v| Some(v.into())).collect());
174                dedup_table.update(row_prefix.chain(old_counts), row_prefix.chain(new_counts));
175            });
176
177        for (vis, vis_mask_inv) in visibilities.iter_mut().zip_eq(vis_masks_inv.into_iter()) {
178            // update visibility
179            **vis &= !vis_mask_inv.finish();
180        }
181
182        // if we determine to flush to the table when processing every chunk instead of barrier
183        // coming, we can evict all including current epoch data.
184        self.cache.evict();
185
186        Ok(())
187    }
188
189    /// Flush the deduplication table.
190    fn flush(&mut self, _dedup_table: &StateTable<S>) {
191        // TODO(rc): now we flush the table in `dedup` method.
192        // WARN: if you want to change to batching the write to table. please remember to change
193        // `self.cache.evict()` too.
194        self.cache.evict();
195
196        self.metrics
197            .agg_distinct_cached_entry_count
198            .set(self.cache.len() as i64);
199    }
200}
201
202/// # Safety
203///
204/// There must not be duplicate items in `indices`.
205unsafe fn get_many_mut_from_slice<'a, T>(slice: &'a mut [T], indices: &[usize]) -> Vec<&'a mut T> {
206    let mut res = Vec::with_capacity(indices.len());
207    let ptr = slice.as_mut_ptr();
208    for &idx in indices {
209        unsafe {
210            res.push(&mut *ptr.add(idx));
211        }
212    }
213    res
214}
215
216pub struct DistinctDeduplicater<S: StateStore> {
217    /// Key: distinct column index;
218    /// Value: (agg call indices that distinct on the column, deduplicater for the column).
219    deduplicaters: HashMap<usize, (Box<[usize]>, ColumnDeduplicater<S>)>,
220}
221
222impl<S: StateStore> DistinctDeduplicater<S> {
223    pub fn new(
224        agg_calls: &[AggCall],
225        watermark_epoch: Arc<AtomicU64>,
226        distinct_dedup_tables: &HashMap<usize, StateTable<S>>,
227        actor_ctx: &ActorContext,
228    ) -> Self {
229        let deduplicaters: HashMap<_, _> = agg_calls
230            .iter()
231            .enumerate()
232            .filter(|(_, call)| call.distinct) // only distinct agg calls need dedup table
233            .into_group_map_by(|(_, call)| call.args.val_indices()[0])
234            .into_iter()
235            .map(|(distinct_col, indices_and_calls)| {
236                let dedup_table = distinct_dedup_tables.get(&distinct_col).unwrap();
237                let call_indices: Box<[_]> = indices_and_calls.into_iter().map(|v| v.0).collect();
238                let deduplicater = ColumnDeduplicater::new(
239                    distinct_col,
240                    dedup_table,
241                    watermark_epoch.clone(),
242                    actor_ctx,
243                );
244                (distinct_col, (call_indices, deduplicater))
245            })
246            .collect();
247        Self { deduplicaters }
248    }
249
250    pub fn dedup_caches_mut(&mut self) -> impl Iterator<Item = &mut DedupCache> {
251        self.deduplicaters
252            .values_mut()
253            .map(|(_, deduplicater)| &mut deduplicater.cache)
254    }
255
256    /// Deduplicate the chunk for each agg call, by returning new visibilities
257    /// that hide duplicate rows.
258    pub async fn dedup_chunk(
259        &mut self,
260        ops: &[Op],
261        columns: &[ArrayRef],
262        mut visibilities: Vec<Bitmap>,
263        dedup_tables: &mut HashMap<usize, StateTable<S>>,
264        group_key: Option<&GroupKey>,
265    ) -> StreamExecutorResult<Vec<Bitmap>> {
266        for (distinct_col, (call_indices, deduplicater)) in &mut self.deduplicaters {
267            let column = &columns[*distinct_col];
268            let dedup_table = dedup_tables.get_mut(distinct_col).unwrap();
269            // Select visibilities (as mutable references) of distinct agg calls that distinct on
270            // `distinct_col` so that `Deduplicater` doesn't need to care about index mapping.
271            // SAFETY: all items in `agg_call_indices` are unique by nature, see `new`.
272            let visibilities = unsafe { get_many_mut_from_slice(&mut visibilities, call_indices) };
273            deduplicater
274                .dedup(ops, column, visibilities, dedup_table, group_key)
275                .await?;
276        }
277        Ok(visibilities)
278    }
279
280    /// Flush dedup state caches to dedup tables.
281    pub fn flush(
282        &mut self,
283        dedup_tables: &mut HashMap<usize, StateTable<S>>,
284    ) -> StreamExecutorResult<()> {
285        for (distinct_col, (_, deduplicater)) in &mut self.deduplicaters {
286            let dedup_table = dedup_tables.get_mut(distinct_col).unwrap();
287            deduplicater.flush(dedup_table);
288        }
289        Ok(())
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId};
296    use risingwave_common::test_prelude::StreamChunkTestExt;
297    use risingwave_common::util::epoch::{EpochPair, test_epoch};
298    use risingwave_common::util::sort_util::OrderType;
299    use risingwave_storage::memory::MemoryStateStore;
300
301    use super::*;
302    use crate::common::table::test_utils::gen_pbtable_with_value_indices;
303
304    async fn infer_dedup_tables<S: StateStore>(
305        agg_calls: &[AggCall],
306        group_key_types: &[DataType],
307        store: S,
308    ) -> HashMap<usize, StateTable<S>> {
309        // corresponding to `Agg::infer_distinct_dedup_table` in frontend
310        let mut dedup_tables = HashMap::new();
311
312        for (distinct_col, indices_and_calls) in agg_calls
313            .iter()
314            .enumerate()
315            .filter(|(_, call)| call.distinct) // only distinct agg calls need dedup table
316            .into_group_map_by(|(_, call)| call.args.val_indices()[0])
317        {
318            let mut columns = vec![];
319            let mut order_types = vec![];
320
321            let mut next_column_id = 0;
322            let mut add_column_desc = |data_type: DataType| {
323                columns.push(ColumnDesc::unnamed(
324                    ColumnId::new(next_column_id),
325                    data_type,
326                ));
327                next_column_id += 1;
328            };
329
330            // group key columns
331            for data_type in group_key_types {
332                add_column_desc(data_type.clone());
333                order_types.push(OrderType::ascending());
334            }
335
336            // distinct key column
337            add_column_desc(indices_and_calls[0].1.args.arg_types()[0].clone());
338            order_types.push(OrderType::ascending());
339
340            // count columns
341            for (_, _) in indices_and_calls {
342                add_column_desc(DataType::Int64);
343            }
344
345            let pk_indices = (0..(group_key_types.len() + 1)).collect::<Vec<_>>();
346            let value_indices = ((group_key_types.len() + 1)..columns.len()).collect::<Vec<_>>();
347            let table_pb = gen_pbtable_with_value_indices(
348                TableId::new(2333 + distinct_col as u32),
349                columns,
350                order_types,
351                pk_indices,
352                0,
353                value_indices,
354            );
355            let table = StateTable::from_table_catalog(&table_pb, store.clone(), None).await;
356            dedup_tables.insert(distinct_col, table);
357        }
358
359        dedup_tables
360    }
361
362    #[tokio::test]
363    async fn test_distinct_deduplicater() {
364        // Schema:
365        // a: int, b int, c int
366        // Agg calls:
367        // count(a), count(distinct a), sum(distinct a), count(distinct b)
368        // Group keys:
369        // empty
370
371        let agg_calls = [
372            AggCall::from_pretty("(count:int8 $0:int8)"), // count(a)
373            AggCall::from_pretty("(count:int8 $0:int8 distinct)"), // count(distinct a)
374            AggCall::from_pretty("(  sum:int8 $0:int8 distinct)"), // sum(distinct a)
375            AggCall::from_pretty("(count:int8 $1:int8 distinct)"), // count(distinct b)
376        ];
377
378        let store = MemoryStateStore::new();
379        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
380        let mut dedup_tables = infer_dedup_tables(&agg_calls, &[], store).await;
381        for table in dedup_tables.values_mut() {
382            table.init_epoch(epoch).await.unwrap()
383        }
384
385        let mut deduplicater = DistinctDeduplicater::new(
386            &agg_calls,
387            Arc::new(AtomicU64::new(0)),
388            &dedup_tables,
389            &ActorContext::for_test(0),
390        );
391
392        // --- chunk 1 ---
393
394        let chunk = StreamChunk::from_pretty(
395            " I   I     I
396            + 1  10   100
397            + 1  11   101",
398        );
399        let (ops, columns, visibility) = chunk.into_inner();
400
401        let visibilities = std::iter::repeat_n(visibility, agg_calls.len()).collect_vec();
402        let visibilities = deduplicater
403            .dedup_chunk(&ops, &columns, visibilities, &mut dedup_tables, None)
404            .await
405            .unwrap();
406        assert_eq!(
407            visibilities[0].iter().collect_vec(),
408            vec![true, true] // same as original chunk
409        );
410        assert_eq!(
411            visibilities[1].iter().collect_vec(),
412            vec![true, false] // distinct on a
413        );
414        assert_eq!(
415            visibilities[2].iter().collect_vec(),
416            vec![true, false] // distinct on a, same as above
417        );
418        assert_eq!(
419            visibilities[3].iter().collect_vec(),
420            vec![true, true] // distinct on b
421        );
422
423        deduplicater.flush(&mut dedup_tables).unwrap();
424
425        epoch.inc_for_test();
426        for table in dedup_tables.values_mut() {
427            table.commit_for_test(epoch).await.unwrap();
428        }
429
430        // --- chunk 2 ---
431
432        let chunk = StreamChunk::from_pretty(
433            " I   I     I
434            + 1  11  -102
435            + 2  12   103  D
436            + 2  12  -104",
437        );
438        let (ops, columns, visibility) = chunk.into_inner();
439
440        let visibilities = std::iter::repeat_n(visibility, agg_calls.len()).collect_vec();
441        let visibilities = deduplicater
442            .dedup_chunk(&ops, &columns, visibilities, &mut dedup_tables, None)
443            .await
444            .unwrap();
445        assert_eq!(
446            visibilities[0].iter().collect_vec(),
447            vec![true, false, true] // same as original chunk
448        );
449        assert_eq!(
450            visibilities[1].iter().collect_vec(),
451            vec![false, false, true] // distinct on a
452        );
453        assert_eq!(
454            visibilities[2].iter().collect_vec(),
455            vec![false, false, true] // distinct on a, same as above
456        );
457        assert_eq!(
458            visibilities[3].iter().collect_vec(),
459            vec![false, false, true] // distinct on b
460        );
461
462        deduplicater.flush(&mut dedup_tables).unwrap();
463
464        epoch.inc_for_test();
465        for table in dedup_tables.values_mut() {
466            table.commit_for_test(epoch).await.unwrap();
467        }
468
469        drop(deduplicater);
470
471        // test recovery
472        let mut deduplicater = DistinctDeduplicater::new(
473            &agg_calls,
474            Arc::new(AtomicU64::new(0)),
475            &dedup_tables,
476            &ActorContext::for_test(0),
477        );
478
479        // --- chunk 3 ---
480
481        let chunk = StreamChunk::from_pretty(
482            " I   I     I
483            - 1  10   100  D
484            - 1  11   101
485            - 1  11  -102",
486        );
487        let (ops, columns, visibility) = chunk.into_inner();
488
489        let visibilities = std::iter::repeat_n(visibility, agg_calls.len()).collect_vec();
490        let visibilities = deduplicater
491            .dedup_chunk(&ops, &columns, visibilities, &mut dedup_tables, None)
492            .await
493            .unwrap();
494        assert_eq!(
495            visibilities[0].iter().collect_vec(),
496            vec![false, true, true] // same as original chunk
497        );
498        assert_eq!(
499            visibilities[1].iter().collect_vec(),
500            // distinct on a
501            vec![
502                false, // hidden in original chunk
503                false, // not the last one
504                false, // not the last one
505            ]
506        );
507        assert_eq!(
508            visibilities[2].iter().collect_vec(),
509            // distinct on a, same as above
510            vec![
511                false, // hidden in original chunk
512                false, // not the last one
513                false, // not the last one
514            ]
515        );
516        assert_eq!(
517            visibilities[3].iter().collect_vec(),
518            // distinct on b
519            vec![
520                false, // hidden in original chunk
521                false, // not the last one
522                true,  // is the last one
523            ]
524        );
525
526        deduplicater.flush(&mut dedup_tables).unwrap();
527
528        epoch.inc_for_test();
529        for table in dedup_tables.values_mut() {
530            table.commit_for_test(epoch).await.unwrap();
531        }
532    }
533
534    #[tokio::test]
535    async fn test_distinct_deduplicater_with_group() {
536        // Schema:
537        // a: int, b int, c int
538        // Agg calls:
539        // count(a), count(distinct a), count(distinct b)
540        // Group keys:
541        // c
542
543        let agg_calls = [
544            AggCall::from_pretty("(count:int8 $0:int8)"), // count(a)
545            AggCall::from_pretty("(count:int8 $0:int8 distinct)"), // count(distinct a)
546            AggCall::from_pretty("(count:int8 $1:int8 distinct)"), // count(distinct b)
547        ];
548
549        let group_key_types = [DataType::Int64];
550        let group_key = GroupKey::new(OwnedRow::new(vec![Some(100.into())]), None);
551
552        let store = MemoryStateStore::new();
553        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
554        let mut dedup_tables = infer_dedup_tables(&agg_calls, &group_key_types, store).await;
555        for table in dedup_tables.values_mut() {
556            table.init_epoch(epoch).await.unwrap()
557        }
558
559        let mut deduplicater = DistinctDeduplicater::new(
560            &agg_calls,
561            Arc::new(AtomicU64::new(0)),
562            &dedup_tables,
563            &ActorContext::for_test(0),
564        );
565
566        let chunk = StreamChunk::from_pretty(
567            " I   I     I
568            + 1  10   100
569            + 1  11   100
570            + 1  11   100
571            + 2  12   200  D
572            + 2  12   100",
573        );
574        let (ops, columns, visibility) = chunk.into_inner();
575
576        let visibilities = std::iter::repeat_n(visibility, agg_calls.len()).collect_vec();
577        let visibilities = deduplicater
578            .dedup_chunk(
579                &ops,
580                &columns,
581                visibilities,
582                &mut dedup_tables,
583                Some(&group_key),
584            )
585            .await
586            .unwrap();
587        assert_eq!(
588            visibilities[0].iter().collect_vec(),
589            vec![true, true, true, false, true] // same as original chunk
590        );
591        assert_eq!(
592            visibilities[1].iter().collect_vec(),
593            vec![true, false, false, false, true] // distinct on a
594        );
595        assert_eq!(
596            visibilities[2].iter().collect_vec(),
597            vec![true, true, false, false, true] // distinct on b
598        );
599
600        deduplicater.flush(&mut dedup_tables).unwrap();
601
602        epoch.inc_for_test();
603        for table in dedup_tables.values_mut() {
604            table.commit_for_test(epoch).await.unwrap();
605        }
606
607        let chunk = StreamChunk::from_pretty(
608            " I   I     I
609            - 1  10   100  D
610            - 1  11   100
611            - 1  11   100",
612        );
613        let (ops, columns, visibility) = chunk.into_inner();
614
615        let visibilities = std::iter::repeat_n(visibility, agg_calls.len()).collect_vec();
616        let visibilities = deduplicater
617            .dedup_chunk(
618                &ops,
619                &columns,
620                visibilities,
621                &mut dedup_tables,
622                Some(&group_key),
623            )
624            .await
625            .unwrap();
626        assert_eq!(
627            visibilities[0].iter().collect_vec(),
628            vec![false, true, true] // same as original chunk
629        );
630        assert_eq!(
631            visibilities[1].iter().collect_vec(),
632            // distinct on a
633            vec![
634                false, // hidden in original chunk
635                false, // not the last one
636                false, // not the last one
637            ]
638        );
639        assert_eq!(
640            visibilities[2].iter().collect_vec(),
641            // distinct on b
642            vec![
643                false, // hidden in original chunk
644                false, // not the last one
645                true,  // is the last one
646            ]
647        );
648
649        deduplicater.flush(&mut dedup_tables).unwrap();
650
651        epoch.inc_for_test();
652        for table in dedup_tables.values_mut() {
653            table.commit_for_test(epoch).await.unwrap();
654        }
655    }
656}