risingwave_batch_executors/executor/aggregation/
orderby.rs1use std::ops::Range;
16
17use anyhow::Context;
18use risingwave_common::array::{Op, RowRef, StreamChunk};
19use risingwave_common::row::{OwnedRow, Row, RowExt};
20use risingwave_common::types::{DataType, Datum};
21use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
22use risingwave_common::util::memcmp_encoding;
23use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_expr::Result;
26use risingwave_expr::aggregate::{
27 AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFunction,
28};
29
30pub struct ProjectionOrderBy {
33 inner: BoxedAggregateFunction,
34 arg_types: Vec<DataType>,
35 arg_indices: Vec<usize>,
36 order_col_indices: Vec<usize>,
37 order_types: Vec<OrderType>,
38}
39
40#[derive(Debug)]
41struct State {
42 unordered_values: Vec<(OrderKey, OwnedRow)>,
43 unordered_values_estimated_heap_size: usize,
44}
45
46impl EstimateSize for State {
47 fn estimated_heap_size(&self) -> usize {
48 self.unordered_values.capacity() * std::mem::size_of::<(OrderKey, OwnedRow)>()
49 + self.unordered_values_estimated_heap_size
50 }
51}
52
53impl AggStateDyn for State {}
54
55type OrderKey = Box<[u8]>;
56
57impl ProjectionOrderBy {
58 pub fn new(
59 arg_types: Vec<DataType>,
60 arg_indices: Vec<usize>,
61 column_orders: Vec<ColumnOrder>,
62 inner: BoxedAggregateFunction,
63 ) -> Self {
64 let (order_col_indices, order_types) = column_orders
65 .into_iter()
66 .map(|c| (c.column_index, c.order_type))
67 .unzip();
68 Self {
69 inner,
70 arg_types,
71 arg_indices,
72 order_col_indices,
73 order_types,
74 }
75 }
76
77 fn push_row(&self, state: &mut State, row: RowRef<'_>) -> Result<()> {
78 let key =
79 memcmp_encoding::encode_row(row.project(&self.order_col_indices), &self.order_types)
80 .context("failed to encode row")?;
81 let projected_row = row.project(&self.arg_indices).to_owned_row();
82
83 state.unordered_values_estimated_heap_size +=
84 key.len() + projected_row.estimated_heap_size();
85 state.unordered_values.push((key.into(), projected_row));
86 Ok(())
87 }
88}
89
90#[async_trait::async_trait]
91impl AggregateFunction for ProjectionOrderBy {
92 fn return_type(&self) -> DataType {
93 self.inner.return_type()
94 }
95
96 fn create_state(&self) -> Result<AggregateState> {
97 Ok(AggregateState::Any(Box::new(State {
98 unordered_values: vec![],
99 unordered_values_estimated_heap_size: 0,
100 })))
101 }
102
103 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
104 let state = state.downcast_mut::<State>();
105 state.unordered_values.reserve(input.cardinality());
106 for (op, row) in input.rows() {
107 assert_eq!(op, Op::Insert, "only support append");
108 self.push_row(state, row)?;
109 }
110 Ok(())
111 }
112
113 async fn update_range(
114 &self,
115 state: &mut AggregateState,
116 input: &StreamChunk,
117 range: Range<usize>,
118 ) -> Result<()> {
119 let state = state.downcast_mut::<State>();
120 state.unordered_values.reserve(range.len());
121 for (op, row) in input.rows_in(range) {
122 assert_eq!(op, Op::Insert, "only support append");
123 self.push_row(state, row)?;
124 }
125 Ok(())
126 }
127
128 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
129 let state = state.downcast_ref::<State>();
130 let mut inner_state = self.inner.create_state()?;
131 let mut rows = state.unordered_values.clone();
133 rows.sort_unstable_by(|(key_a, _), (key_b, _)| key_a.cmp(key_b));
134 let mut chunk_builder = DataChunkBuilder::new(self.arg_types.clone(), 1024);
136 for (_, row) in rows {
137 if let Some(data_chunk) = chunk_builder.append_one_row(row) {
138 let chunk = StreamChunk::from(data_chunk);
139 self.inner.update(&mut inner_state, &chunk).await?;
140 }
141 }
142 if let Some(data_chunk) = chunk_builder.consume_all() {
143 let chunk = StreamChunk::from(data_chunk);
144 self.inner.update(&mut inner_state, &chunk).await?;
145 }
146 self.inner.get_result(&inner_state).await
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use risingwave_common::array::{ListValue, StreamChunk};
153 use risingwave_common::test_prelude::StreamChunkTestExt;
154 use risingwave_expr::aggregate::AggCall;
155
156 use super::super::build;
157
158 #[tokio::test]
159 async fn array_agg_with_order() {
160 let chunk = StreamChunk::from_pretty(
161 " i i
162 + 123 3
163 + 456 2
164 + 789 2
165 + 321 9",
166 );
167 let agg = build(&AggCall::from_pretty(
168 "(array_agg:int4[] $0:int4 orderby $1:asc $0:desc)",
169 ))
170 .unwrap();
171 let mut state = agg.create_state().unwrap();
172 agg.update(&mut state, &chunk).await.unwrap();
173 assert_eq!(
174 agg.get_result(&state).await.unwrap(),
175 Some(ListValue::from_iter([789, 456, 123, 321]).into())
176 );
177 }
178
179 #[tokio::test]
180 async fn string_agg_with_order() {
181 let chunk = StreamChunk::from_pretty(
182 " T T i i
183 + aaa _ 1 3
184 + bbb _ 0 4
185 + ccc _ 0 8
186 + ddd _ 1 3",
187 );
188 let agg = build(&AggCall::from_pretty(
189 "(string_agg:varchar $0:varchar $1:varchar orderby $2:asc $3:desc $0:desc)",
190 ))
191 .unwrap();
192 let mut state = agg.create_state().unwrap();
193 agg.update(&mut state, &chunk).await.unwrap();
194 assert_eq!(
195 agg.get_result(&state).await.unwrap(),
196 Some("ccc_bbb_ddd_aaa".into())
197 );
198 }
199}