risingwave_batch/execution/
local_exchange.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{Debug, Formatter};
16
17use risingwave_common::array::DataChunk;
18
19use crate::error::Result;
20use crate::exchange_source::ExchangeSource;
21use crate::task::{BatchTaskContext, TaskId, TaskOutput, TaskOutputId};
22
23/// Exchange data from a local task execution.
24pub struct LocalExchangeSource {
25    task_output: TaskOutput,
26
27    /// Id of task which contains the `ExchangeExecutor` of this source.
28    task_id: TaskId,
29}
30
31impl LocalExchangeSource {
32    pub fn create(
33        output_id: TaskOutputId,
34        context: &dyn BatchTaskContext,
35        task_id: TaskId,
36    ) -> Result<Self> {
37        let task_output = context.get_task_output(output_id)?;
38        Ok(Self {
39            task_output,
40            task_id,
41        })
42    }
43}
44
45impl Debug for LocalExchangeSource {
46    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("LocalExchangeSource")
48            .field("task_output_id", self.task_output.id())
49            .finish()
50    }
51}
52
53impl ExchangeSource for LocalExchangeSource {
54    async fn take_data(&mut self) -> Result<Option<DataChunk>> {
55        let ret = self.task_output.direct_take_data().await?;
56        if let Some(data) = ret {
57            let data = data.compact_vis();
58            trace!(
59                "Receiver task: {:?}, source task output: {:?}, data: {:?}",
60                self.task_id,
61                self.task_output.id(),
62                data
63            );
64            Ok(Some(data))
65        } else {
66            Ok(None)
67        }
68    }
69
70    fn get_task_id(&self) -> TaskId {
71        self.task_id.clone()
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use std::net::SocketAddr;
78    use std::sync::Arc;
79    use std::sync::atomic::{AtomicBool, Ordering};
80    use std::time::Duration;
81
82    use risingwave_common::config::RpcClientConfig;
83    use risingwave_pb::batch_plan::{TaskId, TaskOutputId};
84    use risingwave_pb::data::DataChunk;
85    use risingwave_pb::task_service::batch_exchange_service_server::{
86        BatchExchangeService, BatchExchangeServiceServer,
87    };
88    use risingwave_pb::task_service::{GetDataRequest, GetDataResponse};
89    use risingwave_rpc_client::ComputeClient;
90    use tokio::time::sleep;
91    use tokio_stream::wrappers::ReceiverStream;
92    use tonic::{Request, Response, Status};
93
94    use crate::exchange_source::ExchangeSource;
95    use crate::execution::grpc_exchange::GrpcExchangeSource;
96
97    struct FakeExchangeService {
98        rpc_called: Arc<AtomicBool>,
99    }
100
101    #[async_trait::async_trait]
102    impl BatchExchangeService for FakeExchangeService {
103        type GetDataStream = ReceiverStream<Result<GetDataResponse, Status>>;
104
105        async fn get_data(
106            &self,
107            _: Request<GetDataRequest>,
108        ) -> Result<Response<Self::GetDataStream>, Status> {
109            let (tx, rx) = tokio::sync::mpsc::channel(10);
110            self.rpc_called.store(true, Ordering::SeqCst);
111            for _ in 0..3 {
112                tx.send(Ok(GetDataResponse {
113                    record_batch: Some(DataChunk::default()),
114                }))
115                .await
116                .unwrap();
117            }
118            Ok(Response::new(ReceiverStream::new(rx)))
119        }
120    }
121
122    #[tokio::test]
123    async fn test_exchange_client() {
124        let rpc_called = Arc::new(AtomicBool::new(false));
125        let server_run = Arc::new(AtomicBool::new(false));
126        let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
127
128        // Start a server.
129        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
130        let exchange_svc = BatchExchangeServiceServer::new(FakeExchangeService {
131            rpc_called: rpc_called.clone(),
132        });
133        let cp_server_run = server_run.clone();
134        let join_handle = tokio::spawn(async move {
135            cp_server_run.store(true, Ordering::SeqCst);
136            tonic::transport::Server::builder()
137                .add_service(exchange_svc)
138                .serve_with_shutdown(addr, async move {
139                    shutdown_recv.await.unwrap();
140                })
141                .await
142                .unwrap();
143        });
144
145        sleep(Duration::from_secs(1)).await;
146        assert!(server_run.load(Ordering::SeqCst));
147
148        let client = ComputeClient::new(addr.into(), &RpcClientConfig::default())
149            .await
150            .unwrap();
151        let task_output_id = TaskOutputId {
152            task_id: Some(TaskId::default()),
153            ..Default::default()
154        };
155        let mut src = GrpcExchangeSource::create(client, task_output_id, None)
156            .await
157            .unwrap();
158        for _ in 0..3 {
159            assert!(src.take_data().await.unwrap().is_some());
160        }
161        assert!(src.take_data().await.unwrap().is_none());
162        assert!(rpc_called.load(Ordering::SeqCst));
163
164        // Gracefully terminate the server.
165        shutdown_send.send(()).unwrap();
166        join_handle.await.unwrap();
167    }
168}