risingwave_stream/executor/approx_percentile/
global_state.rs1use std::collections::{BTreeMap, Bound};
16use std::mem;
17
18use risingwave_common::array::Op;
19use risingwave_common::bail;
20use risingwave_common::row::Row;
21use risingwave_common::types::{Datum, ToOwnedDatum};
22use risingwave_common::util::epoch::EpochPair;
23use risingwave_storage::StateStore;
24use risingwave_storage::store::PrefetchOptions;
25
26use crate::executor::StreamExecutorResult;
27use crate::executor::prelude::*;
28
29pub struct GlobalApproxPercentileState<S: StateStore> {
31 quantile: f64,
32 base: f64,
33 row_count: i64,
34 bucket_state_table: StateTable<S>,
35 count_state_table: StateTable<S>,
36 cache: BucketTableCache,
37 last_output: Option<Datum>,
38 output_changed: bool,
39}
40
41impl<S: StateStore> GlobalApproxPercentileState<S> {
43 pub fn new(
44 quantile: f64,
45 base: f64,
46 bucket_state_table: StateTable<S>,
47 count_state_table: StateTable<S>,
48 ) -> Self {
49 Self {
50 quantile,
51 base,
52 row_count: 0,
53 bucket_state_table,
54 count_state_table,
55 cache: BucketTableCache::new(),
56 last_output: None,
57 output_changed: false,
58 }
59 }
60
61 pub async fn init(&mut self, init_epoch: EpochPair) -> StreamExecutorResult<()> {
62 self.count_state_table.init_epoch(init_epoch).await?;
64 self.bucket_state_table.init_epoch(init_epoch).await?;
65
66 let row_count_state = self.get_row_count_state().await?;
68 let row_count = Self::decode_row_count(&row_count_state)?;
69 self.row_count = row_count;
70 tracing::debug!(?row_count, "recovered row_count");
71
72 self.refill_cache().await?;
74
75 let last_output = if row_count_state.is_none() {
77 None
78 } else {
79 Some(self.cache.get_output(row_count, self.quantile, self.base))
80 };
81 tracing::debug!(?last_output, "recovered last_output");
82 self.last_output = last_output;
83 Ok(())
84 }
85
86 async fn refill_cache(&mut self) -> StreamExecutorResult<()> {
87 let bounds: (Bound<OwnedRow>, Bound<OwnedRow>) = (Bound::Unbounded, Bound::Unbounded);
88 #[for_await]
89 for keyed_row in self
90 .bucket_state_table
91 .iter_with_prefix(&[Datum::None; 0], &bounds, PrefetchOptions::default())
92 .await?
93 {
94 let row = keyed_row?.into_owned_row();
95 let sign = row.datum_at(0).unwrap().into_int16();
96 let bucket_id = row.datum_at(1).unwrap().into_int32();
97 let count = row.datum_at(2).unwrap().into_int64();
98 match sign {
99 -1 => {
100 self.cache.neg_buckets.insert(bucket_id, count as i64);
101 }
102 0 => {
103 self.cache.zeros = count as i64;
104 }
105 1 => {
106 self.cache.pos_buckets.insert(bucket_id, count as i64);
107 }
108 _ => {
109 bail!("Invalid sign: {}", sign);
110 }
111 }
112 }
113 Ok(())
114 }
115
116 async fn get_row_count_state(&self) -> StreamExecutorResult<Option<OwnedRow>> {
117 self.count_state_table.get_row(&[Datum::None; 0]).await
118 }
119
120 fn decode_row_count(row_count_state: &Option<OwnedRow>) -> StreamExecutorResult<i64> {
121 if let Some(row) = row_count_state.as_ref() {
122 let Some(datum) = row.datum_at(0) else {
123 bail!("Invalid row count state: {:?}", row)
124 };
125 Ok(datum.into_int64())
126 } else {
127 Ok(0)
128 }
129 }
130}
131
132impl<S: StateStore> GlobalApproxPercentileState<S> {
134 pub fn apply_chunk(&mut self, chunk: StreamChunk) -> StreamExecutorResult<()> {
135 for (_op, row) in chunk.rows() {
139 debug_assert_eq!(_op, Op::Insert);
140 self.apply_row(row)?;
141 }
142 Ok(())
143 }
144
145 pub fn apply_row(&mut self, row: impl Row) -> StreamExecutorResult<()> {
146 let sign_datum = row.datum_at(0);
148 let sign = sign_datum.unwrap().into_int16();
149 let sign_datum = sign_datum.to_owned_datum();
150 let bucket_id_datum = row.datum_at(1);
151 let bucket_id = bucket_id_datum.unwrap().into_int32();
152 let bucket_id_datum = bucket_id_datum.to_owned_datum();
153 let delta_datum = row.datum_at(2);
154 let delta: i32 = delta_datum.unwrap().into_int32();
155
156 if delta == 0 {
157 return Ok(());
158 }
159
160 self.output_changed = true;
161
162 self.row_count = self.row_count.checked_add(delta as i64).unwrap();
164 tracing::debug!("updated row_count: {}", self.row_count);
165
166 let (is_new_entry, old_count, new_count) = match sign {
167 -1 => {
168 let count_entry = self.cache.neg_buckets.get(&bucket_id).copied();
169 let old_count = count_entry.unwrap_or(0);
170 let new_count = old_count.checked_add(delta as i64).unwrap();
171 let is_new_entry = count_entry.is_none();
172 if new_count != 0 {
173 self.cache.neg_buckets.insert(bucket_id, new_count);
174 } else {
175 self.cache.neg_buckets.remove(&bucket_id);
176 }
177 (is_new_entry, old_count, new_count)
178 }
179 0 => {
180 let old_count = self.cache.zeros;
181 let new_count = old_count.checked_add(delta as i64).unwrap();
182 let is_new_entry = old_count == 0;
183 if new_count != 0 {
184 self.cache.zeros = new_count;
185 }
186 (is_new_entry, old_count, new_count)
187 }
188 1 => {
189 let count_entry = self.cache.pos_buckets.get(&bucket_id).copied();
190 let old_count = count_entry.unwrap_or(0);
191 let new_count = old_count.checked_add(delta as i64).unwrap();
192 let is_new_entry = count_entry.is_none();
193 if new_count != 0 {
194 self.cache.pos_buckets.insert(bucket_id, new_count);
195 } else {
196 self.cache.pos_buckets.remove(&bucket_id);
197 }
198 (is_new_entry, old_count, new_count)
199 }
200 _ => bail!("Invalid sign: {}", sign),
201 };
202
203 let old_row = &[
204 sign_datum.clone(),
205 bucket_id_datum.clone(),
206 Datum::from(ScalarImpl::Int64(old_count)),
207 ];
208 if new_count == 0 && !is_new_entry {
209 self.bucket_state_table.delete(old_row);
210 } else if new_count > 0 {
211 let new_row = &[
212 sign_datum,
213 bucket_id_datum,
214 Datum::from(ScalarImpl::Int64(new_count)),
215 ];
216 if is_new_entry {
217 self.bucket_state_table.insert(new_row);
218 } else {
219 self.bucket_state_table.update(old_row, new_row);
220 }
221 } else {
222 bail!("invalid state, new_count = 0 and is_new_entry is true")
223 }
224
225 Ok(())
226 }
227
228 pub async fn commit(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
229 let row_count_datum = Datum::from(ScalarImpl::Int64(self.row_count));
231 let row_count_row = &[row_count_datum];
232 let last_row_count_state = self.count_state_table.get_row(&[Datum::None; 0]).await?;
233 match last_row_count_state {
234 None => self.count_state_table.insert(row_count_row),
235 Some(last_row_count_state) => self
236 .count_state_table
237 .update(last_row_count_state, row_count_row),
238 }
239 self.count_state_table
240 .commit_assert_no_update_vnode_bitmap(epoch)
241 .await?;
242 self.bucket_state_table
243 .commit_assert_no_update_vnode_bitmap(epoch)
244 .await?;
245 Ok(())
246 }
247}
248
249impl<S: StateStore> GlobalApproxPercentileState<S> {
251 pub fn get_output(&mut self) -> StreamChunk {
252 let last_output = mem::take(&mut self.last_output);
253 let new_output = if !self.output_changed {
254 tracing::debug!("last_output: {:#?}", last_output);
255 last_output.clone().flatten()
256 } else {
257 self.cache
258 .get_output(self.row_count, self.quantile, self.base)
259 };
260 self.last_output = Some(new_output.clone());
261 let output_chunk = match last_output {
262 None => StreamChunk::from_rows(&[(Op::Insert, &[new_output])], &[DataType::Float64]),
263 Some(last_output) if !self.output_changed => StreamChunk::from_rows(
264 &[
265 (Op::UpdateDelete, &[last_output.clone()]),
266 (Op::UpdateInsert, &[last_output]),
267 ],
268 &[DataType::Float64],
269 ),
270 Some(last_output) => StreamChunk::from_rows(
271 &[
272 (Op::UpdateDelete, &[last_output.clone()]),
273 (Op::UpdateInsert, &[new_output.clone()]),
274 ],
275 &[DataType::Float64],
276 ),
277 };
278 tracing::debug!("get_output: {:#?}", output_chunk,);
279 self.output_changed = false;
280 output_chunk
281 }
282}
283
284type Count = i64;
285type BucketId = i32;
286
287type BucketMap = BTreeMap<BucketId, Count>;
288
289struct BucketTableCache {
291 neg_buckets: BucketMap,
292 zeros: Count, pos_buckets: BucketMap,
294}
295
296impl BucketTableCache {
297 pub fn new() -> Self {
298 Self {
299 neg_buckets: BucketMap::new(),
300 zeros: 0,
301 pos_buckets: BucketMap::new(),
302 }
303 }
304
305 pub fn get_output(&self, row_count: i64, quantile: f64, base: f64) -> Datum {
306 let quantile_count = ((row_count - 1) as f64 * quantile).floor() as i64;
307 let mut acc_count = 0;
308 for (bucket_id, count) in self.neg_buckets.iter().rev() {
309 acc_count += count;
310 if acc_count > quantile_count {
311 let approx_percentile = -2.0 * base.powi(*bucket_id) / (base + 1.0);
313 let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
314 return Datum::from(approx_percentile);
315 }
316 }
317 acc_count += self.zeros;
318 if acc_count > quantile_count {
319 return Datum::from(ScalarImpl::Float64(0.0.into()));
320 }
321 for (bucket_id, count) in &self.pos_buckets {
322 acc_count += count;
323 if acc_count > quantile_count {
324 let approx_percentile = 2.0 * base.powi(*bucket_id) / (base + 1.0);
326 let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
327 return Datum::from(approx_percentile);
328 }
329 }
330 Datum::None
331 }
332}