risingwave_batch/execution/
grpc_exchange.rs1use std::fmt::{Debug, Formatter};
16
17use futures::StreamExt;
18use risingwave_common::array::DataChunk;
19use risingwave_expr::expr_context::capture_expr_context;
20use risingwave_pb::batch_plan::TaskOutputId;
21use risingwave_pb::batch_plan::exchange_source::LocalExecutePlan::{self, Plan};
22use risingwave_pb::task_service::{ExecuteRequest, GetDataResponse};
23use risingwave_rpc_client::ComputeClient;
24use risingwave_rpc_client::error::RpcError;
25use tonic::Streaming;
26
27use crate::error::Result;
28use crate::exchange_source::ExchangeSource;
29use crate::task::TaskId;
30
31pub struct GrpcExchangeSource {
33 stream: Streaming<GetDataResponse>,
34
35 task_output_id: TaskOutputId,
36}
37
38impl GrpcExchangeSource {
39 pub async fn create(
40 client: ComputeClient,
41 task_output_id: TaskOutputId,
42 local_execute_plan: Option<LocalExecutePlan>,
43 ) -> Result<Self> {
44 let task_id = task_output_id.get_task_id()?.clone();
45 let stream = match local_execute_plan {
46 Some(local_execute_plan) => {
49 let plan = try_match_expand!(local_execute_plan, Plan)?;
50 let execute_request = ExecuteRequest {
51 task_id: Some(task_id),
52 plan: plan.plan,
53 epoch: plan.epoch,
54 tracing_context: plan.tracing_context,
55 expr_context: Some(capture_expr_context()?),
56 };
57 client.execute(execute_request).await?
58 }
59 None => client.get_data(task_output_id.clone()).await?,
60 };
61 let source = Self {
62 stream,
63 task_output_id,
64 };
65 Ok(source)
66 }
67}
68
69impl Debug for GrpcExchangeSource {
70 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("GrpcExchangeSource")
72 .field("task_output_id", &self.task_output_id)
73 .finish()
74 }
75}
76
77impl ExchangeSource for GrpcExchangeSource {
78 async fn take_data(&mut self) -> Result<Option<DataChunk>> {
79 let res = match self.stream.next().await {
80 None => {
81 return Ok(None);
82 }
83 Some(r) => r,
84 };
85 let task_data = res.map_err(RpcError::from_batch_status)?;
86 let data = DataChunk::from_protobuf(task_data.get_record_batch()?)?.compact();
87 trace!(
88 "Receiver taskOutput = {:?}, data = {:?}",
89 self.task_output_id, data
90 );
91
92 Ok(Some(data))
93 }
94
95 fn get_task_id(&self) -> TaskId {
96 TaskId::from(self.task_output_id.get_task_id().unwrap())
97 }
98}