1use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::catalog::DatabaseId;
21use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
22use risingwave_common::monitor::{EndpointExt, TcpConfig};
23use risingwave_common::util::addr::HostAddr;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
26use risingwave_pb::common::BatchQueryEpoch;
27use risingwave_pb::compute::config_service_client::ConfigServiceClient;
28use risingwave_pb::compute::{
29 ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
30};
31use risingwave_pb::monitor_service::monitor_service_client::MonitorServiceClient;
32use risingwave_pb::monitor_service::{
33 AnalyzeHeapRequest, AnalyzeHeapResponse, GetStreamingStatsRequest, GetStreamingStatsResponse,
34 HeapProfilingRequest, HeapProfilingResponse, ListHeapProfilingRequest,
35 ListHeapProfilingResponse, ProfilingRequest, ProfilingResponse, StackTraceRequest,
36 StackTraceResponse,
37};
38use risingwave_pb::plan_common::ExprContext;
39use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient;
40use risingwave_pb::task_service::task_service_client::TaskServiceClient;
41use risingwave_pb::task_service::{
42 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
43 FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
44 PbPermits, TaskInfoResponse, permits,
45};
46use tokio::sync::mpsc;
47use tokio_stream::wrappers::UnboundedReceiverStream;
48use tonic::Streaming;
49use tonic::transport::{Channel, Endpoint};
50
51use crate::error::{Result, RpcError};
52use crate::{RpcClient, RpcClientPool};
53
54#[derive(Clone)]
61pub struct ComputeClient {
62 pub exchange_client: ExchangeServiceClient<Channel>,
63 pub task_client: TaskServiceClient<Channel>,
64 pub monitor_client: MonitorServiceClient<Channel>,
65 pub config_client: ConfigServiceClient<Channel>,
66 pub addr: HostAddr,
67}
68
69impl ComputeClient {
70 pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
71 let channel = Endpoint::from_shared(format!("http://{}", &addr))?
72 .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
73 .initial_stream_window_size(STREAM_WINDOW_SIZE)
74 .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
75 .monitored_connect(
76 "grpc-compute-client",
77 TcpConfig {
78 tcp_nodelay: true,
79 ..Default::default()
80 },
81 )
82 .await?;
83 Ok(Self::with_channel(addr, channel))
84 }
85
86 pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
87 let exchange_client =
88 ExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
89 let task_client =
90 TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
91 let monitor_client =
92 MonitorServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
93 let config_client = ConfigServiceClient::new(channel);
94 Self {
95 exchange_client,
96 task_client,
97 monitor_client,
98 config_client,
99 addr,
100 }
101 }
102
103 pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
104 Ok(self
105 .exchange_client
106 .to_owned()
107 .get_data(GetDataRequest {
108 task_output_id: Some(output_id),
109 })
110 .await
111 .map_err(RpcError::from_compute_status)?
112 .into_inner())
113 }
114
115 pub async fn get_stream(
116 &self,
117 up_actor_id: u32,
118 down_actor_id: u32,
119 up_fragment_id: u32,
120 down_fragment_id: u32,
121 database_id: DatabaseId,
122 term_id: String,
123 ) -> Result<(
124 Streaming<GetStreamResponse>,
125 mpsc::UnboundedSender<permits::Value>,
126 )> {
127 use risingwave_pb::task_service::get_stream_request::*;
128
129 let (permits_tx, permits_rx) = mpsc::unbounded_channel();
131
132 let request_stream = futures::stream::once(futures::future::ready(
133 GetStreamRequest {
135 value: Some(Value::Get(Get {
136 up_actor_id,
137 down_actor_id,
138 up_fragment_id,
139 down_fragment_id,
140 database_id: database_id.database_id,
141 term_id,
142 })),
143 },
144 ))
145 .chain(
146 UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
148 value: Some(Value::AddPermits(PbPermits {
149 value: Some(permits),
150 })),
151 }),
152 );
153
154 let response_stream = self
155 .exchange_client
156 .to_owned()
157 .get_stream(request_stream)
158 .await
159 .inspect_err(|_| {
160 tracing::error!(
161 "failed to create stream from remote_input {} from actor {} to actor {}",
162 self.addr,
163 up_actor_id,
164 down_actor_id
165 )
166 })
167 .map_err(RpcError::from_compute_status)?
168 .into_inner();
169
170 Ok((response_stream, permits_tx))
171 }
172
173 pub async fn create_task(
174 &self,
175 task_id: TaskId,
176 plan: PlanFragment,
177 epoch: BatchQueryEpoch,
178 expr_context: ExprContext,
179 ) -> Result<Streaming<TaskInfoResponse>> {
180 Ok(self
181 .task_client
182 .to_owned()
183 .create_task(CreateTaskRequest {
184 task_id: Some(task_id),
185 plan: Some(plan),
186 epoch: Some(epoch),
187 tracing_context: TracingContext::from_current_span().to_protobuf(),
188 expr_context: Some(expr_context),
189 })
190 .await
191 .map_err(RpcError::from_compute_status)?
192 .into_inner())
193 }
194
195 pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
196 Ok(self
197 .task_client
198 .to_owned()
199 .execute(req)
200 .await
201 .map_err(RpcError::from_compute_status)?
202 .into_inner())
203 }
204
205 pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
206 Ok(self
207 .task_client
208 .to_owned()
209 .cancel_task(req)
210 .await
211 .map_err(RpcError::from_compute_status)?
212 .into_inner())
213 }
214
215 pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
216 Ok(self
217 .task_client
218 .to_owned()
219 .fast_insert(req)
220 .await
221 .map_err(RpcError::from_compute_status)?
222 .into_inner())
223 }
224
225 pub async fn stack_trace(&self) -> Result<StackTraceResponse> {
226 Ok(self
227 .monitor_client
228 .to_owned()
229 .stack_trace(StackTraceRequest::default())
230 .await
231 .map_err(RpcError::from_compute_status)?
232 .into_inner())
233 }
234
235 pub async fn get_streaming_stats(&self) -> Result<GetStreamingStatsResponse> {
236 Ok(self
237 .monitor_client
238 .to_owned()
239 .get_streaming_stats(GetStreamingStatsRequest::default())
240 .await
241 .map_err(RpcError::from_compute_status)?
242 .into_inner())
243 }
244
245 pub async fn profile(&self, sleep_s: u64) -> Result<ProfilingResponse> {
246 Ok(self
247 .monitor_client
248 .to_owned()
249 .profiling(ProfilingRequest { sleep_s })
250 .await
251 .map_err(RpcError::from_compute_status)?
252 .into_inner())
253 }
254
255 pub async fn heap_profile(&self, dir: String) -> Result<HeapProfilingResponse> {
256 Ok(self
257 .monitor_client
258 .to_owned()
259 .heap_profiling(HeapProfilingRequest { dir })
260 .await
261 .map_err(RpcError::from_compute_status)?
262 .into_inner())
263 }
264
265 pub async fn list_heap_profile(&self) -> Result<ListHeapProfilingResponse> {
266 Ok(self
267 .monitor_client
268 .to_owned()
269 .list_heap_profiling(ListHeapProfilingRequest {})
270 .await
271 .map_err(RpcError::from_compute_status)?
272 .into_inner())
273 }
274
275 pub async fn analyze_heap(&self, path: String) -> Result<AnalyzeHeapResponse> {
276 Ok(self
277 .monitor_client
278 .to_owned()
279 .analyze_heap(AnalyzeHeapRequest { path })
280 .await
281 .map_err(RpcError::from_compute_status)?
282 .into_inner())
283 }
284
285 pub async fn show_config(&self) -> Result<ShowConfigResponse> {
286 Ok(self
287 .config_client
288 .to_owned()
289 .show_config(ShowConfigRequest {})
290 .await
291 .map_err(RpcError::from_compute_status)?
292 .into_inner())
293 }
294
295 pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
296 Ok(self
297 .config_client
298 .to_owned()
299 .resize_cache(request)
300 .await
301 .map_err(RpcError::from_compute_status)?
302 .into_inner())
303 }
304}
305
306#[async_trait]
307impl RpcClient for ComputeClient {
308 async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
309 Self::new(host_addr, opts).await
310 }
311}
312
313pub type ComputeClientPool = RpcClientPool<ComputeClient>;
314pub type ComputeClientPoolRef = Arc<ComputeClientPool>;