risingwave_batch/rpc/service/
task_service.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::util::tracing::TracingContext;
18use risingwave_pb::batch_plan::TaskOutputId;
19use risingwave_pb::task_service::task_service_server::TaskService;
20use risingwave_pb::task_service::{
21    CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
22    FastInsertResponse, GetDataResponse, TaskInfoResponse, fast_insert_response,
23};
24use risingwave_storage::dispatch_state_store;
25use thiserror_ext::AsReport;
26use tokio_stream::wrappers::ReceiverStream;
27use tonic::{Request, Response, Status};
28
29use crate::error::BatchError;
30use crate::executor::FastInsertExecutor;
31use crate::rpc::service::exchange::GrpcExchangeWriter;
32use crate::task::{
33    BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
34    TASK_STATUS_BUFFER_SIZE,
35};
36
37#[derive(Clone)]
38pub struct BatchServiceImpl {
39    mgr: Arc<BatchManager>,
40    env: BatchEnvironment,
41}
42
43impl BatchServiceImpl {
44    pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
45        BatchServiceImpl { mgr, env }
46    }
47}
48
49pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
50pub type GetDataResponseResult = Result<GetDataResponse, Status>;
51
52#[async_trait::async_trait]
53impl TaskService for BatchServiceImpl {
54    type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
55    type ExecuteStream = ReceiverStream<GetDataResponseResult>;
56
57    async fn create_task(
58        &self,
59        request: Request<CreateTaskRequest>,
60    ) -> Result<Response<Self::CreateTaskStream>, Status> {
61        let CreateTaskRequest {
62            task_id,
63            plan,
64            epoch,
65            tracing_context,
66            expr_context,
67        } = request.into_inner();
68
69        let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
70        let state_reporter = StateReporter::new_with_dist_sender(state_tx);
71        let res = self
72            .mgr
73            .fire_task(
74                task_id.as_ref().expect("no task id found"),
75                plan.expect("no plan found").clone(),
76                epoch.expect("no epoch found"),
77                ComputeNodeContext::create(self.env.clone()),
78                state_reporter,
79                TracingContext::from_protobuf(&tracing_context),
80                expr_context.expect("no expression context found"),
81            )
82            .await;
83        match res {
84            Ok(_) => Ok(Response::new(ReceiverStream::new(
85                // Create receiver stream from state receiver.
86                // The state receiver is init in `.async_execute()`.
87                // Will be used for receive task status update.
88                // Note: we introduce this hack cuz `.execute()` do not produce a status stream,
89                // but still share `.async_execute()` and `.try_execute()`.
90                state_rx,
91            ))),
92            Err(e) => {
93                error!(error = %e.as_report(), "failed to fire task");
94                Err(e.into())
95            }
96        }
97    }
98
99    async fn cancel_task(
100        &self,
101        req: Request<CancelTaskRequest>,
102    ) -> Result<Response<CancelTaskResponse>, Status> {
103        let req = req.into_inner();
104        tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
105        self.mgr
106            .cancel_task(req.get_task_id().expect("no task id found"));
107        Ok(Response::new(CancelTaskResponse { status: None }))
108    }
109
110    async fn execute(
111        &self,
112        req: Request<ExecuteRequest>,
113    ) -> Result<Response<Self::ExecuteStream>, Status> {
114        let req = req.into_inner();
115        let env = self.env.clone();
116        let mgr = self.mgr.clone();
117        BatchServiceImpl::get_execute_stream(env, mgr, req).await
118    }
119
120    async fn fast_insert(
121        &self,
122        request: Request<FastInsertRequest>,
123    ) -> Result<Response<FastInsertResponse>, Status> {
124        let req = request.into_inner();
125        let res = self.do_fast_insert(req).await;
126        match res {
127            Ok(_) => Ok(Response::new(FastInsertResponse {
128                status: fast_insert_response::Status::Succeeded.into(),
129                error_message: "".to_owned(),
130            })),
131            Err(e) => match e {
132                BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
133                    status: fast_insert_response::Status::DmlFailed.into(),
134                    error_message: format!("{}", e.as_report()),
135                })),
136                _ => {
137                    error!(error = %e.as_report(), "failed to fast insert");
138                    Err(e.into())
139                }
140            },
141        }
142    }
143}
144
145impl BatchServiceImpl {
146    async fn get_execute_stream(
147        env: BatchEnvironment,
148        mgr: Arc<BatchManager>,
149        req: ExecuteRequest,
150    ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
151        let ExecuteRequest {
152            task_id,
153            plan,
154            epoch,
155            tracing_context,
156            expr_context,
157        } = req;
158
159        let task_id = task_id.expect("no task id found");
160        let plan = plan.expect("no plan found").clone();
161        let epoch = epoch.expect("no epoch found");
162        let tracing_context = TracingContext::from_protobuf(&tracing_context);
163        let expr_context = expr_context.expect("no expression context found");
164
165        let context = ComputeNodeContext::create(env.clone());
166        trace!(
167            "local execute request: plan:{:?} with task id:{:?}",
168            plan, task_id
169        );
170        let task = BatchTaskExecution::new(&task_id, plan, context, epoch, mgr.runtime())?;
171        let task = Arc::new(task);
172        let (tx, rx) = tokio::sync::mpsc::channel(mgr.config().developer.local_execute_buffer_size);
173        if let Err(e) = task
174            .clone()
175            .async_execute(None, tracing_context, expr_context)
176            .await
177        {
178            error!(
179                error = %e.as_report(),
180                ?task_id,
181                "failed to build executors and trigger execution"
182            );
183            return Err(e.into());
184        }
185
186        let pb_task_output_id = TaskOutputId {
187            task_id: Some(task_id.clone()),
188            // Since this is local execution path, the exchange would follow single distribution,
189            // therefore we would only have one data output.
190            output_id: 0,
191        };
192        let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
193            error!(
194                error = %e.as_report(),
195                ?task_id,
196                "failed to get task output in local execution mode",
197            );
198        })?;
199        let mut writer = GrpcExchangeWriter::new(tx.clone());
200        // Always spawn a task and do not block current function.
201        mgr.runtime().spawn(async move {
202            match output.take_data(&mut writer).await {
203                Ok(_) => Ok(()),
204                Err(e) => tx.send(Err(e.into())).await,
205            }
206        });
207        Ok(Response::new(ReceiverStream::new(rx)))
208    }
209
210    async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
211        let table_id = insert_req.table_id;
212        let wait_for_persistence = insert_req.wait_for_persistence;
213        let (executor, data_chunk) =
214            FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
215        let epoch = executor
216            .do_execute(data_chunk, wait_for_persistence)
217            .await?;
218        if wait_for_persistence {
219            dispatch_state_store!(self.env.state_store(), store, {
220                use risingwave_common::catalog::TableId;
221                use risingwave_hummock_sdk::HummockReadEpoch;
222                use risingwave_storage::StateStore;
223                use risingwave_storage::store::TryWaitEpochOptions;
224
225                store
226                    .try_wait_epoch(
227                        HummockReadEpoch::Committed(epoch.0),
228                        TryWaitEpochOptions {
229                            table_id: TableId::new(table_id),
230                        },
231                    )
232                    .await
233                    .map_err(BatchError::from)?;
234            });
235        }
236        Ok(())
237    }
238}