risingwave_stream/executor/approx_percentile/
global_state.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, Bound};
16use std::mem;
17
18use risingwave_common::array::Op;
19use risingwave_common::bail;
20use risingwave_common::row::Row;
21use risingwave_common::types::{Datum, ToOwnedDatum};
22use risingwave_common::util::epoch::EpochPair;
23use risingwave_storage::StateStore;
24use risingwave_storage::store::PrefetchOptions;
25
26use crate::executor::StreamExecutorResult;
27use crate::executor::prelude::*;
28
29/// The global approx percentile state.
30pub struct GlobalApproxPercentileState<S: StateStore> {
31    quantile: f64,
32    base: f64,
33    row_count: i64,
34    bucket_state_table: StateTable<S>,
35    count_state_table: StateTable<S>,
36    cache: BucketTableCache,
37    last_output: Option<Datum>,
38    output_changed: bool,
39}
40
41// Initialization
42impl<S: StateStore> GlobalApproxPercentileState<S> {
43    pub fn new(
44        quantile: f64,
45        base: f64,
46        bucket_state_table: StateTable<S>,
47        count_state_table: StateTable<S>,
48    ) -> Self {
49        Self {
50            quantile,
51            base,
52            row_count: 0,
53            bucket_state_table,
54            count_state_table,
55            cache: BucketTableCache::new(),
56            last_output: None,
57            output_changed: false,
58        }
59    }
60
61    pub async fn init(&mut self, init_epoch: EpochPair) -> StreamExecutorResult<()> {
62        // Init state tables.
63        self.count_state_table.init_epoch(init_epoch).await?;
64        self.bucket_state_table.init_epoch(init_epoch).await?;
65
66        // Refill row_count
67        let row_count_state = self.get_row_count_state().await?;
68        let row_count = Self::decode_row_count(&row_count_state)?;
69        self.row_count = row_count;
70        tracing::debug!(?row_count, "recovered row_count");
71
72        // Refill cache
73        self.refill_cache().await?;
74
75        // Update the last output downstream
76        let last_output = if row_count_state.is_none() {
77            None
78        } else {
79            Some(self.cache.get_output(row_count, self.quantile, self.base))
80        };
81        tracing::debug!(?last_output, "recovered last_output");
82        self.last_output = last_output;
83        Ok(())
84    }
85
86    async fn refill_cache(&mut self) -> StreamExecutorResult<()> {
87        let bounds: (Bound<OwnedRow>, Bound<OwnedRow>) = (Bound::Unbounded, Bound::Unbounded);
88        #[for_await]
89        for keyed_row in self
90            .bucket_state_table
91            .iter_with_prefix(&[Datum::None; 0], &bounds, PrefetchOptions::default())
92            .await?
93        {
94            let row = keyed_row?.into_owned_row();
95            let sign = row.datum_at(0).unwrap().into_int16();
96            let bucket_id = row.datum_at(1).unwrap().into_int32();
97            let count = row.datum_at(2).unwrap().into_int64();
98            match sign {
99                -1 => {
100                    self.cache.neg_buckets.insert(bucket_id, count as i64);
101                }
102                0 => {
103                    self.cache.zeros = count as i64;
104                }
105                1 => {
106                    self.cache.pos_buckets.insert(bucket_id, count as i64);
107                }
108                _ => {
109                    bail!("Invalid sign: {}", sign);
110                }
111            }
112        }
113        Ok(())
114    }
115
116    async fn get_row_count_state(&self) -> StreamExecutorResult<Option<OwnedRow>> {
117        self.count_state_table.get_row(&[Datum::None; 0]).await
118    }
119
120    fn decode_row_count(row_count_state: &Option<OwnedRow>) -> StreamExecutorResult<i64> {
121        if let Some(row) = row_count_state.as_ref() {
122            let Some(datum) = row.datum_at(0) else {
123                bail!("Invalid row count state: {:?}", row)
124            };
125            Ok(datum.into_int64())
126        } else {
127            Ok(0)
128        }
129    }
130}
131
132// Update
133impl<S: StateStore> GlobalApproxPercentileState<S> {
134    pub fn apply_chunk(&mut self, chunk: StreamChunk) -> StreamExecutorResult<()> {
135        // Op is ignored here, because we only check the `delta` column inside the row.
136        // The sign of the `delta` column will tell us if we need to decrease or increase the
137        // count of the bucket.
138        for (_op, row) in chunk.rows() {
139            debug_assert_eq!(_op, Op::Insert);
140            self.apply_row(row)?;
141        }
142        Ok(())
143    }
144
145    pub fn apply_row(&mut self, row: impl Row) -> StreamExecutorResult<()> {
146        // Decoding
147        let sign_datum = row.datum_at(0);
148        let sign = sign_datum.unwrap().into_int16();
149        let sign_datum = sign_datum.to_owned_datum();
150        let bucket_id_datum = row.datum_at(1);
151        let bucket_id = bucket_id_datum.unwrap().into_int32();
152        let bucket_id_datum = bucket_id_datum.to_owned_datum();
153        let delta_datum = row.datum_at(2);
154        let delta: i32 = delta_datum.unwrap().into_int32();
155
156        if delta == 0 {
157            return Ok(());
158        }
159
160        self.output_changed = true;
161
162        // Updates
163        self.row_count = self.row_count.checked_add(delta as i64).unwrap();
164        tracing::debug!("updated row_count: {}", self.row_count);
165
166        let (is_new_entry, old_count, new_count) = match sign {
167            -1 => {
168                let count_entry = self.cache.neg_buckets.get(&bucket_id).copied();
169                let old_count = count_entry.unwrap_or(0);
170                let new_count = old_count.checked_add(delta as i64).unwrap();
171                let is_new_entry = count_entry.is_none();
172                if new_count != 0 {
173                    self.cache.neg_buckets.insert(bucket_id, new_count);
174                } else {
175                    self.cache.neg_buckets.remove(&bucket_id);
176                }
177                (is_new_entry, old_count, new_count)
178            }
179            0 => {
180                let old_count = self.cache.zeros;
181                let new_count = old_count.checked_add(delta as i64).unwrap();
182                let is_new_entry = old_count == 0;
183                if new_count != 0 {
184                    self.cache.zeros = new_count;
185                }
186                (is_new_entry, old_count, new_count)
187            }
188            1 => {
189                let count_entry = self.cache.pos_buckets.get(&bucket_id).copied();
190                let old_count = count_entry.unwrap_or(0);
191                let new_count = old_count.checked_add(delta as i64).unwrap();
192                let is_new_entry = count_entry.is_none();
193                if new_count != 0 {
194                    self.cache.pos_buckets.insert(bucket_id, new_count);
195                } else {
196                    self.cache.pos_buckets.remove(&bucket_id);
197                }
198                (is_new_entry, old_count, new_count)
199            }
200            _ => bail!("Invalid sign: {}", sign),
201        };
202
203        let old_row = &[
204            sign_datum.clone(),
205            bucket_id_datum.clone(),
206            Datum::from(ScalarImpl::Int64(old_count)),
207        ];
208        if new_count == 0 && !is_new_entry {
209            self.bucket_state_table.delete(old_row);
210        } else if new_count > 0 {
211            let new_row = &[
212                sign_datum,
213                bucket_id_datum,
214                Datum::from(ScalarImpl::Int64(new_count)),
215            ];
216            if is_new_entry {
217                self.bucket_state_table.insert(new_row);
218            } else {
219                self.bucket_state_table.update(old_row, new_row);
220            }
221        } else {
222            bail!("invalid state, new_count = 0 and is_new_entry is true")
223        }
224
225        Ok(())
226    }
227
228    pub async fn commit(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
229        // Commit row count state.
230        let row_count_datum = Datum::from(ScalarImpl::Int64(self.row_count));
231        let row_count_row = &[row_count_datum];
232        let last_row_count_state = self.count_state_table.get_row(&[Datum::None; 0]).await?;
233        match last_row_count_state {
234            None => self.count_state_table.insert(row_count_row),
235            Some(last_row_count_state) => self
236                .count_state_table
237                .update(last_row_count_state, row_count_row),
238        }
239        self.count_state_table
240            .commit_assert_no_update_vnode_bitmap(epoch)
241            .await?;
242        self.bucket_state_table
243            .commit_assert_no_update_vnode_bitmap(epoch)
244            .await?;
245        Ok(())
246    }
247}
248
249// Read
250impl<S: StateStore> GlobalApproxPercentileState<S> {
251    pub fn get_output(&mut self) -> StreamChunk {
252        let last_output = mem::take(&mut self.last_output);
253        let new_output = if !self.output_changed {
254            tracing::debug!("last_output: {:#?}", last_output);
255            last_output.clone().flatten()
256        } else {
257            self.cache
258                .get_output(self.row_count, self.quantile, self.base)
259        };
260        self.last_output = Some(new_output.clone());
261        let output_chunk = match last_output {
262            None => StreamChunk::from_rows(&[(Op::Insert, &[new_output])], &[DataType::Float64]),
263            Some(last_output) if !self.output_changed => StreamChunk::from_rows(
264                &[
265                    (Op::UpdateDelete, &[last_output.clone()]),
266                    (Op::UpdateInsert, &[last_output]),
267                ],
268                &[DataType::Float64],
269            ),
270            Some(last_output) => StreamChunk::from_rows(
271                &[
272                    (Op::UpdateDelete, &[last_output.clone()]),
273                    (Op::UpdateInsert, &[new_output.clone()]),
274                ],
275                &[DataType::Float64],
276            ),
277        };
278        tracing::debug!("get_output: {:#?}", output_chunk,);
279        self.output_changed = false;
280        output_chunk
281    }
282}
283
284type Count = i64;
285type BucketId = i32;
286
287type BucketMap = BTreeMap<BucketId, Count>;
288
289/// Keeps the entire bucket state table contents in-memory.
290struct BucketTableCache {
291    neg_buckets: BucketMap,
292    zeros: Count, // If Count is 0, it means this bucket has not be inserted into before.
293    pos_buckets: BucketMap,
294}
295
296impl BucketTableCache {
297    pub fn new() -> Self {
298        Self {
299            neg_buckets: BucketMap::new(),
300            zeros: 0,
301            pos_buckets: BucketMap::new(),
302        }
303    }
304
305    pub fn get_output(&self, row_count: i64, quantile: f64, base: f64) -> Datum {
306        let quantile_count = ((row_count - 1) as f64 * quantile).floor() as i64;
307        let mut acc_count = 0;
308        for (bucket_id, count) in self.neg_buckets.iter().rev() {
309            acc_count += count;
310            if acc_count > quantile_count {
311                // approx value = -2 * y^i / (y + 1)
312                let approx_percentile = -2.0 * base.powi(*bucket_id) / (base + 1.0);
313                let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
314                return Datum::from(approx_percentile);
315            }
316        }
317        acc_count += self.zeros;
318        if acc_count > quantile_count {
319            return Datum::from(ScalarImpl::Float64(0.0.into()));
320        }
321        for (bucket_id, count) in &self.pos_buckets {
322            acc_count += count;
323            if acc_count > quantile_count {
324                // approx value = 2 * y^i / (y + 1)
325                let approx_percentile = 2.0 * base.powi(*bucket_id) / (base + 1.0);
326                let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
327                return Datum::from(approx_percentile);
328            }
329        }
330        Datum::None
331    }
332}