risingwave_stream/executor/mview/
refresh_progress_table.rs1use std::collections::{HashMap, HashSet};
22use std::sync::Arc;
23
24use risingwave_common::bitmap::Bitmap;
25use risingwave_common::hash::VirtualNode;
26use risingwave_common::row::{OwnedRow, Row};
27use risingwave_common::types::{DataType, ScalarImpl, ScalarRefImpl};
28use risingwave_common::util::epoch::EpochPair;
29use risingwave_storage::StateStore;
30
31use crate::common::table::state_table::StateTablePostCommit;
32use crate::executor::StreamExecutorResult;
33use crate::executor::prelude::StateTable;
34
35pub struct RefreshProgressTable<S: StateStore> {
41 pub state_table: StateTable<S>,
43 cache: HashMap<VirtualNode, RefreshProgressEntry>,
45 pk_len: usize,
47}
48
49#[derive(Debug, Clone, PartialEq)]
51pub struct RefreshProgressEntry {
52 pub vnode: VirtualNode,
53 pub current_pos: Option<OwnedRow>,
54 pub is_completed: bool,
55 pub processed_rows: u64,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61#[repr(i32)]
62pub enum ProgressRefreshStage {
63 Normal = 0,
64 Refreshing = 1,
65 Merging = 2,
66 Cleanup = 3,
67}
68
69impl From<i32> for ProgressRefreshStage {
70 fn from(value: i32) -> Self {
71 match value {
72 0 => ProgressRefreshStage::Normal,
73 1 => ProgressRefreshStage::Refreshing,
74 2 => ProgressRefreshStage::Merging,
75 3 => ProgressRefreshStage::Cleanup,
76 _ => unreachable!(),
77 }
78 }
79}
80
81impl<S: StateStore> RefreshProgressTable<S> {
82 pub fn new(state_table: StateTable<S>, pk_len: usize) -> Self {
84 Self {
85 state_table,
86 cache: HashMap::new(),
87 pk_len,
88 }
89 }
90
91 pub async fn recover(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
93 self.state_table.init_epoch(epoch).await?;
94
95 let mut loaded_count = 0;
97
98 for vnode in self.state_table.vnodes().iter_ones() {
99 let row = self
100 .state_table
101 .get_row(OwnedRow::new(vec![
102 VirtualNode::from_index(vnode).to_datum(),
103 ]))
104 .await?;
105 if row.is_some()
106 && let Some(entry) = self.parse_row_to_entry(&row, self.pk_len)
107 {
108 self.cache.insert(entry.vnode, entry);
109 loaded_count += 1;
110 }
111 }
112
113 tracing::debug!(
114 loaded_count,
115 "Loading existing progress entries during initialization"
116 );
117
118 tracing::info!(
119 loaded_entries = self.cache.len(),
120 "Initialized refresh progress table"
121 );
122
123 Ok(())
124 }
125
126 pub fn set_progress(
128 &mut self,
129 vnode: VirtualNode,
130 current_pos: Option<OwnedRow>,
131 is_completed: bool,
132 processed_rows: u64,
133 ) -> StreamExecutorResult<()> {
134 let entry = RefreshProgressEntry {
135 vnode,
136 current_pos,
137 is_completed,
138 processed_rows,
139 };
140
141 self.cache.insert(vnode, entry.clone());
143
144 let row = self.entry_to_row(&entry, self.pk_len);
146 self.state_table.insert(&row);
147
148 Ok(())
149 }
150
151 pub fn get_progress(&self, vnode: VirtualNode) -> Option<&RefreshProgressEntry> {
153 self.cache.get(&vnode)
154 }
155
156 pub fn get_all_progress(&self) -> &HashMap<VirtualNode, RefreshProgressEntry> {
158 &self.cache
159 }
160
161 pub fn get_completed_vnodes(&self) -> HashSet<VirtualNode> {
163 self.cache
164 .iter()
165 .filter(|(_, entry)| entry.is_completed)
166 .map(|(&vnode, _)| vnode)
167 .collect()
168 }
169
170 pub fn get_vnodes_in_stage(&self, _stage: ProgressRefreshStage) -> Vec<VirtualNode> {
173 tracing::warn!(
174 "get_vnodes_in_stage called on simplified progress table - stage info no longer stored"
175 );
176 Vec::new()
177 }
178
179 pub fn clear_progress(&mut self, vnode: VirtualNode) -> StreamExecutorResult<()> {
181 if let Some(entry) = self.cache.remove(&vnode) {
182 let row = self.entry_to_row(&entry, self.pk_len);
183 self.state_table.delete(&row);
184 }
185
186 Ok(())
187 }
188
189 pub fn clear_all_progress(&mut self) -> StreamExecutorResult<()> {
191 for vnode in self.cache.keys().copied().collect::<Vec<_>>() {
192 self.clear_progress(vnode)?;
193 }
194 Ok(())
195 }
196
197 pub fn get_total_processed_rows(&self) -> u64 {
199 self.cache.values().map(|entry| entry.processed_rows).sum()
200 }
201
202 pub fn get_progress_stats(&self) -> RefreshProgressStats {
204 let total_vnodes = self.cache.len();
205 let completed_vnodes = self.get_completed_vnodes().len();
206 let total_processed_rows = self.get_total_processed_rows();
207
208 RefreshProgressStats {
209 total_vnodes,
210 completed_vnodes,
211 total_processed_rows,
212 }
213 }
214
215 pub async fn commit(
217 &mut self,
218 epoch: EpochPair,
219 ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
220 self.state_table.commit(epoch).await
221 }
222
223 fn entry_to_row(&self, entry: &RefreshProgressEntry, pk_len: usize) -> OwnedRow {
226 let mut row_data = vec![entry.vnode.to_datum()];
227
228 if let Some(ref pos) = entry.current_pos {
230 row_data.extend(pos.iter().map(|d| d.map(|s| s.into_scalar_impl())));
231 } else {
232 for _ in 0..pk_len {
234 row_data.push(None);
235 }
236 }
237
238 row_data.push(Some(ScalarImpl::Bool(entry.is_completed)));
240 row_data.push(Some(ScalarImpl::Int64(entry.processed_rows as i64)));
241
242 OwnedRow::new(row_data)
243 }
244
245 fn parse_row_to_entry(&self, row: &impl Row, pk_len: usize) -> Option<RefreshProgressEntry> {
249 let datums = row.iter().collect::<Vec<_>>();
250 let expected_len = 1 + pk_len + 2; if datums.len() != expected_len {
253 tracing::warn!(
254 "Row length mismatch: got {}, expected {} (pk_len={}), row: {:?}",
255 datums.len(),
256 expected_len,
257 pk_len,
258 row,
259 );
260 return None;
261 }
262
263 let vnode = VirtualNode::from_index(match datums[0]? {
265 ScalarRefImpl::Int32(val) => val as usize,
266 _ => return None,
267 });
268
269 let current_pos = if pk_len > 0 {
271 let pos_datums: Vec<_> = datums[1..1 + pk_len]
272 .iter()
273 .map(|d| d.map(|s| s.into_scalar_impl()))
274 .collect();
275 if pos_datums.iter().all(|d| d.is_none()) {
277 None
278 } else {
279 Some(OwnedRow::new(pos_datums))
280 }
281 } else {
282 None
283 };
284
285 let is_completed = match datums[1 + pk_len]? {
287 ScalarRefImpl::Bool(val) => val,
288 _ => return None,
289 };
290
291 let processed_rows = match datums[1 + pk_len + 1]? {
293 ScalarRefImpl::Int64(val) => val as u64,
294 _ => return None,
295 };
296
297 Some(RefreshProgressEntry {
298 vnode,
299 current_pos,
300 is_completed,
301 processed_rows,
302 })
303 }
304
305 pub fn expected_schema(pk_data_types: &[DataType]) -> Vec<DataType> {
308 let mut schema = vec![DataType::Int32]; schema.extend(pk_data_types.iter().cloned()); schema.push(DataType::Boolean); schema.push(DataType::Int64); schema
313 }
314
315 pub fn column_names(pk_column_names: &[&str]) -> Vec<String> {
318 let mut names = vec!["vnode".to_owned()];
319 for pk_name in pk_column_names {
320 names.push(format!("pos_{}", pk_name));
321 }
322 names.push("is_completed".to_owned());
323 names.push("processed_rows".to_owned());
324 names
325 }
326
327 pub fn vnodes(&self) -> &Arc<Bitmap> {
328 self.state_table.vnodes()
329 }
330}
331
332#[derive(Debug, Clone)]
334pub struct RefreshProgressStats {
335 pub total_vnodes: usize,
336 pub completed_vnodes: usize,
337 pub total_processed_rows: u64,
338}
339
340impl RefreshProgressStats {
341 pub fn is_complete(&self) -> bool {
343 self.total_vnodes > 0 && self.completed_vnodes == self.total_vnodes
344 }
345}