risingwave_batch/rpc/service/
task_service.rs1use std::sync::Arc;
16
17use risingwave_common::util::tracing::TracingContext;
18use risingwave_pb::batch_plan::TaskOutputId;
19use risingwave_pb::task_service::task_service_server::TaskService;
20use risingwave_pb::task_service::{
21 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
22 FastInsertResponse, GetDataResponse, TaskInfoResponse, fast_insert_response,
23};
24use risingwave_storage::dispatch_state_store;
25use thiserror_ext::AsReport;
26use tokio_stream::wrappers::ReceiverStream;
27use tonic::{Request, Response, Status};
28
29use crate::error::BatchError;
30use crate::executor::FastInsertExecutor;
31use crate::rpc::service::exchange::GrpcExchangeWriter;
32use crate::task::{
33 BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
34 TASK_STATUS_BUFFER_SIZE,
35};
36
37const LOCAL_EXECUTE_BUFFER_SIZE: usize = 64;
38
39#[derive(Clone)]
40pub struct BatchServiceImpl {
41 mgr: Arc<BatchManager>,
42 env: BatchEnvironment,
43}
44
45impl BatchServiceImpl {
46 pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
47 BatchServiceImpl { mgr, env }
48 }
49}
50
51pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
52pub type GetDataResponseResult = Result<GetDataResponse, Status>;
53
54#[async_trait::async_trait]
55impl TaskService for BatchServiceImpl {
56 type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
57 type ExecuteStream = ReceiverStream<GetDataResponseResult>;
58
59 #[cfg_attr(coverage, coverage(off))]
60 async fn create_task(
61 &self,
62 request: Request<CreateTaskRequest>,
63 ) -> Result<Response<Self::CreateTaskStream>, Status> {
64 let CreateTaskRequest {
65 task_id,
66 plan,
67 epoch,
68 tracing_context,
69 expr_context,
70 } = request.into_inner();
71
72 let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
73 let state_reporter = StateReporter::new_with_dist_sender(state_tx);
74 let res = self
75 .mgr
76 .fire_task(
77 task_id.as_ref().expect("no task id found"),
78 plan.expect("no plan found").clone(),
79 epoch.expect("no epoch found"),
80 ComputeNodeContext::create(self.env.clone()),
81 state_reporter,
82 TracingContext::from_protobuf(&tracing_context),
83 expr_context.expect("no expression context found"),
84 )
85 .await;
86 match res {
87 Ok(_) => Ok(Response::new(ReceiverStream::new(
88 state_rx,
94 ))),
95 Err(e) => {
96 error!(error = %e.as_report(), "failed to fire task");
97 Err(e.into())
98 }
99 }
100 }
101
102 #[cfg_attr(coverage, coverage(off))]
103 async fn cancel_task(
104 &self,
105 req: Request<CancelTaskRequest>,
106 ) -> Result<Response<CancelTaskResponse>, Status> {
107 let req = req.into_inner();
108 tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
109 self.mgr
110 .cancel_task(req.get_task_id().expect("no task id found"));
111 Ok(Response::new(CancelTaskResponse { status: None }))
112 }
113
114 #[cfg_attr(coverage, coverage(off))]
115 async fn execute(
116 &self,
117 req: Request<ExecuteRequest>,
118 ) -> Result<Response<Self::ExecuteStream>, Status> {
119 let req = req.into_inner();
120 let env = self.env.clone();
121 let mgr = self.mgr.clone();
122 BatchServiceImpl::get_execute_stream(env, mgr, req).await
123 }
124
125 #[cfg_attr(coverage, coverage(off))]
126 async fn fast_insert(
127 &self,
128 request: Request<FastInsertRequest>,
129 ) -> Result<Response<FastInsertResponse>, Status> {
130 let req = request.into_inner();
131 let res = self.do_fast_insert(req).await;
132 match res {
133 Ok(_) => Ok(Response::new(FastInsertResponse {
134 status: fast_insert_response::Status::Succeeded.into(),
135 error_message: "".to_owned(),
136 })),
137 Err(e) => match e {
138 BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
139 status: fast_insert_response::Status::DmlFailed.into(),
140 error_message: format!("{}", e.as_report()),
141 })),
142 _ => {
143 error!(error = %e.as_report(), "failed to fast insert");
144 Err(e.into())
145 }
146 },
147 }
148 }
149}
150
151impl BatchServiceImpl {
152 async fn get_execute_stream(
153 env: BatchEnvironment,
154 mgr: Arc<BatchManager>,
155 req: ExecuteRequest,
156 ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
157 let ExecuteRequest {
158 task_id,
159 plan,
160 epoch,
161 tracing_context,
162 expr_context,
163 } = req;
164
165 let task_id = task_id.expect("no task id found");
166 let plan = plan.expect("no plan found").clone();
167 let epoch = epoch.expect("no epoch found");
168 let tracing_context = TracingContext::from_protobuf(&tracing_context);
169 let expr_context = expr_context.expect("no expression context found");
170
171 let context = ComputeNodeContext::create(env.clone());
172 trace!(
173 "local execute request: plan:{:?} with task id:{:?}",
174 plan, task_id
175 );
176 let task = BatchTaskExecution::new(&task_id, plan, context, epoch, mgr.runtime())?;
177 let task = Arc::new(task);
178 let (tx, rx) = tokio::sync::mpsc::channel(LOCAL_EXECUTE_BUFFER_SIZE);
179 if let Err(e) = task
180 .clone()
181 .async_execute(None, tracing_context, expr_context)
182 .await
183 {
184 error!(
185 error = %e.as_report(),
186 ?task_id,
187 "failed to build executors and trigger execution"
188 );
189 return Err(e.into());
190 }
191
192 let pb_task_output_id = TaskOutputId {
193 task_id: Some(task_id.clone()),
194 output_id: 0,
197 };
198 let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
199 error!(
200 error = %e.as_report(),
201 ?task_id,
202 "failed to get task output in local execution mode",
203 );
204 })?;
205 let mut writer = GrpcExchangeWriter::new(tx.clone());
206 mgr.runtime().spawn(async move {
208 match output.take_data(&mut writer).await {
209 Ok(_) => Ok(()),
210 Err(e) => tx.send(Err(e.into())).await,
211 }
212 });
213 Ok(Response::new(ReceiverStream::new(rx)))
214 }
215
216 async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
217 let table_id = insert_req.table_id;
218 let wait_for_persistence = insert_req.wait_for_persistence;
219 let (executor, data_chunk) =
220 FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
221 let epoch = executor
222 .do_execute(data_chunk, wait_for_persistence)
223 .await?;
224 if wait_for_persistence {
225 dispatch_state_store!(self.env.state_store(), store, {
226 use risingwave_common::catalog::TableId;
227 use risingwave_hummock_sdk::HummockReadEpoch;
228 use risingwave_storage::StateStore;
229 use risingwave_storage::store::TryWaitEpochOptions;
230
231 store
232 .try_wait_epoch(
233 HummockReadEpoch::Committed(epoch.0),
234 TryWaitEpochOptions {
235 table_id: TableId::new(table_id),
236 },
237 )
238 .await
239 .map_err(BatchError::from)?;
240 });
241 }
242 Ok(())
243 }
244}