use async_trait::async_trait;
use futures::{FutureExt, StreamExt, TryStreamExt};
use futures_async_stream::try_stream;
use risingwave_common::catalog::Schema;
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_common::util::epoch::{test_epoch, EpochExt};
use tokio::sync::mpsc;
use super::error::StreamExecutorError;
use super::{
Barrier, BoxedMessageStream, Execute, Executor, ExecutorInfo, Message, MessageStream,
StreamChunk, StreamExecutorResult, Watermark,
};
pub mod prelude {
pub use std::sync::atomic::AtomicU64;
pub use std::sync::Arc;
pub use risingwave_common::array::StreamChunk;
pub use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
pub use risingwave_common::test_prelude::StreamChunkTestExt;
pub use risingwave_common::types::DataType;
pub use risingwave_common::util::sort_util::OrderType;
pub use risingwave_storage::memory::MemoryStateStore;
pub use risingwave_storage::StateStore;
pub use crate::common::table::state_table::StateTable;
pub use crate::executor::test_utils::expr::build_from_pretty;
pub use crate::executor::test_utils::{MessageSender, MockSource, StreamExecutorTestExt};
pub use crate::executor::{ActorContext, BoxedMessageStream, Execute, PkIndices};
}
pub struct MockSource {
rx: mpsc::UnboundedReceiver<Message>,
stop_on_finish: bool,
}
pub struct MessageSender(mpsc::UnboundedSender<Message>);
impl MessageSender {
#[allow(dead_code)]
pub fn push_chunk(&mut self, chunk: StreamChunk) {
self.0.send(Message::Chunk(chunk)).unwrap();
}
#[allow(dead_code)]
pub fn push_barrier(&mut self, epoch: u64, stop: bool) {
let mut barrier = Barrier::new_test_barrier(epoch);
if stop {
barrier = barrier.with_stop();
}
self.0.send(Message::Barrier(barrier)).unwrap();
}
pub fn send_barrier(&self, barrier: Barrier) {
self.0.send(Message::Barrier(barrier)).unwrap();
}
#[allow(dead_code)]
pub fn push_barrier_with_prev_epoch_for_test(
&mut self,
cur_epoch: u64,
prev_epoch: u64,
stop: bool,
) {
let mut barrier = Barrier::with_prev_epoch_for_test(cur_epoch, prev_epoch);
if stop {
barrier = barrier.with_stop();
}
self.0.send(Message::Barrier(barrier)).unwrap();
}
#[allow(dead_code)]
pub fn push_watermark(&mut self, col_idx: usize, data_type: DataType, val: ScalarImpl) {
self.0
.send(Message::Watermark(Watermark {
col_idx,
data_type,
val,
}))
.unwrap();
}
#[allow(dead_code)]
pub fn push_int64_watermark(&mut self, col_idx: usize, val: i64) {
self.push_watermark(col_idx, DataType::Int64, ScalarImpl::Int64(val));
}
}
impl std::fmt::Debug for MockSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockSource").finish()
}
}
impl MockSource {
#[allow(dead_code)]
pub fn channel() -> (MessageSender, Self) {
let (tx, rx) = mpsc::unbounded_channel();
let source = Self {
rx,
stop_on_finish: true,
};
(MessageSender(tx), source)
}
#[allow(dead_code)]
pub fn with_messages(msgs: Vec<Message>) -> Self {
let (tx, source) = Self::channel();
for msg in msgs {
tx.0.send(msg).unwrap();
}
source
}
pub fn with_chunks(chunks: Vec<StreamChunk>) -> Self {
let (tx, source) = Self::channel();
for chunk in chunks {
tx.0.send(Message::Chunk(chunk)).unwrap();
}
source
}
#[allow(dead_code)]
#[must_use]
pub fn stop_on_finish(self, stop_on_finish: bool) -> Self {
Self {
stop_on_finish,
..self
}
}
pub fn into_executor(self, schema: Schema, pk_indices: Vec<usize>) -> Executor {
Executor::new(
ExecutorInfo {
schema,
pk_indices,
identity: "MockSource".to_string(),
},
self.boxed(),
)
}
#[try_stream(ok = Message, error = StreamExecutorError)]
async fn execute_inner(mut self: Box<Self>) {
let mut epoch = test_epoch(1);
while let Some(msg) = self.rx.recv().await {
epoch.inc_epoch();
yield msg;
}
if self.stop_on_finish {
yield Message::Barrier(Barrier::new_test_barrier(epoch).with_stop());
}
}
}
impl Execute for MockSource {
fn execute(self: Box<Self>) -> super::BoxedMessageStream {
self.execute_inner().boxed()
}
}
#[macro_export]
macro_rules! row_nonnull {
[$( $value:expr ),*] => {
{
risingwave_common::row::OwnedRow::new(vec![$(Some($value.into()), )*])
}
};
}
#[async_trait]
pub trait StreamExecutorTestExt: MessageStream + Unpin {
fn next_unwrap_pending(&mut self) {
if let Some(r) = self.try_next().now_or_never() {
panic!("expect pending stream, but got `{:?}`", r);
}
}
fn next_unwrap_ready(&mut self) -> StreamExecutorResult<Message> {
match self.next().now_or_never() {
Some(Some(r)) => r,
Some(None) => panic!("expect ready stream, but got terminated"),
None => panic!("expect ready stream, but got pending"),
}
}
fn next_unwrap_ready_chunk(&mut self) -> StreamExecutorResult<StreamChunk> {
self.next_unwrap_ready()
.map(|msg| msg.into_chunk().expect("expect chunk"))
}
fn next_unwrap_ready_barrier(&mut self) -> StreamExecutorResult<Barrier> {
self.next_unwrap_ready()
.map(|msg| msg.into_barrier().expect("expect barrier"))
}
fn next_unwrap_ready_watermark(&mut self) -> StreamExecutorResult<Watermark> {
self.next_unwrap_ready()
.map(|msg| msg.into_watermark().expect("expect watermark"))
}
async fn expect_barrier(&mut self) -> Barrier {
let msg = self.next().await.unwrap().unwrap();
msg.into_barrier().unwrap()
}
async fn expect_chunk(&mut self) -> StreamChunk {
let msg = self.next().await.unwrap().unwrap();
msg.into_chunk().unwrap()
}
async fn expect_watermark(&mut self) -> Watermark {
let msg = self.next().await.unwrap().unwrap();
msg.into_watermark().unwrap()
}
}
impl StreamExecutorTestExt for BoxedMessageStream {}
pub mod expr {
use risingwave_expr::expr::NonStrictExpression;
pub fn build_from_pretty(s: impl AsRef<str>) -> NonStrictExpression {
NonStrictExpression::for_test(risingwave_expr::expr::build_from_pretty(s))
}
}
pub mod agg_executor {
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use futures::future;
use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
use risingwave_common::hash::SerializedKey;
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_expr::aggregate::{AggCall, AggType, PbAggKind};
use risingwave_pb::stream_plan::PbAggNodeVersion;
use risingwave_storage::StateStore;
use crate::common::table::state_table::StateTable;
use crate::common::table::test_utils::gen_pbtable;
use crate::common::StateTableColumnMapping;
use crate::executor::agg_common::{
AggExecutorArgs, HashAggExecutorExtraArgs, SimpleAggExecutorExtraArgs,
};
use crate::executor::aggregation::AggStateStorage;
use crate::executor::{
ActorContext, ActorContextRef, Executor, ExecutorInfo, HashAggExecutor, PkIndices,
SimpleAggExecutor,
};
pub fn generate_agg_schema(
input_ref: &Executor,
agg_calls: &[AggCall],
group_key_indices: Option<&[usize]>,
) -> Schema {
let aggs = agg_calls
.iter()
.map(|agg| Field::unnamed(agg.return_type.clone()));
let fields = if let Some(key_indices) = group_key_indices {
let keys = key_indices
.iter()
.map(|idx| input_ref.schema().fields[*idx].clone());
keys.chain(aggs).collect()
} else {
aggs.collect()
};
Schema { fields }
}
pub async fn create_agg_state_storage<S: StateStore>(
store: S,
table_id: TableId,
agg_call: &AggCall,
group_key_indices: &[usize],
pk_indices: &[usize],
input_fields: Vec<Field>,
is_append_only: bool,
) -> AggStateStorage<S> {
match agg_call.agg_type {
AggType::Builtin(PbAggKind::Min | PbAggKind::Max) if !is_append_only => {
let mut column_descs = Vec::new();
let mut order_types = Vec::new();
let mut upstream_columns = Vec::new();
let mut order_columns = Vec::new();
let mut next_column_id = 0;
let mut add_column = |upstream_idx: usize, data_type: DataType, order_type: Option<OrderType>| {
upstream_columns.push(upstream_idx);
column_descs.push(ColumnDesc::unnamed(
ColumnId::new(next_column_id),
data_type,
));
if let Some(order_type) = order_type {
order_columns.push(ColumnOrder::new(upstream_idx as _, order_type));
order_types.push(order_type);
}
next_column_id += 1;
};
for idx in group_key_indices {
add_column(*idx, input_fields[*idx].data_type(), None);
}
add_column(agg_call.args.val_indices()[0], agg_call.args.arg_types()[0].clone(), if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Max)) {
Some(OrderType::descending())
} else {
Some(OrderType::ascending())
});
for idx in pk_indices {
add_column(*idx, input_fields[*idx].data_type(), Some(OrderType::ascending()));
}
let state_table = StateTable::from_table_catalog(
&gen_pbtable(
table_id,
column_descs,
order_types.clone(),
(0..order_types.len()).collect(),
0,
),
store,
None,
).await;
AggStateStorage::MaterializedInput { table: state_table, mapping: StateTableColumnMapping::new(upstream_columns, None), order_columns }
}
AggType::Builtin(
PbAggKind::Min | PbAggKind::Max | PbAggKind::Sum
| PbAggKind::Sum0
| PbAggKind::Count
| PbAggKind::Avg
| PbAggKind::ApproxCountDistinct
) => {
AggStateStorage::Value
}
_ => {
panic!("no need to mock other agg kinds here");
}
}
}
pub async fn create_intermediate_state_table<S: StateStore>(
store: S,
table_id: TableId,
agg_calls: &[AggCall],
group_key_indices: &[usize],
input_fields: Vec<Field>,
) -> StateTable<S> {
let mut column_descs = Vec::new();
let mut order_types = Vec::new();
let mut next_column_id = 0;
let mut add_column_desc = |data_type: DataType| {
column_descs.push(ColumnDesc::unnamed(
ColumnId::new(next_column_id),
data_type,
));
next_column_id += 1;
};
group_key_indices.iter().for_each(|idx| {
add_column_desc(input_fields[*idx].data_type());
order_types.push(OrderType::ascending());
});
agg_calls.iter().for_each(|agg_call| {
add_column_desc(agg_call.return_type.clone());
});
StateTable::from_table_catalog_inconsistent_op(
&gen_pbtable(
table_id,
column_descs,
order_types,
(0..group_key_indices.len()).collect(),
0,
),
store,
None,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn new_boxed_hash_agg_executor<S: StateStore>(
store: S,
input: Executor,
is_append_only: bool,
agg_calls: Vec<AggCall>,
row_count_index: usize,
group_key_indices: Vec<usize>,
pk_indices: PkIndices,
extreme_cache_size: usize,
emit_on_window_close: bool,
executor_id: u64,
) -> Executor {
let mut storages = Vec::with_capacity(agg_calls.iter().len());
for (idx, agg_call) in agg_calls.iter().enumerate() {
storages.push(
create_agg_state_storage(
store.clone(),
TableId::new(idx as u32),
agg_call,
&group_key_indices,
&pk_indices,
input.info.schema.fields.clone(),
is_append_only,
)
.await,
)
}
let intermediate_state_table = create_intermediate_state_table(
store,
TableId::new(agg_calls.len() as u32),
&agg_calls,
&group_key_indices,
input.info.schema.fields.clone(),
)
.await;
let schema = generate_agg_schema(&input, &agg_calls, Some(&group_key_indices));
let info = ExecutorInfo {
schema,
pk_indices,
identity: format!("HashAggExecutor {:X}", executor_id),
};
let exec = HashAggExecutor::<SerializedKey, S>::new(AggExecutorArgs {
version: PbAggNodeVersion::Max,
input,
actor_ctx: ActorContext::for_test(123),
info: info.clone(),
extreme_cache_size,
agg_calls,
row_count_index,
storages,
intermediate_state_table,
distinct_dedup_tables: Default::default(),
watermark_epoch: Arc::new(AtomicU64::new(0)),
extra: HashAggExecutorExtraArgs {
group_key_indices,
chunk_size: 1024,
max_dirty_groups_heap_size: 64 << 20,
emit_on_window_close,
},
})
.unwrap();
(info, exec).into()
}
#[allow(clippy::too_many_arguments)]
pub async fn new_boxed_simple_agg_executor<S: StateStore>(
actor_ctx: ActorContextRef,
store: S,
input: Executor,
is_append_only: bool,
agg_calls: Vec<AggCall>,
row_count_index: usize,
pk_indices: PkIndices,
executor_id: u64,
must_output_per_barrier: bool,
) -> Executor {
let storages = future::join_all(agg_calls.iter().enumerate().map(|(idx, agg_call)| {
create_agg_state_storage(
store.clone(),
TableId::new(idx as u32),
agg_call,
&[],
&pk_indices,
input.info.schema.fields.clone(),
is_append_only,
)
}))
.await;
let intermediate_state_table = create_intermediate_state_table(
store,
TableId::new(agg_calls.len() as u32),
&agg_calls,
&[],
input.info.schema.fields.clone(),
)
.await;
let schema = generate_agg_schema(&input, &agg_calls, None);
let info = ExecutorInfo {
schema,
pk_indices,
identity: format!("SimpleAggExecutor {:X}", executor_id),
};
let exec = SimpleAggExecutor::new(AggExecutorArgs {
version: PbAggNodeVersion::Max,
input,
actor_ctx,
info: info.clone(),
extreme_cache_size: 1024,
agg_calls,
row_count_index,
storages,
intermediate_state_table,
distinct_dedup_tables: Default::default(),
watermark_epoch: Arc::new(AtomicU64::new(0)),
extra: SimpleAggExecutorExtraArgs {
must_output_per_barrier,
},
})
.unwrap();
(info, exec).into()
}
}
pub mod top_n_executor {
use itertools::Itertools;
use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId};
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::OrderType;
use risingwave_storage::memory::MemoryStateStore;
use crate::common::table::state_table::StateTable;
use crate::common::table::test_utils::gen_pbtable;
pub async fn create_in_memory_state_table(
data_types: &[DataType],
order_types: &[OrderType],
pk_indices: &[usize],
) -> StateTable<MemoryStateStore> {
create_in_memory_state_table_from_state_store(
data_types,
order_types,
pk_indices,
MemoryStateStore::new(),
)
.await
}
pub async fn create_in_memory_state_table_from_state_store(
data_types: &[DataType],
order_types: &[OrderType],
pk_indices: &[usize],
state_store: MemoryStateStore,
) -> StateTable<MemoryStateStore> {
let column_descs = data_types
.iter()
.enumerate()
.map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
.collect_vec();
StateTable::from_table_catalog(
&gen_pbtable(
TableId::new(0),
column_descs,
order_types.to_vec(),
pk_indices.to_vec(),
0,
),
state_store,
None,
)
.await
}
}