risingwave_batch/task/
task_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, hash_map};
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use anyhow::Context;
20use parking_lot::Mutex;
21use risingwave_common::config::BatchConfig;
22use risingwave_common::memory::MemoryContext;
23use risingwave_common::util::runtime::BackgroundShutdownRuntime;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
26use risingwave_pb::common::BatchQueryEpoch;
27use risingwave_pb::plan_common::ExprContext;
28use risingwave_pb::task_service::task_info_response::TaskStatus;
29use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
30use tokio::sync::mpsc::Sender;
31use tonic::Status;
32
33use super::BatchTaskContext;
34use crate::error::Result;
35use crate::monitor::BatchManagerMetrics;
36use crate::rpc::service::exchange::GrpcExchangeWriter;
37use crate::task::{BatchTaskExecution, StateReporter, TaskId, TaskOutput, TaskOutputId};
38
39/// `BatchManager` is responsible for managing all batch tasks.
40#[derive(Clone)]
41pub struct BatchManager {
42    /// Every task id has a corresponding task execution.
43    tasks: Arc<Mutex<HashMap<TaskId, Arc<BatchTaskExecution>>>>,
44
45    /// Runtime for the batch manager.
46    runtime: Arc<BackgroundShutdownRuntime>,
47
48    /// Batch configuration
49    config: BatchConfig,
50
51    /// Memory context used for batch tasks in cn.
52    mem_context: MemoryContext,
53
54    /// Metrics for batch manager.
55    metrics: Arc<BatchManagerMetrics>,
56}
57
58impl BatchManager {
59    pub fn new(config: BatchConfig, metrics: Arc<BatchManagerMetrics>, mem_limit: u64) -> Self {
60        let runtime = {
61            let mut builder = tokio::runtime::Builder::new_multi_thread();
62            if let Some(worker_threads_num) = config.worker_threads_num {
63                builder.worker_threads(worker_threads_num);
64            }
65            builder
66                .thread_name("rw-batch")
67                .enable_all()
68                .build()
69                .unwrap()
70        };
71
72        let mem_context = MemoryContext::root(metrics.batch_total_mem.clone(), mem_limit);
73        BatchManager {
74            tasks: Arc::new(Mutex::new(HashMap::new())),
75            runtime: Arc::new(runtime.into()),
76            config,
77            metrics,
78            mem_context,
79        }
80    }
81
82    pub(crate) fn metrics(&self) -> Arc<BatchManagerMetrics> {
83        self.metrics.clone()
84    }
85
86    pub fn memory_context_ref(&self) -> MemoryContext {
87        self.mem_context.clone()
88    }
89
90    pub async fn fire_task(
91        self: &Arc<Self>,
92        tid: &PbTaskId,
93        plan: PlanFragment,
94        epoch: BatchQueryEpoch,
95        context: Arc<dyn BatchTaskContext>, // ComputeNodeContext
96        state_reporter: StateReporter,
97        tracing_context: TracingContext,
98        expr_context: ExprContext,
99    ) -> Result<()> {
100        trace!("Received task id: {:?}, plan: {:?}", tid, plan);
101        let task = BatchTaskExecution::new(tid, plan, context, epoch, self.runtime())?;
102        let task_id = task.get_task_id().clone();
103        let task = Arc::new(task);
104        // Here the task id insert into self.tasks is put in front of `.async_execute`, cuz when
105        // send `TaskStatus::Running` in `.async_execute`, the query runner may schedule next stage,
106        // it's possible do not found parent task id in theory.
107        let ret = if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) {
108            e.insert(task.clone());
109
110            let this = self.clone();
111            let task_id = task_id.clone();
112            let state_reporter = state_reporter.clone();
113            let heartbeat_join_handle = self.runtime.spawn(async move {
114                this.start_task_heartbeat(state_reporter, task_id).await;
115            });
116            task.set_heartbeat_join_handle(heartbeat_join_handle);
117
118            Ok(())
119        } else {
120            bail!(
121                "can not create duplicate task with the same id: {:?}",
122                task_id,
123            );
124        };
125        task.async_execute(Some(state_reporter), tracing_context, expr_context)
126            .await
127            .inspect_err(|_| {
128                self.cancel_task(&task_id.to_prost());
129            })?;
130        ret
131    }
132
133    #[cfg(test)]
134    async fn fire_task_for_test(
135        self: &Arc<Self>,
136        tid: &PbTaskId,
137        plan: PlanFragment,
138    ) -> Result<()> {
139        use risingwave_hummock_sdk::test_batch_query_epoch;
140
141        use crate::task::ComputeNodeContext;
142
143        self.fire_task(
144            tid,
145            plan,
146            test_batch_query_epoch(),
147            ComputeNodeContext::for_test(),
148            StateReporter::new_with_test(),
149            TracingContext::none(),
150            ExprContext {
151                time_zone: "UTC".to_owned(),
152                strict_mode: false,
153            },
154        )
155        .await
156    }
157
158    async fn start_task_heartbeat(&self, mut state_reporter: StateReporter, task_id: TaskId) {
159        let _metric_guard = scopeguard::guard((), |_| {
160            tracing::debug!("heartbeat worker for task {:?} stopped", task_id);
161            self.metrics.batch_heartbeat_worker_num.dec();
162        });
163        tracing::debug!("heartbeat worker for task {:?} started", task_id);
164        self.metrics.batch_heartbeat_worker_num.inc();
165        // The heartbeat is to ensure task cancellation when frontend's cancellation request fails
166        // to reach compute node (for any reason like RPC fails, frontend crashes).
167        let mut heartbeat_interval = tokio::time::interval(core::time::Duration::from_secs(60));
168        heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
169        heartbeat_interval.reset();
170        loop {
171            heartbeat_interval.tick().await;
172            if !self.tasks.lock().contains_key(&task_id) {
173                break;
174            }
175            if state_reporter
176                .send(TaskInfoResponse {
177                    task_id: Some(task_id.to_prost()),
178                    task_status: TaskStatus::Ping.into(),
179                    error_message: "".to_owned(),
180                })
181                .await
182                .is_err()
183            {
184                tracing::warn!("try to cancel task {:?} due to heartbeat", task_id);
185                // Task may have been cancelled, but it's fine to `cancel_task` again.
186                self.cancel_task(&task_id.to_prost());
187                break;
188            }
189        }
190    }
191
192    pub fn get_data(
193        &self,
194        tx: Sender<std::result::Result<GetDataResponse, Status>>,
195        peer_addr: SocketAddr,
196        pb_task_output_id: &PbTaskOutputId,
197    ) -> Result<()> {
198        let task_id = TaskOutputId::try_from(pb_task_output_id)?;
199        tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, from = ?task_id, "serve exchange RPC");
200        let mut task_output = self.take_output(pb_task_output_id)?;
201        self.runtime.spawn(async move {
202            let mut writer = GrpcExchangeWriter::new(tx.clone());
203            match task_output.take_data(&mut writer).await {
204                Ok(_) => {
205                    tracing::trace!(
206                        from = ?task_id,
207                        "exchanged {} chunks",
208                        writer.written_chunks(),
209                    );
210                    Ok(())
211                }
212                Err(e) => tx.send(Err(e.into())).await,
213            }
214        });
215        Ok(())
216    }
217
218    pub fn take_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
219        let task_id = TaskId::from(output_id.get_task_id()?);
220        self.tasks
221            .lock()
222            .get(&task_id)
223            .with_context(|| format!("task {:?} not found", task_id))?
224            .get_task_output(output_id)
225    }
226
227    pub fn cancel_task(&self, sid: &PbTaskId) {
228        let sid = TaskId::from(sid);
229        match self.tasks.lock().remove(&sid) {
230            Some(task) => {
231                tracing::trace!("Removed task: {:?}", task.get_task_id());
232                // Use `cancel` rather than `abort` here since this is not an error which should be
233                // propagated to upstream.
234                task.cancel();
235                if let Some(heartbeat_join_handle) = task.heartbeat_join_handle() {
236                    heartbeat_join_handle.abort();
237                }
238            }
239            None => {
240                warn!("Task {:?} not found for cancel", sid)
241            }
242        };
243    }
244
245    /// Returns error if task is not running.
246    pub fn check_if_task_running(&self, task_id: &TaskId) -> Result<()> {
247        match self.tasks.lock().get(task_id) {
248            Some(task) => task.check_if_running(),
249            None => bail!("task {:?} not found", task_id),
250        }
251    }
252
253    pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<bool> {
254        match self.tasks.lock().get(task_id) {
255            Some(task) => task.check_if_aborted(),
256            None => bail!("task {:?} not found", task_id),
257        }
258    }
259
260    #[cfg(test)]
261    async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> {
262        use std::time::Duration;
263        loop {
264            match self.tasks.lock().get(task_id) {
265                Some(task) => {
266                    let ret = task.check_if_aborted();
267                    match ret {
268                        Ok(true) => return Ok(()),
269                        Ok(false) => {}
270                        Err(err) => return Err(err),
271                    }
272                }
273                None => bail!("task {:?} not found", task_id),
274            }
275            tokio::time::sleep(Duration::from_millis(100)).await
276        }
277    }
278
279    pub fn runtime(&self) -> Arc<BackgroundShutdownRuntime> {
280        self.runtime.clone()
281    }
282
283    pub fn config(&self) -> &BatchConfig {
284        &self.config
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use std::sync::Arc;
291
292    use risingwave_common::config::BatchConfig;
293    use risingwave_pb::batch_plan::exchange_info::DistributionMode;
294    use risingwave_pb::batch_plan::plan_node::NodeBody;
295    use risingwave_pb::batch_plan::{
296        ExchangeInfo, PbTaskId, PbTaskOutputId, PlanFragment, PlanNode,
297    };
298
299    use crate::monitor::BatchManagerMetrics;
300    use crate::task::{BatchManager, TaskId};
301
302    #[tokio::test]
303    async fn test_task_not_found() {
304        let manager = Arc::new(BatchManager::new(
305            BatchConfig::default(),
306            BatchManagerMetrics::for_test(),
307            u64::MAX,
308        ));
309        let task_id = TaskId {
310            task_id: 0,
311            stage_id: 0,
312            query_id: "abc".to_owned(),
313        };
314
315        let error = manager.check_if_task_running(&task_id).unwrap_err();
316        assert!(error.to_string().contains("not found"), "{:?}", error);
317
318        let output_id = PbTaskOutputId {
319            task_id: Some(risingwave_pb::batch_plan::TaskId {
320                stage_id: 0,
321                task_id: 0,
322                query_id: "".to_owned(),
323            }),
324            output_id: 0,
325        };
326        let error = manager.take_output(&output_id).unwrap_err();
327        assert!(error.to_string().contains("not found"), "{:?}", error);
328    }
329
330    #[tokio::test]
331    // see https://github.com/risingwavelabs/risingwave/issues/11979
332    #[ignore]
333    async fn test_task_cancel_for_busy_loop() {
334        let manager = Arc::new(BatchManager::new(
335            BatchConfig::default(),
336            BatchManagerMetrics::for_test(),
337            u64::MAX,
338        ));
339        let plan = PlanFragment {
340            root: Some(PlanNode {
341                children: vec![],
342                identity: "".to_owned(),
343                node_body: Some(NodeBody::BusyLoopExecutor(true)),
344            }),
345            exchange_info: Some(ExchangeInfo {
346                mode: DistributionMode::Single as i32,
347                distribution: None,
348            }),
349        };
350        let task_id = PbTaskId {
351            query_id: "".to_owned(),
352            stage_id: 0,
353            task_id: 0,
354        };
355        manager.fire_task_for_test(&task_id, plan).await.unwrap();
356        manager.cancel_task(&task_id);
357        let task_id = TaskId::from(&task_id);
358        assert!(!manager.tasks.lock().contains_key(&task_id));
359    }
360
361    #[tokio::test]
362    // see https://github.com/risingwavelabs/risingwave/issues/11979
363    #[ignore]
364    async fn test_task_abort_for_busy_loop() {
365        let manager = Arc::new(BatchManager::new(
366            BatchConfig::default(),
367            BatchManagerMetrics::for_test(),
368            u64::MAX,
369        ));
370        let plan = PlanFragment {
371            root: Some(PlanNode {
372                children: vec![],
373                identity: "".to_owned(),
374                node_body: Some(NodeBody::BusyLoopExecutor(true)),
375            }),
376            exchange_info: Some(ExchangeInfo {
377                mode: DistributionMode::Single as i32,
378                distribution: None,
379            }),
380        };
381        let task_id = PbTaskId {
382            query_id: "".to_owned(),
383            stage_id: 0,
384            task_id: 0,
385        };
386        manager.fire_task_for_test(&task_id, plan).await.unwrap();
387        let task_id = TaskId::from(&task_id);
388        manager
389            .tasks
390            .lock()
391            .get(&task_id)
392            .unwrap()
393            .abort("Abort Test".to_owned());
394        assert!(manager.wait_until_task_aborted(&task_id).await.is_ok());
395    }
396}