risingwave_stream/executor/
row_id_gen.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::{Array, ArrayBuilder, ArrayRef, Op, SerialArrayBuilder};
16use risingwave_common::bitmap::Bitmap;
17use risingwave_common::hash::VnodeBitmapExt;
18use risingwave_common::types::Serial;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_common::util::row_id::RowIdGenerator;
21
22use crate::executor::prelude::*;
23
24/// [`RowIdGenExecutor`] generates row id for data, where the user has not specified a pk.
25pub struct RowIdGenExecutor {
26    ctx: ActorContextRef,
27
28    upstream: Option<Executor>,
29
30    row_id_index: usize,
31
32    row_id_generator: RowIdGenerator,
33}
34
35impl RowIdGenExecutor {
36    pub fn new(
37        ctx: ActorContextRef,
38        upstream: Executor,
39        row_id_index: usize,
40        vnodes: Bitmap,
41    ) -> Self {
42        Self {
43            ctx,
44            upstream: Some(upstream),
45            row_id_index,
46            row_id_generator: Self::new_generator(&vnodes),
47        }
48    }
49
50    /// Create a new row id generator based on the assigned vnodes.
51    fn new_generator(vnodes: &Bitmap) -> RowIdGenerator {
52        RowIdGenerator::new(vnodes.iter_vnodes(), vnodes.len())
53    }
54
55    /// Generate a row ID column according to ops.
56    fn gen_row_id_column_by_op(
57        &mut self,
58        column: &ArrayRef,
59        ops: &'_ [Op],
60        vis: &Bitmap,
61    ) -> ArrayRef {
62        let len = column.len();
63        let mut builder = SerialArrayBuilder::new(len);
64
65        for ((datum, op), vis) in column.iter().zip_eq_fast(ops).zip_eq_fast(vis.iter()) {
66            // Only refill row_id for insert operation.
67            match op {
68                Op::Insert => builder.append(Some(self.row_id_generator.next().into())),
69                _ => {
70                    if vis {
71                        builder.append(Some(Serial::try_from(datum.unwrap()).unwrap()))
72                    } else {
73                        builder.append(None)
74                    }
75                }
76            }
77        }
78
79        builder.finish().into_ref()
80    }
81
82    #[try_stream(ok = Message, error = StreamExecutorError)]
83    async fn execute_inner(mut self) {
84        let mut upstream = self.upstream.take().unwrap().execute();
85
86        // The first barrier mush propagated.
87        let barrier = expect_first_barrier(&mut upstream).await?;
88        yield Message::Barrier(barrier);
89
90        #[for_await]
91        for msg in upstream {
92            let msg = msg?;
93
94            match msg {
95                Message::Chunk(chunk) => {
96                    // For chunk message, we fill the row id column and then yield it.
97                    let (ops, mut columns, bitmap) = chunk.into_inner();
98                    columns[self.row_id_index] =
99                        self.gen_row_id_column_by_op(&columns[self.row_id_index], &ops, &bitmap);
100                    yield Message::Chunk(StreamChunk::with_visibility(ops, columns, bitmap));
101                }
102                Message::Barrier(barrier) => {
103                    // Update row id generator if vnode mapping changes.
104                    // Note that: since `Update` barrier only exists between `Pause` and `Resume`
105                    // barrier, duplicated row id won't be generated.
106                    if let Some(vnodes) = barrier.as_update_vnode_bitmap(self.ctx.id) {
107                        self.row_id_generator = Self::new_generator(&vnodes);
108                    }
109                    yield Message::Barrier(barrier);
110                }
111                Message::Watermark(watermark) => yield Message::Watermark(watermark),
112            }
113        }
114    }
115}
116
117impl Execute for RowIdGenExecutor {
118    fn execute(self: Box<Self>) -> super::BoxedMessageStream {
119        self.execute_inner().boxed()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use risingwave_common::array::PrimitiveArray;
126    use risingwave_common::catalog::Field;
127    use risingwave_common::hash::VirtualNode;
128    use risingwave_common::test_prelude::StreamChunkTestExt;
129    use risingwave_common::util::epoch::test_epoch;
130
131    use super::*;
132    use crate::executor::test_utils::MockSource;
133
134    #[tokio::test]
135    async fn test_row_id_gen_executor() {
136        // This test only works when vnode count is 256.
137        assert_eq!(VirtualNode::COUNT_FOR_TEST, 256);
138
139        let schema = Schema::new(vec![
140            Field::unnamed(DataType::Serial),
141            Field::unnamed(DataType::Int64),
142        ]);
143        let pk_indices = vec![0];
144        let row_id_index = 0;
145        let row_id_generator = Bitmap::ones(VirtualNode::COUNT_FOR_TEST);
146        let (mut tx, upstream) = MockSource::channel();
147        let upstream = upstream.into_executor(schema.clone(), pk_indices.clone());
148
149        let row_id_gen_executor = RowIdGenExecutor::new(
150            ActorContext::for_test(233),
151            upstream,
152            row_id_index,
153            row_id_generator,
154        );
155        let mut row_id_gen_executor = row_id_gen_executor.boxed().execute();
156
157        // Init barrier
158        tx.push_barrier(test_epoch(1), false);
159        row_id_gen_executor.next().await.unwrap().unwrap();
160
161        // Insert operation
162        let chunk1 = StreamChunk::from_pretty(
163            " SRL I
164            + . 1
165            + . 2
166            + . 6
167            + . 7",
168        );
169        tx.push_chunk(chunk1);
170        let chunk: StreamChunk = row_id_gen_executor
171            .next()
172            .await
173            .unwrap()
174            .unwrap()
175            .into_chunk()
176            .unwrap();
177        let row_id_col: &PrimitiveArray<Serial> = chunk.column_at(row_id_index).as_serial();
178        row_id_col.iter().for_each(|row_id| {
179            // Should generate row id for insert operations.
180            assert!(row_id.is_some());
181        });
182
183        // Update operation
184        let chunk2 = StreamChunk::from_pretty(
185            "      SRL        I
186            U- 32874283748  1
187            U+ 32874283748 999",
188        );
189        tx.push_chunk(chunk2);
190        let chunk: StreamChunk = row_id_gen_executor
191            .next()
192            .await
193            .unwrap()
194            .unwrap()
195            .into_chunk()
196            .unwrap();
197        let row_id_col: &PrimitiveArray<Serial> = chunk.column_at(row_id_index).as_serial();
198        // Should not generate row id for update operations.
199        assert_eq!(row_id_col.value_at(0).unwrap(), Serial::from(32874283748));
200        assert_eq!(row_id_col.value_at(1).unwrap(), Serial::from(32874283748));
201
202        // Delete operation
203        let chunk3 = StreamChunk::from_pretty(
204            "      SRL       I
205            - 84629409685  1",
206        );
207        tx.push_chunk(chunk3);
208        let chunk: StreamChunk = row_id_gen_executor
209            .next()
210            .await
211            .unwrap()
212            .unwrap()
213            .into_chunk()
214            .unwrap();
215        let row_id_col: &PrimitiveArray<Serial> = chunk.column_at(row_id_index).as_serial();
216        // Should not generate row id for delete operations.
217        assert_eq!(row_id_col.value_at(0).unwrap(), Serial::from(84629409685));
218    }
219}