risingwave_stream/executor/
lookup_union.rsuse async_trait::async_trait;
use futures::channel::mpsc;
use futures::future::{join_all, select, Either};
use futures::{FutureExt, SinkExt};
use itertools::Itertools;
use crate::executor::prelude::*;
pub struct LookupUnionExecutor {
inputs: Vec<Executor>,
order: Vec<usize>,
}
impl std::fmt::Debug for LookupUnionExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LookupUnionExecutor").finish()
}
}
impl LookupUnionExecutor {
pub fn new(inputs: Vec<Executor>, order: Vec<u32>) -> Self {
Self {
inputs,
order: order.iter().map(|x| *x as _).collect(),
}
}
}
#[async_trait]
impl Execute for LookupUnionExecutor {
fn execute(self: Box<Self>) -> BoxedMessageStream {
self.execute_inner().boxed()
}
}
impl LookupUnionExecutor {
#[try_stream(ok = Message, error = StreamExecutorError)]
async fn execute_inner(self) {
let mut inputs = self.inputs.into_iter().map(Some).collect_vec();
let mut futures = vec![];
let mut rxs = vec![];
for idx in self.order {
let mut stream = inputs[idx].take().unwrap().execute();
let (mut tx, rx) = mpsc::channel(1024); rxs.push(rx);
futures.push(
async move {
while let Some(ret) = stream.next().await {
tx.send(ret).await.unwrap();
}
}
.boxed(),
);
}
let mut drive_inputs = join_all(futures).fuse();
let mut end = false;
while !end {
end = true; let mut this_barrier: Option<Barrier> = None;
for rx in &mut rxs {
loop {
let msg = match select(rx.next(), &mut drive_inputs).await {
Either::Left((Some(msg), _)) => msg?,
Either::Left((None, _)) => break, Either::Right(_) => continue,
};
end = false;
match msg {
Message::Watermark(_) => {
}
msg @ Message::Chunk(_) => yield msg,
Message::Barrier(barrier) => {
if let Some(this_barrier) = &this_barrier {
if this_barrier.epoch != barrier.epoch {
return Err(StreamExecutorError::align_barrier(
this_barrier.clone(),
barrier,
));
}
} else {
this_barrier = Some(barrier);
}
break; }
}
}
}
if end {
break;
} else {
yield Message::Barrier(this_barrier.take().unwrap());
}
}
}
}
#[cfg(test)]
mod tests {
use futures::TryStreamExt;
use risingwave_common::catalog::Field;
use risingwave_common::test_prelude::StreamChunkTestExt;
use risingwave_common::util::epoch::test_epoch;
use super::*;
use crate::executor::test_utils::MockSource;
#[tokio::test]
async fn lookup_union() {
let schema = Schema {
fields: vec![Field::unnamed(DataType::Int64)],
};
let source0 = MockSource::with_messages(vec![
Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
])
.stop_on_finish(false)
.into_executor(schema.clone(), vec![0]);
let source1 = MockSource::with_messages(vec![
Message::Chunk(StreamChunk::from_pretty("I\n + 11")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
Message::Chunk(StreamChunk::from_pretty("I\n + 12")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
])
.stop_on_finish(false)
.into_executor(schema.clone(), vec![0]);
let source2 = MockSource::with_messages(vec![
Message::Chunk(StreamChunk::from_pretty("I\n + 21")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
Message::Chunk(StreamChunk::from_pretty("I\n + 22")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
])
.stop_on_finish(false)
.into_executor(schema, vec![0]);
let executor = LookupUnionExecutor::new(vec![source0, source1, source2], vec![2, 1, 0])
.boxed()
.execute();
let outputs: Vec<_> = executor.try_collect().await.unwrap();
assert_eq!(
outputs,
vec![
Message::Chunk(StreamChunk::from_pretty("I\n + 21")),
Message::Chunk(StreamChunk::from_pretty("I\n + 11")),
Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
Message::Chunk(StreamChunk::from_pretty("I\n + 22")),
Message::Chunk(StreamChunk::from_pretty("I\n + 12")),
Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
]
);
}
}