risingwave_batch_executors/executor/aggregation/
orderby.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16
17use anyhow::Context;
18use risingwave_common::array::{Op, RowRef, StreamChunk};
19use risingwave_common::row::{OwnedRow, Row, RowExt};
20use risingwave_common::types::{DataType, Datum};
21use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
22use risingwave_common::util::memcmp_encoding;
23use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_expr::Result;
26use risingwave_expr::aggregate::{
27    AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFunction,
28};
29
30/// `ProjectionOrderBy` is a wrapper of `AggregateFunction` that sorts rows by given columns and
31/// then projects columns.
32pub struct ProjectionOrderBy {
33    inner: BoxedAggregateFunction,
34    arg_types: Vec<DataType>,
35    arg_indices: Vec<usize>,
36    order_col_indices: Vec<usize>,
37    order_types: Vec<OrderType>,
38}
39
40#[derive(Debug)]
41struct State {
42    unordered_values: Vec<(OrderKey, OwnedRow)>,
43    unordered_values_estimated_heap_size: usize,
44}
45
46impl EstimateSize for State {
47    fn estimated_heap_size(&self) -> usize {
48        self.unordered_values.capacity() * std::mem::size_of::<(OrderKey, OwnedRow)>()
49            + self.unordered_values_estimated_heap_size
50    }
51}
52
53impl AggStateDyn for State {}
54
55type OrderKey = Box<[u8]>;
56
57impl ProjectionOrderBy {
58    pub fn new(
59        arg_types: Vec<DataType>,
60        arg_indices: Vec<usize>,
61        column_orders: Vec<ColumnOrder>,
62        inner: BoxedAggregateFunction,
63    ) -> Self {
64        let (order_col_indices, order_types) = column_orders
65            .into_iter()
66            .map(|c| (c.column_index, c.order_type))
67            .unzip();
68        Self {
69            inner,
70            arg_types,
71            arg_indices,
72            order_col_indices,
73            order_types,
74        }
75    }
76
77    fn push_row(&self, state: &mut State, row: RowRef<'_>) -> Result<()> {
78        let key =
79            memcmp_encoding::encode_row(row.project(&self.order_col_indices), &self.order_types)
80                .context("failed to encode row")?;
81        let projected_row = row.project(&self.arg_indices).to_owned_row();
82
83        state.unordered_values_estimated_heap_size +=
84            key.len() + projected_row.estimated_heap_size();
85        state.unordered_values.push((key.into(), projected_row));
86        Ok(())
87    }
88}
89
90#[async_trait::async_trait]
91impl AggregateFunction for ProjectionOrderBy {
92    fn return_type(&self) -> DataType {
93        self.inner.return_type()
94    }
95
96    fn create_state(&self) -> Result<AggregateState> {
97        Ok(AggregateState::Any(Box::new(State {
98            unordered_values: vec![],
99            unordered_values_estimated_heap_size: 0,
100        })))
101    }
102
103    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
104        let state = state.downcast_mut::<State>();
105        state.unordered_values.reserve(input.cardinality());
106        for (op, row) in input.rows() {
107            assert_eq!(op, Op::Insert, "only support append");
108            self.push_row(state, row)?;
109        }
110        Ok(())
111    }
112
113    async fn update_range(
114        &self,
115        state: &mut AggregateState,
116        input: &StreamChunk,
117        range: Range<usize>,
118    ) -> Result<()> {
119        let state = state.downcast_mut::<State>();
120        state.unordered_values.reserve(range.len());
121        for (op, row) in input.rows_in(range) {
122            assert_eq!(op, Op::Insert, "only support append");
123            self.push_row(state, row)?;
124        }
125        Ok(())
126    }
127
128    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
129        let state = state.downcast_ref::<State>();
130        let mut inner_state = self.inner.create_state()?;
131        // sort
132        let mut rows = state.unordered_values.clone();
133        rows.sort_unstable_by(|(key_a, _), (key_b, _)| key_a.cmp(key_b));
134        // build chunk
135        let mut chunk_builder = DataChunkBuilder::new(self.arg_types.clone(), 1024);
136        for (_, row) in rows {
137            if let Some(data_chunk) = chunk_builder.append_one_row(row) {
138                let chunk = StreamChunk::from(data_chunk);
139                self.inner.update(&mut inner_state, &chunk).await?;
140            }
141        }
142        if let Some(data_chunk) = chunk_builder.consume_all() {
143            let chunk = StreamChunk::from(data_chunk);
144            self.inner.update(&mut inner_state, &chunk).await?;
145        }
146        self.inner.get_result(&inner_state).await
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use risingwave_common::array::{ListValue, StreamChunk};
153    use risingwave_common::test_prelude::StreamChunkTestExt;
154    use risingwave_expr::aggregate::AggCall;
155
156    use super::super::build;
157
158    #[tokio::test]
159    async fn array_agg_with_order() {
160        let chunk = StreamChunk::from_pretty(
161            " i    i
162            + 123  3
163            + 456  2
164            + 789  2
165            + 321  9",
166        );
167        let agg = build(&AggCall::from_pretty(
168            "(array_agg:int4[] $0:int4 orderby $1:asc $0:desc)",
169        ))
170        .unwrap();
171        let mut state = agg.create_state().unwrap();
172        agg.update(&mut state, &chunk).await.unwrap();
173        assert_eq!(
174            agg.get_result(&state).await.unwrap(),
175            Some(ListValue::from_iter([789, 456, 123, 321]).into())
176        );
177    }
178
179    #[tokio::test]
180    async fn string_agg_with_order() {
181        let chunk = StreamChunk::from_pretty(
182            " T   T i i
183            + aaa _ 1 3
184            + bbb _ 0 4
185            + ccc _ 0 8
186            + ddd _ 1 3",
187        );
188        let agg = build(&AggCall::from_pretty(
189            "(string_agg:varchar $0:varchar $1:varchar orderby $2:asc $3:desc $0:desc)",
190        ))
191        .unwrap();
192        let mut state = agg.create_state().unwrap();
193        agg.update(&mut state, &chunk).await.unwrap();
194        assert_eq!(
195            agg.get_result(&state).await.unwrap(),
196            Some("ccc_bbb_ddd_aaa".into())
197        );
198    }
199}