risingwave_batch/rpc/service/
task_service.rs1use std::future::Future;
16use std::sync::Arc;
17
18use anyhow::Context;
19use futures::stream::{FuturesOrdered, StreamExt};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::util::tracing::TracingContext;
22use risingwave_dml::TableDmlHandleRef;
23use risingwave_dml::dml_manager::DmlManagerRef;
24use risingwave_dml::error::DmlError;
25use risingwave_pb::batch_plan::TaskOutputId;
26use risingwave_pb::task_service::task_service_server::TaskService;
27use risingwave_pb::task_service::{
28 CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
29 FastInsertResponse, GetDataResponse, IngestDmlAckResponse, IngestDmlInitRequest,
30 IngestDmlInitResponse, IngestDmlPayloadRequest, IngestDmlRequest, IngestDmlResponse,
31 TaskInfoResponse, fast_insert_response, ingest_dml_request, ingest_dml_response,
32};
33use thiserror_ext::AsReport;
34use tokio_stream::wrappers::ReceiverStream;
35use tonic::{Request, Response, Status};
36
37use crate::error::BatchError;
38use crate::executor::{FastInsertExecutor, inject_optional_row_id_column};
39use crate::rpc::service::exchange::GrpcExchangeWriter;
40use crate::task::{
41 BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
42 TASK_STATUS_BUFFER_SIZE,
43};
44
45#[derive(Clone)]
46pub struct BatchServiceImpl {
47 mgr: Arc<BatchManager>,
48 env: BatchEnvironment,
49}
50
51impl BatchServiceImpl {
52 pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
53 BatchServiceImpl { mgr, env }
54 }
55}
56
57pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
58pub type GetDataResponseResult = Result<GetDataResponse, Status>;
59pub type IngestDmlResponseResult = Result<IngestDmlResponse, Status>;
60
61#[async_trait::async_trait]
62impl TaskService for BatchServiceImpl {
63 type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
64 type ExecuteStream = ReceiverStream<GetDataResponseResult>;
65 type IngestDmlStream = ReceiverStream<IngestDmlResponseResult>;
66
67 async fn create_task(
68 &self,
69 request: Request<CreateTaskRequest>,
70 ) -> Result<Response<Self::CreateTaskStream>, Status> {
71 let CreateTaskRequest {
72 task_id,
73 plan,
74 tracing_context,
75 expr_context,
76 } = request.into_inner();
77
78 let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
79 let state_reporter = StateReporter::new_with_dist_sender(state_tx);
80 let res = self
81 .mgr
82 .fire_task(
83 task_id.as_ref().expect("no task id found"),
84 plan.expect("no plan found").clone(),
85 ComputeNodeContext::create(self.env.clone()),
86 state_reporter,
87 TracingContext::from_protobuf(&tracing_context),
88 expr_context.expect("no expression context found"),
89 )
90 .await;
91 match res {
92 Ok(_) => Ok(Response::new(ReceiverStream::new(
93 state_rx,
99 ))),
100 Err(e) => {
101 error!(error = %e.as_report(), "failed to fire task");
102 Err(e.into())
103 }
104 }
105 }
106
107 async fn cancel_task(
108 &self,
109 req: Request<CancelTaskRequest>,
110 ) -> Result<Response<CancelTaskResponse>, Status> {
111 let req = req.into_inner();
112 tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
113 self.mgr
114 .cancel_task(req.get_task_id().expect("no task id found"));
115 Ok(Response::new(CancelTaskResponse { status: None }))
116 }
117
118 async fn execute(
119 &self,
120 req: Request<ExecuteRequest>,
121 ) -> Result<Response<Self::ExecuteStream>, Status> {
122 let req = req.into_inner();
123 let env = self.env.clone();
124 let mgr = self.mgr.clone();
125 BatchServiceImpl::get_execute_stream(env, mgr, req).await
126 }
127
128 async fn fast_insert(
129 &self,
130 request: Request<FastInsertRequest>,
131 ) -> Result<Response<FastInsertResponse>, Status> {
132 let req = request.into_inner();
133 let res = self.do_fast_insert(req).await;
134 match res {
135 Ok(_) => Ok(Response::new(FastInsertResponse {
136 status: fast_insert_response::Status::Succeeded.into(),
137 error_message: "".to_owned(),
138 })),
139 Err(e) => match e {
140 BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
141 status: fast_insert_response::Status::DmlFailed.into(),
142 error_message: format!("{}", e.as_report()),
143 })),
144 _ => {
145 error!(error = %e.as_report(), "failed to fast insert");
146 Err(e.into())
147 }
148 },
149 }
150 }
151
152 async fn ingest_dml(
153 &self,
154 request: Request<tonic::Streaming<IngestDmlRequest>>,
155 ) -> Result<Response<Self::IngestDmlStream>, Status> {
156 let mut req_stream = request.into_inner();
157 let init = match req_stream.message().await? {
158 Some(req) => match req.request {
159 Some(ingest_dml_request::Request::Init(init)) => init,
160 Some(ingest_dml_request::Request::Payload(_)) => {
161 return Err(Status::invalid_argument(
162 "first ingest dml message must be init",
163 ));
164 }
165 None => return Err(Status::invalid_argument("empty ingest dml request")),
166 },
167 None => return Err(Status::invalid_argument("empty ingest dml stream")),
168 };
169
170 let (tx, rx) = tokio::sync::mpsc::channel(64);
171
172 let (table_dml_handle, request_id, row_id_index) = self.init_ingest_dml(&init)?;
173 let _ = tx.send(Ok(Self::ingest_dml_init_response())).await;
174
175 let dml_manager = self.env.dml_manager_ref();
176 tokio::spawn(async move {
177 let result: Result<(), String> = async {
178 let mut pending_acks = FuturesOrdered::new();
179
180 loop {
181 tokio::select! {
182 req = req_stream.message() => {
183 let req = req
184 .map_err(|err| format!("ingest dml stream read failed: {}", err.as_report()))?
185 .ok_or_else(|| "ingest dml stream closed unexpectedly".to_owned())?;
186 let payload = match req.request {
187 Some(ingest_dml_request::Request::Payload(payload)) => payload,
188 Some(ingest_dml_request::Request::Init(_)) | None => {
189 Err("unexpected non-payload request in ingest dml stream".to_owned())?
190 }
191 };
192
193 let dml_batch_id = payload.dml_batch_id;
194 let wait_fut = Self::do_ingest_dml_payload(
195 table_dml_handle.clone(),
196 dml_manager.clone(),
197 request_id,
198 row_id_index,
199 payload,
200 )
201 .await
202 .map_err(|err| format!("ingest dml batch {} failed: {}", dml_batch_id, err.as_report()))?;
203
204 pending_acks.push_back(async move { wait_fut.await.map(|()| dml_batch_id) });
205 }
206 ack = pending_acks.next(), if !pending_acks.is_empty() => {
207 let ack_dml_batch_id = ack
208 .expect("branch guarded by non-empty pending_acks")
209 .map_err(|err: DmlError| format!("ingest dml persistence failed: {}", err.as_report()))?;
210
211 if tx
212 .send(Ok(BatchServiceImpl::ingest_dml_ack_response(ack_dml_batch_id)))
213 .await
214 .is_err()
215 {
216 return Ok(());
217 }
218 }
219 }
220 }
221 }
222 .await;
223
224 if let Err(err) = result {
225 let _ = tx.send(Err(Status::internal(err))).await;
226 }
227 });
228
229 Ok(Response::new(ReceiverStream::new(rx)))
230 }
231}
232
233impl BatchServiceImpl {
234 async fn get_execute_stream(
235 env: BatchEnvironment,
236 mgr: Arc<BatchManager>,
237 req: ExecuteRequest,
238 ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
239 let ExecuteRequest {
240 task_id,
241 plan,
242 tracing_context,
243 expr_context,
244 } = req;
245
246 let task_id = task_id.expect("no task id found");
247 let plan = plan.expect("no plan found").clone();
248 let tracing_context = TracingContext::from_protobuf(&tracing_context);
249 let expr_context = expr_context.expect("no expression context found");
250
251 let context = ComputeNodeContext::create(env.clone());
252 trace!(
253 "local execute request: plan:{:?} with task id:{:?}",
254 plan, task_id
255 );
256 let task = BatchTaskExecution::new(&task_id, plan, context, mgr.runtime())?;
257 let task = Arc::new(task);
258 let (tx, rx) = tokio::sync::mpsc::channel(mgr.config().developer.local_execute_buffer_size);
259 if let Err(e) = task
260 .clone()
261 .async_execute(None, tracing_context, expr_context)
262 .await
263 {
264 error!(
265 error = %e.as_report(),
266 ?task_id,
267 "failed to build executors and trigger execution"
268 );
269 return Err(e.into());
270 }
271
272 let pb_task_output_id = TaskOutputId {
273 task_id: Some(task_id.clone()),
274 output_id: 0,
277 };
278 let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
279 error!(
280 error = %e.as_report(),
281 ?task_id,
282 "failed to get task output in local execution mode",
283 );
284 })?;
285 let mut writer = GrpcExchangeWriter::new(tx.clone());
286 mgr.runtime().spawn(async move {
288 match output.take_data(&mut writer).await {
289 Ok(_) => Ok(()),
290 Err(e) => tx.send(Err(e.into())).await,
291 }
292 });
293 Ok(Response::new(ReceiverStream::new(rx)))
294 }
295
296 async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
297 let wait_for_persistence = insert_req.wait_for_persistence;
298 let (executor, data_chunk) =
299 FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
300 executor
301 .do_execute(data_chunk, wait_for_persistence)
302 .await?;
303 Ok(())
304 }
305
306 fn init_ingest_dml(
307 &self,
308 init: &IngestDmlInitRequest,
309 ) -> Result<(TableDmlHandleRef, u32, Option<u32>), Status> {
310 let table_id = init.table_id;
311 let table_version_id = init.table_version_id;
312 let table_dml_handle = self
313 .env
314 .dml_manager_ref()
315 .table_dml_handle(table_id, table_version_id)
316 .map_err(|err| Status::internal(format!("{}", err.as_report())))?;
317 Ok((table_dml_handle, init.request_id, init.row_id_index))
318 }
319
320 async fn do_ingest_dml_payload(
321 table_dml_handle: TableDmlHandleRef,
322 dml_manager: DmlManagerRef,
323 request_id: u32,
324 row_id_index: Option<u32>,
325 payload: IngestDmlPayloadRequest,
326 ) -> Result<impl Future<Output = risingwave_dml::error::Result<()>> + Send + 'static, BatchError>
327 {
328 let pb_chunk = payload.chunk.ok_or_else(|| {
329 BatchError::Internal(anyhow::anyhow!("no chunk in IngestDmlPayloadRequest"))
330 })?;
331 let mut chunk = StreamChunk::from_protobuf(&pb_chunk)
332 .context("failed to decode chunk")
333 .map_err(BatchError::Internal)?;
334 chunk = inject_optional_row_id_column(chunk, row_id_index.map(|index| index as usize));
335 let txn_id = dml_manager.gen_txn_id();
336 let mut write_handle = table_dml_handle
337 .write_handle(request_id, txn_id)
338 .map_err(BatchError::Dml)?;
339
340 write_handle.begin().map_err(BatchError::Dml)?;
341 write_handle
342 .write_chunk(chunk)
343 .await
344 .map_err(BatchError::Dml)?;
345 let persistence_future = write_handle
346 .end_wait_persistence()
347 .map_err(BatchError::Dml)?;
348 Ok(persistence_future)
349 }
350
351 fn ingest_dml_init_response() -> IngestDmlResponse {
352 IngestDmlResponse {
353 response: Some(ingest_dml_response::Response::Init(
354 IngestDmlInitResponse {},
355 )),
356 }
357 }
358
359 fn ingest_dml_ack_response(dml_batch_id: u64) -> IngestDmlResponse {
360 IngestDmlResponse {
361 response: Some(ingest_dml_response::Response::Ack(IngestDmlAckResponse {
362 dml_batch_id,
363 })),
364 }
365 }
366}