risingwave_batch/execution/
local_exchange.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{Debug, Formatter};
16
17use risingwave_common::array::DataChunk;
18
19use crate::error::Result;
20use crate::exchange_source::ExchangeSource;
21use crate::task::{BatchTaskContext, TaskId, TaskOutput, TaskOutputId};
22
23/// Exchange data from a local task execution.
24pub struct LocalExchangeSource {
25    task_output: TaskOutput,
26
27    /// Id of task which contains the `ExchangeExecutor` of this source.
28    task_id: TaskId,
29}
30
31impl LocalExchangeSource {
32    pub fn create(
33        output_id: TaskOutputId,
34        context: &dyn BatchTaskContext,
35        task_id: TaskId,
36    ) -> Result<Self> {
37        let task_output = context.get_task_output(output_id)?;
38        Ok(Self {
39            task_output,
40            task_id,
41        })
42    }
43}
44
45impl Debug for LocalExchangeSource {
46    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("LocalExchangeSource")
48            .field("task_output_id", self.task_output.id())
49            .finish()
50    }
51}
52
53impl ExchangeSource for LocalExchangeSource {
54    async fn take_data(&mut self) -> Result<Option<DataChunk>> {
55        let ret = self.task_output.direct_take_data().await?;
56        if let Some(data) = ret {
57            let data = data.compact();
58            trace!(
59                "Receiver task: {:?}, source task output: {:?}, data: {:?}",
60                self.task_id,
61                self.task_output.id(),
62                data
63            );
64            Ok(Some(data))
65        } else {
66            Ok(None)
67        }
68    }
69
70    fn get_task_id(&self) -> TaskId {
71        self.task_id.clone()
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use std::net::SocketAddr;
78    use std::sync::Arc;
79    use std::sync::atomic::{AtomicBool, Ordering};
80    use std::time::Duration;
81
82    use risingwave_common::config::RpcClientConfig;
83    use risingwave_pb::batch_plan::{TaskId, TaskOutputId};
84    use risingwave_pb::data::DataChunk;
85    use risingwave_pb::task_service::exchange_service_server::{
86        ExchangeService, ExchangeServiceServer,
87    };
88    use risingwave_pb::task_service::{
89        GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
90    };
91    use risingwave_rpc_client::ComputeClient;
92    use tokio::time::sleep;
93    use tokio_stream::wrappers::ReceiverStream;
94    use tonic::{Request, Response, Status, Streaming};
95
96    use crate::exchange_source::ExchangeSource;
97    use crate::execution::grpc_exchange::GrpcExchangeSource;
98
99    struct FakeExchangeService {
100        rpc_called: Arc<AtomicBool>,
101    }
102
103    #[async_trait::async_trait]
104    impl ExchangeService for FakeExchangeService {
105        type GetDataStream = ReceiverStream<Result<GetDataResponse, Status>>;
106        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
107
108        async fn get_data(
109            &self,
110            _: Request<GetDataRequest>,
111        ) -> Result<Response<Self::GetDataStream>, Status> {
112            let (tx, rx) = tokio::sync::mpsc::channel(10);
113            self.rpc_called.store(true, Ordering::SeqCst);
114            for _ in 0..3 {
115                tx.send(Ok(GetDataResponse {
116                    record_batch: Some(DataChunk::default()),
117                }))
118                .await
119                .unwrap();
120            }
121            Ok(Response::new(ReceiverStream::new(rx)))
122        }
123
124        async fn get_stream(
125            &self,
126            _request: Request<Streaming<GetStreamRequest>>,
127        ) -> Result<Response<Self::GetStreamStream>, Status> {
128            unimplemented!()
129        }
130    }
131
132    #[tokio::test]
133    async fn test_exchange_client() {
134        let rpc_called = Arc::new(AtomicBool::new(false));
135        let server_run = Arc::new(AtomicBool::new(false));
136        let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
137
138        // Start a server.
139        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
140        let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
141            rpc_called: rpc_called.clone(),
142        });
143        let cp_server_run = server_run.clone();
144        let join_handle = tokio::spawn(async move {
145            cp_server_run.store(true, Ordering::SeqCst);
146            tonic::transport::Server::builder()
147                .add_service(exchange_svc)
148                .serve_with_shutdown(addr, async move {
149                    shutdown_recv.await.unwrap();
150                })
151                .await
152                .unwrap();
153        });
154
155        sleep(Duration::from_secs(1)).await;
156        assert!(server_run.load(Ordering::SeqCst));
157
158        let client = ComputeClient::new(addr.into(), &RpcClientConfig::default())
159            .await
160            .unwrap();
161        let task_output_id = TaskOutputId {
162            task_id: Some(TaskId::default()),
163            ..Default::default()
164        };
165        let mut src = GrpcExchangeSource::create(client, task_output_id, None)
166            .await
167            .unwrap();
168        for _ in 0..3 {
169            assert!(src.take_data().await.unwrap().is_some());
170        }
171        assert!(src.take_data().await.unwrap().is_none());
172        assert!(rpc_called.load(Ordering::SeqCst));
173
174        // Gracefully terminate the server.
175        shutdown_send.send(()).unwrap();
176        join_handle.await.unwrap();
177    }
178}