risingwave_batch/rpc/service/
task_service.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::util::tracing::TracingContext;
18use risingwave_pb::batch_plan::TaskOutputId;
19use risingwave_pb::task_service::task_service_server::TaskService;
20use risingwave_pb::task_service::{
21    CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
22    FastInsertResponse, GetDataResponse, TaskInfoResponse, fast_insert_response,
23};
24use thiserror_ext::AsReport;
25use tokio_stream::wrappers::ReceiverStream;
26use tonic::{Request, Response, Status};
27
28use crate::error::BatchError;
29use crate::executor::FastInsertExecutor;
30use crate::rpc::service::exchange::GrpcExchangeWriter;
31use crate::task::{
32    BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
33    TASK_STATUS_BUFFER_SIZE,
34};
35
36#[derive(Clone)]
37pub struct BatchServiceImpl {
38    mgr: Arc<BatchManager>,
39    env: BatchEnvironment,
40}
41
42impl BatchServiceImpl {
43    pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
44        BatchServiceImpl { mgr, env }
45    }
46}
47
48pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
49pub type GetDataResponseResult = Result<GetDataResponse, Status>;
50
51#[async_trait::async_trait]
52impl TaskService for BatchServiceImpl {
53    type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
54    type ExecuteStream = ReceiverStream<GetDataResponseResult>;
55
56    async fn create_task(
57        &self,
58        request: Request<CreateTaskRequest>,
59    ) -> Result<Response<Self::CreateTaskStream>, Status> {
60        let CreateTaskRequest {
61            task_id,
62            plan,
63            tracing_context,
64            expr_context,
65        } = request.into_inner();
66
67        let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
68        let state_reporter = StateReporter::new_with_dist_sender(state_tx);
69        let res = self
70            .mgr
71            .fire_task(
72                task_id.as_ref().expect("no task id found"),
73                plan.expect("no plan found").clone(),
74                ComputeNodeContext::create(self.env.clone()),
75                state_reporter,
76                TracingContext::from_protobuf(&tracing_context),
77                expr_context.expect("no expression context found"),
78            )
79            .await;
80        match res {
81            Ok(_) => Ok(Response::new(ReceiverStream::new(
82                // Create receiver stream from state receiver.
83                // The state receiver is init in `.async_execute()`.
84                // Will be used for receive task status update.
85                // Note: we introduce this hack cuz `.execute()` do not produce a status stream,
86                // but still share `.async_execute()` and `.try_execute()`.
87                state_rx,
88            ))),
89            Err(e) => {
90                error!(error = %e.as_report(), "failed to fire task");
91                Err(e.into())
92            }
93        }
94    }
95
96    async fn cancel_task(
97        &self,
98        req: Request<CancelTaskRequest>,
99    ) -> Result<Response<CancelTaskResponse>, Status> {
100        let req = req.into_inner();
101        tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
102        self.mgr
103            .cancel_task(req.get_task_id().expect("no task id found"));
104        Ok(Response::new(CancelTaskResponse { status: None }))
105    }
106
107    async fn execute(
108        &self,
109        req: Request<ExecuteRequest>,
110    ) -> Result<Response<Self::ExecuteStream>, Status> {
111        let req = req.into_inner();
112        let env = self.env.clone();
113        let mgr = self.mgr.clone();
114        BatchServiceImpl::get_execute_stream(env, mgr, req).await
115    }
116
117    async fn fast_insert(
118        &self,
119        request: Request<FastInsertRequest>,
120    ) -> Result<Response<FastInsertResponse>, Status> {
121        let req = request.into_inner();
122        let res = self.do_fast_insert(req).await;
123        match res {
124            Ok(_) => Ok(Response::new(FastInsertResponse {
125                status: fast_insert_response::Status::Succeeded.into(),
126                error_message: "".to_owned(),
127            })),
128            Err(e) => match e {
129                BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
130                    status: fast_insert_response::Status::DmlFailed.into(),
131                    error_message: format!("{}", e.as_report()),
132                })),
133                _ => {
134                    error!(error = %e.as_report(), "failed to fast insert");
135                    Err(e.into())
136                }
137            },
138        }
139    }
140}
141
142impl BatchServiceImpl {
143    async fn get_execute_stream(
144        env: BatchEnvironment,
145        mgr: Arc<BatchManager>,
146        req: ExecuteRequest,
147    ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
148        let ExecuteRequest {
149            task_id,
150            plan,
151            tracing_context,
152            expr_context,
153        } = req;
154
155        let task_id = task_id.expect("no task id found");
156        let plan = plan.expect("no plan found").clone();
157        let tracing_context = TracingContext::from_protobuf(&tracing_context);
158        let expr_context = expr_context.expect("no expression context found");
159
160        let context = ComputeNodeContext::create(env.clone());
161        trace!(
162            "local execute request: plan:{:?} with task id:{:?}",
163            plan, task_id
164        );
165        let task = BatchTaskExecution::new(&task_id, plan, context, mgr.runtime())?;
166        let task = Arc::new(task);
167        let (tx, rx) = tokio::sync::mpsc::channel(mgr.config().developer.local_execute_buffer_size);
168        if let Err(e) = task
169            .clone()
170            .async_execute(None, tracing_context, expr_context)
171            .await
172        {
173            error!(
174                error = %e.as_report(),
175                ?task_id,
176                "failed to build executors and trigger execution"
177            );
178            return Err(e.into());
179        }
180
181        let pb_task_output_id = TaskOutputId {
182            task_id: Some(task_id.clone()),
183            // Since this is local execution path, the exchange would follow single distribution,
184            // therefore we would only have one data output.
185            output_id: 0,
186        };
187        let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
188            error!(
189                error = %e.as_report(),
190                ?task_id,
191                "failed to get task output in local execution mode",
192            );
193        })?;
194        let mut writer = GrpcExchangeWriter::new(tx.clone());
195        // Always spawn a task and do not block current function.
196        mgr.runtime().spawn(async move {
197            match output.take_data(&mut writer).await {
198                Ok(_) => Ok(()),
199                Err(e) => tx.send(Err(e.into())).await,
200            }
201        });
202        Ok(Response::new(ReceiverStream::new(rx)))
203    }
204
205    async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
206        let wait_for_persistence = insert_req.wait_for_persistence;
207        let (executor, data_chunk) =
208            FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
209        executor
210            .do_execute(data_chunk, wait_for_persistence)
211            .await?;
212        Ok(())
213    }
214}