risingwave_batch_executors/executor/aggregation/
projection.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16
17use risingwave_common::array::StreamChunk;
18use risingwave_common::types::{DataType, Datum};
19use risingwave_expr::Result;
20use risingwave_expr::aggregate::{AggregateFunction, AggregateState, BoxedAggregateFunction};
21
22pub struct Projection {
23    inner: BoxedAggregateFunction,
24    indices: Vec<usize>,
25}
26
27impl Projection {
28    pub fn new(indices: Vec<usize>, inner: BoxedAggregateFunction) -> Self {
29        Self { inner, indices }
30    }
31}
32
33#[async_trait::async_trait]
34impl AggregateFunction for Projection {
35    fn return_type(&self) -> DataType {
36        self.inner.return_type()
37    }
38
39    fn create_state(&self) -> Result<AggregateState> {
40        self.inner.create_state()
41    }
42
43    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
44        self.inner
45            .update(state, &input.project(&self.indices))
46            .await
47    }
48
49    async fn update_range(
50        &self,
51        state: &mut AggregateState,
52        input: &StreamChunk,
53        range: Range<usize>,
54    ) -> Result<()> {
55        self.inner
56            .update_range(state, &input.project(&self.indices), range)
57            .await
58    }
59
60    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
61        self.inner.get_result(state).await
62    }
63}