risingwave_batch/execution/
local_exchange.rs1use std::fmt::{Debug, Formatter};
16
17use risingwave_common::array::DataChunk;
18
19use crate::error::Result;
20use crate::exchange_source::ExchangeSource;
21use crate::task::{BatchTaskContext, TaskId, TaskOutput, TaskOutputId};
22
23pub struct LocalExchangeSource {
25 task_output: TaskOutput,
26
27 task_id: TaskId,
29}
30
31impl LocalExchangeSource {
32 pub fn create(
33 output_id: TaskOutputId,
34 context: &dyn BatchTaskContext,
35 task_id: TaskId,
36 ) -> Result<Self> {
37 let task_output = context.get_task_output(output_id)?;
38 Ok(Self {
39 task_output,
40 task_id,
41 })
42 }
43}
44
45impl Debug for LocalExchangeSource {
46 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("LocalExchangeSource")
48 .field("task_output_id", self.task_output.id())
49 .finish()
50 }
51}
52
53impl ExchangeSource for LocalExchangeSource {
54 async fn take_data(&mut self) -> Result<Option<DataChunk>> {
55 let ret = self.task_output.direct_take_data().await?;
56 if let Some(data) = ret {
57 let data = data.compact();
58 trace!(
59 "Receiver task: {:?}, source task output: {:?}, data: {:?}",
60 self.task_id,
61 self.task_output.id(),
62 data
63 );
64 Ok(Some(data))
65 } else {
66 Ok(None)
67 }
68 }
69
70 fn get_task_id(&self) -> TaskId {
71 self.task_id.clone()
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use std::net::SocketAddr;
78 use std::sync::Arc;
79 use std::sync::atomic::{AtomicBool, Ordering};
80 use std::time::Duration;
81
82 use risingwave_common::config::RpcClientConfig;
83 use risingwave_pb::batch_plan::{TaskId, TaskOutputId};
84 use risingwave_pb::data::DataChunk;
85 use risingwave_pb::task_service::exchange_service_server::{
86 ExchangeService, ExchangeServiceServer,
87 };
88 use risingwave_pb::task_service::{
89 GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
90 };
91 use risingwave_rpc_client::ComputeClient;
92 use tokio::time::sleep;
93 use tokio_stream::wrappers::ReceiverStream;
94 use tonic::{Request, Response, Status, Streaming};
95
96 use crate::exchange_source::ExchangeSource;
97 use crate::execution::grpc_exchange::GrpcExchangeSource;
98
99 struct FakeExchangeService {
100 rpc_called: Arc<AtomicBool>,
101 }
102
103 #[async_trait::async_trait]
104 impl ExchangeService for FakeExchangeService {
105 type GetDataStream = ReceiverStream<Result<GetDataResponse, Status>>;
106 type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
107
108 async fn get_data(
109 &self,
110 _: Request<GetDataRequest>,
111 ) -> Result<Response<Self::GetDataStream>, Status> {
112 let (tx, rx) = tokio::sync::mpsc::channel(10);
113 self.rpc_called.store(true, Ordering::SeqCst);
114 for _ in 0..3 {
115 tx.send(Ok(GetDataResponse {
116 record_batch: Some(DataChunk::default()),
117 }))
118 .await
119 .unwrap();
120 }
121 Ok(Response::new(ReceiverStream::new(rx)))
122 }
123
124 async fn get_stream(
125 &self,
126 _request: Request<Streaming<GetStreamRequest>>,
127 ) -> Result<Response<Self::GetStreamStream>, Status> {
128 unimplemented!()
129 }
130 }
131
132 #[tokio::test]
133 async fn test_exchange_client() {
134 let rpc_called = Arc::new(AtomicBool::new(false));
135 let server_run = Arc::new(AtomicBool::new(false));
136 let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
137
138 let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
140 let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
141 rpc_called: rpc_called.clone(),
142 });
143 let cp_server_run = server_run.clone();
144 let join_handle = tokio::spawn(async move {
145 cp_server_run.store(true, Ordering::SeqCst);
146 tonic::transport::Server::builder()
147 .add_service(exchange_svc)
148 .serve_with_shutdown(addr, async move {
149 shutdown_recv.await.unwrap();
150 })
151 .await
152 .unwrap();
153 });
154
155 sleep(Duration::from_secs(1)).await;
156 assert!(server_run.load(Ordering::SeqCst));
157
158 let client = ComputeClient::new(addr.into(), &RpcClientConfig::default())
159 .await
160 .unwrap();
161 let task_output_id = TaskOutputId {
162 task_id: Some(TaskId::default()),
163 ..Default::default()
164 };
165 let mut src = GrpcExchangeSource::create(client, task_output_id, None)
166 .await
167 .unwrap();
168 for _ in 0..3 {
169 assert!(src.take_data().await.unwrap().is_some());
170 }
171 assert!(src.take_data().await.unwrap().is_none());
172 assert!(rpc_called.load(Ordering::SeqCst));
173
174 shutdown_send.send(()).unwrap();
176 join_handle.await.unwrap();
177 }
178}