risingwave_batch/rpc/service/
task_service.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::util::tracing::TracingContext;
18use risingwave_pb::batch_plan::TaskOutputId;
19use risingwave_pb::task_service::task_service_server::TaskService;
20use risingwave_pb::task_service::{
21    CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
22    FastInsertResponse, GetDataResponse, TaskInfoResponse, fast_insert_response,
23};
24use risingwave_storage::dispatch_state_store;
25use thiserror_ext::AsReport;
26use tokio_stream::wrappers::ReceiverStream;
27use tonic::{Request, Response, Status};
28
29use crate::error::BatchError;
30use crate::executor::FastInsertExecutor;
31use crate::rpc::service::exchange::GrpcExchangeWriter;
32use crate::task::{
33    BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
34    TASK_STATUS_BUFFER_SIZE,
35};
36
37const LOCAL_EXECUTE_BUFFER_SIZE: usize = 64;
38
39#[derive(Clone)]
40pub struct BatchServiceImpl {
41    mgr: Arc<BatchManager>,
42    env: BatchEnvironment,
43}
44
45impl BatchServiceImpl {
46    pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
47        BatchServiceImpl { mgr, env }
48    }
49}
50
51pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
52pub type GetDataResponseResult = Result<GetDataResponse, Status>;
53
54#[async_trait::async_trait]
55impl TaskService for BatchServiceImpl {
56    type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
57    type ExecuteStream = ReceiverStream<GetDataResponseResult>;
58
59    #[cfg_attr(coverage, coverage(off))]
60    async fn create_task(
61        &self,
62        request: Request<CreateTaskRequest>,
63    ) -> Result<Response<Self::CreateTaskStream>, Status> {
64        let CreateTaskRequest {
65            task_id,
66            plan,
67            epoch,
68            tracing_context,
69            expr_context,
70        } = request.into_inner();
71
72        let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
73        let state_reporter = StateReporter::new_with_dist_sender(state_tx);
74        let res = self
75            .mgr
76            .fire_task(
77                task_id.as_ref().expect("no task id found"),
78                plan.expect("no plan found").clone(),
79                epoch.expect("no epoch found"),
80                ComputeNodeContext::create(self.env.clone()),
81                state_reporter,
82                TracingContext::from_protobuf(&tracing_context),
83                expr_context.expect("no expression context found"),
84            )
85            .await;
86        match res {
87            Ok(_) => Ok(Response::new(ReceiverStream::new(
88                // Create receiver stream from state receiver.
89                // The state receiver is init in `.async_execute()`.
90                // Will be used for receive task status update.
91                // Note: we introduce this hack cuz `.execute()` do not produce a status stream,
92                // but still share `.async_execute()` and `.try_execute()`.
93                state_rx,
94            ))),
95            Err(e) => {
96                error!(error = %e.as_report(), "failed to fire task");
97                Err(e.into())
98            }
99        }
100    }
101
102    #[cfg_attr(coverage, coverage(off))]
103    async fn cancel_task(
104        &self,
105        req: Request<CancelTaskRequest>,
106    ) -> Result<Response<CancelTaskResponse>, Status> {
107        let req = req.into_inner();
108        tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
109        self.mgr
110            .cancel_task(req.get_task_id().expect("no task id found"));
111        Ok(Response::new(CancelTaskResponse { status: None }))
112    }
113
114    #[cfg_attr(coverage, coverage(off))]
115    async fn execute(
116        &self,
117        req: Request<ExecuteRequest>,
118    ) -> Result<Response<Self::ExecuteStream>, Status> {
119        let req = req.into_inner();
120        let env = self.env.clone();
121        let mgr = self.mgr.clone();
122        BatchServiceImpl::get_execute_stream(env, mgr, req).await
123    }
124
125    #[cfg_attr(coverage, coverage(off))]
126    async fn fast_insert(
127        &self,
128        request: Request<FastInsertRequest>,
129    ) -> Result<Response<FastInsertResponse>, Status> {
130        let req = request.into_inner();
131        let res = self.do_fast_insert(req).await;
132        match res {
133            Ok(_) => Ok(Response::new(FastInsertResponse {
134                status: fast_insert_response::Status::Succeeded.into(),
135                error_message: "".to_owned(),
136            })),
137            Err(e) => match e {
138                BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
139                    status: fast_insert_response::Status::DmlFailed.into(),
140                    error_message: format!("{}", e.as_report()),
141                })),
142                _ => {
143                    error!(error = %e.as_report(), "failed to fast insert");
144                    Err(e.into())
145                }
146            },
147        }
148    }
149}
150
151impl BatchServiceImpl {
152    async fn get_execute_stream(
153        env: BatchEnvironment,
154        mgr: Arc<BatchManager>,
155        req: ExecuteRequest,
156    ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
157        let ExecuteRequest {
158            task_id,
159            plan,
160            epoch,
161            tracing_context,
162            expr_context,
163        } = req;
164
165        let task_id = task_id.expect("no task id found");
166        let plan = plan.expect("no plan found").clone();
167        let epoch = epoch.expect("no epoch found");
168        let tracing_context = TracingContext::from_protobuf(&tracing_context);
169        let expr_context = expr_context.expect("no expression context found");
170
171        let context = ComputeNodeContext::create(env.clone());
172        trace!(
173            "local execute request: plan:{:?} with task id:{:?}",
174            plan, task_id
175        );
176        let task = BatchTaskExecution::new(&task_id, plan, context, epoch, mgr.runtime())?;
177        let task = Arc::new(task);
178        let (tx, rx) = tokio::sync::mpsc::channel(LOCAL_EXECUTE_BUFFER_SIZE);
179        if let Err(e) = task
180            .clone()
181            .async_execute(None, tracing_context, expr_context)
182            .await
183        {
184            error!(
185                error = %e.as_report(),
186                ?task_id,
187                "failed to build executors and trigger execution"
188            );
189            return Err(e.into());
190        }
191
192        let pb_task_output_id = TaskOutputId {
193            task_id: Some(task_id.clone()),
194            // Since this is local execution path, the exchange would follow single distribution,
195            // therefore we would only have one data output.
196            output_id: 0,
197        };
198        let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
199            error!(
200                error = %e.as_report(),
201                ?task_id,
202                "failed to get task output in local execution mode",
203            );
204        })?;
205        let mut writer = GrpcExchangeWriter::new(tx.clone());
206        // Always spawn a task and do not block current function.
207        mgr.runtime().spawn(async move {
208            match output.take_data(&mut writer).await {
209                Ok(_) => Ok(()),
210                Err(e) => tx.send(Err(e.into())).await,
211            }
212        });
213        Ok(Response::new(ReceiverStream::new(rx)))
214    }
215
216    async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
217        let table_id = insert_req.table_id;
218        let wait_for_persistence = insert_req.wait_for_persistence;
219        let (executor, data_chunk) =
220            FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
221        let epoch = executor
222            .do_execute(data_chunk, wait_for_persistence)
223            .await?;
224        if wait_for_persistence {
225            dispatch_state_store!(self.env.state_store(), store, {
226                use risingwave_common::catalog::TableId;
227                use risingwave_hummock_sdk::HummockReadEpoch;
228                use risingwave_storage::StateStore;
229                use risingwave_storage::store::TryWaitEpochOptions;
230
231                store
232                    .try_wait_epoch(
233                        HummockReadEpoch::Committed(epoch.0),
234                        TryWaitEpochOptions {
235                            table_id: TableId::new(table_id),
236                        },
237                    )
238                    .await
239                    .map_err(BatchError::from)?;
240            });
241        }
242        Ok(())
243    }
244}