risingwave_meta/barrier/
utils.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
16
17use itertools::Itertools;
18use risingwave_common::catalog::TableId;
19use risingwave_common::util::epoch::EpochPair;
20use risingwave_common::util::stream_graph_visitor::visit_stream_node_cont;
21use risingwave_hummock_sdk::sstable_info::SstableInfo;
22use risingwave_hummock_sdk::table_stats::from_prost_table_stats_map;
23use risingwave_hummock_sdk::table_watermark::{
24    TableWatermarks, merge_multiple_new_table_watermarks,
25};
26use risingwave_hummock_sdk::vector_index::{VectorIndexAdd, VectorIndexDelta};
27use risingwave_hummock_sdk::{HummockSstableObjectId, LocalSstableInfo};
28use risingwave_meta_model::WorkerId;
29use risingwave_pb::catalog::PbTable;
30use risingwave_pb::catalog::table::PbTableType;
31use risingwave_pb::hummock::vector_index_delta::PbVectorIndexInit;
32use risingwave_pb::stream_plan::stream_node::NodeBody;
33use risingwave_pb::stream_service::BarrierCompleteResponse;
34
35use crate::barrier::CreateStreamingJobCommandInfo;
36use crate::barrier::command::PostCollectCommand;
37use crate::barrier::partial_graph::PartialGraphBarrierInfo;
38use crate::hummock::{CommitEpochInfo, NewTableFragmentInfo};
39
40#[expect(clippy::type_complexity)]
41pub(super) fn collect_resp_info(
42    resps: Vec<BarrierCompleteResponse>,
43) -> (
44    HashMap<HummockSstableObjectId, WorkerId>,
45    Vec<LocalSstableInfo>,
46    HashMap<TableId, TableWatermarks>,
47    Vec<SstableInfo>,
48    HashMap<TableId, Vec<VectorIndexAdd>>,
49    HashSet<TableId>,
50) {
51    let mut sst_to_worker: HashMap<HummockSstableObjectId, _> = HashMap::new();
52    let mut synced_ssts: Vec<LocalSstableInfo> = vec![];
53    let mut table_watermarks = Vec::with_capacity(resps.len());
54    let mut old_value_ssts = Vec::with_capacity(resps.len());
55    let mut vector_index_adds = HashMap::new();
56    let mut truncate_tables: HashSet<TableId> = HashSet::new();
57
58    for resp in resps {
59        let ssts_iter = resp.synced_sstables.into_iter().map(|local_sst| {
60            let sst_info = local_sst.sst.expect("field not None");
61            sst_to_worker.insert(sst_info.object_id, resp.worker_id);
62            LocalSstableInfo::new(
63                sst_info.into(),
64                from_prost_table_stats_map(local_sst.table_stats_map),
65                local_sst.created_at,
66            )
67        });
68        synced_ssts.extend(ssts_iter);
69        table_watermarks.push(resp.table_watermarks);
70        old_value_ssts.extend(resp.old_value_sstables.into_iter().map(|s| s.into()));
71        for (table_id, vector_index_add) in resp.vector_index_adds {
72            vector_index_adds
73                .try_insert(
74                    table_id,
75                    vector_index_add
76                        .adds
77                        .into_iter()
78                        .map(VectorIndexAdd::from)
79                        .collect(),
80                )
81                .expect("non-duplicate");
82        }
83        truncate_tables.extend(resp.truncate_tables);
84    }
85
86    (
87        sst_to_worker,
88        synced_ssts,
89        merge_multiple_new_table_watermarks(
90            table_watermarks
91                .into_iter()
92                .map(|watermarks| {
93                    watermarks
94                        .into_iter()
95                        .map(|(table_id, watermarks)| {
96                            (table_id, TableWatermarks::from(&watermarks))
97                        })
98                        .collect()
99                })
100                .collect_vec(),
101        ),
102        old_value_ssts,
103        vector_index_adds,
104        truncate_tables,
105    )
106}
107
108pub(super) fn collect_new_vector_index_info(
109    info: &CreateStreamingJobCommandInfo,
110) -> Option<&PbTable> {
111    let mut vector_index_table = None;
112    {
113        for fragment in info.stream_job_fragments.fragments.values() {
114            visit_stream_node_cont(&fragment.nodes, |node| {
115                match node.node_body.as_ref().unwrap() {
116                    NodeBody::VectorIndexWrite(vector_index_write) => {
117                        let index_table = vector_index_write.table.as_ref().unwrap();
118                        assert_eq!(index_table.table_type, PbTableType::VectorIndex as i32);
119                        vector_index_table = Some(index_table);
120                        false
121                    }
122                    _ => true,
123                }
124            })
125        }
126        vector_index_table
127    }
128}
129
130pub(super) fn collect_independent_job_commit_epoch_info(
131    commit_info: &mut CommitEpochInfo,
132    epoch: u64,
133    resps: Vec<BarrierCompleteResponse>,
134    barrier_info: &PartialGraphBarrierInfo,
135) {
136    let (
137        sst_to_context,
138        sstables,
139        new_table_watermarks,
140        old_value_sst,
141        vector_index_adds,
142        truncate_tables,
143    ) = collect_resp_info(resps);
144    assert!(old_value_sst.is_empty());
145    commit_info.sst_to_context.extend(sst_to_context);
146    commit_info.sstables.extend(sstables);
147    commit_info
148        .new_table_watermarks
149        .extend(new_table_watermarks);
150    for (table_id, vector_index_adds) in vector_index_adds {
151        commit_info
152            .vector_index_delta
153            .try_insert(table_id, VectorIndexDelta::Adds(vector_index_adds))
154            .expect("non-duplicate");
155    }
156    commit_info.truncate_tables.extend(truncate_tables);
157    barrier_info
158        .table_ids_to_commit
159        .iter()
160        .for_each(|table_id| {
161            commit_info
162                .tables_to_commit
163                .try_insert(*table_id, epoch)
164                .expect("non duplicate");
165        });
166    if let PostCollectCommand::CreateStreamingJob { info, .. } = &barrier_info.post_collect_command
167    {
168        commit_info
169            .new_table_fragment_infos
170            .push(NewTableFragmentInfo {
171                table_ids: barrier_info.table_ids_to_commit.clone(),
172            });
173        if let Some(index_table) = collect_new_vector_index_info(info) {
174            commit_info
175                .vector_index_delta
176                .try_insert(
177                    index_table.id,
178                    VectorIndexDelta::Init(PbVectorIndexInit {
179                        info: Some(index_table.vector_index_info.unwrap()),
180                    }),
181                )
182                .expect("non-duplicate");
183        }
184    };
185}
186
187pub(super) type NodeToCollect = HashSet<WorkerId>;
188pub(super) fn is_valid_after_worker_err(
189    node_to_collect: &NodeToCollect,
190    worker_id: WorkerId,
191) -> bool {
192    !node_to_collect.contains(&worker_id)
193}
194
195#[derive(Debug)]
196struct InflightBarrierNode<K, Item, Info> {
197    epoch: EpochPair,
198    to_collect: HashSet<K>,
199    collected: HashMap<K, Item>,
200    info: Info,
201}
202
203#[derive(Debug)]
204struct CollectedBarrierNode<K, Item, Info> {
205    epoch: EpochPair,
206    collected: HashMap<K, Item>,
207    info: Info,
208}
209
210#[derive(Debug)]
211pub(super) struct BarrierItemCollector<K, Item, Info> {
212    /// `prev_epoch` -> barrier
213    inflight_barriers: BTreeMap<u64, InflightBarrierNode<K, Item, Info>>,
214    /// newer epoch at the back. `push_back` and `pop_front`
215    collected_barriers: VecDeque<CollectedBarrierNode<K, Item, Info>>,
216}
217
218impl<K: std::fmt::Debug + Eq + std::hash::Hash, Item: std::fmt::Debug, Info: std::fmt::Debug>
219    BarrierItemCollector<K, Item, Info>
220{
221    pub(super) fn new() -> Self {
222        Self {
223            inflight_barriers: Default::default(),
224            collected_barriers: Default::default(),
225        }
226    }
227
228    pub(super) fn is_empty(&self) -> bool {
229        self.inflight_barriers.is_empty() && self.collected_barriers.is_empty()
230    }
231
232    pub(super) fn inflight_barrier_num(&self) -> usize {
233        self.inflight_barriers.len()
234    }
235
236    pub(super) fn collected_barrier_num(&self) -> usize {
237        self.collected_barriers.len()
238    }
239
240    pub(super) fn enqueue(&mut self, epoch: EpochPair, to_collect: HashSet<K>, info: Info) {
241        assert!(!to_collect.is_empty());
242        if let Some((last_prev_epoch, last_barrier)) = self.inflight_barriers.last_key_value() {
243            assert_eq!(last_barrier.epoch.curr, epoch.prev);
244            assert!(*last_prev_epoch < epoch.prev);
245        }
246        self.inflight_barriers
247            .try_insert(
248                epoch.prev,
249                InflightBarrierNode {
250                    epoch,
251                    to_collect,
252                    collected: Default::default(),
253                    info,
254                },
255            )
256            .expect("non-duplicated");
257    }
258
259    pub(super) fn collect(&mut self, prev_epoch: u64, key: K, item: Item) {
260        let inflight_barrier = self
261            .inflight_barriers
262            .get_mut(&prev_epoch)
263            .expect("should exist");
264        assert!(inflight_barrier.to_collect.remove(&key));
265        inflight_barrier
266            .collected
267            .try_insert(key, item)
268            .expect("non-duplicate");
269    }
270
271    pub(super) fn barrier_collected(&mut self) -> Option<(EpochPair, &Info)> {
272        if let Some(entry) = self.inflight_barriers.first_entry()
273            && entry.get().to_collect.is_empty()
274        {
275            let InflightBarrierNode {
276                epoch,
277                collected,
278                info,
279                ..
280            } = entry.remove();
281            if let Some(prev_barrier) = self.collected_barriers.back() {
282                assert_eq!(prev_barrier.epoch.curr, epoch.prev);
283            }
284            self.collected_barriers.push_back(CollectedBarrierNode {
285                epoch,
286                collected,
287                info,
288            });
289            Some((
290                epoch,
291                &self.collected_barriers.back().expect("non-empty").info,
292            ))
293        } else {
294            None
295        }
296    }
297
298    pub(super) fn advance_collected(&mut self) {
299        while self.barrier_collected().is_some() {}
300    }
301
302    pub(super) fn first_inflight_epoch(&self) -> Option<EpochPair> {
303        self.inflight_barriers
304            .first_key_value()
305            .map(|(_, barrier)| barrier.epoch)
306    }
307
308    pub(super) fn last_collected(&self) -> Option<(EpochPair, &HashMap<K, Item>, &Info)> {
309        self.collected_barriers
310            .back()
311            .map(|barrier| (barrier.epoch, &barrier.collected, &barrier.info))
312    }
313
314    pub(super) fn iter_infos(&self) -> impl Iterator<Item = &Info> + '_ {
315        self.inflight_barriers
316            .values()
317            .map(|barrier| &barrier.info)
318            .chain(self.collected_barriers.iter().map(|barrier| &barrier.info))
319    }
320
321    pub(super) fn into_infos(self) -> impl Iterator<Item = Info> {
322        self.inflight_barriers
323            .into_values()
324            .map(|barrier| barrier.info)
325            .chain(
326                self.collected_barriers
327                    .into_iter()
328                    .map(|barrier| barrier.info),
329            )
330    }
331
332    pub(super) fn iter_to_collect(&self) -> impl Iterator<Item = &HashSet<K>> + '_ {
333        self.inflight_barriers
334            .values()
335            .map(|barrier| &barrier.to_collect)
336    }
337
338    pub(super) fn take_collected_if(
339        &mut self,
340        cond: impl FnOnce(EpochPair) -> bool,
341    ) -> Option<(EpochPair, HashMap<K, Item>, Info)> {
342        self.collected_barriers
343            .pop_front_if(move |barrier| cond(barrier.epoch))
344            .map(|barrier| (barrier.epoch, barrier.collected, barrier.info))
345    }
346}