1use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::catalog::DatabaseId;
21use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
22use risingwave_common::monitor::{EndpointExt, TcpConfig};
23use risingwave_common::util::addr::HostAddr;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
26use risingwave_pb::compute::config_service_client::ConfigServiceClient;
27use risingwave_pb::compute::{
28 ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
29};
30use risingwave_pb::monitor_service::monitor_service_client::MonitorServiceClient;
31use risingwave_pb::monitor_service::{
32 AnalyzeHeapRequest, AnalyzeHeapResponse, GetStreamingStatsRequest, GetStreamingStatsResponse,
33 HeapProfilingRequest, HeapProfilingResponse, ListHeapProfilingRequest,
34 ListHeapProfilingResponse, ProfilingRequest, ProfilingResponse, StackTraceRequest,
35 StackTraceResponse,
36};
37use risingwave_pb::plan_common::ExprContext;
38use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient;
39use risingwave_pb::task_service::task_service_client::TaskServiceClient;
40use risingwave_pb::task_service::{
41 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
42 FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
43 PbPermits, TaskInfoResponse, permits,
44};
45use tokio::sync::mpsc;
46use tokio_stream::wrappers::UnboundedReceiverStream;
47use tonic::Streaming;
48use tonic::transport::{Channel, Endpoint};
49
50use crate::error::{Result, RpcError};
51use crate::{RpcClient, RpcClientPool};
52
53#[derive(Clone)]
60pub struct ComputeClient {
61 pub exchange_client: ExchangeServiceClient<Channel>,
62 pub task_client: TaskServiceClient<Channel>,
63 pub monitor_client: MonitorServiceClient<Channel>,
64 pub config_client: ConfigServiceClient<Channel>,
65 pub addr: HostAddr,
66}
67
68impl ComputeClient {
69 pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
70 let channel = Endpoint::from_shared(format!("http://{}", &addr))?
71 .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
72 .initial_stream_window_size(STREAM_WINDOW_SIZE)
73 .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
74 .monitored_connect(
75 "grpc-compute-client",
76 TcpConfig {
77 tcp_nodelay: true,
78 ..Default::default()
79 },
80 )
81 .await?;
82 Ok(Self::with_channel(addr, channel))
83 }
84
85 pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
86 let exchange_client =
87 ExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
88 let task_client =
89 TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
90 let monitor_client =
91 MonitorServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
92 let config_client = ConfigServiceClient::new(channel);
93 Self {
94 exchange_client,
95 task_client,
96 monitor_client,
97 config_client,
98 addr,
99 }
100 }
101
102 pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
103 Ok(self
104 .exchange_client
105 .clone()
106 .get_data(GetDataRequest {
107 task_output_id: Some(output_id),
108 })
109 .await
110 .map_err(RpcError::from_compute_status)?
111 .into_inner())
112 }
113
114 pub async fn get_stream(
115 &self,
116 up_actor_id: u32,
117 down_actor_id: u32,
118 up_fragment_id: u32,
119 down_fragment_id: u32,
120 database_id: DatabaseId,
121 term_id: String,
122 ) -> Result<(
123 Streaming<GetStreamResponse>,
124 mpsc::UnboundedSender<permits::Value>,
125 )> {
126 use risingwave_pb::task_service::get_stream_request::*;
127
128 let (permits_tx, permits_rx) = mpsc::unbounded_channel();
130
131 let request_stream = futures::stream::once(futures::future::ready(
132 GetStreamRequest {
134 value: Some(Value::Get(Get {
135 up_actor_id,
136 down_actor_id,
137 up_fragment_id,
138 down_fragment_id,
139 database_id: database_id.database_id,
140 term_id,
141 })),
142 },
143 ))
144 .chain(
145 UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
147 value: Some(Value::AddPermits(PbPermits {
148 value: Some(permits),
149 })),
150 }),
151 );
152
153 let response_stream = self
154 .exchange_client
155 .clone()
156 .get_stream(request_stream)
157 .await
158 .inspect_err(|_| {
159 tracing::error!(
160 "failed to create stream from remote_input {} from actor {} to actor {}",
161 self.addr,
162 up_actor_id,
163 down_actor_id
164 )
165 })
166 .map_err(RpcError::from_compute_status)?
167 .into_inner();
168
169 Ok((response_stream, permits_tx))
170 }
171
172 pub async fn create_task(
173 &self,
174 task_id: TaskId,
175 plan: PlanFragment,
176 expr_context: ExprContext,
177 ) -> Result<Streaming<TaskInfoResponse>> {
178 Ok(self
179 .task_client
180 .clone()
181 .create_task(CreateTaskRequest {
182 task_id: Some(task_id),
183 plan: Some(plan),
184 tracing_context: TracingContext::from_current_span().to_protobuf(),
185 expr_context: Some(expr_context),
186 })
187 .await
188 .map_err(RpcError::from_compute_status)?
189 .into_inner())
190 }
191
192 pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
193 Ok(self
194 .task_client
195 .clone()
196 .execute(req)
197 .await
198 .map_err(RpcError::from_compute_status)?
199 .into_inner())
200 }
201
202 pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
203 Ok(self
204 .task_client
205 .clone()
206 .cancel_task(req)
207 .await
208 .map_err(RpcError::from_compute_status)?
209 .into_inner())
210 }
211
212 pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
213 Ok(self
214 .task_client
215 .clone()
216 .fast_insert(req)
217 .await
218 .map_err(RpcError::from_compute_status)?
219 .into_inner())
220 }
221
222 pub async fn stack_trace(&self, req: StackTraceRequest) -> Result<StackTraceResponse> {
223 Ok(self
224 .monitor_client
225 .clone()
226 .stack_trace(req)
227 .await
228 .map_err(RpcError::from_compute_status)?
229 .into_inner())
230 }
231
232 pub async fn get_streaming_stats(&self) -> Result<GetStreamingStatsResponse> {
233 Ok(self
234 .monitor_client
235 .clone()
236 .get_streaming_stats(GetStreamingStatsRequest::default())
237 .await
238 .map_err(RpcError::from_compute_status)?
239 .into_inner())
240 }
241
242 pub async fn profile(&self, sleep_s: u64) -> Result<ProfilingResponse> {
243 Ok(self
244 .monitor_client
245 .clone()
246 .profiling(ProfilingRequest { sleep_s })
247 .await
248 .map_err(RpcError::from_compute_status)?
249 .into_inner())
250 }
251
252 pub async fn heap_profile(&self, dir: String) -> Result<HeapProfilingResponse> {
253 Ok(self
254 .monitor_client
255 .clone()
256 .heap_profiling(HeapProfilingRequest { dir })
257 .await
258 .map_err(RpcError::from_compute_status)?
259 .into_inner())
260 }
261
262 pub async fn list_heap_profile(&self) -> Result<ListHeapProfilingResponse> {
263 Ok(self
264 .monitor_client
265 .clone()
266 .list_heap_profiling(ListHeapProfilingRequest {})
267 .await
268 .map_err(RpcError::from_compute_status)?
269 .into_inner())
270 }
271
272 pub async fn analyze_heap(&self, path: String) -> Result<AnalyzeHeapResponse> {
273 Ok(self
274 .monitor_client
275 .clone()
276 .analyze_heap(AnalyzeHeapRequest { path })
277 .await
278 .map_err(RpcError::from_compute_status)?
279 .into_inner())
280 }
281
282 pub async fn show_config(&self) -> Result<ShowConfigResponse> {
283 Ok(self
284 .config_client
285 .clone()
286 .show_config(ShowConfigRequest {})
287 .await
288 .map_err(RpcError::from_compute_status)?
289 .into_inner())
290 }
291
292 pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
293 Ok(self
294 .config_client
295 .clone()
296 .resize_cache(request)
297 .await
298 .map_err(RpcError::from_compute_status)?
299 .into_inner())
300 }
301}
302
303#[async_trait]
304impl RpcClient for ComputeClient {
305 async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
306 Self::new(host_addr, opts).await
307 }
308}
309
310pub type ComputeClientPool = RpcClientPool<ComputeClient>;
311pub type ComputeClientPoolRef = Arc<ComputeClientPool>;