risingwave_expr_impl/window_function/
aggregate.rs1use std::collections::BTreeSet;
16
17use educe::Educe;
18use futures_util::FutureExt;
19use risingwave_common::array::{DataChunk, Op, StreamChunk};
20use risingwave_common::types::{DataType, Datum};
21use risingwave_common::util::iter_util::ZipEqFast;
22use risingwave_common::{bail, must_match};
23use risingwave_common_estimate_size::{EstimateSize, KvSize};
24use risingwave_expr::Result;
25use risingwave_expr::aggregate::{
26 AggCall, AggType, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction,
27 PbAggKind, build_append_only,
28};
29use risingwave_expr::sig::FUNCTION_REGISTRY;
30use risingwave_expr::window_function::{
31 BoxedWindowState, FrameBounds, StateEvictHint, StateKey, StatePos, WindowFuncCall,
32 WindowFuncKind, WindowState,
33};
34use smallvec::SmallVec;
35
36use super::buffer::{RangeWindow, RowsWindow, SessionWindow, WindowBuffer, WindowImpl};
37
38type StateValue = SmallVec<[Datum; 2]>;
39
40struct AggregateState<W>
41where
42 W: WindowImpl<Key = StateKey, Value = StateValue>,
43{
44 agg_impl: AggImpl,
45 arg_data_types: Vec<DataType>,
46 ignore_nulls: bool,
47 buffer: WindowBuffer<W>,
48 buffer_heap_size: KvSize,
49}
50
51pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
52 if call.frame.bounds.validate().is_err() {
53 bail!("the window frame must be valid");
54 }
55 let agg_type = must_match!(&call.kind, WindowFuncKind::Aggregate(agg_type) => agg_type);
56 let arg_data_types = call.args.arg_types().to_vec();
57 let agg_call = AggCall {
58 agg_type: agg_type.clone(),
59 args: call.args.clone(),
60 return_type: call.return_type.clone(),
61 column_orders: Vec::new(), filter: None,
64 distinct: false,
66 direct_args: vec![],
67 };
68
69 let (agg_impl, enable_delta) = match agg_type {
70 AggType::Builtin(PbAggKind::FirstValue) => (AggImpl::Shortcut(Shortcut::FirstValue), false),
71 AggType::Builtin(PbAggKind::LastValue) => (AggImpl::Shortcut(Shortcut::LastValue), false),
72 AggType::Builtin(kind) => {
73 let agg_func_sig = FUNCTION_REGISTRY
74 .get(*kind, &arg_data_types, &call.return_type)
75 .expect("the agg func must exist");
76 let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
77 let (agg_impl, enable_delta) =
78 if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
79 let init_state = agg_func.create_state()?;
80 (AggImpl::Incremental(agg_func, init_state), true)
81 } else {
82 (AggImpl::Full(agg_func), false)
83 };
84 (agg_impl, enable_delta)
85 }
86 AggType::UserDefined(_) => {
87 (AggImpl::Full(build_append_only(&agg_call)?), false)
89 }
90 AggType::WrapScalar(_) => {
91 (AggImpl::Full(build_append_only(&agg_call)?), false)
94 }
95 };
96
97 let this = match &call.frame.bounds {
98 FrameBounds::Rows(frame_bounds) => Box::new(AggregateState {
99 agg_impl,
100 arg_data_types,
101 ignore_nulls: call.ignore_nulls,
102 buffer: WindowBuffer::<RowsWindow<StateKey, StateValue>>::new(
103 RowsWindow::new(frame_bounds.clone()),
104 call.frame.exclusion,
105 enable_delta,
106 ),
107 buffer_heap_size: KvSize::new(),
108 }) as BoxedWindowState,
109 FrameBounds::Range(frame_bounds) => Box::new(AggregateState {
110 agg_impl,
111 arg_data_types,
112 ignore_nulls: call.ignore_nulls,
113 buffer: WindowBuffer::<RangeWindow<StateValue>>::new(
114 RangeWindow::new(frame_bounds.clone()),
115 call.frame.exclusion,
116 enable_delta,
117 ),
118 buffer_heap_size: KvSize::new(),
119 }) as BoxedWindowState,
120 FrameBounds::Session(frame_bounds) => Box::new(AggregateState {
121 agg_impl,
122 arg_data_types,
123 ignore_nulls: call.ignore_nulls,
124 buffer: WindowBuffer::<SessionWindow<StateValue>>::new(
125 SessionWindow::new(frame_bounds.clone()),
126 call.frame.exclusion,
127 enable_delta,
128 ),
129 buffer_heap_size: KvSize::new(),
130 }) as BoxedWindowState,
131 };
132 Ok(this)
133}
134
135impl<W> AggregateState<W>
136where
137 W: WindowImpl<Key = StateKey, Value = StateValue>,
138{
139 fn slide_inner(&mut self) -> StateEvictHint {
140 let removed_keys: BTreeSet<_> = self
141 .buffer
142 .slide()
143 .map(|(k, v)| {
144 v.iter().for_each(|arg| {
145 self.buffer_heap_size.sub_val(arg);
146 });
147 self.buffer_heap_size.sub_val(&k);
148 k
149 })
150 .collect();
151 if removed_keys.is_empty() {
152 StateEvictHint::CannotEvict(
153 self.buffer
154 .smallest_key()
155 .expect("sliding without removing, must have some entry in the buffer")
156 .clone(),
157 )
158 } else {
159 StateEvictHint::CanEvict(removed_keys)
160 }
161 }
162}
163
164impl<W> WindowState for AggregateState<W>
165where
166 W: WindowImpl<Key = StateKey, Value = StateValue>,
167{
168 fn append(&mut self, key: StateKey, args: SmallVec<[Datum; 2]>) {
169 args.iter().for_each(|arg| {
170 self.buffer_heap_size.add_val(arg);
171 });
172 self.buffer_heap_size.add_val(&key);
173 self.buffer.append(key, args);
174 }
175
176 fn curr_window(&self) -> StatePos<'_> {
177 let window = self.buffer.curr_window();
178 StatePos {
179 key: window.key,
180 is_ready: window.following_saturated,
181 }
182 }
183
184 fn slide(&mut self) -> Result<(Datum, StateEvictHint)> {
185 let output = match self.agg_impl {
186 AggImpl::Full(ref agg_func) => {
187 let wrapper = AggregatorWrapper {
188 agg_func: agg_func.as_ref(),
189 arg_data_types: &self.arg_data_types,
190 };
191 wrapper.aggregate(self.buffer.curr_window_values())
192 }
193 AggImpl::Incremental(ref agg_func, ref mut state) => {
194 let wrapper = AggregatorWrapper {
195 agg_func: agg_func.as_ref(),
196 arg_data_types: &self.arg_data_types,
197 };
198 wrapper.update(state, self.buffer.consume_curr_window_values_delta())
199 }
200 AggImpl::Shortcut(shortcut) => match shortcut {
201 Shortcut::FirstValue => Ok(if !self.ignore_nulls {
202 self.buffer
204 .curr_window_values()
205 .next()
206 .and_then(|args| args[0].clone())
207 } else {
208 self.buffer
210 .curr_window_values()
211 .find(|args| args[0].is_some())
212 .and_then(|args| args[0].clone())
213 }),
214 Shortcut::LastValue => Ok(if !self.ignore_nulls {
215 self.buffer
216 .curr_window_values()
217 .next_back()
218 .and_then(|args| args[0].clone())
219 } else {
220 self.buffer
221 .curr_window_values()
222 .rev()
223 .find(|args| args[0].is_some())
224 .and_then(|args| args[0].clone())
225 }),
226 },
227 }?;
228 let evict_hint = self.slide_inner();
229 Ok((output, evict_hint))
230 }
231
232 fn slide_no_output(&mut self) -> Result<StateEvictHint> {
233 match self.agg_impl {
234 AggImpl::Full(..) => {}
235 AggImpl::Incremental(ref agg_func, ref mut state) => {
236 let wrapper = AggregatorWrapper {
239 agg_func: agg_func.as_ref(),
240 arg_data_types: &self.arg_data_types,
241 };
242 wrapper.update(state, self.buffer.consume_curr_window_values_delta())?;
243 }
244 AggImpl::Shortcut(..) => {}
245 };
246 Ok(self.slide_inner())
247 }
248}
249
250impl<W> EstimateSize for AggregateState<W>
251where
252 W: WindowImpl<Key = StateKey, Value = StateValue>,
253{
254 fn estimated_heap_size(&self) -> usize {
255 self.arg_data_types.estimated_heap_size() + self.buffer_heap_size.size()
258 }
259}
260
261#[derive(Educe)]
262#[educe(Debug)]
263enum AggImpl {
264 Incremental(#[educe(Debug(ignore))] BoxedAggregateFunction, AggImplState),
265 Full(#[educe(Debug(ignore))] BoxedAggregateFunction),
266 Shortcut(Shortcut),
267}
268
269#[derive(Debug, Clone, Copy)]
270enum Shortcut {
271 FirstValue,
272 LastValue,
273}
274
275struct AggregatorWrapper<'a> {
276 agg_func: &'a dyn AggregateFunction,
277 arg_data_types: &'a [DataType],
278}
279
280impl AggregatorWrapper<'_> {
281 fn aggregate<V>(&self, values: impl IntoIterator<Item = V>) -> Result<Datum>
282 where
283 V: AsRef<[Datum]>,
284 {
285 let mut state = self.agg_func.create_state()?;
286 self.update(
287 &mut state,
288 values.into_iter().map(|args| (Op::Insert, args)),
289 )
290 }
291
292 fn update<V>(
293 &self,
294 state: &mut AggImplState,
295 delta: impl IntoIterator<Item = (Op, V)>,
296 ) -> Result<Datum>
297 where
298 V: AsRef<[Datum]>,
299 {
300 let mut args_builders = self
301 .arg_data_types
302 .iter()
303 .map(|data_type| data_type.create_array_builder(0 ))
304 .collect::<Vec<_>>();
305 let mut ops = Vec::new();
306 let mut n_rows = 0;
307 for (op, value) in delta {
308 n_rows += 1;
309 ops.push(op);
310 for (builder, datum) in args_builders.iter_mut().zip_eq_fast(value.as_ref()) {
311 builder.append(datum);
312 }
313 }
314 let columns = args_builders
315 .into_iter()
316 .map(|builder| builder.finish().into())
317 .collect::<Vec<_>>();
318 let chunk = StreamChunk::from_parts(ops, DataChunk::new(columns, n_rows));
319
320 self.agg_func
321 .update(state, &chunk)
322 .now_or_never()
323 .expect("we don't support UDAF currently, so the function should return immediately")?;
324 self.agg_func
325 .get_result(state)
326 .now_or_never()
327 .expect("we don't support UDAF currently, so the function should return immediately")
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 }