risingwave_compute/rpc/service/
exchange_service.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16use std::sync::Arc;
17
18use either::Either;
19use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use risingwave_batch::task::BatchManager;
22use risingwave_pb::task_service::exchange_service_server::ExchangeService;
23use risingwave_pb::task_service::{
24    GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, permits,
25};
26use risingwave_stream::executor::DispatcherMessageBatch;
27use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver};
28use risingwave_stream::task::LocalStreamManager;
29use thiserror_ext::AsReport;
30use tokio_stream::wrappers::ReceiverStream;
31use tonic::{Request, Response, Status, Streaming};
32
33use crate::rpc::service::exchange_metrics::ExchangeServiceMetrics;
34
35#[derive(Clone)]
36pub struct ExchangeServiceImpl {
37    batch_mgr: Arc<BatchManager>,
38    stream_mgr: LocalStreamManager,
39    metrics: Arc<ExchangeServiceMetrics>,
40}
41
42pub type BatchDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
43pub type StreamDataStream = impl Stream<Item = std::result::Result<GetStreamResponse, Status>>;
44
45#[async_trait::async_trait]
46impl ExchangeService for ExchangeServiceImpl {
47    type GetDataStream = BatchDataStream;
48    type GetStreamStream = StreamDataStream;
49
50    async fn get_data(
51        &self,
52        request: Request<GetDataRequest>,
53    ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
54        let peer_addr = request
55            .remote_addr()
56            .ok_or_else(|| Status::unavailable("connection unestablished"))?;
57        let pb_task_output_id = request
58            .into_inner()
59            .task_output_id
60            .expect("Failed to get task output id.");
61        let (tx, rx) =
62            tokio::sync::mpsc::channel(self.batch_mgr.config().developer.receiver_channel_size);
63        if let Err(e) = self.batch_mgr.get_data(tx, peer_addr, &pb_task_output_id) {
64            error!(
65                %peer_addr,
66                error = %e.as_report(),
67                "Failed to serve exchange RPC"
68            );
69            return Err(e.into());
70        }
71
72        Ok(Response::new(ReceiverStream::new(rx)))
73    }
74
75    #[define_opaque(StreamDataStream)]
76    async fn get_stream(
77        &self,
78        request: Request<Streaming<GetStreamRequest>>,
79    ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
80        use risingwave_pb::task_service::get_stream_request::*;
81
82        let peer_addr = request
83            .remote_addr()
84            .ok_or_else(|| Status::unavailable("get_stream connection unestablished"))?;
85
86        let mut request_stream: Streaming<GetStreamRequest> = request.into_inner();
87
88        // Extract the first `Get` request from the stream.
89        let Get {
90            up_actor_id,
91            down_actor_id,
92            up_fragment_id,
93            down_fragment_id,
94            database_id,
95            term_id,
96        } = {
97            let req = request_stream
98                .next()
99                .await
100                .ok_or_else(|| Status::invalid_argument("get_stream request is empty"))??;
101            match req.value.unwrap() {
102                Value::Get(get) => get,
103                Value::AddPermits(_) => unreachable!("the first message must be `Get`"),
104            }
105        };
106
107        let receiver = self
108            .stream_mgr
109            .take_receiver(database_id, term_id, (up_actor_id, down_actor_id))
110            .await?;
111
112        // Map the remaining stream to add-permits.
113        let add_permits_stream = request_stream.map_ok(|req| match req.value.unwrap() {
114            Value::Get(_) => unreachable!("the following messages must be `AddPermits`"),
115            Value::AddPermits(add_permits) => add_permits.value.unwrap(),
116        });
117
118        Ok(Response::new(Self::get_stream_impl(
119            self.metrics.clone(),
120            peer_addr,
121            receiver,
122            add_permits_stream,
123            (up_fragment_id, down_fragment_id),
124        )))
125    }
126}
127
128impl ExchangeServiceImpl {
129    pub fn new(
130        mgr: Arc<BatchManager>,
131        stream_mgr: LocalStreamManager,
132        metrics: Arc<ExchangeServiceMetrics>,
133    ) -> Self {
134        ExchangeServiceImpl {
135            batch_mgr: mgr,
136            stream_mgr,
137            metrics,
138        }
139    }
140
141    #[try_stream(ok = GetStreamResponse, error = Status)]
142    async fn get_stream_impl(
143        metrics: Arc<ExchangeServiceMetrics>,
144        peer_addr: SocketAddr,
145        mut receiver: Receiver,
146        add_permits_stream: impl Stream<Item = std::result::Result<permits::Value, tonic::Status>>,
147        up_down_fragment_ids: (u32, u32),
148    ) {
149        tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, "serve stream exchange RPC");
150        let up_fragment_id = up_down_fragment_ids.0.to_string();
151        let down_fragment_id = up_down_fragment_ids.1.to_string();
152
153        let permits = receiver.permits();
154
155        // Select from the permits back from the downstream and the upstream receiver.
156        let select_stream = futures::stream::select(
157            add_permits_stream.map_ok(Either::Left),
158            #[try_stream]
159            async move {
160                while let Some(m) = receiver.recv_raw().await {
161                    yield Either::Right(m);
162                }
163            },
164        );
165        pin_mut!(select_stream);
166
167        while let Some(r) = select_stream.try_next().await? {
168            match r {
169                Either::Left(permits_to_add) => {
170                    permits.add_permits(permits_to_add);
171                }
172                Either::Right(MessageWithPermits { message, permits }) => {
173                    let message = match message {
174                        DispatcherMessageBatch::Chunk(chunk) => {
175                            DispatcherMessageBatch::Chunk(chunk.compact_vis())
176                        }
177                        msg @ (DispatcherMessageBatch::Watermark(_)
178                        | DispatcherMessageBatch::BarrierBatch(_)) => msg,
179                    };
180                    let proto = message.to_protobuf();
181                    // forward the acquired permit to the downstream
182                    let response = GetStreamResponse {
183                        message: Some(proto),
184                        permits: Some(PbPermits { value: permits }),
185                    };
186                    let bytes = DispatcherMessageBatch::get_encoded_len(&response);
187
188                    yield response;
189
190                    metrics
191                        .stream_fragment_exchange_bytes
192                        .with_label_values(&[&up_fragment_id, &down_fragment_id])
193                        .inc_by(bytes as u64);
194                }
195            }
196        }
197    }
198}