risingwave_stream/executor/aggregate/
agg_state.rs1use risingwave_common::array::StreamChunk;
16use risingwave_common::bitmap::Bitmap;
17use risingwave_common::catalog::Schema;
18use risingwave_common::must_match;
19use risingwave_common::types::Datum;
20use risingwave_common::util::sort_util::ColumnOrder;
21use risingwave_common_estimate_size::EstimateSize;
22use risingwave_expr::aggregate::{AggCall, AggregateState, BoxedAggregateFunction};
23use risingwave_pb::stream_plan::PbAggNodeVersion;
24use risingwave_storage::StateStore;
25
26use super::minput::MaterializedInputState;
27use crate::common::StateTableColumnMapping;
28use crate::common::table::state_table::StateTable;
29use crate::executor::aggregate::agg_group::{AggStateCacheStats, GroupKey};
30use crate::executor::{PkIndices, StreamExecutorResult};
31
32pub enum AggStateStorage<S: StateStore> {
34 Value,
36
37 MaterializedInput {
41 table: StateTable<S>,
42 mapping: StateTableColumnMapping,
43 order_columns: Vec<ColumnOrder>,
44 },
45}
46
47pub enum AggState {
50 Value(AggregateState),
53
54 MaterializedInput(Box<MaterializedInputState>),
56}
57
58impl EstimateSize for AggState {
59 fn estimated_heap_size(&self) -> usize {
60 match self {
61 Self::Value(state) => state.estimated_heap_size(),
62 Self::MaterializedInput(state) => state.estimated_size(),
63 }
64 }
65}
66
67impl AggState {
68 #[allow(clippy::too_many_arguments)]
70 pub fn create(
71 version: PbAggNodeVersion,
72 agg_call: &AggCall,
73 agg_func: &BoxedAggregateFunction,
74 storage: &AggStateStorage<impl StateStore>,
75 encoded_state: Option<&Datum>,
76 pk_indices: &PkIndices,
77 extreme_cache_size: usize,
78 input_schema: &Schema,
79 ) -> StreamExecutorResult<Self> {
80 Ok(match storage {
81 AggStateStorage::Value => {
82 let state = match encoded_state {
83 Some(encoded) => agg_func.decode_state(encoded.clone())?,
84 None => agg_func.create_state()?,
85 };
86 Self::Value(state)
87 }
88 AggStateStorage::MaterializedInput {
89 mapping,
90 order_columns,
91 ..
92 } => Self::MaterializedInput(Box::new(MaterializedInputState::new(
93 version,
94 agg_call,
95 pk_indices,
96 order_columns,
97 mapping,
98 extreme_cache_size,
99 input_schema,
100 )?)),
101 })
102 }
103
104 pub async fn apply_chunk(
106 &mut self,
107 chunk: &StreamChunk,
108 call: &AggCall,
109 func: &BoxedAggregateFunction,
110 visibility: Bitmap,
111 ) -> StreamExecutorResult<()> {
112 match self {
113 Self::Value(state) => {
114 let chunk = chunk.project_with_vis(call.args.val_indices(), visibility);
115 func.update(state, &chunk).await?;
116 Ok(())
117 }
118 Self::MaterializedInput(state) => {
119 let chunk = chunk.clone_with_vis(visibility);
121 state.apply_chunk(&chunk)
122 }
123 }
124 }
125
126 pub async fn get_output(
128 &mut self,
129 storage: &AggStateStorage<impl StateStore>,
130 func: &BoxedAggregateFunction,
131 group_key: Option<&GroupKey>,
132 ) -> StreamExecutorResult<(Datum, AggStateCacheStats)> {
133 match self {
134 Self::Value(state) => {
135 debug_assert!(matches!(storage, AggStateStorage::Value));
136 Ok((func.get_result(state).await?, AggStateCacheStats::default()))
137 }
138 Self::MaterializedInput(state) => {
139 let state_table = must_match!(
140 storage,
141 AggStateStorage::MaterializedInput { table, .. } => table
142 );
143 state.get_output(state_table, group_key, func).await
144 }
145 }
146 }
147
148 pub fn reset(&mut self, func: &BoxedAggregateFunction) -> StreamExecutorResult<()> {
150 if let Self::Value(state) = self {
151 *state = func.create_state()?;
153 }
154 Ok(())
155 }
156}