risingwave_stream/executor/source/
state_table_handler.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15cfg_if::cfg_if! {
16    if #[cfg(test)] {
17        use risingwave_common::catalog::{DatabaseId, SchemaId};
18        use risingwave_pb::catalog::table::TableType;
19        use risingwave_pb::common::{PbColumnOrder, PbDirection, PbNullsAre, PbOrderType};
20        use risingwave_pb::data::data_type::TypeName;
21        use risingwave_pb::data::DataType;
22        use risingwave_pb::plan_common::{ColumnCatalog, ColumnDesc};
23    }
24}
25
26use std::ops::Deref;
27use std::sync::Arc;
28
29use risingwave_common::bitmap::Bitmap;
30use risingwave_common::row;
31use risingwave_common::row::{OwnedRow, Row};
32use risingwave_common::types::{JsonbVal, ScalarImpl, ScalarRef, ScalarRefImpl};
33use risingwave_common::util::epoch::EpochPair;
34use risingwave_connector::source::{SplitImpl, SplitMetaData};
35use risingwave_pb::catalog::PbTable;
36use risingwave_storage::StateStore;
37
38use crate::common::table::state_table::{StateTable, StateTablePostCommit};
39use crate::executor::StreamExecutorResult;
40
41pub struct SourceStateTableHandler<S: StateStore> {
42    state_table: StateTable<S>,
43}
44
45impl<S: StateStore> SourceStateTableHandler<S> {
46    /// Creates a state table with singleton distribution (only one vnode 0).
47    ///
48    /// Refer to `infer_internal_table_catalog` in `src/frontend/src/optimizer/plan_node/generic/source.rs` for more details.
49    pub async fn from_table_catalog(table_catalog: &PbTable, store: S) -> Self {
50        Self {
51            state_table: StateTable::from_table_catalog(table_catalog, store, None).await,
52        }
53    }
54
55    /// For [`super::FsFetchExecutor`], each actor accesses splits according to the `vnode` computed from `partition_id`.
56    pub async fn from_table_catalog_with_vnodes(
57        table_catalog: &PbTable,
58        store: S,
59        vnodes: Option<Arc<Bitmap>>,
60    ) -> Self {
61        Self {
62            state_table: StateTable::from_table_catalog(table_catalog, store, vnodes).await,
63        }
64    }
65
66    pub async fn init_epoch(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
67        self.state_table.init_epoch(epoch).await
68    }
69
70    fn str_to_scalar_ref(s: &str) -> ScalarRefImpl<'_> {
71        ScalarRefImpl::Utf8(s)
72    }
73
74    pub(crate) async fn get(&self, key: &str) -> StreamExecutorResult<Option<OwnedRow>> {
75        self.state_table
76            .get_row(row::once(Some(Self::str_to_scalar_ref(key))))
77            .await
78    }
79
80    pub async fn set(&mut self, key: &str, value: JsonbVal) -> StreamExecutorResult<()> {
81        let row = [
82            Some(Self::str_to_scalar_ref(key).into_scalar_impl()),
83            Some(ScalarImpl::Jsonb(value)),
84        ];
85        match self.get(key).await? {
86            Some(prev_row) => {
87                self.state_table.update(prev_row, row);
88            }
89            None => {
90                self.state_table.insert(row);
91            }
92        }
93        Ok(())
94    }
95
96    pub async fn delete(&mut self, key: &str) -> StreamExecutorResult<()> {
97        if let Some(prev_row) = self.get(key).await? {
98            self.state_table.delete(prev_row);
99        }
100
101        Ok(())
102    }
103
104    pub async fn set_states<SS>(&mut self, states: Vec<SS>) -> StreamExecutorResult<()>
105    where
106        SS: SplitMetaData,
107    {
108        for split_impl in states {
109            self.set(split_impl.id().deref(), split_impl.encode_to_json())
110                .await?;
111        }
112        Ok(())
113    }
114
115    pub async fn set_states_json(
116        &mut self,
117        states: impl IntoIterator<Item = (String, JsonbVal)>,
118    ) -> StreamExecutorResult<()> {
119        for (key, value) in states {
120            self.set(&key, value).await?;
121        }
122        Ok(())
123    }
124
125    pub async fn trim_state(&mut self, to_trim: &[SplitImpl]) -> StreamExecutorResult<()> {
126        for split in to_trim {
127            tracing::info!("trimming source state for split {}", split.id());
128            self.delete(&split.id()).await?;
129        }
130
131        Ok(())
132    }
133
134    pub async fn new_committed_reader(
135        &self,
136        epoch: EpochPair,
137    ) -> StreamExecutorResult<SourceStateTableCommittedReader<'_, S>> {
138        self.state_table
139            .try_wait_committed_epoch(epoch.prev)
140            .await?;
141        Ok(SourceStateTableCommittedReader { handle: self })
142    }
143
144    pub fn state_table(&self) -> &StateTable<S> {
145        &self.state_table
146    }
147
148    pub fn state_table_mut(&mut self) -> &mut StateTable<S> {
149        &mut self.state_table
150    }
151
152    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
153        self.state_table.try_flush().await
154    }
155
156    pub async fn commit_may_update_vnode_bitmap(
157        &mut self,
158        epoch: EpochPair,
159    ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
160        self.state_table.commit(epoch).await
161    }
162
163    pub async fn commit(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
164        self.state_table
165            .commit_assert_no_update_vnode_bitmap(epoch)
166            .await
167    }
168}
169
170pub struct SourceStateTableCommittedReader<'a, S: StateStore> {
171    handle: &'a SourceStateTableHandler<S>,
172}
173
174impl<S: StateStore> SourceStateTableCommittedReader<'_, S> {
175    pub async fn try_recover_from_state_store(
176        &self,
177        stream_source_split: &SplitImpl,
178    ) -> StreamExecutorResult<Option<SplitImpl>> {
179        Ok(match self.handle.get(&stream_source_split.id()).await? {
180            None => None,
181            Some(row) => match row.datum_at(1) {
182                Some(ScalarRefImpl::Jsonb(jsonb_ref)) => {
183                    Some(SplitImpl::restore_from_json(jsonb_ref.to_owned_scalar())?)
184                }
185                _ => unreachable!(),
186            },
187        })
188    }
189}
190
191/// align with schema defined in `LogicalSource::infer_internal_table_catalog`. The function is used
192/// for test purpose and should not be used in production.
193#[cfg(test)]
194pub fn default_source_internal_table(id: u32) -> PbTable {
195    let make_column = |column_type: TypeName, column_id: i32| -> ColumnCatalog {
196        ColumnCatalog {
197            column_desc: Some(ColumnDesc {
198                column_type: Some(DataType {
199                    type_name: column_type as i32,
200                    ..Default::default()
201                }),
202                column_id,
203                nullable: true,
204                ..Default::default()
205            }),
206            is_hidden: false,
207        }
208    };
209
210    let columns = vec![
211        make_column(TypeName::Varchar, 0),
212        make_column(TypeName::Jsonb, 1),
213    ];
214    PbTable {
215        id,
216        schema_id: SchemaId::placeholder().schema_id,
217        database_id: DatabaseId::placeholder().database_id,
218        name: String::new(),
219        columns,
220        table_type: TableType::Internal as i32,
221        value_indices: vec![0, 1],
222        pk: vec![PbColumnOrder {
223            column_index: 0,
224            order_type: Some(PbOrderType {
225                direction: PbDirection::Ascending as _,
226                nulls_are: PbNullsAre::Largest as _,
227            }),
228        }],
229        ..Default::default()
230    }
231}
232
233#[cfg(test)]
234pub(crate) mod tests {
235
236    use risingwave_common::types::Datum;
237    use risingwave_common::util::epoch::test_epoch;
238    use risingwave_connector::source::kafka::KafkaSplit;
239    use risingwave_storage::memory::MemoryStateStore;
240    use serde_json::Value;
241
242    use super::*;
243
244    #[tokio::test]
245    async fn test_from_table_catalog() {
246        let store = MemoryStateStore::new();
247        let mut state_table =
248            StateTable::from_table_catalog(&default_source_internal_table(0x2333), store, None)
249                .await;
250        let a: Arc<str> = String::from("a").into();
251        let a: Datum = Some(ScalarImpl::Utf8(a.as_ref().into()));
252        let b: JsonbVal = serde_json::from_str::<Value>("{\"k1\": \"v1\", \"k2\": 11}")
253            .unwrap()
254            .into();
255        let b: Datum = Some(ScalarImpl::Jsonb(b));
256
257        let init_epoch_num = test_epoch(1);
258        let init_epoch = EpochPair::new_test_epoch(init_epoch_num);
259        let next_epoch = EpochPair::new_test_epoch(init_epoch_num + test_epoch(1));
260
261        state_table.init_epoch(init_epoch).await.unwrap();
262        state_table.insert(OwnedRow::new(vec![a.clone(), b.clone()]));
263        state_table.commit_for_test(next_epoch).await.unwrap();
264
265        let a: Arc<str> = String::from("a").into();
266        let a: Datum = Some(ScalarImpl::Utf8(a.as_ref().into()));
267        let _resp = state_table.get_row(&OwnedRow::new(vec![a])).await.unwrap();
268    }
269
270    #[tokio::test]
271    async fn test_set_and_get() -> StreamExecutorResult<()> {
272        let store = MemoryStateStore::new();
273        let mut state_table_handler = SourceStateTableHandler::from_table_catalog(
274            &default_source_internal_table(0x2333),
275            store,
276        )
277        .await;
278        let split_impl = SplitImpl::Kafka(KafkaSplit::new(0, Some(0), None, "test".into()));
279        let serialized = split_impl.encode_to_bytes();
280        let serialized_json = split_impl.encode_to_json();
281
282        let epoch_1 = EpochPair::new_test_epoch(test_epoch(1));
283        let epoch_2 = EpochPair::new_test_epoch(test_epoch(2));
284        let epoch_3 = EpochPair::new_test_epoch(test_epoch(3));
285
286        state_table_handler.init_epoch(epoch_1).await?;
287        state_table_handler
288            .set_states(vec![split_impl.clone()])
289            .await?;
290        state_table_handler
291            .state_table
292            .commit_for_test(epoch_2)
293            .await?;
294
295        state_table_handler
296            .state_table
297            .commit_for_test(epoch_3)
298            .await?;
299
300        match state_table_handler
301            .new_committed_reader(epoch_3)
302            .await?
303            .try_recover_from_state_store(&split_impl)
304            .await?
305        {
306            Some(s) => {
307                assert_eq!(s.encode_to_bytes(), serialized);
308                assert_eq!(s.encode_to_json(), serialized_json);
309            }
310            None => unreachable!(),
311        }
312        Ok(())
313    }
314}