risingwave_compute/rpc/service/
exchange_service.rs1use std::net::SocketAddr;
16use std::sync::Arc;
17
18use either::Either;
19use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use risingwave_batch::task::BatchManager;
22use risingwave_common::catalog::DatabaseId;
23use risingwave_pb::task_service::exchange_service_server::ExchangeService;
24use risingwave_pb::task_service::{
25 GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, permits,
26};
27use risingwave_stream::executor::DispatcherMessageBatch;
28use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver};
29use risingwave_stream::task::LocalStreamManager;
30use thiserror_ext::AsReport;
31use tokio_stream::wrappers::ReceiverStream;
32use tonic::{Request, Response, Status, Streaming};
33
34use crate::rpc::service::exchange_metrics::ExchangeServiceMetrics;
35
36#[derive(Clone)]
37pub struct ExchangeServiceImpl {
38 batch_mgr: Arc<BatchManager>,
39 stream_mgr: LocalStreamManager,
40 metrics: Arc<ExchangeServiceMetrics>,
41}
42
43pub type BatchDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
44pub type StreamDataStream = impl Stream<Item = std::result::Result<GetStreamResponse, Status>>;
45
46#[async_trait::async_trait]
47impl ExchangeService for ExchangeServiceImpl {
48 type GetDataStream = BatchDataStream;
49 type GetStreamStream = StreamDataStream;
50
51 async fn get_data(
52 &self,
53 request: Request<GetDataRequest>,
54 ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
55 let peer_addr = request
56 .remote_addr()
57 .ok_or_else(|| Status::unavailable("connection unestablished"))?;
58 let pb_task_output_id = request
59 .into_inner()
60 .task_output_id
61 .expect("Failed to get task output id.");
62 let (tx, rx) =
63 tokio::sync::mpsc::channel(self.batch_mgr.config().developer.receiver_channel_size);
64 if let Err(e) = self.batch_mgr.get_data(tx, peer_addr, &pb_task_output_id) {
65 error!(
66 %peer_addr,
67 error = %e.as_report(),
68 "Failed to serve exchange RPC"
69 );
70 return Err(e.into());
71 }
72
73 Ok(Response::new(ReceiverStream::new(rx)))
74 }
75
76 #[define_opaque(StreamDataStream)]
77 async fn get_stream(
78 &self,
79 request: Request<Streaming<GetStreamRequest>>,
80 ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
81 use risingwave_pb::task_service::get_stream_request::*;
82
83 let peer_addr = request
84 .remote_addr()
85 .ok_or_else(|| Status::unavailable("get_stream connection unestablished"))?;
86
87 let mut request_stream: Streaming<GetStreamRequest> = request.into_inner();
88
89 let Get {
91 up_actor_id,
92 down_actor_id,
93 up_fragment_id,
94 down_fragment_id,
95 database_id,
96 term_id,
97 } = {
98 let req = request_stream
99 .next()
100 .await
101 .ok_or_else(|| Status::invalid_argument("get_stream request is empty"))??;
102 match req.value.unwrap() {
103 Value::Get(get) => get,
104 Value::AddPermits(_) => unreachable!("the first message must be `Get`"),
105 }
106 };
107
108 let receiver = self
109 .stream_mgr
110 .take_receiver(
111 DatabaseId::new(database_id),
112 term_id,
113 (up_actor_id, down_actor_id),
114 )
115 .await?;
116
117 let add_permits_stream = request_stream.map_ok(|req| match req.value.unwrap() {
119 Value::Get(_) => unreachable!("the following messages must be `AddPermits`"),
120 Value::AddPermits(add_permits) => add_permits.value.unwrap(),
121 });
122
123 Ok(Response::new(Self::get_stream_impl(
124 self.metrics.clone(),
125 peer_addr,
126 receiver,
127 add_permits_stream,
128 (up_fragment_id, down_fragment_id),
129 )))
130 }
131}
132
133impl ExchangeServiceImpl {
134 pub fn new(
135 mgr: Arc<BatchManager>,
136 stream_mgr: LocalStreamManager,
137 metrics: Arc<ExchangeServiceMetrics>,
138 ) -> Self {
139 ExchangeServiceImpl {
140 batch_mgr: mgr,
141 stream_mgr,
142 metrics,
143 }
144 }
145
146 #[try_stream(ok = GetStreamResponse, error = Status)]
147 async fn get_stream_impl(
148 metrics: Arc<ExchangeServiceMetrics>,
149 peer_addr: SocketAddr,
150 mut receiver: Receiver,
151 add_permits_stream: impl Stream<Item = std::result::Result<permits::Value, tonic::Status>>,
152 up_down_fragment_ids: (u32, u32),
153 ) {
154 tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, "serve stream exchange RPC");
155 let up_fragment_id = up_down_fragment_ids.0.to_string();
156 let down_fragment_id = up_down_fragment_ids.1.to_string();
157
158 let permits = receiver.permits();
159
160 let select_stream = futures::stream::select(
162 add_permits_stream.map_ok(Either::Left),
163 #[try_stream]
164 async move {
165 while let Some(m) = receiver.recv_raw().await {
166 yield Either::Right(m);
167 }
168 },
169 );
170 pin_mut!(select_stream);
171
172 while let Some(r) = select_stream.try_next().await? {
173 match r {
174 Either::Left(permits_to_add) => {
175 permits.add_permits(permits_to_add);
176 }
177 Either::Right(MessageWithPermits { message, permits }) => {
178 let message = match message {
179 DispatcherMessageBatch::Chunk(chunk) => {
180 DispatcherMessageBatch::Chunk(chunk.compact())
181 }
182 msg @ (DispatcherMessageBatch::Watermark(_)
183 | DispatcherMessageBatch::BarrierBatch(_)) => msg,
184 };
185 let proto = message.to_protobuf();
186 let response = GetStreamResponse {
188 message: Some(proto),
189 permits: Some(PbPermits { value: permits }),
190 };
191 let bytes = DispatcherMessageBatch::get_encoded_len(&response);
192
193 yield response;
194
195 metrics
196 .stream_fragment_exchange_bytes
197 .with_label_values(&[&up_fragment_id, &down_fragment_id])
198 .inc_by(bytes as u64);
199 }
200 }
201 }
202 }
203}