risingwave_batch/rpc/service/
task_service.rs1use std::sync::Arc;
16
17use risingwave_common::util::tracing::TracingContext;
18use risingwave_pb::batch_plan::TaskOutputId;
19use risingwave_pb::task_service::task_service_server::TaskService;
20use risingwave_pb::task_service::{
21 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
22 FastInsertResponse, GetDataResponse, TaskInfoResponse, fast_insert_response,
23};
24use risingwave_storage::dispatch_state_store;
25use thiserror_ext::AsReport;
26use tokio_stream::wrappers::ReceiverStream;
27use tonic::{Request, Response, Status};
28
29use crate::error::BatchError;
30use crate::executor::FastInsertExecutor;
31use crate::rpc::service::exchange::GrpcExchangeWriter;
32use crate::task::{
33 BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
34 TASK_STATUS_BUFFER_SIZE,
35};
36
37#[derive(Clone)]
38pub struct BatchServiceImpl {
39 mgr: Arc<BatchManager>,
40 env: BatchEnvironment,
41}
42
43impl BatchServiceImpl {
44 pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
45 BatchServiceImpl { mgr, env }
46 }
47}
48
49pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
50pub type GetDataResponseResult = Result<GetDataResponse, Status>;
51
52#[async_trait::async_trait]
53impl TaskService for BatchServiceImpl {
54 type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
55 type ExecuteStream = ReceiverStream<GetDataResponseResult>;
56
57 async fn create_task(
58 &self,
59 request: Request<CreateTaskRequest>,
60 ) -> Result<Response<Self::CreateTaskStream>, Status> {
61 let CreateTaskRequest {
62 task_id,
63 plan,
64 epoch,
65 tracing_context,
66 expr_context,
67 } = request.into_inner();
68
69 let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
70 let state_reporter = StateReporter::new_with_dist_sender(state_tx);
71 let res = self
72 .mgr
73 .fire_task(
74 task_id.as_ref().expect("no task id found"),
75 plan.expect("no plan found").clone(),
76 epoch.expect("no epoch found"),
77 ComputeNodeContext::create(self.env.clone()),
78 state_reporter,
79 TracingContext::from_protobuf(&tracing_context),
80 expr_context.expect("no expression context found"),
81 )
82 .await;
83 match res {
84 Ok(_) => Ok(Response::new(ReceiverStream::new(
85 state_rx,
91 ))),
92 Err(e) => {
93 error!(error = %e.as_report(), "failed to fire task");
94 Err(e.into())
95 }
96 }
97 }
98
99 async fn cancel_task(
100 &self,
101 req: Request<CancelTaskRequest>,
102 ) -> Result<Response<CancelTaskResponse>, Status> {
103 let req = req.into_inner();
104 tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
105 self.mgr
106 .cancel_task(req.get_task_id().expect("no task id found"));
107 Ok(Response::new(CancelTaskResponse { status: None }))
108 }
109
110 async fn execute(
111 &self,
112 req: Request<ExecuteRequest>,
113 ) -> Result<Response<Self::ExecuteStream>, Status> {
114 let req = req.into_inner();
115 let env = self.env.clone();
116 let mgr = self.mgr.clone();
117 BatchServiceImpl::get_execute_stream(env, mgr, req).await
118 }
119
120 async fn fast_insert(
121 &self,
122 request: Request<FastInsertRequest>,
123 ) -> Result<Response<FastInsertResponse>, Status> {
124 let req = request.into_inner();
125 let res = self.do_fast_insert(req).await;
126 match res {
127 Ok(_) => Ok(Response::new(FastInsertResponse {
128 status: fast_insert_response::Status::Succeeded.into(),
129 error_message: "".to_owned(),
130 })),
131 Err(e) => match e {
132 BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
133 status: fast_insert_response::Status::DmlFailed.into(),
134 error_message: format!("{}", e.as_report()),
135 })),
136 _ => {
137 error!(error = %e.as_report(), "failed to fast insert");
138 Err(e.into())
139 }
140 },
141 }
142 }
143}
144
145impl BatchServiceImpl {
146 async fn get_execute_stream(
147 env: BatchEnvironment,
148 mgr: Arc<BatchManager>,
149 req: ExecuteRequest,
150 ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
151 let ExecuteRequest {
152 task_id,
153 plan,
154 epoch,
155 tracing_context,
156 expr_context,
157 } = req;
158
159 let task_id = task_id.expect("no task id found");
160 let plan = plan.expect("no plan found").clone();
161 let epoch = epoch.expect("no epoch found");
162 let tracing_context = TracingContext::from_protobuf(&tracing_context);
163 let expr_context = expr_context.expect("no expression context found");
164
165 let context = ComputeNodeContext::create(env.clone());
166 trace!(
167 "local execute request: plan:{:?} with task id:{:?}",
168 plan, task_id
169 );
170 let task = BatchTaskExecution::new(&task_id, plan, context, epoch, mgr.runtime())?;
171 let task = Arc::new(task);
172 let (tx, rx) = tokio::sync::mpsc::channel(mgr.config().developer.local_execute_buffer_size);
173 if let Err(e) = task
174 .clone()
175 .async_execute(None, tracing_context, expr_context)
176 .await
177 {
178 error!(
179 error = %e.as_report(),
180 ?task_id,
181 "failed to build executors and trigger execution"
182 );
183 return Err(e.into());
184 }
185
186 let pb_task_output_id = TaskOutputId {
187 task_id: Some(task_id.clone()),
188 output_id: 0,
191 };
192 let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
193 error!(
194 error = %e.as_report(),
195 ?task_id,
196 "failed to get task output in local execution mode",
197 );
198 })?;
199 let mut writer = GrpcExchangeWriter::new(tx.clone());
200 mgr.runtime().spawn(async move {
202 match output.take_data(&mut writer).await {
203 Ok(_) => Ok(()),
204 Err(e) => tx.send(Err(e.into())).await,
205 }
206 });
207 Ok(Response::new(ReceiverStream::new(rx)))
208 }
209
210 async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
211 let table_id = insert_req.table_id;
212 let wait_for_persistence = insert_req.wait_for_persistence;
213 let (executor, data_chunk) =
214 FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
215 let epoch = executor
216 .do_execute(data_chunk, wait_for_persistence)
217 .await?;
218 if wait_for_persistence {
219 dispatch_state_store!(self.env.state_store(), store, {
220 use risingwave_common::catalog::TableId;
221 use risingwave_hummock_sdk::HummockReadEpoch;
222 use risingwave_storage::StateStore;
223 use risingwave_storage::store::TryWaitEpochOptions;
224
225 store
226 .try_wait_epoch(
227 HummockReadEpoch::Committed(epoch.0),
228 TryWaitEpochOptions {
229 table_id: TableId::new(table_id),
230 },
231 )
232 .await
233 .map_err(BatchError::from)?;
234 });
235 }
236 Ok(())
237 }
238}