risingwave_batch/rpc/service/
exchange.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::Context;
16use risingwave_pb::task_service::GetDataResponse;
17use tonic::Status;
18
19use crate::error::Result;
20
21pub type GetDataResponseResult = std::result::Result<GetDataResponse, Status>;
22
23type ExchangeDataSender = tokio::sync::mpsc::Sender<GetDataResponseResult>;
24
25pub trait ExchangeWriter: Send {
26    async fn write(&mut self, resp: GetDataResponseResult) -> Result<()>;
27}
28
29pub struct GrpcExchangeWriter {
30    sender: ExchangeDataSender,
31    written_chunks: usize,
32}
33
34impl GrpcExchangeWriter {
35    pub fn new(sender: ExchangeDataSender) -> Self {
36        Self {
37            sender,
38            written_chunks: 0,
39        }
40    }
41
42    pub fn written_chunks(&self) -> usize {
43        self.written_chunks
44    }
45}
46
47impl ExchangeWriter for GrpcExchangeWriter {
48    async fn write(&mut self, data: GetDataResponseResult) -> Result<()> {
49        self.sender
50            .send(data)
51            .await
52            .context("failed to write data to ExchangeWriter")?;
53        self.written_chunks += 1;
54
55        Ok(())
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use risingwave_pb::task_service::GetDataResponse;
62
63    use crate::rpc::service::exchange::{ExchangeWriter, GrpcExchangeWriter};
64
65    #[tokio::test]
66    async fn test_exchange_writer() {
67        let (tx, _rx) = tokio::sync::mpsc::channel(10);
68        let mut writer = GrpcExchangeWriter::new(tx);
69        writer.write(Ok(GetDataResponse::default())).await.unwrap();
70        assert_eq!(writer.written_chunks(), 1);
71    }
72
73    #[tokio::test]
74    async fn test_write_to_closed_channel() {
75        let (tx, rx) = tokio::sync::mpsc::channel(10);
76        drop(rx);
77        let mut writer = GrpcExchangeWriter::new(tx);
78        let res = writer.write(Ok(GetDataResponse::default())).await;
79        assert!(res.is_err());
80    }
81}