risingwave_rpc_client/
compute_client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::catalog::DatabaseId;
21use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
22use risingwave_common::monitor::{EndpointExt, TcpConfig};
23use risingwave_common::util::addr::HostAddr;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
26use risingwave_pb::common::BatchQueryEpoch;
27use risingwave_pb::compute::config_service_client::ConfigServiceClient;
28use risingwave_pb::compute::{
29    ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
30};
31use risingwave_pb::monitor_service::monitor_service_client::MonitorServiceClient;
32use risingwave_pb::monitor_service::{
33    AnalyzeHeapRequest, AnalyzeHeapResponse, GetStreamingStatsRequest, GetStreamingStatsResponse,
34    HeapProfilingRequest, HeapProfilingResponse, ListHeapProfilingRequest,
35    ListHeapProfilingResponse, ProfilingRequest, ProfilingResponse, StackTraceRequest,
36    StackTraceResponse,
37};
38use risingwave_pb::plan_common::ExprContext;
39use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient;
40use risingwave_pb::task_service::task_service_client::TaskServiceClient;
41use risingwave_pb::task_service::{
42    CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
43    FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
44    PbPermits, TaskInfoResponse, permits,
45};
46use tokio::sync::mpsc;
47use tokio_stream::wrappers::UnboundedReceiverStream;
48use tonic::Streaming;
49use tonic::transport::{Channel, Endpoint};
50
51use crate::error::{Result, RpcError};
52use crate::{RpcClient, RpcClientPool};
53
54// TODO: this client has too many roles, e.g.
55// - batch MPP task query execution
56// - batch exchange
57// - streaming exchange
58// - general services specific to compute node, like monitoring, profiling, debugging, etc.
59// We should consider splitting them into different clients.
60#[derive(Clone)]
61pub struct ComputeClient {
62    pub exchange_client: ExchangeServiceClient<Channel>,
63    pub task_client: TaskServiceClient<Channel>,
64    pub monitor_client: MonitorServiceClient<Channel>,
65    pub config_client: ConfigServiceClient<Channel>,
66    pub addr: HostAddr,
67}
68
69impl ComputeClient {
70    pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
71        let channel = Endpoint::from_shared(format!("http://{}", &addr))?
72            .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
73            .initial_stream_window_size(STREAM_WINDOW_SIZE)
74            .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
75            .monitored_connect(
76                "grpc-compute-client",
77                TcpConfig {
78                    tcp_nodelay: true,
79                    ..Default::default()
80                },
81            )
82            .await?;
83        Ok(Self::with_channel(addr, channel))
84    }
85
86    pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
87        let exchange_client =
88            ExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
89        let task_client =
90            TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
91        let monitor_client =
92            MonitorServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
93        let config_client = ConfigServiceClient::new(channel);
94        Self {
95            exchange_client,
96            task_client,
97            monitor_client,
98            config_client,
99            addr,
100        }
101    }
102
103    pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
104        Ok(self
105            .exchange_client
106            .to_owned()
107            .get_data(GetDataRequest {
108                task_output_id: Some(output_id),
109            })
110            .await
111            .map_err(RpcError::from_compute_status)?
112            .into_inner())
113    }
114
115    pub async fn get_stream(
116        &self,
117        up_actor_id: u32,
118        down_actor_id: u32,
119        up_fragment_id: u32,
120        down_fragment_id: u32,
121        database_id: DatabaseId,
122        term_id: String,
123    ) -> Result<(
124        Streaming<GetStreamResponse>,
125        mpsc::UnboundedSender<permits::Value>,
126    )> {
127        use risingwave_pb::task_service::get_stream_request::*;
128
129        // Create channel used for the downstream to add back the permits to the upstream.
130        let (permits_tx, permits_rx) = mpsc::unbounded_channel();
131
132        let request_stream = futures::stream::once(futures::future::ready(
133            // `Get` as the first request.
134            GetStreamRequest {
135                value: Some(Value::Get(Get {
136                    up_actor_id,
137                    down_actor_id,
138                    up_fragment_id,
139                    down_fragment_id,
140                    database_id: database_id.database_id,
141                    term_id,
142                })),
143            },
144        ))
145        .chain(
146            // `AddPermits` as the followings.
147            UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
148                value: Some(Value::AddPermits(PbPermits {
149                    value: Some(permits),
150                })),
151            }),
152        );
153
154        let response_stream = self
155            .exchange_client
156            .to_owned()
157            .get_stream(request_stream)
158            .await
159            .inspect_err(|_| {
160                tracing::error!(
161                    "failed to create stream from remote_input {} from actor {} to actor {}",
162                    self.addr,
163                    up_actor_id,
164                    down_actor_id
165                )
166            })
167            .map_err(RpcError::from_compute_status)?
168            .into_inner();
169
170        Ok((response_stream, permits_tx))
171    }
172
173    pub async fn create_task(
174        &self,
175        task_id: TaskId,
176        plan: PlanFragment,
177        epoch: BatchQueryEpoch,
178        expr_context: ExprContext,
179    ) -> Result<Streaming<TaskInfoResponse>> {
180        Ok(self
181            .task_client
182            .to_owned()
183            .create_task(CreateTaskRequest {
184                task_id: Some(task_id),
185                plan: Some(plan),
186                epoch: Some(epoch),
187                tracing_context: TracingContext::from_current_span().to_protobuf(),
188                expr_context: Some(expr_context),
189            })
190            .await
191            .map_err(RpcError::from_compute_status)?
192            .into_inner())
193    }
194
195    pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
196        Ok(self
197            .task_client
198            .to_owned()
199            .execute(req)
200            .await
201            .map_err(RpcError::from_compute_status)?
202            .into_inner())
203    }
204
205    pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
206        Ok(self
207            .task_client
208            .to_owned()
209            .cancel_task(req)
210            .await
211            .map_err(RpcError::from_compute_status)?
212            .into_inner())
213    }
214
215    pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
216        Ok(self
217            .task_client
218            .to_owned()
219            .fast_insert(req)
220            .await
221            .map_err(RpcError::from_compute_status)?
222            .into_inner())
223    }
224
225    pub async fn stack_trace(&self) -> Result<StackTraceResponse> {
226        Ok(self
227            .monitor_client
228            .to_owned()
229            .stack_trace(StackTraceRequest::default())
230            .await
231            .map_err(RpcError::from_compute_status)?
232            .into_inner())
233    }
234
235    pub async fn get_streaming_stats(&self) -> Result<GetStreamingStatsResponse> {
236        Ok(self
237            .monitor_client
238            .to_owned()
239            .get_streaming_stats(GetStreamingStatsRequest::default())
240            .await
241            .map_err(RpcError::from_compute_status)?
242            .into_inner())
243    }
244
245    pub async fn profile(&self, sleep_s: u64) -> Result<ProfilingResponse> {
246        Ok(self
247            .monitor_client
248            .to_owned()
249            .profiling(ProfilingRequest { sleep_s })
250            .await
251            .map_err(RpcError::from_compute_status)?
252            .into_inner())
253    }
254
255    pub async fn heap_profile(&self, dir: String) -> Result<HeapProfilingResponse> {
256        Ok(self
257            .monitor_client
258            .to_owned()
259            .heap_profiling(HeapProfilingRequest { dir })
260            .await
261            .map_err(RpcError::from_compute_status)?
262            .into_inner())
263    }
264
265    pub async fn list_heap_profile(&self) -> Result<ListHeapProfilingResponse> {
266        Ok(self
267            .monitor_client
268            .to_owned()
269            .list_heap_profiling(ListHeapProfilingRequest {})
270            .await
271            .map_err(RpcError::from_compute_status)?
272            .into_inner())
273    }
274
275    pub async fn analyze_heap(&self, path: String) -> Result<AnalyzeHeapResponse> {
276        Ok(self
277            .monitor_client
278            .to_owned()
279            .analyze_heap(AnalyzeHeapRequest { path })
280            .await
281            .map_err(RpcError::from_compute_status)?
282            .into_inner())
283    }
284
285    pub async fn show_config(&self) -> Result<ShowConfigResponse> {
286        Ok(self
287            .config_client
288            .to_owned()
289            .show_config(ShowConfigRequest {})
290            .await
291            .map_err(RpcError::from_compute_status)?
292            .into_inner())
293    }
294
295    pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
296        Ok(self
297            .config_client
298            .to_owned()
299            .resize_cache(request)
300            .await
301            .map_err(RpcError::from_compute_status)?
302            .into_inner())
303    }
304}
305
306#[async_trait]
307impl RpcClient for ComputeClient {
308    async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
309        Self::new(host_addr, opts).await
310    }
311}
312
313pub type ComputeClientPool = RpcClientPool<ComputeClient>;
314pub type ComputeClientPoolRef = Arc<ComputeClientPool>; // TODO: no need for `Arc` since clone is cheap and shared