risingwave_compute/rpc/service/
stream_exchange_service.rs1use std::net::SocketAddr;
16use std::sync::Arc;
17
18use either::Either;
19use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use risingwave_pb::id::FragmentId;
22use risingwave_pb::task_service::stream_exchange_service_server::StreamExchangeService;
23use risingwave_pb::task_service::{GetStreamRequest, GetStreamResponse, PbPermits, permits};
24use risingwave_stream::executor::DispatcherMessageBatch;
25use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver};
26use risingwave_stream::task::LocalStreamManager;
27use tonic::{Request, Response, Status, Streaming};
28
29pub mod metrics;
30pub use metrics::{GLOBAL_STREAM_EXCHANGE_SERVICE_METRICS, StreamExchangeServiceMetrics};
31
32pub type StreamDataStream = impl Stream<Item = std::result::Result<GetStreamResponse, Status>>;
33
34#[derive(Clone)]
35pub struct StreamExchangeServiceImpl {
36 stream_mgr: LocalStreamManager,
37 metrics: Arc<StreamExchangeServiceMetrics>,
38}
39
40#[async_trait::async_trait]
41impl StreamExchangeService for StreamExchangeServiceImpl {
42 type GetStreamStream = StreamDataStream;
43
44 #[define_opaque(StreamDataStream)]
45 async fn get_stream(
46 &self,
47 request: Request<Streaming<GetStreamRequest>>,
48 ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
49 use risingwave_pb::task_service::get_stream_request::*;
50
51 let peer_addr = request
52 .remote_addr()
53 .ok_or_else(|| Status::unavailable("get_stream connection unestablished"))?;
54
55 let mut request_stream: Streaming<GetStreamRequest> = request.into_inner();
56
57 let Get {
59 up_actor_id,
60 down_actor_id,
61 up_fragment_id,
62 down_fragment_id,
63 up_partial_graph_id,
64 term_id,
65 } = {
66 let req = request_stream
67 .next()
68 .await
69 .ok_or_else(|| Status::invalid_argument("get_stream request is empty"))??;
70 match req.value.unwrap() {
71 Value::Get(get) => get,
72 Value::AddPermits(_) => unreachable!("the first message must be `Get`"),
73 }
74 };
75
76 let receiver = self
77 .stream_mgr
78 .take_receiver(
79 up_partial_graph_id,
80 term_id,
81 (up_actor_id, down_actor_id),
82 up_fragment_id,
83 )
84 .await?;
85
86 let add_permits_stream = request_stream.map_ok(|req| match req.value.unwrap() {
88 Value::Get(_) => unreachable!("the following messages must be `AddPermits`"),
89 Value::AddPermits(add_permits) => add_permits.value.unwrap(),
90 });
91
92 Ok(Response::new(Self::get_stream_impl(
93 self.metrics.clone(),
94 peer_addr,
95 receiver,
96 add_permits_stream,
97 (up_fragment_id, down_fragment_id),
98 )))
99 }
100}
101
102impl StreamExchangeServiceImpl {
103 pub fn new(stream_mgr: LocalStreamManager, metrics: Arc<StreamExchangeServiceMetrics>) -> Self {
104 Self {
105 stream_mgr,
106 metrics,
107 }
108 }
109
110 #[try_stream(ok = GetStreamResponse, error = Status)]
111 async fn get_stream_impl(
112 metrics: Arc<StreamExchangeServiceMetrics>,
113 peer_addr: SocketAddr,
114 mut receiver: Receiver,
115 add_permits_stream: impl Stream<Item = std::result::Result<permits::Value, tonic::Status>>,
116 up_down_fragment_ids: (FragmentId, FragmentId),
117 ) {
118 tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, "serve stream exchange RPC");
119 let up_fragment_id = up_down_fragment_ids.0.to_string();
120 let down_fragment_id = up_down_fragment_ids.1.to_string();
121
122 let permits = receiver.permits();
123
124 let select_stream = futures::stream::select(
126 add_permits_stream.map_ok(Either::Left),
127 #[try_stream]
128 async move {
129 while let Some(m) = receiver.recv_raw().await {
130 yield Either::Right(m);
131 }
132 },
133 );
134 pin_mut!(select_stream);
135
136 let exchange_frag_send_size_metrics = metrics
137 .stream_fragment_exchange_bytes
138 .with_label_values(&[&up_fragment_id, &down_fragment_id]);
139
140 while let Some(r) = select_stream.try_next().await? {
141 match r {
142 Either::Left(permits_to_add) => {
143 permits.add_permits(permits_to_add);
144 }
145 Either::Right(MessageWithPermits { message, permits }) => {
146 let message = match message {
147 DispatcherMessageBatch::Chunk(chunk) => {
148 DispatcherMessageBatch::Chunk(chunk.compact_vis())
149 }
150 msg @ (DispatcherMessageBatch::Watermark(_)
151 | DispatcherMessageBatch::BarrierBatch(_)) => msg,
152 };
153 let proto = message.to_protobuf();
154 let response = GetStreamResponse {
156 message: Some(proto),
157 permits: Some(PbPermits { value: permits }),
158 };
159 let bytes = DispatcherMessageBatch::get_encoded_len(&response);
160
161 yield response;
162
163 exchange_frag_send_size_metrics.inc_by(bytes as u64);
164 }
165 }
166 }
167 }
168}