risingwave_stream/executor/source/
state_table_handler.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15cfg_if::cfg_if! {
16    if #[cfg(test)] {
17        use risingwave_common::catalog::{DatabaseId, SchemaId};
18        use risingwave_pb::catalog::table::TableType;
19        use risingwave_pb::common::{PbColumnOrder, PbDirection, PbNullsAre, PbOrderType};
20        use risingwave_pb::data::data_type::TypeName;
21        use risingwave_pb::data::DataType;
22        use risingwave_pb::plan_common::{ColumnCatalog, ColumnDesc};
23    }
24}
25
26use std::ops::Deref;
27use std::sync::Arc;
28
29use risingwave_common::bitmap::Bitmap;
30use risingwave_common::row;
31use risingwave_common::row::{OwnedRow, Row};
32use risingwave_common::types::{JsonbVal, ScalarImpl, ScalarRef, ScalarRefImpl};
33use risingwave_common::util::epoch::EpochPair;
34use risingwave_connector::source::{SplitImpl, SplitMetaData};
35use risingwave_pb::catalog::PbTable;
36use risingwave_storage::StateStore;
37
38use crate::common::table::state_table::{StateTable, StateTableBuilder, StateTablePostCommit};
39use crate::executor::StreamExecutorResult;
40
41pub struct SourceStateTableHandler<S: StateStore> {
42    state_table: StateTable<S>,
43}
44
45impl<S: StateStore> SourceStateTableHandler<S> {
46    /// Creates a state table with singleton distribution (only one vnode 0).
47    ///
48    /// Refer to `infer_internal_table_catalog` in `src/frontend/src/optimizer/plan_node/generic/source.rs` for more details.
49    pub async fn from_table_catalog(table_catalog: &PbTable, store: S) -> Self {
50        Self {
51            // Note: should not enable `preload_all_rows` for `StateTable` of source
52            // because it uses storage to synchronize different parallelisms, which is a special
53            // access pattern that in-mem state table has not supported yet.
54            state_table: StateTableBuilder::new(table_catalog, store, None)
55                .forbid_preload_all_rows()
56                .build()
57                .await,
58        }
59    }
60
61    /// For [`super::FsFetchExecutor`], each actor accesses splits according to the `vnode` computed from `partition_id`.
62    pub async fn from_table_catalog_with_vnodes(
63        table_catalog: &PbTable,
64        store: S,
65        vnodes: Option<Arc<Bitmap>>,
66    ) -> Self {
67        Self {
68            // Note: should not enable `preload_all_rows` for `StateTable` of source
69            // because it uses storage to synchronize different parallelisms, which is a special
70            // access pattern that in-mem state table has not supported yet.
71            state_table: StateTableBuilder::new(table_catalog, store, vnodes)
72                .forbid_preload_all_rows()
73                .build()
74                .await,
75        }
76    }
77
78    pub async fn init_epoch(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
79        self.state_table.init_epoch(epoch).await
80    }
81
82    fn str_to_scalar_ref(s: &str) -> ScalarRefImpl<'_> {
83        ScalarRefImpl::Utf8(s)
84    }
85
86    pub(crate) async fn get(&self, key: &str) -> StreamExecutorResult<Option<OwnedRow>> {
87        self.state_table
88            .get_row(row::once(Some(Self::str_to_scalar_ref(key))))
89            .await
90    }
91
92    pub async fn set(&mut self, key: &str, value: JsonbVal) -> StreamExecutorResult<()> {
93        let row = [
94            Some(Self::str_to_scalar_ref(key).into_scalar_impl()),
95            Some(ScalarImpl::Jsonb(value)),
96        ];
97        match self.get(key).await? {
98            Some(prev_row) => {
99                self.state_table.update(prev_row, row);
100            }
101            None => {
102                self.state_table.insert(row);
103            }
104        }
105        Ok(())
106    }
107
108    pub async fn delete(&mut self, key: &str) -> StreamExecutorResult<()> {
109        if let Some(prev_row) = self.get(key).await? {
110            self.state_table.delete(prev_row);
111        }
112
113        Ok(())
114    }
115
116    pub async fn set_states<SS>(&mut self, states: Vec<SS>) -> StreamExecutorResult<()>
117    where
118        SS: SplitMetaData,
119    {
120        for split_impl in states {
121            self.set(split_impl.id().deref(), split_impl.encode_to_json())
122                .await?;
123        }
124        Ok(())
125    }
126
127    pub async fn set_states_json(
128        &mut self,
129        states: impl IntoIterator<Item = (String, JsonbVal)>,
130    ) -> StreamExecutorResult<()> {
131        for (key, value) in states {
132            self.set(&key, value).await?;
133        }
134        Ok(())
135    }
136
137    pub async fn trim_state(&mut self, to_trim: &[SplitImpl]) -> StreamExecutorResult<()> {
138        for split in to_trim {
139            tracing::info!("trimming source state for split {}", split.id());
140            self.delete(&split.id()).await?;
141        }
142
143        Ok(())
144    }
145
146    pub async fn new_committed_reader(
147        &self,
148        epoch: EpochPair,
149    ) -> StreamExecutorResult<SourceStateTableCommittedReader<'_, S>> {
150        self.state_table
151            .try_wait_committed_epoch(epoch.prev)
152            .await?;
153        Ok(SourceStateTableCommittedReader { handle: self })
154    }
155
156    pub fn state_table(&self) -> &StateTable<S> {
157        &self.state_table
158    }
159
160    pub fn state_table_mut(&mut self) -> &mut StateTable<S> {
161        &mut self.state_table
162    }
163
164    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
165        self.state_table.try_flush().await
166    }
167
168    pub async fn commit_may_update_vnode_bitmap(
169        &mut self,
170        epoch: EpochPair,
171    ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
172        self.state_table.commit(epoch).await
173    }
174
175    pub async fn commit(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
176        self.state_table
177            .commit_assert_no_update_vnode_bitmap(epoch)
178            .await
179    }
180}
181
182pub struct SourceStateTableCommittedReader<'a, S: StateStore> {
183    handle: &'a SourceStateTableHandler<S>,
184}
185
186impl<S: StateStore> SourceStateTableCommittedReader<'_, S> {
187    pub async fn try_recover_from_state_store(
188        &self,
189        stream_source_split: &SplitImpl,
190    ) -> StreamExecutorResult<Option<SplitImpl>> {
191        Ok(match self.handle.get(&stream_source_split.id()).await? {
192            None => None,
193            Some(row) => match row.datum_at(1) {
194                Some(ScalarRefImpl::Jsonb(jsonb_ref)) => {
195                    Some(SplitImpl::restore_from_json(jsonb_ref.to_owned_scalar())?)
196                }
197                _ => unreachable!(),
198            },
199        })
200    }
201}
202
203/// align with schema defined in `LogicalSource::infer_internal_table_catalog`. The function is used
204/// for test purpose and should not be used in production.
205#[cfg(test)]
206pub fn default_source_internal_table(id: u32) -> PbTable {
207    let make_column = |column_type: TypeName, column_id: i32| -> ColumnCatalog {
208        ColumnCatalog {
209            column_desc: Some(ColumnDesc {
210                column_type: Some(DataType {
211                    type_name: column_type as i32,
212                    ..Default::default()
213                }),
214                column_id,
215                nullable: Some(true),
216                ..Default::default()
217            }),
218            is_hidden: false,
219        }
220    };
221
222    let columns = vec![
223        make_column(TypeName::Varchar, 0),
224        make_column(TypeName::Jsonb, 1),
225    ];
226    PbTable {
227        id,
228        schema_id: SchemaId::placeholder().schema_id,
229        database_id: DatabaseId::placeholder().database_id,
230        name: String::new(),
231        columns,
232        table_type: TableType::Internal as i32,
233        value_indices: vec![0, 1],
234        pk: vec![PbColumnOrder {
235            column_index: 0,
236            order_type: Some(PbOrderType {
237                direction: PbDirection::Ascending as _,
238                nulls_are: PbNullsAre::Largest as _,
239            }),
240        }],
241        ..Default::default()
242    }
243}
244
245#[cfg(test)]
246pub(crate) mod tests {
247
248    use risingwave_common::types::Datum;
249    use risingwave_common::util::epoch::test_epoch;
250    use risingwave_connector::source::kafka::KafkaSplit;
251    use risingwave_storage::memory::MemoryStateStore;
252    use serde_json::Value;
253
254    use super::*;
255
256    #[tokio::test]
257    async fn test_from_table_catalog() {
258        let store = MemoryStateStore::new();
259        let mut state_table =
260            StateTable::from_table_catalog(&default_source_internal_table(0x2333), store, None)
261                .await;
262        let a: Arc<str> = String::from("a").into();
263        let a: Datum = Some(ScalarImpl::Utf8(a.as_ref().into()));
264        let b: JsonbVal = serde_json::from_str::<Value>("{\"k1\": \"v1\", \"k2\": 11}")
265            .unwrap()
266            .into();
267        let b: Datum = Some(ScalarImpl::Jsonb(b));
268
269        let init_epoch_num = test_epoch(1);
270        let init_epoch = EpochPair::new_test_epoch(init_epoch_num);
271        let next_epoch = EpochPair::new_test_epoch(init_epoch_num + test_epoch(1));
272
273        state_table.init_epoch(init_epoch).await.unwrap();
274        state_table.insert(OwnedRow::new(vec![a.clone(), b.clone()]));
275        state_table.commit_for_test(next_epoch).await.unwrap();
276
277        let a: Arc<str> = String::from("a").into();
278        let a: Datum = Some(ScalarImpl::Utf8(a.as_ref().into()));
279        let _resp = state_table.get_row(&OwnedRow::new(vec![a])).await.unwrap();
280    }
281
282    #[tokio::test]
283    async fn test_set_and_get() -> StreamExecutorResult<()> {
284        let store = MemoryStateStore::new();
285        let mut state_table_handler = SourceStateTableHandler::from_table_catalog(
286            &default_source_internal_table(0x2333),
287            store,
288        )
289        .await;
290        let split_impl = SplitImpl::Kafka(KafkaSplit::new(0, Some(0), None, "test".into()));
291        let serialized = split_impl.encode_to_bytes();
292        let serialized_json = split_impl.encode_to_json();
293
294        let epoch_1 = EpochPair::new_test_epoch(test_epoch(1));
295        let epoch_2 = EpochPair::new_test_epoch(test_epoch(2));
296        let epoch_3 = EpochPair::new_test_epoch(test_epoch(3));
297
298        state_table_handler.init_epoch(epoch_1).await?;
299        state_table_handler
300            .set_states(vec![split_impl.clone()])
301            .await?;
302        state_table_handler
303            .state_table
304            .commit_for_test(epoch_2)
305            .await?;
306
307        state_table_handler
308            .state_table
309            .commit_for_test(epoch_3)
310            .await?;
311
312        match state_table_handler
313            .new_committed_reader(epoch_3)
314            .await?
315            .try_recover_from_state_store(&split_impl)
316            .await?
317        {
318            Some(s) => {
319                assert_eq!(s.encode_to_bytes(), serialized);
320                assert_eq!(s.encode_to_json(), serialized_json);
321            }
322            None => unreachable!(),
323        }
324        Ok(())
325    }
326}