risingwave_batch/execution/
local_exchange.rs1use std::fmt::{Debug, Formatter};
16
17use risingwave_common::array::DataChunk;
18
19use crate::error::Result;
20use crate::exchange_source::ExchangeSource;
21use crate::task::{BatchTaskContext, TaskId, TaskOutput, TaskOutputId};
22
23pub struct LocalExchangeSource {
25 task_output: TaskOutput,
26
27 task_id: TaskId,
29}
30
31impl LocalExchangeSource {
32 pub fn create(
33 output_id: TaskOutputId,
34 context: &dyn BatchTaskContext,
35 task_id: TaskId,
36 ) -> Result<Self> {
37 let task_output = context.get_task_output(output_id)?;
38 Ok(Self {
39 task_output,
40 task_id,
41 })
42 }
43}
44
45impl Debug for LocalExchangeSource {
46 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("LocalExchangeSource")
48 .field("task_output_id", self.task_output.id())
49 .finish()
50 }
51}
52
53impl ExchangeSource for LocalExchangeSource {
54 async fn take_data(&mut self) -> Result<Option<DataChunk>> {
55 let ret = self.task_output.direct_take_data().await?;
56 if let Some(data) = ret {
57 let data = data.compact_vis();
58 trace!(
59 "Receiver task: {:?}, source task output: {:?}, data: {:?}",
60 self.task_id,
61 self.task_output.id(),
62 data
63 );
64 Ok(Some(data))
65 } else {
66 Ok(None)
67 }
68 }
69
70 fn get_task_id(&self) -> TaskId {
71 self.task_id.clone()
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use std::net::SocketAddr;
78 use std::sync::Arc;
79 use std::sync::atomic::{AtomicBool, Ordering};
80 use std::time::Duration;
81
82 use risingwave_common::config::RpcClientConfig;
83 use risingwave_pb::batch_plan::{TaskId, TaskOutputId};
84 use risingwave_pb::data::DataChunk;
85 use risingwave_pb::task_service::batch_exchange_service_server::{
86 BatchExchangeService, BatchExchangeServiceServer,
87 };
88 use risingwave_pb::task_service::{GetDataRequest, GetDataResponse};
89 use risingwave_rpc_client::ComputeClient;
90 use tokio::time::sleep;
91 use tokio_stream::wrappers::ReceiverStream;
92 use tonic::{Request, Response, Status};
93
94 use crate::exchange_source::ExchangeSource;
95 use crate::execution::grpc_exchange::GrpcExchangeSource;
96
97 struct FakeExchangeService {
98 rpc_called: Arc<AtomicBool>,
99 }
100
101 #[async_trait::async_trait]
102 impl BatchExchangeService for FakeExchangeService {
103 type GetDataStream = ReceiverStream<Result<GetDataResponse, Status>>;
104
105 async fn get_data(
106 &self,
107 _: Request<GetDataRequest>,
108 ) -> Result<Response<Self::GetDataStream>, Status> {
109 let (tx, rx) = tokio::sync::mpsc::channel(10);
110 self.rpc_called.store(true, Ordering::SeqCst);
111 for _ in 0..3 {
112 tx.send(Ok(GetDataResponse {
113 record_batch: Some(DataChunk::default()),
114 }))
115 .await
116 .unwrap();
117 }
118 Ok(Response::new(ReceiverStream::new(rx)))
119 }
120 }
121
122 #[tokio::test]
123 async fn test_exchange_client() {
124 let rpc_called = Arc::new(AtomicBool::new(false));
125 let server_run = Arc::new(AtomicBool::new(false));
126 let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
127
128 let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
130 let exchange_svc = BatchExchangeServiceServer::new(FakeExchangeService {
131 rpc_called: rpc_called.clone(),
132 });
133 let cp_server_run = server_run.clone();
134 let join_handle = tokio::spawn(async move {
135 cp_server_run.store(true, Ordering::SeqCst);
136 tonic::transport::Server::builder()
137 .add_service(exchange_svc)
138 .serve_with_shutdown(addr, async move {
139 shutdown_recv.await.unwrap();
140 })
141 .await
142 .unwrap();
143 });
144
145 sleep(Duration::from_secs(1)).await;
146 assert!(server_run.load(Ordering::SeqCst));
147
148 let client = ComputeClient::new(addr.into(), &RpcClientConfig::default())
149 .await
150 .unwrap();
151 let task_output_id = TaskOutputId {
152 task_id: Some(TaskId::default()),
153 ..Default::default()
154 };
155 let mut src = GrpcExchangeSource::create(client, task_output_id, None)
156 .await
157 .unwrap();
158 for _ in 0..3 {
159 assert!(src.take_data().await.unwrap().is_some());
160 }
161 assert!(src.take_data().await.unwrap().is_none());
162 assert!(rpc_called.load(Ordering::SeqCst));
163
164 shutdown_send.send(()).unwrap();
166 join_handle.await.unwrap();
167 }
168}