risingwave_stream/executor/aggregate/
agg_state_cache.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Object-safe version of [`StateCache`] for aggregation.
16
17use risingwave_common::array::StreamChunk;
18use risingwave_common::row::Row;
19use risingwave_common::types::{DataType, Datum, ToOwnedDatum};
20use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
21use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
22use risingwave_common::util::row_serde::OrderedRowSerde;
23use risingwave_common_estimate_size::EstimateSize;
24use smallvec::SmallVec;
25
26use crate::common::state_cache::{StateCache, StateCacheFiller};
27
28/// Cache key type.
29type CacheKey = MemcmpEncoded;
30
31#[derive(Debug)]
32pub struct CacheValue(SmallVec<[Datum; 2]>);
33
34/// Trait that defines the interface of state table cache for stateful streaming agg.
35pub trait AggStateCache: EstimateSize {
36    /// Check if the cache is synced with state table.
37    fn is_synced(&self) -> bool;
38
39    /// Apply a batch of updates to the cache.
40    fn apply_batch(
41        &mut self,
42        chunk: &StreamChunk,
43        cache_key_serializer: &OrderedRowSerde,
44        arg_col_indices: &[usize],
45        order_col_indices: &[usize],
46    );
47
48    /// Begin syncing the cache with state table.
49    fn begin_syncing(&mut self) -> Box<dyn AggStateCacheFiller + Send + Sync + '_>;
50
51    /// Output batches from the cache.
52    fn output_batches(&self, chunk_size: usize) -> Box<dyn Iterator<Item = StreamChunk> + '_>;
53
54    /// Output the first value.
55    fn output_first(&self) -> Datum;
56}
57
58/// Trait that defines agg state cache syncing interface.
59pub trait AggStateCacheFiller {
60    /// Get the capacity of the cache to be filled. `None` means unlimited.
61    fn capacity(&self) -> Option<usize>;
62
63    /// Insert an entry to the cache without checking row count, capacity, key order, etc.
64    /// Just insert into the inner cache structure, e.g. `BTreeMap`.
65    fn append(&mut self, key: CacheKey, value: CacheValue);
66
67    /// Mark the cache as synced.
68    fn finish(self: Box<Self>);
69}
70
71/// A wrapper over generic [`StateCache`] that implements [`AggStateCache`].
72#[derive(EstimateSize)]
73pub struct GenericAggStateCache<C>
74where
75    C: StateCache<Key = CacheKey, Value = CacheValue>,
76{
77    state_cache: C,
78    input_types: Vec<DataType>,
79}
80
81impl<C> GenericAggStateCache<C>
82where
83    C: StateCache<Key = CacheKey, Value = CacheValue>,
84{
85    pub fn new(state_cache: C, input_types: &[DataType]) -> Self {
86        Self {
87            state_cache,
88            input_types: input_types.to_vec(),
89        }
90    }
91}
92
93impl<C> AggStateCache for GenericAggStateCache<C>
94where
95    C: StateCache<Key = CacheKey, Value = CacheValue>,
96    for<'a> C::Filler<'a>: Send + Sync,
97{
98    fn is_synced(&self) -> bool {
99        self.state_cache.is_synced()
100    }
101
102    fn apply_batch(
103        &mut self,
104        chunk: &StreamChunk,
105        cache_key_serializer: &OrderedRowSerde,
106        arg_col_indices: &[usize],
107        order_col_indices: &[usize],
108    ) {
109        let rows = chunk.rows().map(|(op, row)| {
110            let key = {
111                let mut key = Vec::new();
112                cache_key_serializer.serialize_datums(
113                    order_col_indices
114                        .iter()
115                        .map(|col_idx| row.datum_at(*col_idx)),
116                    &mut key,
117                );
118                key.into()
119            };
120            let value = CacheValue(
121                arg_col_indices
122                    .iter()
123                    .map(|col_idx| row.datum_at(*col_idx).to_owned_datum())
124                    .collect(),
125            );
126            (op, key, value)
127        });
128        self.state_cache.apply_batch(rows);
129    }
130
131    fn begin_syncing(&mut self) -> Box<dyn AggStateCacheFiller + Send + Sync + '_> {
132        Box::new(GenericAggStateCacheFiller::<'_, C> {
133            cache_filler: self.state_cache.begin_syncing(),
134        })
135    }
136
137    fn output_batches(&self, chunk_size: usize) -> Box<dyn Iterator<Item = StreamChunk> + '_> {
138        let mut values = self.state_cache.values();
139        Box::new(std::iter::from_fn(move || {
140            // build data chunk from rows
141            let mut builder = DataChunkBuilder::new(self.input_types.clone(), chunk_size);
142            for row in &mut values {
143                if let Some(chunk) = builder.append_one_row(row.0.as_slice()) {
144                    return Some(chunk.into());
145                }
146            }
147            builder.consume_all().map(|chunk| chunk.into())
148        }))
149    }
150
151    fn output_first(&self) -> Datum {
152        let value = self.state_cache.values().next()?;
153        value.0[0].clone()
154    }
155}
156
157pub struct GenericAggStateCacheFiller<'filler, C>
158where
159    C: StateCache<Key = CacheKey, Value = CacheValue> + 'filler,
160{
161    cache_filler: C::Filler<'filler>,
162}
163
164impl<C> AggStateCacheFiller for GenericAggStateCacheFiller<'_, C>
165where
166    C: StateCache<Key = CacheKey, Value = CacheValue>,
167{
168    fn capacity(&self) -> Option<usize> {
169        self.cache_filler.capacity()
170    }
171
172    fn append(&mut self, key: CacheKey, value: CacheValue) {
173        self.cache_filler.insert_unchecked(key, value);
174    }
175
176    fn finish(self: Box<Self>) {
177        self.cache_filler.finish()
178    }
179}
180
181impl FromIterator<Datum> for CacheValue {
182    fn from_iter<T: IntoIterator<Item = Datum>>(iter: T) -> Self {
183        Self(iter.into_iter().collect())
184    }
185}
186
187impl EstimateSize for CacheValue {
188    fn estimated_heap_size(&self) -> usize {
189        let data_heap_size: usize = self.0.iter().map(|datum| datum.estimated_heap_size()).sum();
190        self.0.len() * std::mem::size_of::<Datum>() + data_heap_size
191    }
192}