risingwave_stream/executor/aggregate/
agg_state_cache.rs1use risingwave_common::array::StreamChunk;
18use risingwave_common::row::Row;
19use risingwave_common::types::{DataType, Datum, ToOwnedDatum};
20use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
21use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
22use risingwave_common::util::row_serde::OrderedRowSerde;
23use risingwave_common_estimate_size::EstimateSize;
24use smallvec::SmallVec;
25
26use crate::common::state_cache::{StateCache, StateCacheFiller};
27
28type CacheKey = MemcmpEncoded;
30
31#[derive(Debug)]
32pub struct CacheValue(SmallVec<[Datum; 2]>);
33
34pub trait AggStateCache: EstimateSize {
36 fn is_synced(&self) -> bool;
38
39 fn apply_batch(
41 &mut self,
42 chunk: &StreamChunk,
43 cache_key_serializer: &OrderedRowSerde,
44 arg_col_indices: &[usize],
45 order_col_indices: &[usize],
46 );
47
48 fn begin_syncing(&mut self) -> Box<dyn AggStateCacheFiller + Send + Sync + '_>;
50
51 fn output_batches(&self, chunk_size: usize) -> Box<dyn Iterator<Item = StreamChunk> + '_>;
53
54 fn output_first(&self) -> Datum;
56}
57
58pub trait AggStateCacheFiller {
60 fn capacity(&self) -> Option<usize>;
62
63 fn append(&mut self, key: CacheKey, value: CacheValue);
66
67 fn finish(self: Box<Self>);
69}
70
71#[derive(EstimateSize)]
73pub struct GenericAggStateCache<C>
74where
75 C: StateCache<Key = CacheKey, Value = CacheValue>,
76{
77 state_cache: C,
78 input_types: Vec<DataType>,
79}
80
81impl<C> GenericAggStateCache<C>
82where
83 C: StateCache<Key = CacheKey, Value = CacheValue>,
84{
85 pub fn new(state_cache: C, input_types: &[DataType]) -> Self {
86 Self {
87 state_cache,
88 input_types: input_types.to_vec(),
89 }
90 }
91}
92
93impl<C> AggStateCache for GenericAggStateCache<C>
94where
95 C: StateCache<Key = CacheKey, Value = CacheValue>,
96 for<'a> C::Filler<'a>: Send + Sync,
97{
98 fn is_synced(&self) -> bool {
99 self.state_cache.is_synced()
100 }
101
102 fn apply_batch(
103 &mut self,
104 chunk: &StreamChunk,
105 cache_key_serializer: &OrderedRowSerde,
106 arg_col_indices: &[usize],
107 order_col_indices: &[usize],
108 ) {
109 let rows = chunk.rows().map(|(op, row)| {
110 let key = {
111 let mut key = Vec::new();
112 cache_key_serializer.serialize_datums(
113 order_col_indices
114 .iter()
115 .map(|col_idx| row.datum_at(*col_idx)),
116 &mut key,
117 );
118 key.into()
119 };
120 let value = CacheValue(
121 arg_col_indices
122 .iter()
123 .map(|col_idx| row.datum_at(*col_idx).to_owned_datum())
124 .collect(),
125 );
126 (op, key, value)
127 });
128 self.state_cache.apply_batch(rows);
129 }
130
131 fn begin_syncing(&mut self) -> Box<dyn AggStateCacheFiller + Send + Sync + '_> {
132 Box::new(GenericAggStateCacheFiller::<'_, C> {
133 cache_filler: self.state_cache.begin_syncing(),
134 })
135 }
136
137 fn output_batches(&self, chunk_size: usize) -> Box<dyn Iterator<Item = StreamChunk> + '_> {
138 let mut values = self.state_cache.values();
139 Box::new(std::iter::from_fn(move || {
140 let mut builder = DataChunkBuilder::new(self.input_types.clone(), chunk_size);
142 for row in &mut values {
143 if let Some(chunk) = builder.append_one_row(row.0.as_slice()) {
144 return Some(chunk.into());
145 }
146 }
147 builder.consume_all().map(|chunk| chunk.into())
148 }))
149 }
150
151 fn output_first(&self) -> Datum {
152 let value = self.state_cache.values().next()?;
153 value.0[0].clone()
154 }
155}
156
157pub struct GenericAggStateCacheFiller<'filler, C>
158where
159 C: StateCache<Key = CacheKey, Value = CacheValue> + 'filler,
160{
161 cache_filler: C::Filler<'filler>,
162}
163
164impl<C> AggStateCacheFiller for GenericAggStateCacheFiller<'_, C>
165where
166 C: StateCache<Key = CacheKey, Value = CacheValue>,
167{
168 fn capacity(&self) -> Option<usize> {
169 self.cache_filler.capacity()
170 }
171
172 fn append(&mut self, key: CacheKey, value: CacheValue) {
173 self.cache_filler.insert_unchecked(key, value);
174 }
175
176 fn finish(self: Box<Self>) {
177 self.cache_filler.finish()
178 }
179}
180
181impl FromIterator<Datum> for CacheValue {
182 fn from_iter<T: IntoIterator<Item = Datum>>(iter: T) -> Self {
183 Self(iter.into_iter().collect())
184 }
185}
186
187impl EstimateSize for CacheValue {
188 fn estimated_heap_size(&self) -> usize {
189 let data_heap_size: usize = self.0.iter().map(|datum| datum.estimated_heap_size()).sum();
190 self.0.len() * std::mem::size_of::<Datum>() + data_heap_size
191 }
192}