risingwave_compute/rpc/service/
batch_exchange_service.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_batch::task::BatchManager;
18use risingwave_pb::task_service::batch_exchange_service_server::BatchExchangeService;
19use risingwave_pb::task_service::{GetDataRequest, GetDataResponse};
20use thiserror_ext::AsReport;
21use tokio_stream::wrappers::ReceiverStream;
22use tonic::{Request, Response, Status};
23
24pub type BatchDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
25
26#[derive(Clone)]
27pub struct BatchExchangeServiceImpl {
28    batch_mgr: Arc<BatchManager>,
29}
30
31impl BatchExchangeServiceImpl {
32    pub fn new(batch_mgr: Arc<BatchManager>) -> Self {
33        Self { batch_mgr }
34    }
35}
36
37#[async_trait::async_trait]
38impl BatchExchangeService for BatchExchangeServiceImpl {
39    type GetDataStream = BatchDataStream;
40
41    async fn get_data(
42        &self,
43        request: Request<GetDataRequest>,
44    ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
45        let peer_addr = request
46            .remote_addr()
47            .ok_or_else(|| Status::unavailable("connection unestablished"))?;
48        let pb_task_output_id = request
49            .into_inner()
50            .task_output_id
51            .expect("Failed to get task output id.");
52        let (tx, rx) =
53            tokio::sync::mpsc::channel(self.batch_mgr.config().developer.receiver_channel_size);
54        if let Err(e) = self.batch_mgr.get_data(tx, peer_addr, &pb_task_output_id) {
55            error!(
56                %peer_addr,
57                error = %e.as_report(),
58                "Failed to serve exchange RPC"
59            );
60            return Err(e.into());
61        }
62
63        Ok(Response::new(ReceiverStream::new(rx)))
64    }
65}