risingwave_expr_impl/window_function/
aggregate.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeSet;
16
17use educe::Educe;
18use futures_util::FutureExt;
19use risingwave_common::array::{DataChunk, Op, StreamChunk};
20use risingwave_common::types::{DataType, Datum};
21use risingwave_common::util::iter_util::ZipEqFast;
22use risingwave_common::{bail, must_match};
23use risingwave_common_estimate_size::{EstimateSize, KvSize};
24use risingwave_expr::Result;
25use risingwave_expr::aggregate::{
26    AggCall, AggType, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction,
27    PbAggKind, build_append_only,
28};
29use risingwave_expr::sig::FUNCTION_REGISTRY;
30use risingwave_expr::window_function::{
31    BoxedWindowState, FrameBounds, StateEvictHint, StateKey, StatePos, WindowFuncCall,
32    WindowFuncKind, WindowState,
33};
34use smallvec::SmallVec;
35
36use super::buffer::{RangeWindow, RowsWindow, SessionWindow, WindowBuffer, WindowImpl};
37
38type StateValue = SmallVec<[Datum; 2]>;
39
40struct AggregateState<W>
41where
42    W: WindowImpl<Key = StateKey, Value = StateValue>,
43{
44    agg_impl: AggImpl,
45    arg_data_types: Vec<DataType>,
46    ignore_nulls: bool,
47    buffer: WindowBuffer<W>,
48    buffer_heap_size: KvSize,
49}
50
51pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
52    if call.frame.bounds.validate().is_err() {
53        bail!("the window frame must be valid");
54    }
55    let agg_type = must_match!(&call.kind, WindowFuncKind::Aggregate(agg_type) => agg_type);
56    let arg_data_types = call.args.arg_types().to_vec();
57    let agg_call = AggCall {
58        agg_type: agg_type.clone(),
59        args: call.args.clone(),
60        return_type: call.return_type.clone(),
61        column_orders: Vec::new(), // the input is already sorted
62        // TODO(rc): support filter on window function call
63        filter: None,
64        // TODO(rc): support distinct on window function call? PG doesn't support it either.
65        distinct: false,
66        direct_args: vec![],
67    };
68
69    let (agg_impl, enable_delta) = match agg_type {
70        AggType::Builtin(PbAggKind::FirstValue) => (AggImpl::Shortcut(Shortcut::FirstValue), false),
71        AggType::Builtin(PbAggKind::LastValue) => (AggImpl::Shortcut(Shortcut::LastValue), false),
72        AggType::Builtin(kind) => {
73            let agg_func_sig = FUNCTION_REGISTRY
74                .get(*kind, &arg_data_types, &call.return_type)
75                .expect("the agg func must exist");
76            let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
77            let (agg_impl, enable_delta) =
78                if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
79                    let init_state = agg_func.create_state()?;
80                    (AggImpl::Incremental(agg_func, init_state), true)
81                } else {
82                    (AggImpl::Full(agg_func), false)
83                };
84            (agg_impl, enable_delta)
85        }
86        AggType::UserDefined(_) => {
87            // TODO(rc): utilize `retract` method of embedded UDAF to do incremental aggregation
88            (AggImpl::Full(build_append_only(&agg_call)?), false)
89        }
90        AggType::WrapScalar(_) => {
91            // we have to feed the wrapped scalar function with all the rows in the window,
92            // instead of doing incremental aggregation
93            (AggImpl::Full(build_append_only(&agg_call)?), false)
94        }
95    };
96
97    let this = match &call.frame.bounds {
98        FrameBounds::Rows(frame_bounds) => Box::new(AggregateState {
99            agg_impl,
100            arg_data_types,
101            ignore_nulls: call.ignore_nulls,
102            buffer: WindowBuffer::<RowsWindow<StateKey, StateValue>>::new(
103                RowsWindow::new(frame_bounds.clone()),
104                call.frame.exclusion,
105                enable_delta,
106            ),
107            buffer_heap_size: KvSize::new(),
108        }) as BoxedWindowState,
109        FrameBounds::Range(frame_bounds) => Box::new(AggregateState {
110            agg_impl,
111            arg_data_types,
112            ignore_nulls: call.ignore_nulls,
113            buffer: WindowBuffer::<RangeWindow<StateValue>>::new(
114                RangeWindow::new(frame_bounds.clone()),
115                call.frame.exclusion,
116                enable_delta,
117            ),
118            buffer_heap_size: KvSize::new(),
119        }) as BoxedWindowState,
120        FrameBounds::Session(frame_bounds) => Box::new(AggregateState {
121            agg_impl,
122            arg_data_types,
123            ignore_nulls: call.ignore_nulls,
124            buffer: WindowBuffer::<SessionWindow<StateValue>>::new(
125                SessionWindow::new(frame_bounds.clone()),
126                call.frame.exclusion,
127                enable_delta,
128            ),
129            buffer_heap_size: KvSize::new(),
130        }) as BoxedWindowState,
131    };
132    Ok(this)
133}
134
135impl<W> AggregateState<W>
136where
137    W: WindowImpl<Key = StateKey, Value = StateValue>,
138{
139    fn slide_inner(&mut self) -> StateEvictHint {
140        let removed_keys: BTreeSet<_> = self
141            .buffer
142            .slide()
143            .map(|(k, v)| {
144                v.iter().for_each(|arg| {
145                    self.buffer_heap_size.sub_val(arg);
146                });
147                self.buffer_heap_size.sub_val(&k);
148                k
149            })
150            .collect();
151        if removed_keys.is_empty() {
152            StateEvictHint::CannotEvict(
153                self.buffer
154                    .smallest_key()
155                    .expect("sliding without removing, must have some entry in the buffer")
156                    .clone(),
157            )
158        } else {
159            StateEvictHint::CanEvict(removed_keys)
160        }
161    }
162}
163
164impl<W> WindowState for AggregateState<W>
165where
166    W: WindowImpl<Key = StateKey, Value = StateValue>,
167{
168    fn append(&mut self, key: StateKey, args: SmallVec<[Datum; 2]>) {
169        args.iter().for_each(|arg| {
170            self.buffer_heap_size.add_val(arg);
171        });
172        self.buffer_heap_size.add_val(&key);
173        self.buffer.append(key, args);
174    }
175
176    fn curr_window(&self) -> StatePos<'_> {
177        let window = self.buffer.curr_window();
178        StatePos {
179            key: window.key,
180            is_ready: window.following_saturated,
181        }
182    }
183
184    fn slide(&mut self) -> Result<(Datum, StateEvictHint)> {
185        let output = match self.agg_impl {
186            AggImpl::Full(ref agg_func) => {
187                let wrapper = AggregatorWrapper {
188                    agg_func: agg_func.as_ref(),
189                    arg_data_types: &self.arg_data_types,
190                };
191                wrapper.aggregate(self.buffer.curr_window_values())
192            }
193            AggImpl::Incremental(ref agg_func, ref mut state) => {
194                let wrapper = AggregatorWrapper {
195                    agg_func: agg_func.as_ref(),
196                    arg_data_types: &self.arg_data_types,
197                };
198                wrapper.update(state, self.buffer.consume_curr_window_values_delta())
199            }
200            AggImpl::Shortcut(shortcut) => match shortcut {
201                Shortcut::FirstValue => Ok(if !self.ignore_nulls {
202                    // no `IGNORE NULLS`
203                    self.buffer
204                        .curr_window_values()
205                        .next()
206                        .and_then(|args| args[0].clone())
207                } else {
208                    // filter out NULLs
209                    self.buffer
210                        .curr_window_values()
211                        .find(|args| args[0].is_some())
212                        .and_then(|args| args[0].clone())
213                }),
214                Shortcut::LastValue => Ok(if !self.ignore_nulls {
215                    self.buffer
216                        .curr_window_values()
217                        .next_back()
218                        .and_then(|args| args[0].clone())
219                } else {
220                    self.buffer
221                        .curr_window_values()
222                        .rev()
223                        .find(|args| args[0].is_some())
224                        .and_then(|args| args[0].clone())
225                }),
226            },
227        }?;
228        let evict_hint = self.slide_inner();
229        Ok((output, evict_hint))
230    }
231
232    fn slide_no_output(&mut self) -> Result<StateEvictHint> {
233        match self.agg_impl {
234            AggImpl::Full(..) => {}
235            AggImpl::Incremental(ref agg_func, ref mut state) => {
236                // for incremental agg, we need to update the state even if the caller doesn't need
237                // the output
238                let wrapper = AggregatorWrapper {
239                    agg_func: agg_func.as_ref(),
240                    arg_data_types: &self.arg_data_types,
241                };
242                wrapper.update(state, self.buffer.consume_curr_window_values_delta())?;
243            }
244            AggImpl::Shortcut(..) => {}
245        };
246        Ok(self.slide_inner())
247    }
248}
249
250impl<W> EstimateSize for AggregateState<W>
251where
252    W: WindowImpl<Key = StateKey, Value = StateValue>,
253{
254    fn estimated_heap_size(&self) -> usize {
255        // estimate `VecDeque` of `StreamWindowBuffer` internal size
256        // https://github.com/risingwavelabs/risingwave/issues/9713
257        self.arg_data_types.estimated_heap_size() + self.buffer_heap_size.size()
258    }
259}
260
261#[derive(Educe)]
262#[educe(Debug)]
263enum AggImpl {
264    Incremental(#[educe(Debug(ignore))] BoxedAggregateFunction, AggImplState),
265    Full(#[educe(Debug(ignore))] BoxedAggregateFunction),
266    Shortcut(Shortcut),
267}
268
269#[derive(Debug, Clone, Copy)]
270enum Shortcut {
271    FirstValue,
272    LastValue,
273}
274
275struct AggregatorWrapper<'a> {
276    agg_func: &'a dyn AggregateFunction,
277    arg_data_types: &'a [DataType],
278}
279
280impl AggregatorWrapper<'_> {
281    fn aggregate<V>(&self, values: impl IntoIterator<Item = V>) -> Result<Datum>
282    where
283        V: AsRef<[Datum]>,
284    {
285        let mut state = self.agg_func.create_state()?;
286        self.update(
287            &mut state,
288            values.into_iter().map(|args| (Op::Insert, args)),
289        )
290    }
291
292    fn update<V>(
293        &self,
294        state: &mut AggImplState,
295        delta: impl IntoIterator<Item = (Op, V)>,
296    ) -> Result<Datum>
297    where
298        V: AsRef<[Datum]>,
299    {
300        let mut args_builders = self
301            .arg_data_types
302            .iter()
303            .map(|data_type| data_type.create_array_builder(0 /* bad! */))
304            .collect::<Vec<_>>();
305        let mut ops = Vec::new();
306        let mut n_rows = 0;
307        for (op, value) in delta {
308            n_rows += 1;
309            ops.push(op);
310            for (builder, datum) in args_builders.iter_mut().zip_eq_fast(value.as_ref()) {
311                builder.append(datum);
312            }
313        }
314        let columns = args_builders
315            .into_iter()
316            .map(|builder| builder.finish().into())
317            .collect::<Vec<_>>();
318        let chunk = StreamChunk::from_parts(ops, DataChunk::new(columns, n_rows));
319
320        self.agg_func
321            .update(state, &chunk)
322            .now_or_never()
323            .expect("we don't support UDAF currently, so the function should return immediately")?;
324        self.agg_func
325            .get_result(state)
326            .now_or_never()
327            .expect("we don't support UDAF currently, so the function should return immediately")
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    // TODO(rc): need to add some unit tests
334}