risingwave_batch/rpc/service/
exchange.rs1use anyhow::Context;
16use risingwave_pb::task_service::GetDataResponse;
17use tonic::Status;
18
19use crate::error::Result;
20
21pub type GetDataResponseResult = std::result::Result<GetDataResponse, Status>;
22
23type ExchangeDataSender = tokio::sync::mpsc::Sender<GetDataResponseResult>;
24
25pub trait ExchangeWriter: Send {
26 async fn write(&mut self, resp: GetDataResponseResult) -> Result<()>;
27}
28
29pub struct GrpcExchangeWriter {
30 sender: ExchangeDataSender,
31 written_chunks: usize,
32}
33
34impl GrpcExchangeWriter {
35 pub fn new(sender: ExchangeDataSender) -> Self {
36 Self {
37 sender,
38 written_chunks: 0,
39 }
40 }
41
42 pub fn written_chunks(&self) -> usize {
43 self.written_chunks
44 }
45}
46
47impl ExchangeWriter for GrpcExchangeWriter {
48 async fn write(&mut self, data: GetDataResponseResult) -> Result<()> {
49 self.sender
50 .send(data)
51 .await
52 .context("failed to write data to ExchangeWriter")?;
53 self.written_chunks += 1;
54
55 Ok(())
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use risingwave_pb::task_service::GetDataResponse;
62
63 use crate::rpc::service::exchange::{ExchangeWriter, GrpcExchangeWriter};
64
65 #[tokio::test]
66 async fn test_exchange_writer() {
67 let (tx, _rx) = tokio::sync::mpsc::channel(10);
68 let mut writer = GrpcExchangeWriter::new(tx);
69 writer.write(Ok(GetDataResponse::default())).await.unwrap();
70 assert_eq!(writer.written_chunks(), 1);
71 }
72
73 #[tokio::test]
74 async fn test_write_to_closed_channel() {
75 let (tx, rx) = tokio::sync::mpsc::channel(10);
76 drop(rx);
77 let mut writer = GrpcExchangeWriter::new(tx);
78 let res = writer.write(Ok(GetDataResponse::default())).await;
79 assert!(res.is_err());
80 }
81}