risingwave_stream/executor/approx_percentile/
local.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::iter;
17
18use risingwave_common::array::Op;
19use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
20
21use crate::executor::prelude::*;
22
23pub struct LocalApproxPercentileExecutor {
24    _ctx: ActorContextRef,
25    pub input: Executor,
26    pub base: f64,
27    pub percentile_index: usize,
28    pub chunk_size: usize,
29}
30
31impl LocalApproxPercentileExecutor {
32    pub fn new(
33        _ctx: ActorContextRef,
34        input: Executor,
35        base: f64,
36        percentile_index: usize,
37        chunk_size: usize,
38    ) -> Self {
39        Self {
40            _ctx,
41            input,
42            base,
43            percentile_index,
44            chunk_size,
45        }
46    }
47
48    #[try_stream(ok = Message, error = StreamExecutorError)]
49    async fn execute_inner(self) {
50        let percentile_index = self.percentile_index;
51        #[for_await]
52        for message in self.input.execute() {
53            match message? {
54                Message::Chunk(chunk) => {
55                    let mut builder = DataChunkBuilder::new(
56                        vec![DataType::Int16, DataType::Int32, DataType::Int32],
57                        self.chunk_size,
58                    );
59                    let chunk = chunk.project(&[percentile_index]);
60                    let mut pos_counts = HashMap::new();
61                    let mut neg_counts = HashMap::new();
62                    let mut zero_count = 0;
63                    for (op, row) in chunk.rows() {
64                        let value = row.datum_at(0).unwrap();
65                        let value: f64 = value.into_float64().into_inner();
66                        if value < 0.0 {
67                            let value = -value;
68                            let bucket = value.log(self.base).ceil() as i32; // TODO(kwannoel): should this be floor??
69                            let count = neg_counts.entry(bucket).or_insert(0);
70                            match op {
71                                Op::Insert | Op::UpdateInsert => *count += 1,
72                                Op::Delete | Op::UpdateDelete => *count -= 1,
73                            }
74                        } else if value > 0.0 {
75                            let bucket = value.log(self.base).ceil() as i32;
76                            let count = pos_counts.entry(bucket).or_insert(0);
77                            match op {
78                                Op::Insert | Op::UpdateInsert => *count += 1,
79                                Op::Delete | Op::UpdateDelete => *count -= 1,
80                            }
81                        } else {
82                            match op {
83                                Op::Insert | Op::UpdateInsert => zero_count += 1,
84                                Op::Delete | Op::UpdateDelete => zero_count -= 1,
85                            }
86                        }
87                    }
88
89                    for (sign, bucket, count) in neg_counts
90                        .into_iter()
91                        .map(|(b, c)| (-1, b, c))
92                        .chain(pos_counts.into_iter().map(|(b, c)| (1, b, c)))
93                        .chain(iter::once((0, 0, zero_count)))
94                    {
95                        let row = [
96                            Datum::from(ScalarImpl::Int16(sign)),
97                            Datum::from(ScalarImpl::Int32(bucket)),
98                            Datum::from(ScalarImpl::Int32(count)),
99                        ];
100                        if let Some(data_chunk) = builder.append_one_row(&row) {
101                            // NOTE(kwannoel): The op here is simply ignored.
102                            // The downstream global_approx_percentile will always just update its bucket counts.
103                            let ops = vec![Op::Insert; data_chunk.cardinality()];
104                            let chunk = StreamChunk::from_parts(ops, data_chunk);
105                            yield Message::Chunk(chunk);
106                        }
107                    }
108                    if !builder.is_empty() {
109                        let data_chunk = builder.finish();
110                        let ops = vec![Op::Insert; data_chunk.cardinality()];
111                        let chunk = StreamChunk::from_parts(ops, data_chunk);
112                        yield Message::Chunk(chunk);
113                    }
114                }
115                b @ Message::Barrier(_) => yield b,
116                Message::Watermark(_) => {}
117            }
118        }
119    }
120}
121
122impl Execute for LocalApproxPercentileExecutor {
123    fn execute(self: Box<Self>) -> BoxedMessageStream {
124        self.execute_inner().boxed()
125    }
126}