risingwave_rpc_client/
compute_client.rs1use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
21use risingwave_common::id::{ActorId, FragmentId};
22use risingwave_common::monitor::{EndpointExt, TcpConfig};
23use risingwave_common::util::addr::HostAddr;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
26use risingwave_pb::compute::config_service_client::ConfigServiceClient;
27use risingwave_pb::compute::{
28 ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
29};
30use risingwave_pb::id::PartialGraphId;
31use risingwave_pb::plan_common::ExprContext;
32use risingwave_pb::task_service::batch_exchange_service_client::BatchExchangeServiceClient;
33use risingwave_pb::task_service::stream_exchange_service_client::StreamExchangeServiceClient;
34use risingwave_pb::task_service::task_service_client::TaskServiceClient;
35use risingwave_pb::task_service::{
36 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
37 FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
38 PbPermits, TaskInfoResponse, permits,
39};
40use tokio::sync::mpsc;
41use tokio_stream::wrappers::UnboundedReceiverStream;
42use tonic::Streaming;
43use tonic::transport::{Channel, Endpoint};
44
45use crate::error::{Result, RpcError};
46use crate::{RpcClient, RpcClientPool};
47
48#[derive(Clone)]
54pub struct ComputeClient {
55 pub batch_exchange_client: BatchExchangeServiceClient<Channel>,
56 pub stream_exchange_client: StreamExchangeServiceClient<Channel>,
57 pub task_client: TaskServiceClient<Channel>,
58 pub config_client: ConfigServiceClient<Channel>,
59 pub addr: HostAddr,
60}
61
62impl ComputeClient {
63 pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
64 let channel = Endpoint::from_shared(format!("http://{}", &addr))?
65 .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
66 .initial_stream_window_size(STREAM_WINDOW_SIZE)
67 .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
68 .monitored_connect(
69 "grpc-compute-client",
70 TcpConfig {
71 tcp_nodelay: true,
72 ..Default::default()
73 },
74 )
75 .await?;
76 Ok(Self::with_channel(addr, channel))
77 }
78
79 pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
80 let batch_exchange_client =
81 BatchExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
82 let stream_exchange_client =
83 StreamExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
84 let task_client =
85 TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
86 let config_client = ConfigServiceClient::new(channel);
87 Self {
88 batch_exchange_client,
89 stream_exchange_client,
90 task_client,
91 config_client,
92 addr,
93 }
94 }
95
96 pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
97 Ok(self
98 .batch_exchange_client
99 .clone()
100 .get_data(GetDataRequest {
101 task_output_id: Some(output_id),
102 })
103 .await
104 .map_err(RpcError::from_compute_status)?
105 .into_inner())
106 }
107
108 pub async fn get_stream(
109 &self,
110 up_actor_id: ActorId,
111 down_actor_id: ActorId,
112 up_fragment_id: FragmentId,
113 down_fragment_id: FragmentId,
114 up_partial_graph_id: PartialGraphId,
115 term_id: String,
116 ) -> Result<(
117 Streaming<GetStreamResponse>,
118 mpsc::UnboundedSender<permits::Value>,
119 )> {
120 use risingwave_pb::task_service::get_stream_request::*;
121
122 let (permits_tx, permits_rx) = mpsc::unbounded_channel();
124
125 let request_stream = futures::stream::once(futures::future::ready(
126 GetStreamRequest {
128 value: Some(Value::Get(Get {
129 up_actor_id,
130 down_actor_id,
131 up_fragment_id,
132 down_fragment_id,
133 up_partial_graph_id,
134 term_id,
135 })),
136 },
137 ))
138 .chain(
139 UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
141 value: Some(Value::AddPermits(PbPermits {
142 value: Some(permits),
143 })),
144 }),
145 );
146
147 let response_stream = self
148 .stream_exchange_client
149 .clone()
150 .get_stream(request_stream)
151 .await
152 .inspect_err(|_| {
153 tracing::error!(
154 "failed to create stream from remote_input {} from actor {} to actor {}",
155 self.addr,
156 up_actor_id,
157 down_actor_id
158 )
159 })
160 .map_err(RpcError::from_compute_status)?
161 .into_inner();
162
163 Ok((response_stream, permits_tx))
164 }
165
166 pub async fn create_task(
167 &self,
168 task_id: TaskId,
169 plan: PlanFragment,
170 expr_context: ExprContext,
171 ) -> Result<Streaming<TaskInfoResponse>> {
172 Ok(self
173 .task_client
174 .clone()
175 .create_task(CreateTaskRequest {
176 task_id: Some(task_id),
177 plan: Some(plan),
178 tracing_context: TracingContext::from_current_span().to_protobuf(),
179 expr_context: Some(expr_context),
180 })
181 .await
182 .map_err(RpcError::from_compute_status)?
183 .into_inner())
184 }
185
186 pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
187 Ok(self
188 .task_client
189 .clone()
190 .execute(req)
191 .await
192 .map_err(RpcError::from_compute_status)?
193 .into_inner())
194 }
195
196 pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
197 Ok(self
198 .task_client
199 .clone()
200 .cancel_task(req)
201 .await
202 .map_err(RpcError::from_compute_status)?
203 .into_inner())
204 }
205
206 pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
207 Ok(self
208 .task_client
209 .clone()
210 .fast_insert(req)
211 .await
212 .map_err(RpcError::from_compute_status)?
213 .into_inner())
214 }
215
216 pub async fn show_config(&self) -> Result<ShowConfigResponse> {
217 Ok(self
218 .config_client
219 .clone()
220 .show_config(ShowConfigRequest {})
221 .await
222 .map_err(RpcError::from_compute_status)?
223 .into_inner())
224 }
225
226 pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
227 Ok(self
228 .config_client
229 .clone()
230 .resize_cache(request)
231 .await
232 .map_err(RpcError::from_compute_status)?
233 .into_inner())
234 }
235}
236
237#[async_trait]
238impl RpcClient for ComputeClient {
239 async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
240 Self::new(host_addr, opts).await
241 }
242}
243
244pub type ComputeClientPool = RpcClientPool<ComputeClient>;
245pub type ComputeClientPoolRef = Arc<ComputeClientPool>;