1use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::catalog::DatabaseId;
21use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
22use risingwave_common::id::{ActorId, FragmentId};
23use risingwave_common::monitor::{EndpointExt, TcpConfig};
24use risingwave_common::util::addr::HostAddr;
25use risingwave_common::util::tracing::TracingContext;
26use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
27use risingwave_pb::compute::config_service_client::ConfigServiceClient;
28use risingwave_pb::compute::{
29 ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
30};
31use risingwave_pb::monitor_service::monitor_service_client::MonitorServiceClient;
32use risingwave_pb::monitor_service::{
33 AnalyzeHeapRequest, AnalyzeHeapResponse, GetStreamingStatsRequest, GetStreamingStatsResponse,
34 HeapProfilingRequest, HeapProfilingResponse, ListHeapProfilingRequest,
35 ListHeapProfilingResponse, ProfilingRequest, ProfilingResponse, StackTraceRequest,
36 StackTraceResponse,
37};
38use risingwave_pb::plan_common::ExprContext;
39use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient;
40use risingwave_pb::task_service::task_service_client::TaskServiceClient;
41use risingwave_pb::task_service::{
42 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
43 FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
44 PbPermits, TaskInfoResponse, permits,
45};
46use tokio::sync::mpsc;
47use tokio_stream::wrappers::UnboundedReceiverStream;
48use tonic::Streaming;
49use tonic::transport::{Channel, Endpoint};
50
51use crate::error::{Result, RpcError};
52use crate::{RpcClient, RpcClientPool};
53
54#[derive(Clone)]
61pub struct ComputeClient {
62 pub exchange_client: ExchangeServiceClient<Channel>,
63 pub task_client: TaskServiceClient<Channel>,
64 pub monitor_client: MonitorServiceClient<Channel>,
65 pub config_client: ConfigServiceClient<Channel>,
66 pub addr: HostAddr,
67}
68
69impl ComputeClient {
70 pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
71 let channel = Endpoint::from_shared(format!("http://{}", &addr))?
72 .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
73 .initial_stream_window_size(STREAM_WINDOW_SIZE)
74 .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
75 .monitored_connect(
76 "grpc-compute-client",
77 TcpConfig {
78 tcp_nodelay: true,
79 ..Default::default()
80 },
81 )
82 .await?;
83 Ok(Self::with_channel(addr, channel))
84 }
85
86 pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
87 let exchange_client =
88 ExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
89 let task_client =
90 TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
91 let monitor_client =
92 MonitorServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
93 let config_client = ConfigServiceClient::new(channel);
94 Self {
95 exchange_client,
96 task_client,
97 monitor_client,
98 config_client,
99 addr,
100 }
101 }
102
103 pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
104 Ok(self
105 .exchange_client
106 .clone()
107 .get_data(GetDataRequest {
108 task_output_id: Some(output_id),
109 })
110 .await
111 .map_err(RpcError::from_compute_status)?
112 .into_inner())
113 }
114
115 pub async fn get_stream(
116 &self,
117 up_actor_id: ActorId,
118 down_actor_id: ActorId,
119 up_fragment_id: FragmentId,
120 down_fragment_id: FragmentId,
121 database_id: DatabaseId,
122 term_id: String,
123 ) -> Result<(
124 Streaming<GetStreamResponse>,
125 mpsc::UnboundedSender<permits::Value>,
126 )> {
127 use risingwave_pb::task_service::get_stream_request::*;
128
129 let (permits_tx, permits_rx) = mpsc::unbounded_channel();
131
132 let request_stream = futures::stream::once(futures::future::ready(
133 GetStreamRequest {
135 value: Some(Value::Get(Get {
136 up_actor_id,
137 down_actor_id,
138 up_fragment_id,
139 down_fragment_id,
140 database_id,
141 term_id,
142 })),
143 },
144 ))
145 .chain(
146 UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
148 value: Some(Value::AddPermits(PbPermits {
149 value: Some(permits),
150 })),
151 }),
152 );
153
154 let response_stream = self
155 .exchange_client
156 .clone()
157 .get_stream(request_stream)
158 .await
159 .inspect_err(|_| {
160 tracing::error!(
161 "failed to create stream from remote_input {} from actor {} to actor {}",
162 self.addr,
163 up_actor_id,
164 down_actor_id
165 )
166 })
167 .map_err(RpcError::from_compute_status)?
168 .into_inner();
169
170 Ok((response_stream, permits_tx))
171 }
172
173 pub async fn create_task(
174 &self,
175 task_id: TaskId,
176 plan: PlanFragment,
177 expr_context: ExprContext,
178 ) -> Result<Streaming<TaskInfoResponse>> {
179 Ok(self
180 .task_client
181 .clone()
182 .create_task(CreateTaskRequest {
183 task_id: Some(task_id),
184 plan: Some(plan),
185 tracing_context: TracingContext::from_current_span().to_protobuf(),
186 expr_context: Some(expr_context),
187 })
188 .await
189 .map_err(RpcError::from_compute_status)?
190 .into_inner())
191 }
192
193 pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
194 Ok(self
195 .task_client
196 .clone()
197 .execute(req)
198 .await
199 .map_err(RpcError::from_compute_status)?
200 .into_inner())
201 }
202
203 pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
204 Ok(self
205 .task_client
206 .clone()
207 .cancel_task(req)
208 .await
209 .map_err(RpcError::from_compute_status)?
210 .into_inner())
211 }
212
213 pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
214 Ok(self
215 .task_client
216 .clone()
217 .fast_insert(req)
218 .await
219 .map_err(RpcError::from_compute_status)?
220 .into_inner())
221 }
222
223 pub async fn stack_trace(&self, req: StackTraceRequest) -> Result<StackTraceResponse> {
224 Ok(self
225 .monitor_client
226 .clone()
227 .stack_trace(req)
228 .await
229 .map_err(RpcError::from_compute_status)?
230 .into_inner())
231 }
232
233 pub async fn get_streaming_stats(&self) -> Result<GetStreamingStatsResponse> {
234 Ok(self
235 .monitor_client
236 .clone()
237 .get_streaming_stats(GetStreamingStatsRequest::default())
238 .await
239 .map_err(RpcError::from_compute_status)?
240 .into_inner())
241 }
242
243 pub async fn profile(&self, sleep_s: u64) -> Result<ProfilingResponse> {
244 Ok(self
245 .monitor_client
246 .clone()
247 .profiling(ProfilingRequest { sleep_s })
248 .await
249 .map_err(RpcError::from_compute_status)?
250 .into_inner())
251 }
252
253 pub async fn heap_profile(&self, dir: String) -> Result<HeapProfilingResponse> {
254 Ok(self
255 .monitor_client
256 .clone()
257 .heap_profiling(HeapProfilingRequest { dir })
258 .await
259 .map_err(RpcError::from_compute_status)?
260 .into_inner())
261 }
262
263 pub async fn list_heap_profile(&self) -> Result<ListHeapProfilingResponse> {
264 Ok(self
265 .monitor_client
266 .clone()
267 .list_heap_profiling(ListHeapProfilingRequest {})
268 .await
269 .map_err(RpcError::from_compute_status)?
270 .into_inner())
271 }
272
273 pub async fn analyze_heap(&self, path: String) -> Result<AnalyzeHeapResponse> {
274 Ok(self
275 .monitor_client
276 .clone()
277 .analyze_heap(AnalyzeHeapRequest { path })
278 .await
279 .map_err(RpcError::from_compute_status)?
280 .into_inner())
281 }
282
283 pub async fn show_config(&self) -> Result<ShowConfigResponse> {
284 Ok(self
285 .config_client
286 .clone()
287 .show_config(ShowConfigRequest {})
288 .await
289 .map_err(RpcError::from_compute_status)?
290 .into_inner())
291 }
292
293 pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
294 Ok(self
295 .config_client
296 .clone()
297 .resize_cache(request)
298 .await
299 .map_err(RpcError::from_compute_status)?
300 .into_inner())
301 }
302}
303
304#[async_trait]
305impl RpcClient for ComputeClient {
306 async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
307 Self::new(host_addr, opts).await
308 }
309}
310
311pub type ComputeClientPool = RpcClientPool<ComputeClient>;
312pub type ComputeClientPoolRef = Arc<ComputeClientPool>;