risingwave_batch/
exchange_source.rs1use std::fmt::Debug;
16use std::future::Future;
17
18use futures_async_stream::try_stream;
19use risingwave_common::array::DataChunk;
20
21use crate::error::{BatchError, Result};
22use crate::execution::grpc_exchange::GrpcExchangeSource;
23use crate::execution::local_exchange::LocalExchangeSource;
24use crate::executor::test_utils::FakeExchangeSource;
25use crate::task::TaskId;
26
27pub trait ExchangeSource: Send + Debug {
29 fn take_data(&mut self) -> impl Future<Output = Result<Option<DataChunk>>> + '_;
30
31 fn get_task_id(&self) -> TaskId;
33}
34
35#[derive(Debug)]
36pub enum ExchangeSourceImpl {
37 Grpc(GrpcExchangeSource),
38 Local(LocalExchangeSource),
39 Fake(FakeExchangeSource),
40}
41
42impl ExchangeSourceImpl {
43 pub async fn take_data(&mut self) -> Result<Option<DataChunk>> {
44 match self {
45 ExchangeSourceImpl::Grpc(grpc) => grpc.take_data().await,
46 ExchangeSourceImpl::Local(local) => local.take_data().await,
47 ExchangeSourceImpl::Fake(fake) => fake.take_data().await,
48 }
49 }
50
51 pub fn get_task_id(&self) -> TaskId {
52 match self {
53 ExchangeSourceImpl::Grpc(grpc) => grpc.get_task_id(),
54 ExchangeSourceImpl::Local(local) => local.get_task_id(),
55 ExchangeSourceImpl::Fake(fake) => fake.get_task_id(),
56 }
57 }
58
59 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
60 pub async fn take_data_stream(self) {
61 let mut source = self;
62 loop {
63 match source.take_data().await {
64 Ok(Some(chunk)) => yield chunk,
65 Ok(None) => break,
66 Err(e) => return Err(e),
67 }
68 }
69 }
70}