risingwave_batch/rpc/service/
task_service.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::sync::Arc;
17
18use anyhow::Context;
19use futures::stream::{FuturesOrdered, StreamExt};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::util::tracing::TracingContext;
22use risingwave_dml::TableDmlHandleRef;
23use risingwave_dml::dml_manager::DmlManagerRef;
24use risingwave_dml::error::DmlError;
25use risingwave_pb::batch_plan::TaskOutputId;
26use risingwave_pb::task_service::task_service_server::TaskService;
27use risingwave_pb::task_service::{
28    CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
29    FastInsertResponse, GetDataResponse, IngestDmlAckResponse, IngestDmlInitRequest,
30    IngestDmlInitResponse, IngestDmlPayloadRequest, IngestDmlRequest, IngestDmlResponse,
31    TaskInfoResponse, fast_insert_response, ingest_dml_request, ingest_dml_response,
32};
33use thiserror_ext::AsReport;
34use tokio_stream::wrappers::ReceiverStream;
35use tonic::{Request, Response, Status};
36
37use crate::error::BatchError;
38use crate::executor::{FastInsertExecutor, inject_optional_row_id_column};
39use crate::rpc::service::exchange::GrpcExchangeWriter;
40use crate::task::{
41    BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter,
42    TASK_STATUS_BUFFER_SIZE,
43};
44
45#[derive(Clone)]
46pub struct BatchServiceImpl {
47    mgr: Arc<BatchManager>,
48    env: BatchEnvironment,
49}
50
51impl BatchServiceImpl {
52    pub fn new(mgr: Arc<BatchManager>, env: BatchEnvironment) -> Self {
53        BatchServiceImpl { mgr, env }
54    }
55}
56
57pub type TaskInfoResponseResult = Result<TaskInfoResponse, Status>;
58pub type GetDataResponseResult = Result<GetDataResponse, Status>;
59pub type IngestDmlResponseResult = Result<IngestDmlResponse, Status>;
60
61#[async_trait::async_trait]
62impl TaskService for BatchServiceImpl {
63    type CreateTaskStream = ReceiverStream<TaskInfoResponseResult>;
64    type ExecuteStream = ReceiverStream<GetDataResponseResult>;
65    type IngestDmlStream = ReceiverStream<IngestDmlResponseResult>;
66
67    async fn create_task(
68        &self,
69        request: Request<CreateTaskRequest>,
70    ) -> Result<Response<Self::CreateTaskStream>, Status> {
71        let CreateTaskRequest {
72            task_id,
73            plan,
74            tracing_context,
75            expr_context,
76        } = request.into_inner();
77
78        let (state_tx, state_rx) = tokio::sync::mpsc::channel(TASK_STATUS_BUFFER_SIZE);
79        let state_reporter = StateReporter::new_with_dist_sender(state_tx);
80        let res = self
81            .mgr
82            .fire_task(
83                task_id.as_ref().expect("no task id found"),
84                plan.expect("no plan found").clone(),
85                ComputeNodeContext::create(self.env.clone()),
86                state_reporter,
87                TracingContext::from_protobuf(&tracing_context),
88                expr_context.expect("no expression context found"),
89            )
90            .await;
91        match res {
92            Ok(_) => Ok(Response::new(ReceiverStream::new(
93                // Create receiver stream from state receiver.
94                // The state receiver is init in `.async_execute()`.
95                // Will be used for receive task status update.
96                // Note: we introduce this hack cuz `.execute()` do not produce a status stream,
97                // but still share `.async_execute()` and `.try_execute()`.
98                state_rx,
99            ))),
100            Err(e) => {
101                error!(error = %e.as_report(), "failed to fire task");
102                Err(e.into())
103            }
104        }
105    }
106
107    async fn cancel_task(
108        &self,
109        req: Request<CancelTaskRequest>,
110    ) -> Result<Response<CancelTaskResponse>, Status> {
111        let req = req.into_inner();
112        tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap());
113        self.mgr
114            .cancel_task(req.get_task_id().expect("no task id found"));
115        Ok(Response::new(CancelTaskResponse { status: None }))
116    }
117
118    async fn execute(
119        &self,
120        req: Request<ExecuteRequest>,
121    ) -> Result<Response<Self::ExecuteStream>, Status> {
122        let req = req.into_inner();
123        let env = self.env.clone();
124        let mgr = self.mgr.clone();
125        BatchServiceImpl::get_execute_stream(env, mgr, req).await
126    }
127
128    async fn fast_insert(
129        &self,
130        request: Request<FastInsertRequest>,
131    ) -> Result<Response<FastInsertResponse>, Status> {
132        let req = request.into_inner();
133        let res = self.do_fast_insert(req).await;
134        match res {
135            Ok(_) => Ok(Response::new(FastInsertResponse {
136                status: fast_insert_response::Status::Succeeded.into(),
137                error_message: "".to_owned(),
138            })),
139            Err(e) => match e {
140                BatchError::Dml(e) => Ok(Response::new(FastInsertResponse {
141                    status: fast_insert_response::Status::DmlFailed.into(),
142                    error_message: format!("{}", e.as_report()),
143                })),
144                _ => {
145                    error!(error = %e.as_report(), "failed to fast insert");
146                    Err(e.into())
147                }
148            },
149        }
150    }
151
152    async fn ingest_dml(
153        &self,
154        request: Request<tonic::Streaming<IngestDmlRequest>>,
155    ) -> Result<Response<Self::IngestDmlStream>, Status> {
156        let mut req_stream = request.into_inner();
157        let init = match req_stream.message().await? {
158            Some(req) => match req.request {
159                Some(ingest_dml_request::Request::Init(init)) => init,
160                Some(ingest_dml_request::Request::Payload(_)) => {
161                    return Err(Status::invalid_argument(
162                        "first ingest dml message must be init",
163                    ));
164                }
165                None => return Err(Status::invalid_argument("empty ingest dml request")),
166            },
167            None => return Err(Status::invalid_argument("empty ingest dml stream")),
168        };
169
170        let (tx, rx) = tokio::sync::mpsc::channel(64);
171
172        let (table_dml_handle, request_id, row_id_index) = self.init_ingest_dml(&init)?;
173        let _ = tx.send(Ok(Self::ingest_dml_init_response())).await;
174
175        let dml_manager = self.env.dml_manager_ref();
176        tokio::spawn(async move {
177            let result: Result<(), String> = async {
178                let mut pending_acks = FuturesOrdered::new();
179
180                loop {
181                    tokio::select! {
182                        req = req_stream.message() => {
183                            let req = req
184                                .map_err(|err| format!("ingest dml stream read failed: {}", err.as_report()))?
185                                .ok_or_else(|| "ingest dml stream closed unexpectedly".to_owned())?;
186                            let payload = match req.request {
187                                Some(ingest_dml_request::Request::Payload(payload)) => payload,
188                                Some(ingest_dml_request::Request::Init(_)) | None => {
189                                    Err("unexpected non-payload request in ingest dml stream".to_owned())?
190                                }
191                            };
192
193                            let dml_batch_id = payload.dml_batch_id;
194                            let wait_fut = Self::do_ingest_dml_payload(
195                                table_dml_handle.clone(),
196                                dml_manager.clone(),
197                                request_id,
198                                row_id_index,
199                                payload,
200                            )
201                            .await
202                            .map_err(|err| format!("ingest dml batch {} failed: {}", dml_batch_id, err.as_report()))?;
203
204                            pending_acks.push_back(async move { wait_fut.await.map(|()| dml_batch_id) });
205                        }
206                        ack = pending_acks.next(), if !pending_acks.is_empty() => {
207                            let ack_dml_batch_id = ack
208                                .expect("branch guarded by non-empty pending_acks")
209                                .map_err(|err: DmlError| format!("ingest dml persistence failed: {}", err.as_report()))?;
210
211                            if tx
212                                .send(Ok(BatchServiceImpl::ingest_dml_ack_response(ack_dml_batch_id)))
213                                .await
214                                .is_err()
215                            {
216                                return Ok(());
217                            }
218                        }
219                    }
220                }
221            }
222            .await;
223
224            if let Err(err) = result {
225                let _ = tx.send(Err(Status::internal(err))).await;
226            }
227        });
228
229        Ok(Response::new(ReceiverStream::new(rx)))
230    }
231}
232
233impl BatchServiceImpl {
234    async fn get_execute_stream(
235        env: BatchEnvironment,
236        mgr: Arc<BatchManager>,
237        req: ExecuteRequest,
238    ) -> Result<Response<ReceiverStream<GetDataResponseResult>>, Status> {
239        let ExecuteRequest {
240            task_id,
241            plan,
242            tracing_context,
243            expr_context,
244        } = req;
245
246        let task_id = task_id.expect("no task id found");
247        let plan = plan.expect("no plan found").clone();
248        let tracing_context = TracingContext::from_protobuf(&tracing_context);
249        let expr_context = expr_context.expect("no expression context found");
250
251        let context = ComputeNodeContext::create(env.clone());
252        trace!(
253            "local execute request: plan:{:?} with task id:{:?}",
254            plan, task_id
255        );
256        let task = BatchTaskExecution::new(&task_id, plan, context, mgr.runtime())?;
257        let task = Arc::new(task);
258        let (tx, rx) = tokio::sync::mpsc::channel(mgr.config().developer.local_execute_buffer_size);
259        if let Err(e) = task
260            .clone()
261            .async_execute(None, tracing_context, expr_context)
262            .await
263        {
264            error!(
265                error = %e.as_report(),
266                ?task_id,
267                "failed to build executors and trigger execution"
268            );
269            return Err(e.into());
270        }
271
272        let pb_task_output_id = TaskOutputId {
273            task_id: Some(task_id.clone()),
274            // Since this is local execution path, the exchange would follow single distribution,
275            // therefore we would only have one data output.
276            output_id: 0,
277        };
278        let mut output = task.get_task_output(&pb_task_output_id).inspect_err(|e| {
279            error!(
280                error = %e.as_report(),
281                ?task_id,
282                "failed to get task output in local execution mode",
283            );
284        })?;
285        let mut writer = GrpcExchangeWriter::new(tx.clone());
286        // Always spawn a task and do not block current function.
287        mgr.runtime().spawn(async move {
288            match output.take_data(&mut writer).await {
289                Ok(_) => Ok(()),
290                Err(e) => tx.send(Err(e.into())).await,
291            }
292        });
293        Ok(Response::new(ReceiverStream::new(rx)))
294    }
295
296    async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
297        let wait_for_persistence = insert_req.wait_for_persistence;
298        let (executor, data_chunk) =
299            FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
300        executor
301            .do_execute(data_chunk, wait_for_persistence)
302            .await?;
303        Ok(())
304    }
305
306    fn init_ingest_dml(
307        &self,
308        init: &IngestDmlInitRequest,
309    ) -> Result<(TableDmlHandleRef, u32, Option<u32>), Status> {
310        let table_id = init.table_id;
311        let table_version_id = init.table_version_id;
312        let table_dml_handle = self
313            .env
314            .dml_manager_ref()
315            .table_dml_handle(table_id, table_version_id)
316            .map_err(|err| Status::internal(format!("{}", err.as_report())))?;
317        Ok((table_dml_handle, init.request_id, init.row_id_index))
318    }
319
320    async fn do_ingest_dml_payload(
321        table_dml_handle: TableDmlHandleRef,
322        dml_manager: DmlManagerRef,
323        request_id: u32,
324        row_id_index: Option<u32>,
325        payload: IngestDmlPayloadRequest,
326    ) -> Result<impl Future<Output = risingwave_dml::error::Result<()>> + Send + 'static, BatchError>
327    {
328        let pb_chunk = payload.chunk.ok_or_else(|| {
329            BatchError::Internal(anyhow::anyhow!("no chunk in IngestDmlPayloadRequest"))
330        })?;
331        let mut chunk = StreamChunk::from_protobuf(&pb_chunk)
332            .context("failed to decode chunk")
333            .map_err(BatchError::Internal)?;
334        chunk = inject_optional_row_id_column(chunk, row_id_index.map(|index| index as usize));
335        let txn_id = dml_manager.gen_txn_id();
336        let mut write_handle = table_dml_handle
337            .write_handle(request_id, txn_id)
338            .map_err(BatchError::Dml)?;
339
340        write_handle.begin().map_err(BatchError::Dml)?;
341        write_handle
342            .write_chunk(chunk)
343            .await
344            .map_err(BatchError::Dml)?;
345        let persistence_future = write_handle
346            .end_wait_persistence()
347            .map_err(BatchError::Dml)?;
348        Ok(persistence_future)
349    }
350
351    fn ingest_dml_init_response() -> IngestDmlResponse {
352        IngestDmlResponse {
353            response: Some(ingest_dml_response::Response::Init(
354                IngestDmlInitResponse {},
355            )),
356        }
357    }
358
359    fn ingest_dml_ack_response(dml_batch_id: u64) -> IngestDmlResponse {
360        IngestDmlResponse {
361            response: Some(ingest_dml_response::Response::Ack(IngestDmlAckResponse {
362                dml_batch_id,
363            })),
364        }
365    }
366}