risingwave_rpc_client/
compute_client.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use futures::StreamExt;
20use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, RpcClientConfig, STREAM_WINDOW_SIZE};
21use risingwave_common::id::{ActorId, FragmentId};
22use risingwave_common::monitor::{EndpointExt, TcpConfig};
23use risingwave_common::util::addr::HostAddr;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PlanFragment, TaskId, TaskOutputId};
26use risingwave_pb::compute::config_service_client::ConfigServiceClient;
27use risingwave_pb::compute::{
28    ResizeCacheRequest, ResizeCacheResponse, ShowConfigRequest, ShowConfigResponse,
29};
30use risingwave_pb::id::PartialGraphId;
31use risingwave_pb::plan_common::ExprContext;
32use risingwave_pb::task_service::batch_exchange_service_client::BatchExchangeServiceClient;
33use risingwave_pb::task_service::stream_exchange_service_client::StreamExchangeServiceClient;
34use risingwave_pb::task_service::task_service_client::TaskServiceClient;
35use risingwave_pb::task_service::{
36    CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
37    FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
38    PbPermits, TaskInfoResponse, permits,
39};
40use tokio::sync::mpsc;
41use tokio_stream::wrappers::UnboundedReceiverStream;
42use tonic::Streaming;
43use tonic::transport::{Channel, Endpoint};
44
45use crate::error::{Result, RpcError};
46use crate::{RpcClient, RpcClientPool};
47
48// TODO: this client has too many roles, e.g.
49// - batch MPP task query execution
50// - batch exchange
51// - streaming exchange
52// We should consider splitting them into different clients.
53#[derive(Clone)]
54pub struct ComputeClient {
55    pub batch_exchange_client: BatchExchangeServiceClient<Channel>,
56    pub stream_exchange_client: StreamExchangeServiceClient<Channel>,
57    pub task_client: TaskServiceClient<Channel>,
58    pub config_client: ConfigServiceClient<Channel>,
59    pub addr: HostAddr,
60}
61
62impl ComputeClient {
63    pub async fn new(addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
64        let channel = Endpoint::from_shared(format!("http://{}", &addr))?
65            .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
66            .initial_stream_window_size(STREAM_WINDOW_SIZE)
67            .connect_timeout(Duration::from_secs(opts.connect_timeout_secs))
68            .monitored_connect(
69                "grpc-compute-client",
70                TcpConfig {
71                    tcp_nodelay: true,
72                    ..Default::default()
73                },
74            )
75            .await?;
76        Ok(Self::with_channel(addr, channel))
77    }
78
79    pub fn with_channel(addr: HostAddr, channel: Channel) -> Self {
80        let batch_exchange_client =
81            BatchExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
82        let stream_exchange_client =
83            StreamExchangeServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
84        let task_client =
85            TaskServiceClient::new(channel.clone()).max_decoding_message_size(usize::MAX);
86        let config_client = ConfigServiceClient::new(channel);
87        Self {
88            batch_exchange_client,
89            stream_exchange_client,
90            task_client,
91            config_client,
92            addr,
93        }
94    }
95
96    pub async fn get_data(&self, output_id: TaskOutputId) -> Result<Streaming<GetDataResponse>> {
97        Ok(self
98            .batch_exchange_client
99            .clone()
100            .get_data(GetDataRequest {
101                task_output_id: Some(output_id),
102            })
103            .await
104            .map_err(RpcError::from_compute_status)?
105            .into_inner())
106    }
107
108    pub async fn get_stream(
109        &self,
110        up_actor_id: ActorId,
111        down_actor_id: ActorId,
112        up_fragment_id: FragmentId,
113        down_fragment_id: FragmentId,
114        up_partial_graph_id: PartialGraphId,
115        term_id: String,
116    ) -> Result<(
117        Streaming<GetStreamResponse>,
118        mpsc::UnboundedSender<permits::Value>,
119    )> {
120        use risingwave_pb::task_service::get_stream_request::*;
121
122        // Create channel used for the downstream to add back the permits to the upstream.
123        let (permits_tx, permits_rx) = mpsc::unbounded_channel();
124
125        let request_stream = futures::stream::once(futures::future::ready(
126            // `Get` as the first request.
127            GetStreamRequest {
128                value: Some(Value::Get(Get {
129                    up_actor_id,
130                    down_actor_id,
131                    up_fragment_id,
132                    down_fragment_id,
133                    up_partial_graph_id,
134                    term_id,
135                })),
136            },
137        ))
138        .chain(
139            // `AddPermits` as the followings.
140            UnboundedReceiverStream::new(permits_rx).map(|permits| GetStreamRequest {
141                value: Some(Value::AddPermits(PbPermits {
142                    value: Some(permits),
143                })),
144            }),
145        );
146
147        let response_stream = self
148            .stream_exchange_client
149            .clone()
150            .get_stream(request_stream)
151            .await
152            .inspect_err(|_| {
153                tracing::error!(
154                    "failed to create stream from remote_input {} from actor {} to actor {}",
155                    self.addr,
156                    up_actor_id,
157                    down_actor_id
158                )
159            })
160            .map_err(RpcError::from_compute_status)?
161            .into_inner();
162
163        Ok((response_stream, permits_tx))
164    }
165
166    pub async fn create_task(
167        &self,
168        task_id: TaskId,
169        plan: PlanFragment,
170        expr_context: ExprContext,
171    ) -> Result<Streaming<TaskInfoResponse>> {
172        Ok(self
173            .task_client
174            .clone()
175            .create_task(CreateTaskRequest {
176                task_id: Some(task_id),
177                plan: Some(plan),
178                tracing_context: TracingContext::from_current_span().to_protobuf(),
179                expr_context: Some(expr_context),
180            })
181            .await
182            .map_err(RpcError::from_compute_status)?
183            .into_inner())
184    }
185
186    pub async fn execute(&self, req: ExecuteRequest) -> Result<Streaming<GetDataResponse>> {
187        Ok(self
188            .task_client
189            .clone()
190            .execute(req)
191            .await
192            .map_err(RpcError::from_compute_status)?
193            .into_inner())
194    }
195
196    pub async fn cancel(&self, req: CancelTaskRequest) -> Result<CancelTaskResponse> {
197        Ok(self
198            .task_client
199            .clone()
200            .cancel_task(req)
201            .await
202            .map_err(RpcError::from_compute_status)?
203            .into_inner())
204    }
205
206    pub async fn fast_insert(&self, req: FastInsertRequest) -> Result<FastInsertResponse> {
207        Ok(self
208            .task_client
209            .clone()
210            .fast_insert(req)
211            .await
212            .map_err(RpcError::from_compute_status)?
213            .into_inner())
214    }
215
216    pub async fn show_config(&self) -> Result<ShowConfigResponse> {
217        Ok(self
218            .config_client
219            .clone()
220            .show_config(ShowConfigRequest {})
221            .await
222            .map_err(RpcError::from_compute_status)?
223            .into_inner())
224    }
225
226    pub async fn resize_cache(&self, request: ResizeCacheRequest) -> Result<ResizeCacheResponse> {
227        Ok(self
228            .config_client
229            .clone()
230            .resize_cache(request)
231            .await
232            .map_err(RpcError::from_compute_status)?
233            .into_inner())
234    }
235}
236
237#[async_trait]
238impl RpcClient for ComputeClient {
239    async fn new_client(host_addr: HostAddr, opts: &RpcClientConfig) -> Result<Self> {
240        Self::new(host_addr, opts).await
241    }
242}
243
244pub type ComputeClientPool = RpcClientPool<ComputeClient>;
245pub type ComputeClientPoolRef = Arc<ComputeClientPool>; // TODO: no need for `Arc` since clone is cheap and shared