risingwave_compute/rpc/service/
stream_exchange_service.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16use std::sync::Arc;
17
18use either::Either;
19use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use risingwave_pb::id::FragmentId;
22use risingwave_pb::task_service::stream_exchange_service_server::StreamExchangeService;
23use risingwave_pb::task_service::{GetStreamRequest, GetStreamResponse, PbPermits, permits};
24use risingwave_stream::executor::DispatcherMessageBatch;
25use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver};
26use risingwave_stream::task::LocalStreamManager;
27use tonic::{Request, Response, Status, Streaming};
28
29pub mod metrics;
30pub use metrics::{GLOBAL_STREAM_EXCHANGE_SERVICE_METRICS, StreamExchangeServiceMetrics};
31
32pub type StreamDataStream = impl Stream<Item = std::result::Result<GetStreamResponse, Status>>;
33
34#[derive(Clone)]
35pub struct StreamExchangeServiceImpl {
36    stream_mgr: LocalStreamManager,
37    metrics: Arc<StreamExchangeServiceMetrics>,
38}
39
40#[async_trait::async_trait]
41impl StreamExchangeService for StreamExchangeServiceImpl {
42    type GetStreamStream = StreamDataStream;
43
44    #[define_opaque(StreamDataStream)]
45    async fn get_stream(
46        &self,
47        request: Request<Streaming<GetStreamRequest>>,
48    ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
49        use risingwave_pb::task_service::get_stream_request::*;
50
51        let peer_addr = request
52            .remote_addr()
53            .ok_or_else(|| Status::unavailable("get_stream connection unestablished"))?;
54
55        let mut request_stream: Streaming<GetStreamRequest> = request.into_inner();
56
57        // Extract the first `Get` request from the stream.
58        let Get {
59            up_actor_id,
60            down_actor_id,
61            up_fragment_id,
62            down_fragment_id,
63            up_partial_graph_id,
64            term_id,
65        } = {
66            let req = request_stream
67                .next()
68                .await
69                .ok_or_else(|| Status::invalid_argument("get_stream request is empty"))??;
70            match req.value.unwrap() {
71                Value::Get(get) => get,
72                Value::AddPermits(_) => unreachable!("the first message must be `Get`"),
73            }
74        };
75
76        let receiver = self
77            .stream_mgr
78            .take_receiver(
79                up_partial_graph_id,
80                term_id,
81                (up_actor_id, down_actor_id),
82                up_fragment_id,
83            )
84            .await?;
85
86        // Map the remaining stream to add-permits.
87        let add_permits_stream = request_stream.map_ok(|req| match req.value.unwrap() {
88            Value::Get(_) => unreachable!("the following messages must be `AddPermits`"),
89            Value::AddPermits(add_permits) => add_permits.value.unwrap(),
90        });
91
92        Ok(Response::new(Self::get_stream_impl(
93            self.metrics.clone(),
94            peer_addr,
95            receiver,
96            add_permits_stream,
97            (up_fragment_id, down_fragment_id),
98        )))
99    }
100}
101
102impl StreamExchangeServiceImpl {
103    pub fn new(stream_mgr: LocalStreamManager, metrics: Arc<StreamExchangeServiceMetrics>) -> Self {
104        Self {
105            stream_mgr,
106            metrics,
107        }
108    }
109
110    #[try_stream(ok = GetStreamResponse, error = Status)]
111    async fn get_stream_impl(
112        metrics: Arc<StreamExchangeServiceMetrics>,
113        peer_addr: SocketAddr,
114        mut receiver: Receiver,
115        add_permits_stream: impl Stream<Item = std::result::Result<permits::Value, tonic::Status>>,
116        up_down_fragment_ids: (FragmentId, FragmentId),
117    ) {
118        tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, "serve stream exchange RPC");
119        let up_fragment_id = up_down_fragment_ids.0.to_string();
120        let down_fragment_id = up_down_fragment_ids.1.to_string();
121
122        let permits = receiver.permits();
123
124        // Select from the permits back from the downstream and the upstream receiver.
125        let select_stream = futures::stream::select(
126            add_permits_stream.map_ok(Either::Left),
127            #[try_stream]
128            async move {
129                while let Some(m) = receiver.recv_raw().await {
130                    yield Either::Right(m);
131                }
132            },
133        );
134        pin_mut!(select_stream);
135
136        let exchange_frag_send_size_metrics = metrics
137            .stream_fragment_exchange_bytes
138            .with_label_values(&[&up_fragment_id, &down_fragment_id]);
139
140        while let Some(r) = select_stream.try_next().await? {
141            match r {
142                Either::Left(permits_to_add) => {
143                    permits.add_permits(permits_to_add);
144                }
145                Either::Right(MessageWithPermits { message, permits }) => {
146                    let message = match message {
147                        DispatcherMessageBatch::Chunk(chunk) => {
148                            DispatcherMessageBatch::Chunk(chunk.compact_vis())
149                        }
150                        msg @ (DispatcherMessageBatch::Watermark(_)
151                        | DispatcherMessageBatch::BarrierBatch(_)) => msg,
152                    };
153                    let proto = message.to_protobuf();
154                    // forward the acquired permit to the downstream
155                    let response = GetStreamResponse {
156                        message: Some(proto),
157                        permits: Some(PbPermits { value: permits }),
158                    };
159                    let bytes = DispatcherMessageBatch::get_encoded_len(&response);
160
161                    yield response;
162
163                    exchange_frag_send_size_metrics.inc_by(bytes as u64);
164                }
165            }
166        }
167    }
168}