risingwave_common/util/
row_id.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::time::SystemTime;
17
18use super::epoch::UNIX_RISINGWAVE_DATE_EPOCH;
19use crate::hash::VirtualNode;
20
21/// The number of bits occupied by the vnode part and the sequence part of a row id.
22const TIMESTAMP_SHIFT_BITS: u32 = 22;
23
24/// The number of bits occupied by the vnode part of a row id in the previous version.
25const COMPAT_VNODE_BITS: u32 = 10;
26
27/// `RowIdGenerator` generates unique row ids using snowflake algorithm as following format:
28///
29/// | timestamp | vnode & sequence |
30/// |-----------|------------------|
31/// |  41 bits  |     22 bits      |
32///
33/// The vnode part can occupy 10..=15 bits, which is determined by the vnode count. Thus,
34/// the sequence part will occupy 7..=12 bits. See [`bit_for_vnode`] for more details.
35#[derive(Debug)]
36pub struct RowIdGenerator {
37    /// Specific base timestamp using for generating row ids.
38    base: SystemTime,
39
40    /// Last timestamp part of row id, based on `base`.
41    last_timestamp_ms: i64,
42
43    /// The number of bits used for vnode.
44    vnode_bit: u32,
45
46    /// Virtual nodes used by this generator.
47    vnodes: Vec<VirtualNode>,
48
49    /// Current index of `vnodes`.
50    vnodes_index: u16,
51
52    /// Last sequence part of row id.
53    sequence: u16,
54}
55
56pub type RowId = i64;
57
58/// The number of bits occupied by the vnode part of a row id.
59///
60/// In previous versions, this was fixed to 10 bits even if the vnode count was fixed to 256.
61/// For backward compatibility, we still use 10 bits for vnode count less than or equal to 1024.
62/// For larger vnode counts, we use the smallest power of 2 that fits the vnode count.
63fn bit_for_vnode(vnode_count: usize) -> u32 {
64    debug_assert!(
65        vnode_count <= VirtualNode::MAX_COUNT,
66        "invalid vnode count {vnode_count}"
67    );
68
69    if vnode_count <= 1 << COMPAT_VNODE_BITS {
70        COMPAT_VNODE_BITS
71    } else {
72        vnode_count.next_power_of_two().ilog2()
73    }
74}
75
76/// Compute vnode from the given row id.
77///
78/// # `vnode_count`
79///
80/// The given `vnode_count` determines the valid range of the returned vnode. It does not have to
81/// be the same as the vnode count used when the row id was generated with [`RowIdGenerator`].
82///
83/// However, only if they are the same, the vnode retrieved here is guaranteed to be the same as
84/// when it was generated. Otherwise, the vnode can be different and skewed, but the row ids
85/// generated under the same vnode will still yield the same result.
86///
87/// This is okay because we rely on the reversibility only if the serial type (row id) is generated
88/// and persisted in the same fragment, where the vnode count is the same. In other cases, the
89/// serial type is more like a normal integer type, and the algorithm to hash or compute vnode from
90/// it does not matter.
91#[inline]
92pub fn compute_vnode_from_row_id(id: RowId, vnode_count: usize) -> VirtualNode {
93    let vnode_bit = bit_for_vnode(vnode_count);
94    let sequence_bit = TIMESTAMP_SHIFT_BITS - vnode_bit;
95
96    let vnode_part = ((id >> sequence_bit) & ((1 << vnode_bit) - 1)) as usize;
97
98    // If the given `vnode_count` is the same as the one used when the row id was generated, this
99    // is no-op. Otherwise, we clamp the vnode to fit in the given vnode count.
100    VirtualNode::from_index(vnode_part % vnode_count)
101}
102
103impl RowIdGenerator {
104    /// Create a new `RowIdGenerator` with given virtual nodes and vnode count.
105    pub fn new(vnodes: impl IntoIterator<Item = VirtualNode>, vnode_count: usize) -> Self {
106        let base = *UNIX_RISINGWAVE_DATE_EPOCH;
107        let vnode_bit = bit_for_vnode(vnode_count);
108
109        Self {
110            base,
111            last_timestamp_ms: base.elapsed().unwrap().as_millis() as i64,
112            vnode_bit,
113            vnodes: vnodes.into_iter().collect(),
114            vnodes_index: 0,
115            sequence: 0,
116        }
117    }
118
119    /// The upper bound of the sequence part, exclusive.
120    fn sequence_upper_bound(&self) -> u16 {
121        1 << (TIMESTAMP_SHIFT_BITS - self.vnode_bit)
122    }
123
124    /// Update the timestamp, so that the millisecond part of row id is **always** increased.
125    ///
126    /// This method will immediately return if the timestamp is increased or there's remaining
127    /// sequence for the current millisecond. Otherwise, it will spin loop until the timestamp is
128    /// increased.
129    fn try_update_timestamp(&mut self) {
130        let get_current_timestamp_ms = || self.base.elapsed().unwrap().as_millis() as i64;
131
132        let current_timestamp_ms = get_current_timestamp_ms();
133        let to_update = match current_timestamp_ms.cmp(&self.last_timestamp_ms) {
134            Ordering::Less => {
135                tracing::warn!(
136                    "Clock moved backwards: last={}, current={}",
137                    self.last_timestamp_ms,
138                    current_timestamp_ms,
139                );
140                true
141            }
142            Ordering::Equal => {
143                // Update the timestamp if the sequence reaches the upper bound.
144                self.sequence == self.sequence_upper_bound()
145            }
146            Ordering::Greater => true,
147        };
148
149        if to_update {
150            // If the timestamp is not increased, spin loop here and wait for next millisecond. The
151            // case for time going backwards and sequence reaches the upper bound are both covered.
152            let mut current_timestamp_ms = current_timestamp_ms;
153            loop {
154                if current_timestamp_ms > self.last_timestamp_ms {
155                    break;
156                }
157                current_timestamp_ms = get_current_timestamp_ms();
158
159                #[cfg(madsim)]
160                tokio::time::advance(std::time::Duration::from_micros(10));
161                #[cfg(not(madsim))]
162                std::hint::spin_loop();
163            }
164
165            // Reset states. We do not reset the `vnode_index` to make all vnodes are evenly used.
166            self.last_timestamp_ms = current_timestamp_ms;
167            self.sequence = 0;
168        }
169    }
170
171    /// Generate a new `RowId`. Returns `None` if the sequence reaches the upper bound of current
172    /// timestamp, and `try_update_timestamp` should be called to update the timestamp and reset the
173    /// sequence. After that, the next call of this method always returns `Some`.
174    fn next_row_id_in_current_timestamp(&mut self) -> Option<RowId> {
175        if self.sequence >= self.sequence_upper_bound() {
176            return None;
177        }
178
179        let vnode = self.vnodes[self.vnodes_index as usize].to_index();
180        let sequence = self.sequence;
181
182        self.vnodes_index = (self.vnodes_index + 1) % self.vnodes.len() as u16;
183        if self.vnodes_index == 0 {
184            self.sequence += 1;
185        }
186
187        Some(
188            self.last_timestamp_ms << TIMESTAMP_SHIFT_BITS
189                | (vnode << (TIMESTAMP_SHIFT_BITS - self.vnode_bit)) as i64
190                | sequence as i64,
191        )
192    }
193
194    /// Returns an infinite iterator that generates `RowId`s.
195    fn gen_iter(&mut self) -> impl Iterator<Item = RowId> + '_ {
196        std::iter::from_fn(move || {
197            if let Some(next) = self.next_row_id_in_current_timestamp() {
198                Some(next)
199            } else {
200                self.try_update_timestamp();
201                Some(
202                    self.next_row_id_in_current_timestamp()
203                        .expect("timestamp should be updated"),
204                )
205            }
206        })
207    }
208
209    /// Generate a sequence of `RowId`s. Compared to `next`, this method is more efficient as it
210    /// only checks the timestamp once before generating the first `RowId`, instead of doing that
211    /// every `RowId`.
212    ///
213    /// This may block for a while if too many IDs are generated in one millisecond.
214    pub fn next_batch(&mut self, length: usize) -> Vec<RowId> {
215        self.try_update_timestamp();
216
217        let mut ret = Vec::with_capacity(length);
218        ret.extend(self.gen_iter().take(length));
219        assert_eq!(ret.len(), length);
220        ret
221    }
222
223    /// Generate a new `RowId`.
224    ///
225    /// This may block for a while if too many IDs are generated in one millisecond.
226    #[allow(clippy::should_implement_trait)]
227    pub fn next(&mut self) -> RowId {
228        self.try_update_timestamp();
229
230        self.gen_iter().next().unwrap()
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use std::time::Duration;
237
238    use itertools::Itertools;
239
240    use super::*;
241
242    #[allow(clippy::unused_async)] // `madsim::time::advance` requires to be in async context
243    async fn test_generator_with_vnode_count(vnode_count: usize) {
244        let mut generator = RowIdGenerator::new([VirtualNode::from_index(0)], vnode_count);
245        let sequence_upper_bound = generator.sequence_upper_bound();
246
247        let mut last_row_id = generator.next();
248        for _ in 0..100000 {
249            let row_id = generator.next();
250            assert!(row_id > last_row_id);
251            last_row_id = row_id;
252        }
253
254        let dur = Duration::from_millis(10);
255        #[cfg(madsim)]
256        tokio::time::advance(dur);
257        #[cfg(not(madsim))]
258        std::thread::sleep(dur);
259
260        let row_id = generator.next();
261        assert!(row_id > last_row_id);
262        assert_ne!(
263            row_id >> TIMESTAMP_SHIFT_BITS,
264            last_row_id >> TIMESTAMP_SHIFT_BITS
265        );
266        assert_eq!(row_id & (sequence_upper_bound as i64 - 1), 0);
267
268        let mut generator = RowIdGenerator::new([VirtualNode::from_index(1)], vnode_count);
269        let row_ids = generator.next_batch((sequence_upper_bound + 10) as usize);
270        let mut expected = (0..sequence_upper_bound).collect_vec();
271        expected.extend(0..10);
272        assert_eq!(
273            row_ids
274                .into_iter()
275                .map(|id| (id as u16) & (sequence_upper_bound - 1))
276                .collect_vec(),
277            expected
278        );
279    }
280
281    #[allow(clippy::unused_async)] // `madsim::time::advance` requires to be in async context
282    async fn test_generator_multiple_vnodes_with_vnode_count(vnode_count: usize) {
283        assert!(vnode_count >= 20);
284
285        let vnodes = || {
286            (0..10)
287                .chain((vnode_count - 10)..vnode_count)
288                .map(VirtualNode::from_index)
289        };
290        let vnode_of = |row_id: RowId| compute_vnode_from_row_id(row_id, vnode_count);
291
292        let mut generator = RowIdGenerator::new(vnodes(), vnode_count);
293        let sequence_upper_bound = generator.sequence_upper_bound();
294
295        let row_ids = generator.next_batch((sequence_upper_bound as usize) * 20 + 1);
296
297        // Check timestamps.
298        let timestamps = row_ids
299            .iter()
300            .map(|&r| r >> TIMESTAMP_SHIFT_BITS)
301            .collect_vec();
302
303        let (last_timestamp, first_timestamps) = timestamps.split_last().unwrap();
304        let first_timestamp = first_timestamps.iter().unique().exactly_one().unwrap();
305
306        // Check vnodes.
307        let expected_vnodes = vnodes().cycle();
308        let actual_vnodes = row_ids.iter().map(|&r| vnode_of(r));
309
310        #[expect(clippy::disallowed_methods)] // `expected_vnodes` is an endless cycle iterator
311        for (expected, actual) in expected_vnodes.zip(actual_vnodes) {
312            assert_eq!(expected, actual);
313        }
314
315        assert!(last_timestamp > first_timestamp);
316    }
317
318    macro_rules! test {
319        ($vnode_count:expr, $name:ident, $name_mul:ident) => {
320            #[tokio::test]
321            async fn $name() {
322                test_generator_with_vnode_count($vnode_count).await;
323            }
324
325            #[tokio::test]
326            async fn $name_mul() {
327                test_generator_multiple_vnodes_with_vnode_count($vnode_count).await;
328            }
329        };
330    }
331
332    test!(64, test_64, test_64_mul); // less than default value
333    test!(114, test_114, test_114_mul); // not a power of 2, less than default value
334    test!(256, test_256, test_256_mul); // default value, backward compatibility
335    test!(1 << COMPAT_VNODE_BITS, test_1024, test_1024_mul); // max value with 10 bits
336    test!(2048, test_2048, test_2048_mul); // more than 10 bits
337    test!(2333, test_2333, test_2333_mul); // not a power of 2, larger than default value
338    test!(VirtualNode::MAX_COUNT, test_max, test_max_mul); // max supported
339}