risingwave_compute/rpc/service/
exchange_service.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16use std::sync::Arc;
17
18use either::Either;
19use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use risingwave_batch::task::BatchManager;
22use risingwave_common::catalog::DatabaseId;
23use risingwave_pb::task_service::exchange_service_server::ExchangeService;
24use risingwave_pb::task_service::{
25    GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, permits,
26};
27use risingwave_stream::executor::DispatcherMessageBatch;
28use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver};
29use risingwave_stream::task::LocalStreamManager;
30use thiserror_ext::AsReport;
31use tokio_stream::wrappers::ReceiverStream;
32use tonic::{Request, Response, Status, Streaming};
33
34use crate::rpc::service::exchange_metrics::ExchangeServiceMetrics;
35
36#[derive(Clone)]
37pub struct ExchangeServiceImpl {
38    batch_mgr: Arc<BatchManager>,
39    stream_mgr: LocalStreamManager,
40    metrics: Arc<ExchangeServiceMetrics>,
41}
42
43pub type BatchDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
44pub type StreamDataStream = impl Stream<Item = std::result::Result<GetStreamResponse, Status>>;
45
46#[async_trait::async_trait]
47impl ExchangeService for ExchangeServiceImpl {
48    type GetDataStream = BatchDataStream;
49    type GetStreamStream = StreamDataStream;
50
51    #[cfg_attr(coverage, coverage(off))]
52    async fn get_data(
53        &self,
54        request: Request<GetDataRequest>,
55    ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
56        let peer_addr = request
57            .remote_addr()
58            .ok_or_else(|| Status::unavailable("connection unestablished"))?;
59        let pb_task_output_id = request
60            .into_inner()
61            .task_output_id
62            .expect("Failed to get task output id.");
63        let (tx, rx) =
64            tokio::sync::mpsc::channel(self.batch_mgr.config().developer.receiver_channel_size);
65        if let Err(e) = self.batch_mgr.get_data(tx, peer_addr, &pb_task_output_id) {
66            error!(
67                %peer_addr,
68                error = %e.as_report(),
69                "Failed to serve exchange RPC"
70            );
71            return Err(e.into());
72        }
73
74        Ok(Response::new(ReceiverStream::new(rx)))
75    }
76
77    async fn get_stream(
78        &self,
79        request: Request<Streaming<GetStreamRequest>>,
80    ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
81        use risingwave_pb::task_service::get_stream_request::*;
82
83        let peer_addr = request
84            .remote_addr()
85            .ok_or_else(|| Status::unavailable("get_stream connection unestablished"))?;
86
87        let mut request_stream: Streaming<GetStreamRequest> = request.into_inner();
88
89        // Extract the first `Get` request from the stream.
90        let Get {
91            up_actor_id,
92            down_actor_id,
93            up_fragment_id,
94            down_fragment_id,
95            database_id,
96            term_id,
97        } = {
98            let req = request_stream
99                .next()
100                .await
101                .ok_or_else(|| Status::invalid_argument("get_stream request is empty"))??;
102            match req.value.unwrap() {
103                Value::Get(get) => get,
104                Value::AddPermits(_) => unreachable!("the first message must be `Get`"),
105            }
106        };
107
108        let receiver = self
109            .stream_mgr
110            .take_receiver(
111                DatabaseId::new(database_id),
112                term_id,
113                (up_actor_id, down_actor_id),
114            )
115            .await?;
116
117        // Map the remaining stream to add-permits.
118        let add_permits_stream = request_stream.map_ok(|req| match req.value.unwrap() {
119            Value::Get(_) => unreachable!("the following messages must be `AddPermits`"),
120            Value::AddPermits(add_permits) => add_permits.value.unwrap(),
121        });
122
123        Ok(Response::new(Self::get_stream_impl(
124            self.metrics.clone(),
125            peer_addr,
126            receiver,
127            add_permits_stream,
128            (up_fragment_id, down_fragment_id),
129        )))
130    }
131}
132
133impl ExchangeServiceImpl {
134    pub fn new(
135        mgr: Arc<BatchManager>,
136        stream_mgr: LocalStreamManager,
137        metrics: Arc<ExchangeServiceMetrics>,
138    ) -> Self {
139        ExchangeServiceImpl {
140            batch_mgr: mgr,
141            stream_mgr,
142            metrics,
143        }
144    }
145
146    #[try_stream(ok = GetStreamResponse, error = Status)]
147    async fn get_stream_impl(
148        metrics: Arc<ExchangeServiceMetrics>,
149        peer_addr: SocketAddr,
150        mut receiver: Receiver,
151        add_permits_stream: impl Stream<Item = std::result::Result<permits::Value, tonic::Status>>,
152        up_down_fragment_ids: (u32, u32),
153    ) {
154        tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, "serve stream exchange RPC");
155        let up_fragment_id = up_down_fragment_ids.0.to_string();
156        let down_fragment_id = up_down_fragment_ids.1.to_string();
157
158        let permits = receiver.permits();
159
160        // Select from the permits back from the downstream and the upstream receiver.
161        let select_stream = futures::stream::select(
162            add_permits_stream.map_ok(Either::Left),
163            #[try_stream]
164            async move {
165                while let Some(m) = receiver.recv_raw().await {
166                    yield Either::Right(m);
167                }
168            },
169        );
170        pin_mut!(select_stream);
171
172        while let Some(r) = select_stream.try_next().await? {
173            match r {
174                Either::Left(permits_to_add) => {
175                    permits.add_permits(permits_to_add);
176                }
177                Either::Right(MessageWithPermits { message, permits }) => {
178                    let proto = message.to_protobuf();
179                    // forward the acquired permit to the downstream
180                    let response = GetStreamResponse {
181                        message: Some(proto),
182                        permits: Some(PbPermits { value: permits }),
183                    };
184                    let bytes = DispatcherMessageBatch::get_encoded_len(&response);
185
186                    yield response;
187
188                    metrics
189                        .stream_fragment_exchange_bytes
190                        .with_label_values(&[&up_fragment_id, &down_fragment_id])
191                        .inc_by(bytes as u64);
192                }
193            }
194        }
195    }
196}