risingwave_rpc_client/
compute_client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::catalog::DatabaseId;
21use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
22use risingwave_common::monitor::{EndpointExt, TcpConfig};
23use risingwave_common::util::addr::HostAddr;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
26use risingwave_pb::compute::config_service_client::ConfigServiceClient;
27use risingwave_pb::compute::{
28    ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
29};
30use risingwave_pb::monitor_service::monitor_service_client::MonitorServiceClient;
31use risingwave_pb::monitor_service::{
32    AnalyzeHeapRequest, AnalyzeHeapResponse, GetStreamingStatsRequest, GetStreamingStatsResponse,
33    HeapProfilingRequest, HeapProfilingResponse, ListHeapProfilingRequest,
34    ListHeapProfilingResponse, ProfilingRequest, ProfilingResponse, StackTraceRequest,
35    StackTraceResponse,
36};
37use risingwave_pb::plan_common::ExprContext;
38use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient;
39use risingwave_pb::task_service::task_service_client::TaskServiceClient;
40use risingwave_pb::task_service::{
41    CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
42    FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
43    PbPermits, TaskInfoResponse, permits,
44};
45use tokio::sync::mpsc;
46use tokio_stream::wrappers::UnboundedReceiverStream;
47use tonic::Streaming;
48use tonic::transport::{Channel, Endpoint};
49
50use crate::error::{Result, RpcError};
51use crate::{RpcClient, RpcClientPool};
52
53// TODO: this client has too many roles, e.g.
54// - batch MPP task query execution
55// - batch exchange
56// - streaming exchange
57// - general services specific to compute node, like monitoring, profiling, debugging, etc.
58// We should consider splitting them into different clients.
59#[derive(Clone)]
60pub struct ComputeClient {
61    pub exchange_client: ExchangeServiceClient<Channel>,
62    pub task_client: TaskServiceClient<Channel>,
63    pub monitor_client: MonitorServiceClient<Channel>,
64    pub config_client: ConfigServiceClient<Channel>,
65    pub addr: HostAddr,
66}
67
68impl ComputeClient {
69    pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
70        let channel = Endpoint::from_shared(format!("http://{}", &addr))?
71            .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
72            .initial_stream_window_size(STREAM_WINDOW_SIZE)
73            .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
74            .monitored_connect(
75                "grpc-compute-client",
76                TcpConfig {
77                    tcp_nodelay: true,
78                    ..Default::default()
79                },
80            )
81            .await?;
82        Ok(Self::with_channel(addr, channel))
83    }
84
85    pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
86        let exchange_client =
87            ExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
88        let task_client =
89            TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
90        let monitor_client =
91            MonitorServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
92        let config_client = ConfigServiceClient::new(channel);
93        Self {
94            exchange_client,
95            task_client,
96            monitor_client,
97            config_client,
98            addr,
99        }
100    }
101
102    pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
103        Ok(self
104            .exchange_client
105            .clone()
106            .get_data(GetDataRequest {
107                task_output_id: Some(output_id),
108            })
109            .await
110            .map_err(RpcError::from_compute_status)?
111            .into_inner())
112    }
113
114    pub async fn get_stream(
115        &self,
116        up_actor_id: u32,
117        down_actor_id: u32,
118        up_fragment_id: u32,
119        down_fragment_id: u32,
120        database_id: DatabaseId,
121        term_id: String,
122    ) -> Result<(
123        Streaming<GetStreamResponse>,
124        mpsc::UnboundedSender<permits::Value>,
125    )> {
126        use risingwave_pb::task_service::get_stream_request::*;
127
128        // Create channel used for the downstream to add back the permits to the upstream.
129        let (permits_tx, permits_rx) = mpsc::unbounded_channel();
130
131        let request_stream = futures::stream::once(futures::future::ready(
132            // `Get` as the first request.
133            GetStreamRequest {
134                value: Some(Value::Get(Get {
135                    up_actor_id,
136                    down_actor_id,
137                    up_fragment_id,
138                    down_fragment_id,
139                    database_id: database_id.database_id,
140                    term_id,
141                })),
142            },
143        ))
144        .chain(
145            // `AddPermits` as the followings.
146            UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
147                value: Some(Value::AddPermits(PbPermits {
148                    value: Some(permits),
149                })),
150            }),
151        );
152
153        let response_stream = self
154            .exchange_client
155            .clone()
156            .get_stream(request_stream)
157            .await
158            .inspect_err(|_| {
159                tracing::error!(
160                    "failed to create stream from remote_input {} from actor {} to actor {}",
161                    self.addr,
162                    up_actor_id,
163                    down_actor_id
164                )
165            })
166            .map_err(RpcError::from_compute_status)?
167            .into_inner();
168
169        Ok((response_stream, permits_tx))
170    }
171
172    pub async fn create_task(
173        &self,
174        task_id: TaskId,
175        plan: PlanFragment,
176        expr_context: ExprContext,
177    ) -> Result<Streaming<TaskInfoResponse>> {
178        Ok(self
179            .task_client
180            .clone()
181            .create_task(CreateTaskRequest {
182                task_id: Some(task_id),
183                plan: Some(plan),
184                tracing_context: TracingContext::from_current_span().to_protobuf(),
185                expr_context: Some(expr_context),
186            })
187            .await
188            .map_err(RpcError::from_compute_status)?
189            .into_inner())
190    }
191
192    pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
193        Ok(self
194            .task_client
195            .clone()
196            .execute(req)
197            .await
198            .map_err(RpcError::from_compute_status)?
199            .into_inner())
200    }
201
202    pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
203        Ok(self
204            .task_client
205            .clone()
206            .cancel_task(req)
207            .await
208            .map_err(RpcError::from_compute_status)?
209            .into_inner())
210    }
211
212    pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
213        Ok(self
214            .task_client
215            .clone()
216            .fast_insert(req)
217            .await
218            .map_err(RpcError::from_compute_status)?
219            .into_inner())
220    }
221
222    pub async fn stack_trace(&self, req: StackTraceRequest) -> Result<StackTraceResponse> {
223        Ok(self
224            .monitor_client
225            .clone()
226            .stack_trace(req)
227            .await
228            .map_err(RpcError::from_compute_status)?
229            .into_inner())
230    }
231
232    pub async fn get_streaming_stats(&self) -> Result<GetStreamingStatsResponse> {
233        Ok(self
234            .monitor_client
235            .clone()
236            .get_streaming_stats(GetStreamingStatsRequest::default())
237            .await
238            .map_err(RpcError::from_compute_status)?
239            .into_inner())
240    }
241
242    pub async fn profile(&self, sleep_s: u64) -> Result<ProfilingResponse> {
243        Ok(self
244            .monitor_client
245            .clone()
246            .profiling(ProfilingRequest { sleep_s })
247            .await
248            .map_err(RpcError::from_compute_status)?
249            .into_inner())
250    }
251
252    pub async fn heap_profile(&self, dir: String) -> Result<HeapProfilingResponse> {
253        Ok(self
254            .monitor_client
255            .clone()
256            .heap_profiling(HeapProfilingRequest { dir })
257            .await
258            .map_err(RpcError::from_compute_status)?
259            .into_inner())
260    }
261
262    pub async fn list_heap_profile(&self) -> Result<ListHeapProfilingResponse> {
263        Ok(self
264            .monitor_client
265            .clone()
266            .list_heap_profiling(ListHeapProfilingRequest {})
267            .await
268            .map_err(RpcError::from_compute_status)?
269            .into_inner())
270    }
271
272    pub async fn analyze_heap(&self, path: String) -> Result<AnalyzeHeapResponse> {
273        Ok(self
274            .monitor_client
275            .clone()
276            .analyze_heap(AnalyzeHeapRequest { path })
277            .await
278            .map_err(RpcError::from_compute_status)?
279            .into_inner())
280    }
281
282    pub async fn show_config(&self) -> Result<ShowConfigResponse> {
283        Ok(self
284            .config_client
285            .clone()
286            .show_config(ShowConfigRequest {})
287            .await
288            .map_err(RpcError::from_compute_status)?
289            .into_inner())
290    }
291
292    pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
293        Ok(self
294            .config_client
295            .clone()
296            .resize_cache(request)
297            .await
298            .map_err(RpcError::from_compute_status)?
299            .into_inner())
300    }
301}
302
303#[async_trait]
304impl RpcClient for ComputeClient {
305    async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
306        Self::new(host_addr, opts).await
307    }
308}
309
310pub type ComputeClientPool = RpcClientPool<ComputeClient>;
311pub type ComputeClientPoolRef = Arc<ComputeClientPool>; // TODO: no need for `Arc` since clone is cheap and shared