1use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::catalog::DatabaseId;
21use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
22use risingwave_common::id::{ActorId, FragmentId};
23use risingwave_common::monitor::{EndpointExt, TcpConfig};
24use risingwave_common::util::addr::HostAddr;
25use risingwave_common::util::tracing::TracingContext;
26use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
27use risingwave_pb::compute::config_service_client::ConfigServiceClient;
28use risingwave_pb::compute::{
29 ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
30};
31use risingwave_pb::monitor_service::monitor_service_client::MonitorServiceClient;
32use risingwave_pb::monitor_service::{
33 AnalyzeHeapRequest, AnalyzeHeapResponse, GetStreamingStatsRequest, GetStreamingStatsResponse,
34 HeapProfilingRequest, HeapProfilingResponse, ListHeapProfilingRequest,
35 ListHeapProfilingResponse, ProfilingRequest, ProfilingResponse, StackTraceRequest,
36 StackTraceResponse,
37};
38use risingwave_pb::plan_common::ExprContext;
39use risingwave_pb::task_service::batch_exchange_service_client::BatchExchangeServiceClient;
40use risingwave_pb::task_service::stream_exchange_service_client::StreamExchangeServiceClient;
41use risingwave_pb::task_service::task_service_client::TaskServiceClient;
42use risingwave_pb::task_service::{
43 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
44 FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
45 PbPermits, TaskInfoResponse, permits,
46};
47use tokio::sync::mpsc;
48use tokio_stream::wrappers::UnboundedReceiverStream;
49use tonic::Streaming;
50use tonic::transport::{Channel, Endpoint};
51
52use crate::error::{Result, RpcError};
53use crate::{RpcClient, RpcClientPool};
54
55#[derive(Clone)]
62pub struct ComputeClient {
63 pub batch_exchange_client: BatchExchangeServiceClient<Channel>,
64 pub stream_exchange_client: StreamExchangeServiceClient<Channel>,
65 pub task_client: TaskServiceClient<Channel>,
66 pub monitor_client: MonitorServiceClient<Channel>,
67 pub config_client: ConfigServiceClient<Channel>,
68 pub addr: HostAddr,
69}
70
71impl ComputeClient {
72 pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
73 let channel = Endpoint::from_shared(format!("http://{}", &addr))?
74 .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
75 .initial_stream_window_size(STREAM_WINDOW_SIZE)
76 .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
77 .monitored_connect(
78 "grpc-compute-client",
79 TcpConfig {
80 tcp_nodelay: true,
81 ..Default::default()
82 },
83 )
84 .await?;
85 Ok(Self::with_channel(addr, channel))
86 }
87
88 pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
89 let batch_exchange_client =
90 BatchExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
91 let stream_exchange_client =
92 StreamExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
93 let task_client =
94 TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
95 let monitor_client =
96 MonitorServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
97 let config_client = ConfigServiceClient::new(channel);
98 Self {
99 batch_exchange_client,
100 stream_exchange_client,
101 task_client,
102 monitor_client,
103 config_client,
104 addr,
105 }
106 }
107
108 pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
109 Ok(self
110 .batch_exchange_client
111 .clone()
112 .get_data(GetDataRequest {
113 task_output_id: Some(output_id),
114 })
115 .await
116 .map_err(RpcError::from_compute_status)?
117 .into_inner())
118 }
119
120 pub async fn get_stream(
121 &self,
122 up_actor_id: ActorId,
123 down_actor_id: ActorId,
124 up_fragment_id: FragmentId,
125 down_fragment_id: FragmentId,
126 database_id: DatabaseId,
127 term_id: String,
128 ) -> Result<(
129 Streaming<GetStreamResponse>,
130 mpsc::UnboundedSender<permits::Value>,
131 )> {
132 use risingwave_pb::task_service::get_stream_request::*;
133
134 let (permits_tx, permits_rx) = mpsc::unbounded_channel();
136
137 let request_stream = futures::stream::once(futures::future::ready(
138 GetStreamRequest {
140 value: Some(Value::Get(Get {
141 up_actor_id,
142 down_actor_id,
143 up_fragment_id,
144 down_fragment_id,
145 database_id,
146 term_id,
147 })),
148 },
149 ))
150 .chain(
151 UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
153 value: Some(Value::AddPermits(PbPermits {
154 value: Some(permits),
155 })),
156 }),
157 );
158
159 let response_stream = self
160 .stream_exchange_client
161 .clone()
162 .get_stream(request_stream)
163 .await
164 .inspect_err(|_| {
165 tracing::error!(
166 "failed to create stream from remote_input {} from actor {} to actor {}",
167 self.addr,
168 up_actor_id,
169 down_actor_id
170 )
171 })
172 .map_err(RpcError::from_compute_status)?
173 .into_inner();
174
175 Ok((response_stream, permits_tx))
176 }
177
178 pub async fn create_task(
179 &self,
180 task_id: TaskId,
181 plan: PlanFragment,
182 expr_context: ExprContext,
183 ) -> Result<Streaming<TaskInfoResponse>> {
184 Ok(self
185 .task_client
186 .clone()
187 .create_task(CreateTaskRequest {
188 task_id: Some(task_id),
189 plan: Some(plan),
190 tracing_context: TracingContext::from_current_span().to_protobuf(),
191 expr_context: Some(expr_context),
192 })
193 .await
194 .map_err(RpcError::from_compute_status)?
195 .into_inner())
196 }
197
198 pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
199 Ok(self
200 .task_client
201 .clone()
202 .execute(req)
203 .await
204 .map_err(RpcError::from_compute_status)?
205 .into_inner())
206 }
207
208 pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
209 Ok(self
210 .task_client
211 .clone()
212 .cancel_task(req)
213 .await
214 .map_err(RpcError::from_compute_status)?
215 .into_inner())
216 }
217
218 pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
219 Ok(self
220 .task_client
221 .clone()
222 .fast_insert(req)
223 .await
224 .map_err(RpcError::from_compute_status)?
225 .into_inner())
226 }
227
228 pub async fn stack_trace(&self, req: StackTraceRequest) -> Result<StackTraceResponse> {
229 Ok(self
230 .monitor_client
231 .clone()
232 .stack_trace(req)
233 .await
234 .map_err(RpcError::from_compute_status)?
235 .into_inner())
236 }
237
238 pub async fn get_streaming_stats(&self) -> Result<GetStreamingStatsResponse> {
239 Ok(self
240 .monitor_client
241 .clone()
242 .get_streaming_stats(GetStreamingStatsRequest::default())
243 .await
244 .map_err(RpcError::from_compute_status)?
245 .into_inner())
246 }
247
248 pub async fn profile(&self, sleep_s: u64) -> Result<ProfilingResponse> {
249 Ok(self
250 .monitor_client
251 .clone()
252 .profiling(ProfilingRequest { sleep_s })
253 .await
254 .map_err(RpcError::from_compute_status)?
255 .into_inner())
256 }
257
258 pub async fn heap_profile(&self, dir: String) -> Result<HeapProfilingResponse> {
259 Ok(self
260 .monitor_client
261 .clone()
262 .heap_profiling(HeapProfilingRequest { dir })
263 .await
264 .map_err(RpcError::from_compute_status)?
265 .into_inner())
266 }
267
268 pub async fn list_heap_profile(&self) -> Result<ListHeapProfilingResponse> {
269 Ok(self
270 .monitor_client
271 .clone()
272 .list_heap_profiling(ListHeapProfilingRequest {})
273 .await
274 .map_err(RpcError::from_compute_status)?
275 .into_inner())
276 }
277
278 pub async fn analyze_heap(&self, path: String) -> Result<AnalyzeHeapResponse> {
279 Ok(self
280 .monitor_client
281 .clone()
282 .analyze_heap(AnalyzeHeapRequest { path })
283 .await
284 .map_err(RpcError::from_compute_status)?
285 .into_inner())
286 }
287
288 pub async fn show_config(&self) -> Result<ShowConfigResponse> {
289 Ok(self
290 .config_client
291 .clone()
292 .show_config(ShowConfigRequest {})
293 .await
294 .map_err(RpcError::from_compute_status)?
295 .into_inner())
296 }
297
298 pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
299 Ok(self
300 .config_client
301 .clone()
302 .resize_cache(request)
303 .await
304 .map_err(RpcError::from_compute_status)?
305 .into_inner())
306 }
307}
308
309#[async_trait]
310impl RpcClient for ComputeClient {
311 async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
312 Self::new(host_addr, opts).await
313 }
314}
315
316pub type ComputeClientPool = RpcClientPool<ComputeClient>;
317pub type ComputeClientPoolRef = Arc<ComputeClientPool>;