risingwave_batch/rpc/service/
task_service.rs1use std::sync::Arc;
16
17use risingwave_common::util::tracing::TracingContext;
18use risingwave_pb::batch_plan::TaskOutputId;
19use risingwave_pb::task_service::task_service_server::TaskService;
20use risingwave_pb::task_service::{
21 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
22 FastInsertResponse, GetDataResponse, TaskInfoResponse, fast_insert_response,
23};
24use risingwave_storage::dispatch_state_store;
25use thiserror_ext::AsReport;
26use tokio_stream::wrappers::ReceiverStream;
27use tonic::{Request, Response, Status};
28
29use crate::error::BatchError;
30use crate::executor::FastInsertExecutor;
31use crate::rpc::service::exchange::GrpcExchangeWriter;
32use crate::task::{
33 BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
34 TASK_STATUS_BUFFER_SIZE,
35};
36
37#[derive(Clone)]
38pub struct BatchServiceImpl {
39 mgr: Arc<BatchManager>,
40 env: BatchEnvironment,
41}
42
43impl BatchServiceImpl {
44 pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
45 BatchServiceImpl { mgr, env }
46 }
47}
48
49pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
50pub type GetDataResponseResult = Result<GetDataResponse, Status>;
51
52#[async_trait::async_trait]
53impl TaskService for BatchServiceImpl {
54 type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
55 type ExecuteStream = ReceiverStream<GetDataResponseResult>;
56
57 async fn create_task(
58 &self,
59 request: Request<CreateTaskRequest>,
60 ) -> Result<Response<Self::CreateTaskStream>, Status> {
61 let CreateTaskRequest {
62 task_id,
63 plan,
64 tracing_context,
65 expr_context,
66 } = request.into_inner();
67
68 let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
69 let state_reporter = StateReporter::new_with_dist_sender(state_tx);
70 let res = self
71 .mgr
72 .fire_task(
73 task_id.as_ref().expect("no task id found"),
74 plan.expect("no plan found").clone(),
75 ComputeNodeContext::create(self.env.clone()),
76 state_reporter,
77 TracingContext::from_protobuf(&tracing_context),
78 expr_context.expect("no expression context found"),
79 )
80 .await;
81 match res {
82 Ok(_) => Ok(Response::new(ReceiverStream::new(
83 state_rx,
89 ))),
90 Err(e) => {
91 error!(error = %e.as_report(), "failed to fire task");
92 Err(e.into())
93 }
94 }
95 }
96
97 async fn cancel_task(
98 &self,
99 req: Request<CancelTaskRequest>,
100 ) -> Result<Response<CancelTaskResponse>, Status> {
101 let req = req.into_inner();
102 tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
103 self.mgr
104 .cancel_task(req.get_task_id().expect("no task id found"));
105 Ok(Response::new(CancelTaskResponse { status: None }))
106 }
107
108 async fn execute(
109 &self,
110 req: Request<ExecuteRequest>,
111 ) -> Result<Response<Self::ExecuteStream>, Status> {
112 let req = req.into_inner();
113 let env = self.env.clone();
114 let mgr = self.mgr.clone();
115 BatchServiceImpl::get_execute_stream(env, mgr, req).await
116 }
117
118 async fn fast_insert(
119 &self,
120 request: Request<FastInsertRequest>,
121 ) -> Result<Response<FastInsertResponse>, Status> {
122 let req = request.into_inner();
123 let res = self.do_fast_insert(req).await;
124 match res {
125 Ok(_) => Ok(Response::new(FastInsertResponse {
126 status: fast_insert_response::Status::Succeeded.into(),
127 error_message: "".to_owned(),
128 })),
129 Err(e) => match e {
130 BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
131 status: fast_insert_response::Status::DmlFailed.into(),
132 error_message: format!("{}", e.as_report()),
133 })),
134 _ => {
135 error!(error = %e.as_report(), "failed to fast insert");
136 Err(e.into())
137 }
138 },
139 }
140 }
141}
142
143impl BatchServiceImpl {
144 async fn get_execute_stream(
145 env: BatchEnvironment,
146 mgr: Arc<BatchManager>,
147 req: ExecuteRequest,
148 ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
149 let ExecuteRequest {
150 task_id,
151 plan,
152 tracing_context,
153 expr_context,
154 } = req;
155
156 let task_id = task_id.expect("no task id found");
157 let plan = plan.expect("no plan found").clone();
158 let tracing_context = TracingContext::from_protobuf(&tracing_context);
159 let expr_context = expr_context.expect("no expression context found");
160
161 let context = ComputeNodeContext::create(env.clone());
162 trace!(
163 "local execute request: plan:{:?} with task id:{:?}",
164 plan, task_id
165 );
166 let task = BatchTaskExecution::new(&task_id, plan, context, mgr.runtime())?;
167 let task = Arc::new(task);
168 let (tx, rx) = tokio::sync::mpsc::channel(mgr.config().developer.local_execute_buffer_size);
169 if let Err(e) = task
170 .clone()
171 .async_execute(None, tracing_context, expr_context)
172 .await
173 {
174 error!(
175 error = %e.as_report(),
176 ?task_id,
177 "failed to build executors and trigger execution"
178 );
179 return Err(e.into());
180 }
181
182 let pb_task_output_id = TaskOutputId {
183 task_id: Some(task_id.clone()),
184 output_id: 0,
187 };
188 let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
189 error!(
190 error = %e.as_report(),
191 ?task_id,
192 "failed to get task output in local execution mode",
193 );
194 })?;
195 let mut writer = GrpcExchangeWriter::new(tx.clone());
196 mgr.runtime().spawn(async move {
198 match output.take_data(&mut writer).await {
199 Ok(_) => Ok(()),
200 Err(e) => tx.send(Err(e.into())).await,
201 }
202 });
203 Ok(Response::new(ReceiverStream::new(rx)))
204 }
205
206 async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
207 let table_id = insert_req.table_id;
208 let wait_for_persistence = insert_req.wait_for_persistence;
209 let (executor, data_chunk) =
210 FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
211 let epoch = executor
212 .do_execute(data_chunk, wait_for_persistence)
213 .await?;
214 if wait_for_persistence {
215 dispatch_state_store!(self.env.state_store(), store, {
216 use risingwave_common::catalog::TableId;
217 use risingwave_hummock_sdk::HummockReadEpoch;
218 use risingwave_storage::StateStore;
219 use risingwave_storage::store::TryWaitEpochOptions;
220
221 store
222 .try_wait_epoch(
223 HummockReadEpoch::Committed(epoch.0),
224 TryWaitEpochOptions {
225 table_id: TableId::new(table_id),
226 },
227 )
228 .await
229 .map_err(BatchError::from)?;
230 });
231 }
232 Ok(())
233 }
234}