risingwave_expr_impl/aggregate/
array_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::ArrayBuilderImpl;
16use risingwave_common::types::{Datum, ListValue, ScalarRefImpl};
17use risingwave_common_estimate_size::EstimateSize;
18use risingwave_expr::aggregate;
19use risingwave_expr::aggregate::AggStateDyn;
20use risingwave_expr::expr::Context;
21
22#[aggregate("array_agg(any) -> anyarray")]
23fn array_agg(state: &mut ArrayAggState, value: Option<ScalarRefImpl<'_>>, ctx: &Context) {
24    state
25        .0
26        .get_or_insert_with(|| ctx.arg_types[0].create_array_builder(1))
27        .append(value);
28}
29
30#[derive(Debug, Clone, Default)]
31struct ArrayAggState(Option<ArrayBuilderImpl>);
32
33impl EstimateSize for ArrayAggState {
34    fn estimated_heap_size(&self) -> usize {
35        self.0.estimated_heap_size()
36    }
37}
38
39impl AggStateDyn for ArrayAggState {}
40
41/// Finishes aggregation and returns the result.
42impl From<&ArrayAggState> for Datum {
43    fn from(state: &ArrayAggState) -> Self {
44        state
45            .0
46            .as_ref()
47            .map(|b| ListValue::new(b.clone().finish()).into())
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use risingwave_common::array::{ListValue, StreamChunk};
54    use risingwave_common::test_prelude::StreamChunkTestExt;
55    use risingwave_expr::Result;
56    use risingwave_expr::aggregate::{AggCall, build_append_only};
57
58    #[tokio::test]
59    async fn test_array_agg_basic() -> Result<()> {
60        let chunk = StreamChunk::from_pretty(
61            " i
62            + 123
63            + 456
64            + 789",
65        );
66        let array_agg = build_append_only(&AggCall::from_pretty("(array_agg:int4[] $0:int4)"))?;
67        let mut state = array_agg.create_state()?;
68        array_agg.update(&mut state, &chunk).await?;
69        let actual = array_agg.get_result(&state).await?;
70        assert_eq!(actual, Some(ListValue::from_iter([123, 456, 789]).into()));
71        Ok(())
72    }
73
74    #[tokio::test]
75    async fn test_array_agg_empty() -> Result<()> {
76        let array_agg = build_append_only(&AggCall::from_pretty("(array_agg:int4[] $0:int4)"))?;
77        let mut state = array_agg.create_state()?;
78
79        assert_eq!(array_agg.get_result(&state).await?, None);
80
81        let chunk = StreamChunk::from_pretty(
82            " i
83            + .",
84        );
85        array_agg.update(&mut state, &chunk).await?;
86        assert_eq!(
87            array_agg.get_result(&state).await?,
88            Some(ListValue::from_iter([None::<i32>]).into())
89        );
90        Ok(())
91    }
92}