risingwave_rpc_client/
compute_client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::catalog::DatabaseId;
21use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
22use risingwave_common::id::{ActorId, FragmentId};
23use risingwave_common::monitor::{EndpointExt, TcpConfig};
24use risingwave_common::util::addr::HostAddr;
25use risingwave_common::util::tracing::TracingContext;
26use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
27use risingwave_pb::compute::config_service_client::ConfigServiceClient;
28use risingwave_pb::compute::{
29    ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
30};
31use risingwave_pb::monitor_service::monitor_service_client::MonitorServiceClient;
32use risingwave_pb::monitor_service::{
33    AnalyzeHeapRequest, AnalyzeHeapResponse, GetStreamingStatsRequest, GetStreamingStatsResponse,
34    HeapProfilingRequest, HeapProfilingResponse, ListHeapProfilingRequest,
35    ListHeapProfilingResponse, ProfilingRequest, ProfilingResponse, StackTraceRequest,
36    StackTraceResponse,
37};
38use risingwave_pb::plan_common::ExprContext;
39use risingwave_pb::task_service::batch_exchange_service_client::BatchExchangeServiceClient;
40use risingwave_pb::task_service::stream_exchange_service_client::StreamExchangeServiceClient;
41use risingwave_pb::task_service::task_service_client::TaskServiceClient;
42use risingwave_pb::task_service::{
43    CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
44    FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
45    PbPermits, TaskInfoResponse, permits,
46};
47use tokio::sync::mpsc;
48use tokio_stream::wrappers::UnboundedReceiverStream;
49use tonic::Streaming;
50use tonic::transport::{Channel, Endpoint};
51
52use crate::error::{Result, RpcError};
53use crate::{RpcClient, RpcClientPool};
54
55// TODO: this client has too many roles, e.g.
56// - batch MPP task query execution
57// - batch exchange
58// - streaming exchange
59// - general services specific to compute node, like monitoring, profiling, debugging, etc.
60// We should consider splitting them into different clients.
61#[derive(Clone)]
62pub struct ComputeClient {
63    pub batch_exchange_client: BatchExchangeServiceClient<Channel>,
64    pub stream_exchange_client: StreamExchangeServiceClient<Channel>,
65    pub task_client: TaskServiceClient<Channel>,
66    pub monitor_client: MonitorServiceClient<Channel>,
67    pub config_client: ConfigServiceClient<Channel>,
68    pub addr: HostAddr,
69}
70
71impl ComputeClient {
72    pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
73        let channel = Endpoint::from_shared(format!("http://{}", &addr))?
74            .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
75            .initial_stream_window_size(STREAM_WINDOW_SIZE)
76            .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
77            .monitored_connect(
78                "grpc-compute-client",
79                TcpConfig {
80                    tcp_nodelay: true,
81                    ..Default::default()
82                },
83            )
84            .await?;
85        Ok(Self::with_channel(addr, channel))
86    }
87
88    pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
89        let batch_exchange_client =
90            BatchExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
91        let stream_exchange_client =
92            StreamExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
93        let task_client =
94            TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
95        let monitor_client =
96            MonitorServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
97        let config_client = ConfigServiceClient::new(channel);
98        Self {
99            batch_exchange_client,
100            stream_exchange_client,
101            task_client,
102            monitor_client,
103            config_client,
104            addr,
105        }
106    }
107
108    pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
109        Ok(self
110            .batch_exchange_client
111            .clone()
112            .get_data(GetDataRequest {
113                task_output_id: Some(output_id),
114            })
115            .await
116            .map_err(RpcError::from_compute_status)?
117            .into_inner())
118    }
119
120    pub async fn get_stream(
121        &self,
122        up_actor_id: ActorId,
123        down_actor_id: ActorId,
124        up_fragment_id: FragmentId,
125        down_fragment_id: FragmentId,
126        database_id: DatabaseId,
127        term_id: String,
128    ) -> Result<(
129        Streaming<GetStreamResponse>,
130        mpsc::UnboundedSender<permits::Value>,
131    )> {
132        use risingwave_pb::task_service::get_stream_request::*;
133
134        // Create channel used for the downstream to add back the permits to the upstream.
135        let (permits_tx, permits_rx) = mpsc::unbounded_channel();
136
137        let request_stream = futures::stream::once(futures::future::ready(
138            // `Get` as the first request.
139            GetStreamRequest {
140                value: Some(Value::Get(Get {
141                    up_actor_id,
142                    down_actor_id,
143                    up_fragment_id,
144                    down_fragment_id,
145                    database_id,
146                    term_id,
147                })),
148            },
149        ))
150        .chain(
151            // `AddPermits` as the followings.
152            UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
153                value: Some(Value::AddPermits(PbPermits {
154                    value: Some(permits),
155                })),
156            }),
157        );
158
159        let response_stream = self
160            .stream_exchange_client
161            .clone()
162            .get_stream(request_stream)
163            .await
164            .inspect_err(|_| {
165                tracing::error!(
166                    "failed to create stream from remote_input {} from actor {} to actor {}",
167                    self.addr,
168                    up_actor_id,
169                    down_actor_id
170                )
171            })
172            .map_err(RpcError::from_compute_status)?
173            .into_inner();
174
175        Ok((response_stream, permits_tx))
176    }
177
178    pub async fn create_task(
179        &self,
180        task_id: TaskId,
181        plan: PlanFragment,
182        expr_context: ExprContext,
183    ) -> Result<Streaming<TaskInfoResponse>> {
184        Ok(self
185            .task_client
186            .clone()
187            .create_task(CreateTaskRequest {
188                task_id: Some(task_id),
189                plan: Some(plan),
190                tracing_context: TracingContext::from_current_span().to_protobuf(),
191                expr_context: Some(expr_context),
192            })
193            .await
194            .map_err(RpcError::from_compute_status)?
195            .into_inner())
196    }
197
198    pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
199        Ok(self
200            .task_client
201            .clone()
202            .execute(req)
203            .await
204            .map_err(RpcError::from_compute_status)?
205            .into_inner())
206    }
207
208    pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
209        Ok(self
210            .task_client
211            .clone()
212            .cancel_task(req)
213            .await
214            .map_err(RpcError::from_compute_status)?
215            .into_inner())
216    }
217
218    pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
219        Ok(self
220            .task_client
221            .clone()
222            .fast_insert(req)
223            .await
224            .map_err(RpcError::from_compute_status)?
225            .into_inner())
226    }
227
228    pub async fn stack_trace(&self, req: StackTraceRequest) -> Result<StackTraceResponse> {
229        Ok(self
230            .monitor_client
231            .clone()
232            .stack_trace(req)
233            .await
234            .map_err(RpcError::from_compute_status)?
235            .into_inner())
236    }
237
238    pub async fn get_streaming_stats(&self) -> Result<GetStreamingStatsResponse> {
239        Ok(self
240            .monitor_client
241            .clone()
242            .get_streaming_stats(GetStreamingStatsRequest::default())
243            .await
244            .map_err(RpcError::from_compute_status)?
245            .into_inner())
246    }
247
248    pub async fn profile(&self, sleep_s: u64) -> Result<ProfilingResponse> {
249        Ok(self
250            .monitor_client
251            .clone()
252            .profiling(ProfilingRequest { sleep_s })
253            .await
254            .map_err(RpcError::from_compute_status)?
255            .into_inner())
256    }
257
258    pub async fn heap_profile(&self, dir: String) -> Result<HeapProfilingResponse> {
259        Ok(self
260            .monitor_client
261            .clone()
262            .heap_profiling(HeapProfilingRequest { dir })
263            .await
264            .map_err(RpcError::from_compute_status)?
265            .into_inner())
266    }
267
268    pub async fn list_heap_profile(&self) -> Result<ListHeapProfilingResponse> {
269        Ok(self
270            .monitor_client
271            .clone()
272            .list_heap_profiling(ListHeapProfilingRequest {})
273            .await
274            .map_err(RpcError::from_compute_status)?
275            .into_inner())
276    }
277
278    pub async fn analyze_heap(&self, path: String) -> Result<AnalyzeHeapResponse> {
279        Ok(self
280            .monitor_client
281            .clone()
282            .analyze_heap(AnalyzeHeapRequest { path })
283            .await
284            .map_err(RpcError::from_compute_status)?
285            .into_inner())
286    }
287
288    pub async fn show_config(&self) -> Result<ShowConfigResponse> {
289        Ok(self
290            .config_client
291            .clone()
292            .show_config(ShowConfigRequest {})
293            .await
294            .map_err(RpcError::from_compute_status)?
295            .into_inner())
296    }
297
298    pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
299        Ok(self
300            .config_client
301            .clone()
302            .resize_cache(request)
303            .await
304            .map_err(RpcError::from_compute_status)?
305            .into_inner())
306    }
307}
308
309#[async_trait]
310impl RpcClient for ComputeClient {
311    async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
312        Self::new(host_addr, opts).await
313    }
314}
315
316pub type ComputeClientPool = RpcClientPool<ComputeClient>;
317pub type ComputeClientPoolRef = Arc<ComputeClientPool>; // TODO: no need for `Arc` since clone is cheap and shared