risingwave_batch/
exchange_source.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::future::Future;
17
18use futures_async_stream::try_stream;
19use risingwave_common::array::DataChunk;
20
21use crate::error::{BatchError, Result};
22use crate::execution::grpc_exchange::GrpcExchangeSource;
23use crate::execution::local_exchange::LocalExchangeSource;
24use crate::executor::test_utils::FakeExchangeSource;
25use crate::task::TaskId;
26
27/// Each `ExchangeSource` maps to one task, it takes the execution result from task chunk by chunk.
28pub trait ExchangeSource: Send + Debug {
29    fn take_data(&mut self) -> impl Future<Output = Result<Option<DataChunk>>> + '_;
30
31    /// Get upstream task id.
32    fn get_task_id(&self) -> TaskId;
33}
34
35#[derive(Debug)]
36pub enum ExchangeSourceImpl {
37    Grpc(GrpcExchangeSource),
38    Local(LocalExchangeSource),
39    Fake(FakeExchangeSource),
40}
41
42impl ExchangeSourceImpl {
43    pub async fn take_data(&mut self) -> Result<Option<DataChunk>> {
44        match self {
45            ExchangeSourceImpl::Grpc(grpc) => grpc.take_data().await,
46            ExchangeSourceImpl::Local(local) => local.take_data().await,
47            ExchangeSourceImpl::Fake(fake) => fake.take_data().await,
48        }
49    }
50
51    pub fn get_task_id(&self) -> TaskId {
52        match self {
53            ExchangeSourceImpl::Grpc(grpc) => grpc.get_task_id(),
54            ExchangeSourceImpl::Local(local) => local.get_task_id(),
55            ExchangeSourceImpl::Fake(fake) => fake.get_task_id(),
56        }
57    }
58
59    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
60    pub async fn take_data_stream(self) {
61        let mut source = self;
62        loop {
63            match source.take_data().await {
64                Ok(Some(chunk)) => yield chunk,
65                Ok(None) => break,
66                Err(e) => return Err(e),
67            }
68        }
69    }
70}