risingwave_stream/executor/
lookup_union.rs1use async_trait::async_trait;
16use futures::channel::mpsc;
17use futures::future::{Either, join_all, select};
18use futures::{FutureExt, SinkExt};
19use itertools::Itertools;
20
21use crate::executor::prelude::*;
22
23pub struct LookupUnionExecutor {
28 inputs: Vec<Executor>,
29 order: Vec<usize>,
30}
31
32impl std::fmt::Debug for LookupUnionExecutor {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("LookupUnionExecutor").finish()
35 }
36}
37
38impl LookupUnionExecutor {
39 pub fn new(inputs: Vec<Executor>, order: Vec<u32>) -> Self {
40 Self {
41 inputs,
42 order: order.iter().map(|x| *x as _).collect(),
43 }
44 }
45}
46
47#[async_trait]
48impl Execute for LookupUnionExecutor {
49 fn execute(self: Box<Self>) -> BoxedMessageStream {
50 self.execute_inner().boxed()
51 }
52}
53
54impl LookupUnionExecutor {
55 #[try_stream(ok = Message, error = StreamExecutorError)]
56 async fn execute_inner(self) {
57 let mut inputs = self.inputs.into_iter().map(Some).collect_vec();
58 let mut futures = vec![];
59 let mut rxs = vec![];
60 for idx in self.order {
61 let mut stream = inputs[idx].take().unwrap().execute();
62 let (mut tx, rx) = mpsc::channel(1024); rxs.push(rx);
64 futures.push(
65 async move {
68 while let Some(ret) = stream.next().await {
69 tx.send(ret).await.unwrap();
70 }
71 }
72 .boxed(),
73 );
74 }
75 let mut drive_inputs = join_all(futures).fuse();
77 let mut end = false;
78 while !end {
79 end = true; let mut this_barrier: Option<Barrier> = None;
81 for rx in &mut rxs {
82 loop {
83 let msg = match select(rx.next(), &mut drive_inputs).await {
84 Either::Left((Some(msg), _)) => msg?,
85 Either::Left((None, _)) => break, Either::Right(_) => continue,
87 };
88 end = false;
89 match msg {
90 Message::Watermark(_) => {
91 }
93
94 msg @ Message::Chunk(_) => yield msg,
95 Message::Barrier(barrier) => {
96 if let Some(this_barrier) = &this_barrier {
97 if this_barrier.epoch != barrier.epoch {
98 return Err(StreamExecutorError::align_barrier(
99 this_barrier.clone(),
100 barrier,
101 ));
102 }
103 } else {
104 this_barrier = Some(barrier);
105 }
106 break; }
108 }
109 }
110 }
111 if end {
112 break;
113 } else {
114 yield Message::Barrier(this_barrier.take().unwrap());
115 }
116 }
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use futures::TryStreamExt;
123 use risingwave_common::catalog::Field;
124 use risingwave_common::test_prelude::StreamChunkTestExt;
125 use risingwave_common::util::epoch::test_epoch;
126
127 use super::*;
128 use crate::executor::test_utils::MockSource;
129
130 #[tokio::test]
131 async fn lookup_union() {
132 let schema = Schema {
133 fields: vec![Field::unnamed(DataType::Int64)],
134 };
135 let source0 = MockSource::with_messages(vec![
136 Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
137 Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
138 Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
139 Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
140 Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
141 Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
142 ])
143 .stop_on_finish(false)
144 .into_executor(schema.clone(), vec![0]);
145 let source1 = MockSource::with_messages(vec![
146 Message::Chunk(StreamChunk::from_pretty("I\n + 11")),
147 Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
148 Message::Chunk(StreamChunk::from_pretty("I\n + 12")),
149 Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
150 ])
151 .stop_on_finish(false)
152 .into_executor(schema.clone(), vec![0]);
153 let source2 = MockSource::with_messages(vec![
154 Message::Chunk(StreamChunk::from_pretty("I\n + 21")),
155 Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
156 Message::Chunk(StreamChunk::from_pretty("I\n + 22")),
157 Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
158 ])
159 .stop_on_finish(false)
160 .into_executor(schema, vec![0]);
161
162 let executor = LookupUnionExecutor::new(vec![source0, source1, source2], vec![2, 1, 0])
163 .boxed()
164 .execute();
165
166 let outputs: Vec<_> = executor.try_collect().await.unwrap();
167 assert_eq!(
168 outputs,
169 vec![
170 Message::Chunk(StreamChunk::from_pretty("I\n + 21")),
171 Message::Chunk(StreamChunk::from_pretty("I\n + 11")),
172 Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
173 Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
174 Message::Chunk(StreamChunk::from_pretty("I\n + 22")),
175 Message::Chunk(StreamChunk::from_pretty("I\n + 12")),
176 Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
177 Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
178 Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
179 Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
180 ]
181 );
182 }
183}