risingwave_stream/executor/approx_percentile/
global.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::global_state::GlobalApproxPercentileState;
16use crate::executor::prelude::*;
17
18pub struct GlobalApproxPercentileExecutor<S: StateStore> {
19    _ctx: ActorContextRef,
20    pub input: Executor,
21    pub quantile: f64,
22    pub base: f64,
23    pub chunk_size: usize,
24    pub state: GlobalApproxPercentileState<S>,
25}
26
27impl<S: StateStore> GlobalApproxPercentileExecutor<S> {
28    pub fn new(
29        _ctx: ActorContextRef,
30        input: Executor,
31        quantile: f64,
32        base: f64,
33        chunk_size: usize,
34        bucket_state_table: StateTable<S>,
35        count_state_table: StateTable<S>,
36    ) -> Self {
37        let global_state =
38            GlobalApproxPercentileState::new(quantile, base, bucket_state_table, count_state_table);
39        Self {
40            _ctx,
41            input,
42            quantile,
43            base,
44            chunk_size,
45            state: global_state,
46        }
47    }
48
49    /// TODO(kwannoel): Include cache later.
50    #[try_stream(ok = Message, error = StreamExecutorError)]
51    async fn execute_inner(self) {
52        // Initialize state
53        let mut input_stream = self.input.execute();
54        let first_barrier = expect_first_barrier(&mut input_stream).await?;
55        let first_epoch = first_barrier.epoch;
56        yield Message::Barrier(first_barrier);
57        let mut state = self.state;
58        state.init(first_epoch).await?;
59
60        // Get row count state, and row_count.
61        #[for_await]
62        for message in input_stream {
63            match message? {
64                Message::Chunk(chunk) => {
65                    state.apply_chunk(chunk)?;
66                }
67                Message::Barrier(barrier) => {
68                    let output = state.get_output();
69                    yield Message::Chunk(output);
70                    state.commit(barrier.epoch).await?;
71                    yield Message::Barrier(barrier);
72                }
73                Message::Watermark(_) => {}
74            }
75        }
76    }
77}
78
79impl<S: StateStore> Execute for GlobalApproxPercentileExecutor<S> {
80    fn execute(self: Box<Self>) -> BoxedMessageStream {
81        self.execute_inner().boxed()
82    }
83}