use std::collections::HashMap;
use std::marker::PhantomData;
use futures::stream;
use itertools::Itertools;
use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
use risingwave_common::util::epoch::EpochPair;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common_estimate_size::collections::EstimatedHashMap;
use risingwave_common_estimate_size::EstimateSize;
use risingwave_expr::aggregate::{build_retractable, AggCall, BoxedAggregateFunction};
use risingwave_pb::stream_plan::PbAggNodeVersion;
use super::agg_common::{AggExecutorArgs, HashAggExecutorExtraArgs};
use super::aggregation::{
agg_call_filter_res, iter_table_storage, AggStateCacheStats, AggStateStorage,
DistinctDeduplicater, GroupKey, OnlyOutputIfHasInput,
};
use super::monitor::HashAggMetrics;
use super::sort_buffer::SortBuffer;
use crate::cache::{cache_may_stale, ManagedLruCache};
use crate::common::metrics::MetricsInfo;
use crate::executor::aggregation::AggGroup as GenericAggGroup;
use crate::executor::prelude::*;
type AggGroup<S> = GenericAggGroup<S, OnlyOutputIfHasInput>;
type BoxedAggGroup<S> = Box<AggGroup<S>>;
impl<S: StateStore> EstimateSize for BoxedAggGroup<S> {
fn estimated_heap_size(&self) -> usize {
self.as_ref().estimated_size()
}
}
type AggGroupCache<K, S> = ManagedLruCache<K, Option<BoxedAggGroup<S>>, PrecomputedBuildHasher>;
pub struct HashAggExecutor<K: HashKey, S: StateStore> {
input: Executor,
inner: ExecutorInner<K, S>,
}
struct ExecutorInner<K: HashKey, S: StateStore> {
_phantom: PhantomData<K>,
version: PbAggNodeVersion,
actor_ctx: ActorContextRef,
info: ExecutorInfo,
input_pk_indices: Vec<usize>,
input_schema: Schema,
group_key_indices: Vec<usize>,
group_key_table_pk_projection: Arc<[usize]>,
agg_calls: Vec<AggCall>,
agg_funcs: Vec<BoxedAggregateFunction>,
row_count_index: usize,
storages: Vec<AggStateStorage<S>>,
intermediate_state_table: StateTable<S>,
distinct_dedup_tables: HashMap<usize, StateTable<S>>,
watermark_sequence: AtomicU64Ref,
extreme_cache_size: usize,
chunk_size: usize,
max_dirty_groups_heap_size: usize,
emit_on_window_close: bool,
}
impl<K: HashKey, S: StateStore> ExecutorInner<K, S> {
fn all_state_tables_mut(&mut self) -> impl Iterator<Item = &mut StateTable<S>> {
iter_table_storage(&mut self.storages)
.chain(self.distinct_dedup_tables.values_mut())
.chain(std::iter::once(&mut self.intermediate_state_table))
}
}
struct ExecutionVars<K: HashKey, S: StateStore> {
metrics: HashAggMetrics,
stats: ExecutionStats,
agg_group_cache: AggGroupCache<K, S>,
dirty_groups: EstimatedHashMap<K, BoxedAggGroup<S>>,
distinct_dedup: DistinctDeduplicater<S>,
buffered_watermarks: Vec<Option<Watermark>>,
window_watermark: Option<ScalarImpl>,
chunk_builder: StreamChunkBuilder,
buffer: SortBuffer<S>,
}
#[derive(Debug, Default)]
struct ExecutionStats {
lookup_miss_count: u64,
total_lookup_count: u64,
chunk_lookup_miss_count: u64,
chunk_total_lookup_count: u64,
agg_state_cache_lookup_count: u64,
agg_state_cache_miss_count: u64,
}
impl ExecutionStats {
fn merge_state_cache_stats(&mut self, other: AggStateCacheStats) {
self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
}
}
impl<K: HashKey, S: StateStore> Execute for HashAggExecutor<K, S> {
fn execute(self: Box<Self>) -> BoxedMessageStream {
self.execute_inner().boxed()
}
}
impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
pub fn new(args: AggExecutorArgs<S, HashAggExecutorExtraArgs>) -> StreamResult<Self> {
let input_info = args.input.info().clone();
let group_key_len = args.extra.group_key_indices.len();
let group_key_table_pk_projection =
&args.intermediate_state_table.pk_indices()[..group_key_len];
assert!(group_key_table_pk_projection
.iter()
.sorted()
.copied()
.eq(0..group_key_len));
Ok(Self {
input: args.input,
inner: ExecutorInner {
_phantom: PhantomData,
version: args.version,
actor_ctx: args.actor_ctx,
info: args.info,
input_pk_indices: input_info.pk_indices,
input_schema: input_info.schema,
group_key_indices: args.extra.group_key_indices,
group_key_table_pk_projection: group_key_table_pk_projection.to_vec().into(),
agg_funcs: args.agg_calls.iter().map(build_retractable).try_collect()?,
agg_calls: args.agg_calls,
row_count_index: args.row_count_index,
storages: args.storages,
intermediate_state_table: args.intermediate_state_table,
distinct_dedup_tables: args.distinct_dedup_tables,
watermark_sequence: args.watermark_epoch,
extreme_cache_size: args.extreme_cache_size,
chunk_size: args.extra.chunk_size,
max_dirty_groups_heap_size: args.extra.max_dirty_groups_heap_size,
emit_on_window_close: args.extra.emit_on_window_close,
},
})
}
fn get_group_visibilities(keys: Vec<K>, base_visibility: &Bitmap) -> Vec<(K, Bitmap)> {
let n_rows = keys.len();
let mut vis_builders = HashMap::new();
for (row_idx, key) in keys
.into_iter()
.enumerate()
.filter(|(row_idx, _)| base_visibility.is_set(*row_idx))
{
vis_builders
.entry(key)
.or_insert_with(|| BitmapBuilder::zeroed(n_rows))
.set(row_idx, true);
}
vis_builders
.into_iter()
.map(|(key, vis_builder)| (key, vis_builder.finish()))
.collect()
}
async fn touch_agg_groups(
this: &ExecutorInner<K, S>,
vars: &mut ExecutionVars<K, S>,
keys: impl IntoIterator<Item = &K>,
) -> StreamExecutorResult<()> {
let group_key_types = &this.info.schema.data_types()[..this.group_key_indices.len()];
let futs = keys
.into_iter()
.filter_map(|key| {
vars.stats.total_lookup_count += 1;
if vars.dirty_groups.contains_key(key) {
return None;
}
match vars.agg_group_cache.get_mut(key) {
Some(mut agg_group) => {
let agg_group: &mut Option<_> = &mut agg_group;
assert!(
agg_group.is_some(),
"invalid state: AggGroup is None in cache but not dirty"
);
vars.dirty_groups
.insert(key.clone(), agg_group.take().unwrap());
None }
None => {
vars.stats.lookup_miss_count += 1;
Some(async {
let agg_group = AggGroup::create(
this.version,
Some(GroupKey::new(
key.deserialize(group_key_types)?,
Some(this.group_key_table_pk_projection.clone()),
)),
&this.agg_calls,
&this.agg_funcs,
&this.storages,
&this.intermediate_state_table,
&this.input_pk_indices,
this.row_count_index,
this.extreme_cache_size,
&this.input_schema,
)
.await?;
Ok::<_, StreamExecutorError>((key.clone(), Box::new(agg_group)))
})
}
}
})
.collect_vec(); vars.stats.chunk_total_lookup_count += 1;
if !futs.is_empty() {
vars.stats.chunk_lookup_miss_count += 1;
let mut buffered = stream::iter(futs).buffer_unordered(10).fuse();
while let Some(result) = buffered.next().await {
let (key, agg_group) = result?;
let none = vars.dirty_groups.insert(key, agg_group);
debug_assert!(none.is_none());
}
}
Ok(())
}
async fn apply_chunk(
this: &mut ExecutorInner<K, S>,
vars: &mut ExecutionVars<K, S>,
chunk: StreamChunk,
) -> StreamExecutorResult<()> {
let keys = K::build_many(&this.group_key_indices, chunk.data_chunk());
let group_visibilities = Self::get_group_visibilities(keys, chunk.visibility());
Self::touch_agg_groups(this, vars, group_visibilities.iter().map(|(k, _)| k)).await?;
let mut call_visibilities = Vec::with_capacity(this.agg_calls.len());
for agg_call in &this.agg_calls {
let agg_call_filter_res = agg_call_filter_res(agg_call, &chunk).await?;
call_visibilities.push(agg_call_filter_res);
}
for ((call, storage), visibility) in (this.agg_calls.iter())
.zip_eq_fast(&mut this.storages)
.zip_eq_fast(call_visibilities.iter())
{
if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage
&& !call.distinct
{
let chunk = chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
table.write_chunk(chunk);
}
}
for (key, visibility) in group_visibilities {
let agg_group: &mut BoxedAggGroup<_> = &mut vars.dirty_groups.get_mut(&key).unwrap();
let visibilities = call_visibilities
.iter()
.map(|call_vis| call_vis & &visibility)
.collect();
let visibilities = vars
.distinct_dedup
.dedup_chunk(
chunk.ops(),
chunk.columns(),
visibilities,
&mut this.distinct_dedup_tables,
agg_group.group_key(),
)
.await?;
for ((call, storage), visibility) in (this.agg_calls.iter())
.zip_eq_fast(&mut this.storages)
.zip_eq_fast(visibilities.iter())
{
if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage
&& call.distinct
{
let chunk =
chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
table.write_chunk(chunk);
}
}
agg_group
.apply_chunk(&chunk, &this.agg_calls, &this.agg_funcs, visibilities)
.await?;
}
vars.metrics
.agg_dirty_groups_count
.set(vars.dirty_groups.len() as i64);
vars.metrics
.agg_dirty_groups_heap_size
.set(vars.dirty_groups.estimated_heap_size() as i64);
Ok(())
}
#[try_stream(ok = StreamChunk, error = StreamExecutorError)]
async fn flush_data<'a>(this: &'a mut ExecutorInner<K, S>, vars: &'a mut ExecutionVars<K, S>) {
let window_watermark = vars.window_watermark.take();
for agg_group in vars.dirty_groups.values() {
let encoded_states = agg_group.encode_states(&this.agg_funcs)?;
if this.emit_on_window_close {
vars.buffer
.update_without_old_value(encoded_states, &mut this.intermediate_state_table);
} else {
this.intermediate_state_table
.update_without_old_value(encoded_states);
}
}
if this.emit_on_window_close {
if let Some(watermark) = window_watermark.as_ref() {
#[for_await]
for row in vars
.buffer
.consume(watermark.clone(), &mut this.intermediate_state_table)
{
let row = row?;
let group_key = row
.clone()
.into_iter()
.take(this.group_key_indices.len())
.collect();
let states = row.into_iter().skip(this.group_key_indices.len()).collect();
let mut agg_group = AggGroup::create_eowc(
this.version,
Some(GroupKey::new(
group_key,
Some(this.group_key_table_pk_projection.clone()),
)),
&this.agg_calls,
&this.agg_funcs,
&this.storages,
&states,
&this.input_pk_indices,
this.row_count_index,
this.extreme_cache_size,
&this.input_schema,
)?;
let (change, stats) = agg_group
.build_change(&this.storages, &this.agg_funcs)
.await?;
vars.stats.merge_state_cache_stats(stats);
if let Some(change) = change {
if let Some(chunk) = vars.chunk_builder.append_record(change) {
yield chunk;
}
}
}
}
} else {
for mut agg_group in vars.dirty_groups.values_mut() {
let agg_group = agg_group.as_mut();
let (change, stats) = agg_group
.build_change(&this.storages, &this.agg_funcs)
.await?;
vars.stats.merge_state_cache_stats(stats);
if let Some(change) = change {
if let Some(chunk) = vars.chunk_builder.append_record(change) {
yield chunk;
}
}
}
}
for (key, agg_group) in vars.dirty_groups.drain() {
vars.agg_group_cache.put(key, Some(agg_group));
}
if let Some(chunk) = vars.chunk_builder.take() {
yield chunk;
}
if let Some(watermark) = window_watermark {
this.all_state_tables_mut()
.for_each(|table| table.update_watermark(watermark.clone()));
}
vars.distinct_dedup.flush(&mut this.distinct_dedup_tables)?;
vars.agg_group_cache.evict();
}
fn flush_metrics(_this: &ExecutorInner<K, S>, vars: &mut ExecutionVars<K, S>) {
vars.metrics
.agg_lookup_miss_count
.inc_by(std::mem::take(&mut vars.stats.lookup_miss_count));
vars.metrics
.agg_total_lookup_count
.inc_by(std::mem::take(&mut vars.stats.total_lookup_count));
vars.metrics
.agg_cached_entry_count
.set(vars.agg_group_cache.len() as i64);
vars.metrics
.agg_chunk_lookup_miss_count
.inc_by(std::mem::take(&mut vars.stats.chunk_lookup_miss_count));
vars.metrics
.agg_chunk_total_lookup_count
.inc_by(std::mem::take(&mut vars.stats.chunk_total_lookup_count));
vars.metrics
.agg_state_cache_lookup_count
.inc_by(std::mem::take(&mut vars.stats.agg_state_cache_lookup_count));
vars.metrics
.agg_state_cache_miss_count
.inc_by(std::mem::take(&mut vars.stats.agg_state_cache_miss_count));
}
async fn commit_state_tables(
this: &mut ExecutorInner<K, S>,
epoch: EpochPair,
) -> StreamExecutorResult<()> {
futures::future::try_join_all(
this.all_state_tables_mut()
.map(|table| async { table.commit(epoch).await }),
)
.await?;
Ok(())
}
async fn try_flush_data(this: &mut ExecutorInner<K, S>) -> StreamExecutorResult<()> {
futures::future::try_join_all(
this.all_state_tables_mut()
.map(|table| async { table.try_flush().await }),
)
.await?;
Ok(())
}
#[try_stream(ok = Message, error = StreamExecutorError)]
async fn execute_inner(self) {
let HashAggExecutor {
input,
inner: mut this,
} = self;
let window_col_idx_in_group_key = this.intermediate_state_table.pk_indices()[0];
let window_col_idx = this.group_key_indices[window_col_idx_in_group_key];
let agg_group_cache_metrics_info = MetricsInfo::new(
this.actor_ctx.streaming_metrics.clone(),
this.intermediate_state_table.table_id(),
this.actor_ctx.id,
"agg intermediate state table",
);
let metrics = this.actor_ctx.streaming_metrics.new_hash_agg_metrics(
this.intermediate_state_table.table_id(),
this.actor_ctx.id,
this.actor_ctx.fragment_id,
);
let mut vars = ExecutionVars {
metrics,
stats: ExecutionStats::default(),
agg_group_cache: ManagedLruCache::unbounded_with_hasher(
this.watermark_sequence.clone(),
agg_group_cache_metrics_info,
PrecomputedBuildHasher,
),
dirty_groups: Default::default(),
distinct_dedup: DistinctDeduplicater::new(
&this.agg_calls,
this.watermark_sequence.clone(),
&this.distinct_dedup_tables,
&this.actor_ctx,
),
buffered_watermarks: vec![None; this.group_key_indices.len()],
window_watermark: None,
chunk_builder: StreamChunkBuilder::new(this.chunk_size, this.info.schema.data_types()),
buffer: SortBuffer::new(window_col_idx_in_group_key, &this.intermediate_state_table),
};
let group_key_invert_idx = {
let mut group_key_invert_idx = vec![None; input.info().schema.len()];
for (group_key_seq, group_key_idx) in this.group_key_indices.iter().enumerate() {
group_key_invert_idx[*group_key_idx] = Some(group_key_seq);
}
group_key_invert_idx
};
let mut input = input.execute();
let barrier = expect_first_barrier(&mut input).await?;
let first_epoch = barrier.epoch;
yield Message::Barrier(barrier);
for table in this.all_state_tables_mut() {
table.init_epoch(first_epoch).await?;
}
#[for_await]
for msg in input {
let msg = msg?;
vars.agg_group_cache.evict();
match msg {
Message::Watermark(watermark) => {
let group_key_seq = group_key_invert_idx[watermark.col_idx];
if let Some(group_key_seq) = group_key_seq {
if watermark.col_idx == window_col_idx {
vars.window_watermark = Some(watermark.val.clone());
}
vars.buffered_watermarks[group_key_seq] =
Some(watermark.with_idx(group_key_seq));
}
}
Message::Chunk(chunk) => {
Self::apply_chunk(&mut this, &mut vars, chunk).await?;
if vars.dirty_groups.estimated_heap_size() >= this.max_dirty_groups_heap_size {
#[for_await]
for chunk in Self::flush_data(&mut this, &mut vars) {
yield Message::Chunk(chunk?);
}
}
Self::try_flush_data(&mut this).await?;
}
Message::Barrier(barrier) => {
#[for_await]
for chunk in Self::flush_data(&mut this, &mut vars) {
yield Message::Chunk(chunk?);
}
Self::flush_metrics(&this, &mut vars);
Self::commit_state_tables(&mut this, barrier.epoch).await?;
if this.emit_on_window_close {
if let Some(watermark) =
vars.buffered_watermarks[window_col_idx_in_group_key].take()
{
yield Message::Watermark(watermark);
}
} else {
for buffered_watermark in &mut vars.buffered_watermarks {
if let Some(watermark) = buffered_watermark.take() {
yield Message::Watermark(watermark);
}
}
}
if let Some(vnode_bitmap) = barrier.as_update_vnode_bitmap(this.actor_ctx.id) {
let previous_vnode_bitmap = this.intermediate_state_table.vnodes().clone();
this.all_state_tables_mut().for_each(|table| {
let _ = table.update_vnode_bitmap(vnode_bitmap.clone());
});
if cache_may_stale(&previous_vnode_bitmap, &vnode_bitmap) {
vars.agg_group_cache.clear();
vars.distinct_dedup.dedup_caches_mut().for_each(|cache| {
cache.clear();
});
}
}
yield Message::Barrier(barrier);
}
}
}
}
}