risingwave_storage/hummock/event_handler/uploader/
task_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::*;
16
17#[derive(Debug)]
18pub(super) enum UploadingTaskStatus {
19    Spilling(HashSet<TableId>),
20    Sync(SyncId),
21}
22
23#[derive(Debug)]
24struct TaskEntry {
25    task: UploadingTask,
26    status: UploadingTaskStatus,
27}
28
29#[derive(Default, Debug)]
30pub(super) struct TaskManager {
31    tasks: HashMap<UploadingTaskId, TaskEntry>,
32    // newer task at the front
33    task_order: VecDeque<UploadingTaskId>,
34    next_task_id: usize,
35}
36
37impl TaskManager {
38    fn add_task(
39        &mut self,
40        task: UploadingTask,
41        status: UploadingTaskStatus,
42    ) -> &UploadingTaskStatus {
43        let task_id = task.task_id;
44        self.task_order.push_front(task.task_id);
45        assert!(
46            self.tasks
47                .insert(task.task_id, TaskEntry { task, status })
48                .is_none()
49        );
50        &self.tasks.get(&task_id).expect("should exist").status
51    }
52
53    fn poll_task(
54        &mut self,
55        cx: &mut Context<'_>,
56        task_id: UploadingTaskId,
57    ) -> Poll<Result<Arc<StagingSstableInfo>, (SyncId, HummockError)>> {
58        let entry = self.tasks.get_mut(&task_id).expect("should exist");
59        let result = match &entry.status {
60            UploadingTaskStatus::Spilling(_) => {
61                let sst = ready!(entry.task.poll_ok_with_retry(cx));
62                Ok(sst)
63            }
64            UploadingTaskStatus::Sync(sync_id) => {
65                let result = ready!(entry.task.poll_result(cx));
66                result.map_err(|e| (*sync_id, e))
67            }
68        };
69        Poll::Ready(result)
70    }
71
72    fn get_next_task_id(&mut self) -> UploadingTaskId {
73        let task_id = self.next_task_id;
74        self.next_task_id += 1;
75        UploadingTaskId(task_id)
76    }
77
78    #[expect(clippy::type_complexity)]
79    pub(super) fn poll_task_result(
80        &mut self,
81        cx: &mut Context<'_>,
82    ) -> Poll<
83        Option<(
84            UploadingTaskId,
85            UploadingTaskStatus,
86            Result<Arc<StagingSstableInfo>, (SyncId, HummockError)>,
87        )>,
88    > {
89        if let Some(task_id) = self.task_order.back() {
90            let task_id = *task_id;
91            let result = ready!(self.poll_task(cx, task_id));
92            self.task_order.pop_back();
93            let entry = self.tasks.remove(&task_id).expect("should exist");
94
95            Poll::Ready(Some((task_id, entry.status, result)))
96        } else {
97            Poll::Ready(None)
98        }
99    }
100
101    pub(super) fn abort_all_tasks(self) {
102        for task in self.tasks.into_values() {
103            task.task.join_handle.abort();
104        }
105    }
106
107    pub(super) fn abort_task(&mut self, task_id: UploadingTaskId) -> Option<UploadingTaskStatus> {
108        self.tasks.remove(&task_id).map(|entry| {
109            entry.task.join_handle.abort();
110            self.task_order
111                .retain(|inflight_task_id| *inflight_task_id != task_id);
112            entry.status
113        })
114    }
115
116    pub(super) fn spill(
117        &mut self,
118        context: &UploaderContext,
119        table_ids: HashSet<TableId>,
120        imms: HashMap<LocalInstanceId, Vec<UploaderImm>>,
121    ) -> (UploadingTaskId, usize, &HashSet<TableId>) {
122        assert!(!imms.is_empty());
123        let task = UploadingTask::new(self.get_next_task_id(), imms, context);
124        context.stats.spill_task_counts_from_unsealed.inc();
125        context
126            .stats
127            .spill_task_size_from_unsealed
128            .inc_by(task.task_info.task_size as u64);
129        info!("Spill data. Task: {}", task.get_task_info());
130        let size = task.task_info.task_size;
131        let id = task.task_id;
132        let status = self.add_task(task, UploadingTaskStatus::Spilling(table_ids));
133        (
134            id,
135            size,
136            must_match!(status, UploadingTaskStatus::Spilling(table_ids) => table_ids),
137        )
138    }
139
140    pub(super) fn sync(
141        &mut self,
142        context: &UploaderContext,
143        sync_id: SyncId,
144        unflushed_payload: UploadTaskInput,
145        spill_task_ids: impl Iterator<Item = UploadingTaskId>,
146        sync_table_ids: &HashSet<TableId>,
147    ) -> Option<UploadingTaskId> {
148        let task = if unflushed_payload.is_empty() {
149            None
150        } else {
151            Some(UploadingTask::new(
152                self.get_next_task_id(),
153                unflushed_payload,
154                context,
155            ))
156        };
157
158        for task_id in spill_task_ids {
159            let entry = self.tasks.get_mut(&task_id).expect("should exist");
160            must_match!(&entry.status, UploadingTaskStatus::Spilling(table_ids) => {
161                assert!(table_ids.is_subset(sync_table_ids), "spill table_ids: {table_ids:?}, sync_table_ids: {sync_table_ids:?}");
162            });
163            entry.status = UploadingTaskStatus::Sync(sync_id);
164        }
165
166        task.map(|task| {
167            let id = task.task_id;
168            self.add_task(task, UploadingTaskStatus::Sync(sync_id));
169            id
170        })
171    }
172
173    #[cfg(debug_assertions)]
174    pub(super) fn tasks(&self) -> impl Iterator<Item = (UploadingTaskId, &UploadingTaskStatus)> {
175        self.tasks
176            .iter()
177            .map(|(task_id, entry)| (*task_id, &entry.status))
178    }
179}