risingwave_compute/rpc/service/
exchange_service.rs1use std::net::SocketAddr;
16use std::sync::Arc;
17
18use either::Either;
19use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use risingwave_batch::task::BatchManager;
22use risingwave_pb::task_service::exchange_service_server::ExchangeService;
23use risingwave_pb::task_service::{
24 GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, permits,
25};
26use risingwave_stream::executor::DispatcherMessageBatch;
27use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver};
28use risingwave_stream::task::LocalStreamManager;
29use thiserror_ext::AsReport;
30use tokio_stream::wrappers::ReceiverStream;
31use tonic::{Request, Response, Status, Streaming};
32
33use crate::rpc::service::exchange_metrics::ExchangeServiceMetrics;
34
35#[derive(Clone)]
36pub struct ExchangeServiceImpl {
37 batch_mgr: Arc<BatchManager>,
38 stream_mgr: LocalStreamManager,
39 metrics: Arc<ExchangeServiceMetrics>,
40}
41
42pub type BatchDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
43pub type StreamDataStream = impl Stream<Item = std::result::Result<GetStreamResponse, Status>>;
44
45#[async_trait::async_trait]
46impl ExchangeService for ExchangeServiceImpl {
47 type GetDataStream = BatchDataStream;
48 type GetStreamStream = StreamDataStream;
49
50 async fn get_data(
51 &self,
52 request: Request<GetDataRequest>,
53 ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
54 let peer_addr = request
55 .remote_addr()
56 .ok_or_else(|| Status::unavailable("connection unestablished"))?;
57 let pb_task_output_id = request
58 .into_inner()
59 .task_output_id
60 .expect("Failed to get task output id.");
61 let (tx, rx) =
62 tokio::sync::mpsc::channel(self.batch_mgr.config().developer.receiver_channel_size);
63 if let Err(e) = self.batch_mgr.get_data(tx, peer_addr, &pb_task_output_id) {
64 error!(
65 %peer_addr,
66 error = %e.as_report(),
67 "Failed to serve exchange RPC"
68 );
69 return Err(e.into());
70 }
71
72 Ok(Response::new(ReceiverStream::new(rx)))
73 }
74
75 #[define_opaque(StreamDataStream)]
76 async fn get_stream(
77 &self,
78 request: Request<Streaming<GetStreamRequest>>,
79 ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
80 use risingwave_pb::task_service::get_stream_request::*;
81
82 let peer_addr = request
83 .remote_addr()
84 .ok_or_else(|| Status::unavailable("get_stream connection unestablished"))?;
85
86 let mut request_stream: Streaming<GetStreamRequest> = request.into_inner();
87
88 let Get {
90 up_actor_id,
91 down_actor_id,
92 up_fragment_id,
93 down_fragment_id,
94 database_id,
95 term_id,
96 } = {
97 let req = request_stream
98 .next()
99 .await
100 .ok_or_else(|| Status::invalid_argument("get_stream request is empty"))??;
101 match req.value.unwrap() {
102 Value::Get(get) => get,
103 Value::AddPermits(_) => unreachable!("the first message must be `Get`"),
104 }
105 };
106
107 let receiver = self
108 .stream_mgr
109 .take_receiver(database_id, term_id, (up_actor_id, down_actor_id))
110 .await?;
111
112 let add_permits_stream = request_stream.map_ok(|req| match req.value.unwrap() {
114 Value::Get(_) => unreachable!("the following messages must be `AddPermits`"),
115 Value::AddPermits(add_permits) => add_permits.value.unwrap(),
116 });
117
118 Ok(Response::new(Self::get_stream_impl(
119 self.metrics.clone(),
120 peer_addr,
121 receiver,
122 add_permits_stream,
123 (up_fragment_id, down_fragment_id),
124 )))
125 }
126}
127
128impl ExchangeServiceImpl {
129 pub fn new(
130 mgr: Arc<BatchManager>,
131 stream_mgr: LocalStreamManager,
132 metrics: Arc<ExchangeServiceMetrics>,
133 ) -> Self {
134 ExchangeServiceImpl {
135 batch_mgr: mgr,
136 stream_mgr,
137 metrics,
138 }
139 }
140
141 #[try_stream(ok = GetStreamResponse, error = Status)]
142 async fn get_stream_impl(
143 metrics: Arc<ExchangeServiceMetrics>,
144 peer_addr: SocketAddr,
145 mut receiver: Receiver,
146 add_permits_stream: impl Stream<Item = std::result::Result<permits::Value, tonic::Status>>,
147 up_down_fragment_ids: (u32, u32),
148 ) {
149 tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, "serve stream exchange RPC");
150 let up_fragment_id = up_down_fragment_ids.0.to_string();
151 let down_fragment_id = up_down_fragment_ids.1.to_string();
152
153 let permits = receiver.permits();
154
155 let select_stream = futures::stream::select(
157 add_permits_stream.map_ok(Either::Left),
158 #[try_stream]
159 async move {
160 while let Some(m) = receiver.recv_raw().await {
161 yield Either::Right(m);
162 }
163 },
164 );
165 pin_mut!(select_stream);
166
167 while let Some(r) = select_stream.try_next().await? {
168 match r {
169 Either::Left(permits_to_add) => {
170 permits.add_permits(permits_to_add);
171 }
172 Either::Right(MessageWithPermits { message, permits }) => {
173 let message = match message {
174 DispatcherMessageBatch::Chunk(chunk) => {
175 DispatcherMessageBatch::Chunk(chunk.compact_vis())
176 }
177 msg @ (DispatcherMessageBatch::Watermark(_)
178 | DispatcherMessageBatch::BarrierBatch(_)) => msg,
179 };
180 let proto = message.to_protobuf();
181 let response = GetStreamResponse {
183 message: Some(proto),
184 permits: Some(PbPermits { value: permits }),
185 };
186 let bytes = DispatcherMessageBatch::get_encoded_len(&response);
187
188 yield response;
189
190 metrics
191 .stream_fragment_exchange_bytes
192 .with_label_values(&[&up_fragment_id, &down_fragment_id])
193 .inc_by(bytes as u64);
194 }
195 }
196 }
197 }
198}