risingwave_batch/task/
task_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, hash_map};
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use anyhow::Context;
20use parking_lot::Mutex;
21use risingwave_common::config::BatchConfig;
22use risingwave_common::memory::MemoryContext;
23use risingwave_common::util::runtime::BackgroundShutdownRuntime;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
26use risingwave_pb::plan_common::ExprContext;
27use risingwave_pb::task_service::task_info_response::TaskStatus;
28use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
29use tokio::sync::mpsc::Sender;
30use tonic::Status;
31
32use super::BatchTaskContext;
33use crate::error::Result;
34use crate::monitor::BatchManagerMetrics;
35use crate::rpc::service::exchange::GrpcExchangeWriter;
36use crate::task::{BatchTaskExecution, StateReporter, TaskId, TaskOutput, TaskOutputId};
37
38/// `BatchManager` is responsible for managing all batch tasks.
39#[derive(Clone)]
40pub struct BatchManager {
41    /// Every task id has a corresponding task execution.
42    tasks: Arc<Mutex<HashMap<TaskId, Arc<BatchTaskExecution>>>>,
43
44    /// Runtime for the batch manager.
45    runtime: Arc<BackgroundShutdownRuntime>,
46
47    /// Batch configuration
48    config: BatchConfig,
49
50    /// Memory context used for batch tasks in cn.
51    mem_context: MemoryContext,
52
53    /// Metrics for batch manager.
54    metrics: Arc<BatchManagerMetrics>,
55}
56
57impl BatchManager {
58    pub fn new(config: BatchConfig, metrics: Arc<BatchManagerMetrics>, mem_limit: u64) -> Self {
59        let runtime = {
60            let mut builder = tokio::runtime::Builder::new_multi_thread();
61            if let Some(worker_threads_num) = config.worker_threads_num {
62                builder.worker_threads(worker_threads_num);
63            }
64            builder
65                .thread_name("rw-batch")
66                .enable_all()
67                .build()
68                .unwrap()
69        };
70
71        let mem_context = MemoryContext::root(metrics.batch_total_mem.clone(), mem_limit);
72        BatchManager {
73            tasks: Arc::new(Mutex::new(HashMap::new())),
74            runtime: Arc::new(runtime.into()),
75            config,
76            metrics,
77            mem_context,
78        }
79    }
80
81    pub(crate) fn metrics(&self) -> Arc<BatchManagerMetrics> {
82        self.metrics.clone()
83    }
84
85    pub fn memory_context_ref(&self) -> MemoryContext {
86        self.mem_context.clone()
87    }
88
89    pub async fn fire_task(
90        self: &Arc<Self>,
91        tid: &PbTaskId,
92        plan: PlanFragment,
93        context: Arc<dyn BatchTaskContext>, // ComputeNodeContext
94        state_reporter: StateReporter,
95        tracing_context: TracingContext,
96        expr_context: ExprContext,
97    ) -> Result<()> {
98        trace!("Received task id: {:?}, plan: {:?}", tid, plan);
99        let task = BatchTaskExecution::new(tid, plan, context, self.runtime())?;
100        let task_id = task.get_task_id().clone();
101        let task = Arc::new(task);
102        // Here the task id insert into self.tasks is put in front of `.async_execute`, cuz when
103        // send `TaskStatus::Running` in `.async_execute`, the query runner may schedule next stage,
104        // it's possible do not found parent task id in theory.
105        let ret = if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) {
106            e.insert(task.clone());
107
108            let this = self.clone();
109            let task_id = task_id.clone();
110            let state_reporter = state_reporter.clone();
111            let heartbeat_join_handle = self.runtime.spawn(async move {
112                this.start_task_heartbeat(state_reporter, task_id).await;
113            });
114            task.set_heartbeat_join_handle(heartbeat_join_handle);
115
116            Ok(())
117        } else {
118            bail!(
119                "can not create duplicate task with the same id: {:?}",
120                task_id,
121            );
122        };
123        task.async_execute(Some(state_reporter), tracing_context, expr_context)
124            .await
125            .inspect_err(|_| {
126                self.cancel_task(&task_id.to_prost());
127            })?;
128        ret
129    }
130
131    #[cfg(test)]
132    async fn fire_task_for_test(
133        self: &Arc<Self>,
134        tid: &PbTaskId,
135        plan: PlanFragment,
136    ) -> Result<()> {
137        use crate::task::ComputeNodeContext;
138
139        self.fire_task(
140            tid,
141            plan,
142            ComputeNodeContext::for_test(),
143            StateReporter::new_with_test(),
144            TracingContext::none(),
145            ExprContext {
146                time_zone: "UTC".to_owned(),
147                strict_mode: false,
148            },
149        )
150        .await
151    }
152
153    async fn start_task_heartbeat(&self, mut state_reporter: StateReporter, task_id: TaskId) {
154        let _metric_guard = scopeguard::guard((), |_| {
155            tracing::debug!("heartbeat worker for task {:?} stopped", task_id);
156            self.metrics.batch_heartbeat_worker_num.dec();
157        });
158        tracing::debug!("heartbeat worker for task {:?} started", task_id);
159        self.metrics.batch_heartbeat_worker_num.inc();
160        // The heartbeat is to ensure task cancellation when frontend's cancellation request fails
161        // to reach compute node (for any reason like RPC fails, frontend crashes).
162        let mut heartbeat_interval = tokio::time::interval(core::time::Duration::from_secs(60));
163        heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
164        heartbeat_interval.reset();
165        loop {
166            heartbeat_interval.tick().await;
167            if !self.tasks.lock().contains_key(&task_id) {
168                break;
169            }
170            if state_reporter
171                .send(TaskInfoResponse {
172                    task_id: Some(task_id.to_prost()),
173                    task_status: TaskStatus::Ping.into(),
174                    error_message: "".to_owned(),
175                })
176                .await
177                .is_err()
178            {
179                tracing::warn!("try to cancel task {:?} due to heartbeat", task_id);
180                // Task may have been cancelled, but it's fine to `cancel_task` again.
181                self.cancel_task(&task_id.to_prost());
182                break;
183            }
184        }
185    }
186
187    pub fn get_data(
188        &self,
189        tx: Sender<std::result::Result<GetDataResponse, Status>>,
190        peer_addr: SocketAddr,
191        pb_task_output_id: &PbTaskOutputId,
192    ) -> Result<()> {
193        let task_id = TaskOutputId::try_from(pb_task_output_id)?;
194        tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, from = ?task_id, "serve exchange RPC");
195        let mut task_output = self.take_output(pb_task_output_id)?;
196        self.runtime.spawn(async move {
197            let mut writer = GrpcExchangeWriter::new(tx.clone());
198            match task_output.take_data(&mut writer).await {
199                Ok(_) => {
200                    tracing::trace!(
201                        from = ?task_id,
202                        "exchanged {} chunks",
203                        writer.written_chunks(),
204                    );
205                    Ok(())
206                }
207                Err(e) => tx.send(Err(e.into())).await,
208            }
209        });
210        Ok(())
211    }
212
213    pub fn take_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
214        let task_id = TaskId::from(output_id.get_task_id()?);
215        self.tasks
216            .lock()
217            .get(&task_id)
218            .with_context(|| format!("task {:?} not found", task_id))?
219            .get_task_output(output_id)
220    }
221
222    pub fn cancel_task(&self, sid: &PbTaskId) {
223        let sid = TaskId::from(sid);
224        match self.tasks.lock().remove(&sid) {
225            Some(task) => {
226                tracing::trace!("Removed task: {:?}", task.get_task_id());
227                // Use `cancel` rather than `abort` here since this is not an error which should be
228                // propagated to upstream.
229                task.cancel();
230                if let Some(heartbeat_join_handle) = task.heartbeat_join_handle() {
231                    heartbeat_join_handle.abort();
232                }
233            }
234            None => {
235                warn!("Task {:?} not found for cancel", sid)
236            }
237        };
238    }
239
240    /// Returns error if task is not running.
241    pub fn check_if_task_running(&self, task_id: &TaskId) -> Result<()> {
242        match self.tasks.lock().get(task_id) {
243            Some(task) => task.check_if_running(),
244            None => bail!("task {:?} not found", task_id),
245        }
246    }
247
248    pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<bool> {
249        match self.tasks.lock().get(task_id) {
250            Some(task) => task.check_if_aborted(),
251            None => bail!("task {:?} not found", task_id),
252        }
253    }
254
255    #[cfg(test)]
256    async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> {
257        use std::time::Duration;
258        loop {
259            match self.tasks.lock().get(task_id) {
260                Some(task) => {
261                    let ret = task.check_if_aborted();
262                    match ret {
263                        Ok(true) => return Ok(()),
264                        Ok(false) => {}
265                        Err(err) => return Err(err),
266                    }
267                }
268                None => bail!("task {:?} not found", task_id),
269            }
270            tokio::time::sleep(Duration::from_millis(100)).await
271        }
272    }
273
274    pub fn runtime(&self) -> Arc<BackgroundShutdownRuntime> {
275        self.runtime.clone()
276    }
277
278    pub fn config(&self) -> &BatchConfig {
279        &self.config
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use std::sync::Arc;
286
287    use risingwave_common::config::BatchConfig;
288    use risingwave_pb::batch_plan::exchange_info::DistributionMode;
289    use risingwave_pb::batch_plan::plan_node::NodeBody;
290    use risingwave_pb::batch_plan::{
291        ExchangeInfo, PbTaskId, PbTaskOutputId, PlanFragment, PlanNode,
292    };
293
294    use crate::monitor::BatchManagerMetrics;
295    use crate::task::{BatchManager, TaskId};
296
297    #[tokio::test]
298    async fn test_task_not_found() {
299        let manager = Arc::new(BatchManager::new(
300            BatchConfig::default(),
301            BatchManagerMetrics::for_test(),
302            u64::MAX,
303        ));
304        let task_id = TaskId {
305            task_id: 0,
306            stage_id: 0,
307            query_id: "abc".to_owned(),
308        };
309
310        let error = manager.check_if_task_running(&task_id).unwrap_err();
311        assert!(error.to_string().contains("not found"), "{:?}", error);
312
313        let output_id = PbTaskOutputId {
314            task_id: Some(risingwave_pb::batch_plan::TaskId {
315                stage_id: 0,
316                task_id: 0,
317                query_id: "".to_owned(),
318            }),
319            output_id: 0,
320        };
321        let error = manager.take_output(&output_id).unwrap_err();
322        assert!(error.to_string().contains("not found"), "{:?}", error);
323    }
324
325    #[tokio::test]
326    // see https://github.com/risingwavelabs/risingwave/issues/11979
327    #[ignore]
328    async fn test_task_cancel_for_busy_loop() {
329        let manager = Arc::new(BatchManager::new(
330            BatchConfig::default(),
331            BatchManagerMetrics::for_test(),
332            u64::MAX,
333        ));
334        let plan = PlanFragment {
335            root: Some(PlanNode {
336                children: vec![],
337                identity: "".to_owned(),
338                node_body: Some(NodeBody::BusyLoopExecutor(true)),
339            }),
340            exchange_info: Some(ExchangeInfo {
341                mode: DistributionMode::Single as i32,
342                distribution: None,
343            }),
344        };
345        let task_id = PbTaskId {
346            query_id: "".to_owned(),
347            stage_id: 0,
348            task_id: 0,
349        };
350        manager.fire_task_for_test(&task_id, plan).await.unwrap();
351        manager.cancel_task(&task_id);
352        let task_id = TaskId::from(&task_id);
353        assert!(!manager.tasks.lock().contains_key(&task_id));
354    }
355
356    #[tokio::test]
357    // see https://github.com/risingwavelabs/risingwave/issues/11979
358    #[ignore]
359    async fn test_task_abort_for_busy_loop() {
360        let manager = Arc::new(BatchManager::new(
361            BatchConfig::default(),
362            BatchManagerMetrics::for_test(),
363            u64::MAX,
364        ));
365        let plan = PlanFragment {
366            root: Some(PlanNode {
367                children: vec![],
368                identity: "".to_owned(),
369                node_body: Some(NodeBody::BusyLoopExecutor(true)),
370            }),
371            exchange_info: Some(ExchangeInfo {
372                mode: DistributionMode::Single as i32,
373                distribution: None,
374            }),
375        };
376        let task_id = PbTaskId {
377            query_id: "".to_owned(),
378            stage_id: 0,
379            task_id: 0,
380        };
381        manager.fire_task_for_test(&task_id, plan).await.unwrap();
382        let task_id = TaskId::from(&task_id);
383        manager
384            .tasks
385            .lock()
386            .get(&task_id)
387            .unwrap()
388            .abort("Abort Test".to_owned());
389        assert!(manager.wait_until_task_aborted(&task_id).await.is_ok());
390    }
391}