risingwave_batch/rpc/service/
task_service.rs1use std::sync::Arc;
16
17use risingwave_common::util::tracing::TracingContext;
18use risingwave_pb::batch_plan::TaskOutputId;
19use risingwave_pb::task_service::task_service_server::TaskService;
20use risingwave_pb::task_service::{
21 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
22 FastInsertResponse, GetDataResponse, TaskInfoResponse, fast_insert_response,
23};
24use risingwave_storage::dispatch_state_store;
25use thiserror_ext::AsReport;
26use tokio_stream::wrappers::ReceiverStream;
27use tonic::{Request, Response, Status};
28
29use crate::error::BatchError;
30use crate::executor::FastInsertExecutor;
31use crate::rpc::service::exchange::GrpcExchangeWriter;
32use crate::task::{
33 BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
34 TASK_STATUS_BUFFER_SIZE,
35};
36
37#[derive(Clone)]
38pub struct BatchServiceImpl {
39 mgr: Arc<BatchManager>,
40 env: BatchEnvironment,
41}
42
43impl BatchServiceImpl {
44 pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
45 BatchServiceImpl { mgr, env }
46 }
47}
48
49pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
50pub type GetDataResponseResult = Result<GetDataResponse, Status>;
51
52#[async_trait::async_trait]
53impl TaskService for BatchServiceImpl {
54 type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
55 type ExecuteStream = ReceiverStream<GetDataResponseResult>;
56
57 #[cfg_attr(coverage, coverage(off))]
58 async fn create_task(
59 &self,
60 request: Request<CreateTaskRequest>,
61 ) -> Result<Response<Self::CreateTaskStream>, Status> {
62 let CreateTaskRequest {
63 task_id,
64 plan,
65 epoch,
66 tracing_context,
67 expr_context,
68 } = request.into_inner();
69
70 let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
71 let state_reporter = StateReporter::new_with_dist_sender(state_tx);
72 let res = self
73 .mgr
74 .fire_task(
75 task_id.as_ref().expect("no task id found"),
76 plan.expect("no plan found").clone(),
77 epoch.expect("no epoch found"),
78 ComputeNodeContext::create(self.env.clone()),
79 state_reporter,
80 TracingContext::from_protobuf(&tracing_context),
81 expr_context.expect("no expression context found"),
82 )
83 .await;
84 match res {
85 Ok(_) => Ok(Response::new(ReceiverStream::new(
86 state_rx,
92 ))),
93 Err(e) => {
94 error!(error = %e.as_report(), "failed to fire task");
95 Err(e.into())
96 }
97 }
98 }
99
100 #[cfg_attr(coverage, coverage(off))]
101 async fn cancel_task(
102 &self,
103 req: Request<CancelTaskRequest>,
104 ) -> Result<Response<CancelTaskResponse>, Status> {
105 let req = req.into_inner();
106 tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
107 self.mgr
108 .cancel_task(req.get_task_id().expect("no task id found"));
109 Ok(Response::new(CancelTaskResponse { status: None }))
110 }
111
112 #[cfg_attr(coverage, coverage(off))]
113 async fn execute(
114 &self,
115 req: Request<ExecuteRequest>,
116 ) -> Result<Response<Self::ExecuteStream>, Status> {
117 let req = req.into_inner();
118 let env = self.env.clone();
119 let mgr = self.mgr.clone();
120 BatchServiceImpl::get_execute_stream(env, mgr, req).await
121 }
122
123 #[cfg_attr(coverage, coverage(off))]
124 async fn fast_insert(
125 &self,
126 request: Request<FastInsertRequest>,
127 ) -> Result<Response<FastInsertResponse>, Status> {
128 let req = request.into_inner();
129 let res = self.do_fast_insert(req).await;
130 match res {
131 Ok(_) => Ok(Response::new(FastInsertResponse {
132 status: fast_insert_response::Status::Succeeded.into(),
133 error_message: "".to_owned(),
134 })),
135 Err(e) => match e {
136 BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
137 status: fast_insert_response::Status::DmlFailed.into(),
138 error_message: format!("{}", e.as_report()),
139 })),
140 _ => {
141 error!(error = %e.as_report(), "failed to fast insert");
142 Err(e.into())
143 }
144 },
145 }
146 }
147}
148
149impl BatchServiceImpl {
150 async fn get_execute_stream(
151 env: BatchEnvironment,
152 mgr: Arc<BatchManager>,
153 req: ExecuteRequest,
154 ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
155 let ExecuteRequest {
156 task_id,
157 plan,
158 epoch,
159 tracing_context,
160 expr_context,
161 } = req;
162
163 let task_id = task_id.expect("no task id found");
164 let plan = plan.expect("no plan found").clone();
165 let epoch = epoch.expect("no epoch found");
166 let tracing_context = TracingContext::from_protobuf(&tracing_context);
167 let expr_context = expr_context.expect("no expression context found");
168
169 let context = ComputeNodeContext::create(env.clone());
170 trace!(
171 "local execute request: plan:{:?} with task id:{:?}",
172 plan, task_id
173 );
174 let task = BatchTaskExecution::new(&task_id, plan, context, epoch, mgr.runtime())?;
175 let task = Arc::new(task);
176 let (tx, rx) = tokio::sync::mpsc::channel(mgr.config().developer.local_execute_buffer_size);
177 if let Err(e) = task
178 .clone()
179 .async_execute(None, tracing_context, expr_context)
180 .await
181 {
182 error!(
183 error = %e.as_report(),
184 ?task_id,
185 "failed to build executors and trigger execution"
186 );
187 return Err(e.into());
188 }
189
190 let pb_task_output_id = TaskOutputId {
191 task_id: Some(task_id.clone()),
192 output_id: 0,
195 };
196 let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
197 error!(
198 error = %e.as_report(),
199 ?task_id,
200 "failed to get task output in local execution mode",
201 );
202 })?;
203 let mut writer = GrpcExchangeWriter::new(tx.clone());
204 mgr.runtime().spawn(async move {
206 match output.take_data(&mut writer).await {
207 Ok(_) => Ok(()),
208 Err(e) => tx.send(Err(e.into())).await,
209 }
210 });
211 Ok(Response::new(ReceiverStream::new(rx)))
212 }
213
214 async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
215 let table_id = insert_req.table_id;
216 let wait_for_persistence = insert_req.wait_for_persistence;
217 let (executor, data_chunk) =
218 FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
219 let epoch = executor
220 .do_execute(data_chunk, wait_for_persistence)
221 .await?;
222 if wait_for_persistence {
223 dispatch_state_store!(self.env.state_store(), store, {
224 use risingwave_common::catalog::TableId;
225 use risingwave_hummock_sdk::HummockReadEpoch;
226 use risingwave_storage::StateStore;
227 use risingwave_storage::store::TryWaitEpochOptions;
228
229 store
230 .try_wait_epoch(
231 HummockReadEpoch::Committed(epoch.0),
232 TryWaitEpochOptions {
233 table_id: TableId::new(table_id),
234 },
235 )
236 .await
237 .map_err(BatchError::from)?;
238 });
239 }
240 Ok(())
241 }
242}