use std::collections::HashMap;
use risingwave_common::array::stream_record::Record;
use risingwave_common::util::epoch::EpochPair;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::aggregate::{build_retractable, AggCall, BoxedAggregateFunction};
use risingwave_pb::stream_plan::PbAggNodeVersion;
use super::agg_common::{AggExecutorArgs, SimpleAggExecutorExtraArgs};
use super::aggregation::{
agg_call_filter_res, iter_table_storage, AggStateStorage, AlwaysOutput, DistinctDeduplicater,
};
use crate::executor::aggregation::AggGroup;
use crate::executor::prelude::*;
pub struct SimpleAggExecutor<S: StateStore> {
input: Executor,
inner: ExecutorInner<S>,
}
struct ExecutorInner<S: StateStore> {
version: PbAggNodeVersion,
actor_ctx: ActorContextRef,
info: ExecutorInfo,
input_pk_indices: Vec<usize>,
input_schema: Schema,
agg_calls: Vec<AggCall>,
agg_funcs: Vec<BoxedAggregateFunction>,
row_count_index: usize,
storages: Vec<AggStateStorage<S>>,
intermediate_state_table: StateTable<S>,
distinct_dedup_tables: HashMap<usize, StateTable<S>>,
watermark_epoch: AtomicU64Ref,
extreme_cache_size: usize,
must_output_per_barrier: bool,
}
impl<S: StateStore> ExecutorInner<S> {
fn all_state_tables_mut(&mut self) -> impl Iterator<Item = &mut StateTable<S>> {
iter_table_storage(&mut self.storages)
.chain(self.distinct_dedup_tables.values_mut())
.chain(std::iter::once(&mut self.intermediate_state_table))
}
}
struct ExecutionVars<S: StateStore> {
agg_group: AggGroup<S, AlwaysOutput>,
distinct_dedup: DistinctDeduplicater<S>,
state_changed: bool,
}
impl<S: StateStore> Execute for SimpleAggExecutor<S> {
fn execute(self: Box<Self>) -> BoxedMessageStream {
self.execute_inner().boxed()
}
}
impl<S: StateStore> SimpleAggExecutor<S> {
pub fn new(args: AggExecutorArgs<S, SimpleAggExecutorExtraArgs>) -> StreamResult<Self> {
let input_info = args.input.info().clone();
Ok(Self {
input: args.input,
inner: ExecutorInner {
version: args.version,
actor_ctx: args.actor_ctx,
info: args.info,
input_pk_indices: input_info.pk_indices,
input_schema: input_info.schema,
agg_funcs: args.agg_calls.iter().map(build_retractable).try_collect()?,
agg_calls: args.agg_calls,
row_count_index: args.row_count_index,
storages: args.storages,
intermediate_state_table: args.intermediate_state_table,
distinct_dedup_tables: args.distinct_dedup_tables,
watermark_epoch: args.watermark_epoch,
extreme_cache_size: args.extreme_cache_size,
must_output_per_barrier: args.extra.must_output_per_barrier,
},
})
}
async fn apply_chunk(
this: &mut ExecutorInner<S>,
vars: &mut ExecutionVars<S>,
chunk: StreamChunk,
) -> StreamExecutorResult<()> {
if chunk.cardinality() == 0 {
return Ok(());
}
let mut call_visibilities = Vec::with_capacity(this.agg_calls.len());
for agg_call in &this.agg_calls {
let vis = agg_call_filter_res(agg_call, &chunk).await?;
call_visibilities.push(vis);
}
let visibilities = vars
.distinct_dedup
.dedup_chunk(
chunk.ops(),
chunk.columns(),
call_visibilities,
&mut this.distinct_dedup_tables,
None,
)
.await?;
for (storage, visibility) in this.storages.iter_mut().zip_eq_fast(visibilities.iter()) {
if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage {
let chunk = chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
table.write_chunk(chunk);
}
}
vars.agg_group
.apply_chunk(&chunk, &this.agg_calls, &this.agg_funcs, visibilities)
.await?;
vars.state_changed = true;
Ok(())
}
async fn flush_data(
this: &mut ExecutorInner<S>,
vars: &mut ExecutionVars<S>,
epoch: EpochPair,
) -> StreamExecutorResult<Option<StreamChunk>> {
if vars.state_changed || vars.agg_group.is_uninitialized() {
vars.distinct_dedup.flush(&mut this.distinct_dedup_tables)?;
let encoded_states = vars.agg_group.encode_states(&this.agg_funcs)?;
this.intermediate_state_table
.update_without_old_value(encoded_states);
}
let (change, _stats) = vars
.agg_group
.build_change(&this.storages, &this.agg_funcs)
.await?;
let chunk = change.and_then(|change| {
if !this.must_output_per_barrier {
if let Record::Update { old_row, new_row } = &change {
if old_row == new_row {
return None;
}
};
}
Some(change.to_stream_chunk(&this.info.schema.data_types()))
});
futures::future::try_join_all(this.all_state_tables_mut().map(|table| table.commit(epoch)))
.await?;
vars.state_changed = false;
Ok(chunk)
}
async fn try_flush_data(this: &mut ExecutorInner<S>) -> StreamExecutorResult<()> {
futures::future::try_join_all(this.all_state_tables_mut().map(|table| table.try_flush()))
.await?;
Ok(())
}
#[try_stream(ok = Message, error = StreamExecutorError)]
async fn execute_inner(self) {
let Self {
input,
inner: mut this,
} = self;
let mut input = input.execute();
let barrier = expect_first_barrier(&mut input).await?;
let first_epoch = barrier.epoch;
yield Message::Barrier(barrier);
for table in this.all_state_tables_mut() {
table.init_epoch(first_epoch).await?;
}
let distinct_dedup = DistinctDeduplicater::new(
&this.agg_calls,
this.watermark_epoch.clone(),
&this.distinct_dedup_tables,
&this.actor_ctx,
);
let mut vars = ExecutionVars {
agg_group: AggGroup::create(
this.version,
None,
&this.agg_calls,
&this.agg_funcs,
&this.storages,
&this.intermediate_state_table,
&this.input_pk_indices,
this.row_count_index,
this.extreme_cache_size,
&this.input_schema,
)
.await?,
distinct_dedup,
state_changed: false,
};
#[for_await]
for msg in input {
let msg = msg?;
match msg {
Message::Watermark(_) => {}
Message::Chunk(chunk) => {
Self::apply_chunk(&mut this, &mut vars, chunk).await?;
Self::try_flush_data(&mut this).await?;
}
Message::Barrier(barrier) => {
if let Some(chunk) =
Self::flush_data(&mut this, &mut vars, barrier.epoch).await?
{
yield Message::Chunk(chunk);
}
yield Message::Barrier(barrier);
}
}
}
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use risingwave_common::array::stream_chunk::StreamChunkTestExt;
use risingwave_common::catalog::Field;
use risingwave_common::types::*;
use risingwave_common::util::epoch::test_epoch;
use risingwave_storage::memory::MemoryStateStore;
use super::*;
use crate::executor::test_utils::agg_executor::new_boxed_simple_agg_executor;
use crate::executor::test_utils::*;
#[tokio::test]
async fn test_simple_aggregation_in_memory() {
test_simple_aggregation(MemoryStateStore::new()).await
}
async fn test_simple_aggregation<S: StateStore>(store: S) {
let schema = Schema {
fields: vec![
Field::unnamed(DataType::Int64),
Field::unnamed(DataType::Int64),
Field::unnamed(DataType::Int64),
],
};
let (mut tx, source) = MockSource::channel();
let source = source.into_executor(schema, vec![2]);
tx.push_barrier(test_epoch(1), false);
tx.push_barrier(test_epoch(2), false);
tx.push_chunk(StreamChunk::from_pretty(
" I I I
+ 100 200 1001
+ 10 14 1002
+ 4 300 1003",
));
tx.push_barrier(test_epoch(3), false);
tx.push_chunk(StreamChunk::from_pretty(
" I I I
- 100 200 1001
- 10 14 1002 D
- 4 300 1003
+ 104 500 1004",
));
tx.push_barrier(test_epoch(4), false);
let agg_calls = vec![
AggCall::from_pretty("(count:int8)"),
AggCall::from_pretty("(sum:int8 $0:int8)"),
AggCall::from_pretty("(sum:int8 $1:int8)"),
AggCall::from_pretty("(min:int8 $0:int8)"),
];
let simple_agg = new_boxed_simple_agg_executor(
ActorContext::for_test(123),
store,
source,
false,
agg_calls,
0,
vec![2],
1,
false,
)
.await;
let mut simple_agg = simple_agg.execute();
simple_agg.next().await.unwrap().unwrap();
let msg = simple_agg.next().await.unwrap().unwrap();
assert_eq!(
*msg.as_chunk().unwrap(),
StreamChunk::from_pretty(
" I I I I
+ 0 . . . "
)
);
assert_matches!(
simple_agg.next().await.unwrap().unwrap(),
Message::Barrier { .. }
);
let msg = simple_agg.next().await.unwrap().unwrap();
assert_eq!(
*msg.as_chunk().unwrap(),
StreamChunk::from_pretty(
" I I I I
U- 0 . . .
U+ 3 114 514 4"
)
);
assert_matches!(
simple_agg.next().await.unwrap().unwrap(),
Message::Barrier { .. }
);
let msg = simple_agg.next().await.unwrap().unwrap();
assert_eq!(
*msg.as_chunk().unwrap(),
StreamChunk::from_pretty(
" I I I I
U- 3 114 514 4
U+ 2 114 514 10"
)
);
}
#[tokio::test]
async fn test_simple_aggregation_always_output_per_epoch() {
let store = MemoryStateStore::new();
let schema = Schema {
fields: vec![
Field::unnamed(DataType::Int64),
Field::unnamed(DataType::Int64),
Field::unnamed(DataType::Int64),
],
};
let (mut tx, source) = MockSource::channel();
let source = source.into_executor(schema, vec![2]);
tx.push_barrier(test_epoch(1), false);
tx.push_barrier(test_epoch(2), false);
tx.push_chunk(StreamChunk::from_pretty(
" I I I
+ 100 200 1001
- 100 200 1001",
));
tx.push_barrier(test_epoch(3), false);
tx.push_barrier(test_epoch(4), false);
let agg_calls = vec![
AggCall::from_pretty("(count:int8)"),
AggCall::from_pretty("(sum:int8 $0:int8)"),
AggCall::from_pretty("(sum:int8 $1:int8)"),
AggCall::from_pretty("(min:int8 $0:int8)"),
];
let simple_agg = new_boxed_simple_agg_executor(
ActorContext::for_test(123),
store,
source,
false,
agg_calls,
0,
vec![2],
1,
true,
)
.await;
let mut simple_agg = simple_agg.execute();
simple_agg.next().await.unwrap().unwrap();
let msg = simple_agg.next().await.unwrap().unwrap();
assert_eq!(
*msg.as_chunk().unwrap(),
StreamChunk::from_pretty(
" I I I I
+ 0 . . . "
)
);
assert_matches!(
simple_agg.next().await.unwrap().unwrap(),
Message::Barrier { .. }
);
let msg = simple_agg.next().await.unwrap().unwrap();
assert_eq!(
*msg.as_chunk().unwrap(),
StreamChunk::from_pretty(
" I I I I
U- 0 . . .
U+ 0 . . ."
)
);
assert_matches!(
simple_agg.next().await.unwrap().unwrap(),
Message::Barrier { .. }
);
let msg = simple_agg.next().await.unwrap().unwrap();
assert_eq!(
*msg.as_chunk().unwrap(),
StreamChunk::from_pretty(
" I I I I
U- 0 . . .
U+ 0 . . ."
)
);
assert_matches!(
simple_agg.next().await.unwrap().unwrap(),
Message::Barrier { .. }
);
}
#[tokio::test]
async fn test_simple_aggregation_omit_noop_update() {
let store = MemoryStateStore::new();
let schema = Schema {
fields: vec![
Field::unnamed(DataType::Int64),
Field::unnamed(DataType::Int64),
Field::unnamed(DataType::Int64),
],
};
let (mut tx, source) = MockSource::channel();
let source = source.into_executor(schema, vec![2]);
tx.push_barrier(test_epoch(1), false);
tx.push_barrier(test_epoch(2), false);
tx.push_chunk(StreamChunk::from_pretty(
" I I I
+ 100 200 1001
- 100 200 1001",
));
tx.push_barrier(test_epoch(3), false);
tx.push_barrier(test_epoch(4), false);
let agg_calls = vec![
AggCall::from_pretty("(count:int8)"),
AggCall::from_pretty("(sum:int8 $0:int8)"),
AggCall::from_pretty("(sum:int8 $1:int8)"),
AggCall::from_pretty("(min:int8 $0:int8)"),
];
let simple_agg = new_boxed_simple_agg_executor(
ActorContext::for_test(123),
store,
source,
false,
agg_calls,
0,
vec![2],
1,
false,
)
.await;
let mut simple_agg = simple_agg.execute();
simple_agg.next().await.unwrap().unwrap();
let msg = simple_agg.next().await.unwrap().unwrap();
assert_eq!(
*msg.as_chunk().unwrap(),
StreamChunk::from_pretty(
" I I I I
+ 0 . . . "
)
);
assert_matches!(
simple_agg.next().await.unwrap().unwrap(),
Message::Barrier { .. }
);
assert_matches!(
simple_agg.next().await.unwrap().unwrap(),
Message::Barrier { .. }
);
assert_matches!(
simple_agg.next().await.unwrap().unwrap(),
Message::Barrier { .. }
);
}
}