risingwave_stream/task/barrier_worker/
managed_state.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::LazyCell;
16use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
17use std::fmt::{Debug, Display, Formatter};
18use std::future::{Future, pending, poll_fn};
19use std::mem::replace;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use std::time::{Duration, Instant};
23
24use anyhow::anyhow;
25use futures::FutureExt;
26use futures::stream::FuturesOrdered;
27use prometheus::HistogramTimer;
28use risingwave_common::catalog::{DatabaseId, TableId};
29use risingwave_common::util::epoch::EpochPair;
30use risingwave_pb::stream_plan::barrier::BarrierKind;
31use risingwave_pb::stream_service::barrier_complete_response::PbCdcTableBackfillProgress;
32use risingwave_storage::StateStoreImpl;
33use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
34use tokio::sync::{mpsc, oneshot};
35use tokio::task::JoinHandle;
36
37use crate::error::{StreamError, StreamResult};
38use crate::executor::Barrier;
39use crate::executor::monitor::StreamingMetrics;
40use crate::task::progress::BackfillState;
41use crate::task::{
42    ActorId, LocalBarrierEvent, LocalBarrierManager, NewOutputRequest, PartialGraphId,
43    StreamActorManager, UpDownActorIds,
44};
45
46struct IssuedState {
47    /// Actor ids remaining to be collected.
48    pub remaining_actors: BTreeSet<ActorId>,
49
50    pub barrier_inflight_latency: HistogramTimer,
51}
52
53impl Debug for IssuedState {
54    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("IssuedState")
56            .field("remaining_actors", &self.remaining_actors)
57            .finish()
58    }
59}
60
61/// The state machine of local barrier manager.
62#[derive(Debug)]
63enum ManagedBarrierStateInner {
64    /// Meta service has issued a `send_barrier` request. We're collecting barriers now.
65    Issued(IssuedState),
66
67    /// The barrier has been collected by all remaining actors
68    AllCollected {
69        create_mview_progress: Vec<PbCreateMviewProgress>,
70        load_finished_source_ids: Vec<u32>,
71        cdc_table_backfill_progress: Vec<PbCdcTableBackfillProgress>,
72        truncate_tables: Vec<u32>,
73        refresh_finished_tables: Vec<u32>,
74    },
75}
76
77#[derive(Debug)]
78struct BarrierState {
79    barrier: Barrier,
80    /// Only be `Some(_)` when `barrier.kind` is `Checkpoint`
81    table_ids: Option<HashSet<TableId>>,
82    inner: ManagedBarrierStateInner,
83}
84
85use risingwave_common::must_match;
86use risingwave_pb::stream_plan::SubscriptionUpstreamInfo;
87use risingwave_pb::stream_service::InjectBarrierRequest;
88use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress;
89use risingwave_pb::stream_service::streaming_control_stream_request::{
90    DatabaseInitialPartialGraph, InitialPartialGraph,
91};
92
93use crate::executor::exchange::permit;
94use crate::executor::exchange::permit::channel_from_config;
95use crate::task::barrier_worker::ScoredStreamError;
96use crate::task::barrier_worker::await_epoch_completed_future::AwaitEpochCompletedFuture;
97use crate::task::cdc_progress::CdcTableBackfillState;
98
99pub(super) struct ManagedBarrierStateDebugInfo<'a> {
100    running_actors: BTreeSet<ActorId>,
101    graph_states: &'a HashMap<PartialGraphId, PartialGraphManagedBarrierState>,
102}
103
104impl Display for ManagedBarrierStateDebugInfo<'_> {
105    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
106        write!(f, "running_actors: ")?;
107        for actor_id in &self.running_actors {
108            write!(f, "{}, ", actor_id)?;
109        }
110        for (partial_graph_id, graph_states) in self.graph_states {
111            writeln!(f, "--- Partial Group {}", partial_graph_id.0)?;
112            write!(f, "{}", graph_states)?;
113        }
114        Ok(())
115    }
116}
117
118impl Display for &'_ PartialGraphManagedBarrierState {
119    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120        let mut prev_epoch = 0u64;
121        for (epoch, barrier_state) in &self.epoch_barrier_state_map {
122            write!(f, "> Epoch {}: ", epoch)?;
123            match &barrier_state.inner {
124                ManagedBarrierStateInner::Issued(state) => {
125                    write!(
126                        f,
127                        "Issued [{:?}]. Remaining actors: [",
128                        barrier_state.barrier.kind
129                    )?;
130                    let mut is_prev_epoch_issued = false;
131                    if prev_epoch != 0 {
132                        let bs = &self.epoch_barrier_state_map[&prev_epoch];
133                        if let ManagedBarrierStateInner::Issued(IssuedState {
134                            remaining_actors: remaining_actors_prev,
135                            ..
136                        }) = &bs.inner
137                        {
138                            // Only show the actors that are not in the previous epoch.
139                            is_prev_epoch_issued = true;
140                            let mut duplicates = 0usize;
141                            for actor_id in &state.remaining_actors {
142                                if !remaining_actors_prev.contains(actor_id) {
143                                    write!(f, "{}, ", actor_id)?;
144                                } else {
145                                    duplicates += 1;
146                                }
147                            }
148                            if duplicates > 0 {
149                                write!(f, "...and {} actors in prev epoch", duplicates)?;
150                            }
151                        }
152                    }
153                    if !is_prev_epoch_issued {
154                        for actor_id in &state.remaining_actors {
155                            write!(f, "{}, ", actor_id)?;
156                        }
157                    }
158                    write!(f, "]")?;
159                }
160                ManagedBarrierStateInner::AllCollected { .. } => {
161                    write!(f, "AllCollected")?;
162                }
163            }
164            prev_epoch = *epoch;
165            writeln!(f)?;
166        }
167
168        if !self.create_mview_progress.is_empty() {
169            writeln!(f, "Create MView Progress:")?;
170            for (epoch, progress) in &self.create_mview_progress {
171                write!(f, "> Epoch {}:", epoch)?;
172                for (actor_id, state) in progress {
173                    write!(f, ">> Actor {}: {}, ", actor_id, state)?;
174                }
175            }
176        }
177
178        Ok(())
179    }
180}
181
182enum InflightActorStatus {
183    /// The actor has been issued some barriers, but has not collected the first barrier
184    IssuedFirst(Vec<Barrier>),
185    /// The actor has been issued some barriers, and has collected the first barrier
186    Running(u64),
187}
188
189impl InflightActorStatus {
190    fn max_issued_epoch(&self) -> u64 {
191        match self {
192            InflightActorStatus::Running(epoch) => *epoch,
193            InflightActorStatus::IssuedFirst(issued_barriers) => {
194                issued_barriers.last().expect("non-empty").epoch.prev
195            }
196        }
197    }
198}
199
200pub(crate) struct InflightActorState {
201    actor_id: ActorId,
202    barrier_senders: Vec<mpsc::UnboundedSender<Barrier>>,
203    /// `prev_epoch` -> partial graph id
204    pub(crate) inflight_barriers: BTreeMap<u64, PartialGraphId>,
205    status: InflightActorStatus,
206    /// Whether the actor has been issued a stop barrier
207    is_stopping: bool,
208
209    new_output_request_tx: UnboundedSender<(ActorId, NewOutputRequest)>,
210    join_handle: JoinHandle<()>,
211    monitor_task_handle: Option<JoinHandle<()>>,
212}
213
214impl InflightActorState {
215    pub(super) fn start(
216        actor_id: ActorId,
217        initial_partial_graph_id: PartialGraphId,
218        initial_barrier: &Barrier,
219        new_output_request_tx: UnboundedSender<(ActorId, NewOutputRequest)>,
220        join_handle: JoinHandle<()>,
221        monitor_task_handle: Option<JoinHandle<()>>,
222    ) -> Self {
223        Self {
224            actor_id,
225            barrier_senders: vec![],
226            inflight_barriers: BTreeMap::from_iter([(
227                initial_barrier.epoch.prev,
228                initial_partial_graph_id,
229            )]),
230            status: InflightActorStatus::IssuedFirst(vec![initial_barrier.clone()]),
231            is_stopping: false,
232            new_output_request_tx,
233            join_handle,
234            monitor_task_handle,
235        }
236    }
237
238    pub(super) fn issue_barrier(
239        &mut self,
240        partial_graph_id: PartialGraphId,
241        barrier: &Barrier,
242        is_stop: bool,
243    ) -> StreamResult<()> {
244        assert!(barrier.epoch.prev > self.status.max_issued_epoch());
245
246        for barrier_sender in &self.barrier_senders {
247            barrier_sender.send(barrier.clone()).map_err(|_| {
248                StreamError::barrier_send(
249                    barrier.clone(),
250                    self.actor_id,
251                    "failed to send to registered sender",
252                )
253            })?;
254        }
255
256        assert!(
257            self.inflight_barriers
258                .insert(barrier.epoch.prev, partial_graph_id)
259                .is_none()
260        );
261
262        match &mut self.status {
263            InflightActorStatus::IssuedFirst(pending_barriers) => {
264                pending_barriers.push(barrier.clone());
265            }
266            InflightActorStatus::Running(prev_epoch) => {
267                *prev_epoch = barrier.epoch.prev;
268            }
269        };
270
271        if is_stop {
272            assert!(!self.is_stopping, "stopped actor should not issue barrier");
273            self.is_stopping = true;
274        }
275        Ok(())
276    }
277
278    pub(super) fn collect(&mut self, epoch: EpochPair) -> (PartialGraphId, bool) {
279        let (prev_epoch, prev_partial_graph_id) =
280            self.inflight_barriers.pop_first().expect("should exist");
281        assert_eq!(prev_epoch, epoch.prev);
282        match &self.status {
283            InflightActorStatus::IssuedFirst(pending_barriers) => {
284                assert_eq!(
285                    prev_epoch,
286                    pending_barriers.first().expect("non-empty").epoch.prev
287                );
288                self.status = InflightActorStatus::Running(
289                    pending_barriers.last().expect("non-empty").epoch.prev,
290                );
291            }
292            InflightActorStatus::Running(_) => {}
293        }
294        (
295            prev_partial_graph_id,
296            self.inflight_barriers.is_empty() && self.is_stopping,
297        )
298    }
299}
300
301/// Part of [`DatabaseManagedBarrierState`]
302pub(crate) struct PartialGraphManagedBarrierState {
303    /// Record barrier state for each epoch of concurrent checkpoints.
304    ///
305    /// The key is `prev_epoch`, and the first value is `curr_epoch`
306    epoch_barrier_state_map: BTreeMap<u64, BarrierState>,
307
308    prev_barrier_table_ids: Option<(EpochPair, HashSet<TableId>)>,
309
310    mv_depended_subscriptions: HashMap<TableId, HashSet<u32>>,
311
312    /// Record the progress updates of creating mviews for each epoch of concurrent checkpoints.
313    ///
314    /// The process of progress reporting is as follows:
315    /// 1. updated by [`crate::task::barrier_manager::CreateMviewProgressReporter::update`]
316    /// 2. converted to [`ManagedBarrierStateInner`] in [`Self::may_have_collected_all`]
317    /// 3. handled by [`Self::pop_barrier_to_complete`]
318    /// 4. put in [`crate::task::barrier_worker::BarrierCompleteResult`] and reported to meta.
319    pub(crate) create_mview_progress: HashMap<u64, HashMap<ActorId, BackfillState>>,
320
321    /// Record the source load finished reports for each epoch of concurrent checkpoints.
322    /// Used for refreshable batch source.
323    pub(crate) load_finished_source_ids: HashMap<u64, HashSet<u32>>,
324
325    pub(crate) cdc_table_backfill_progress: HashMap<u64, HashMap<ActorId, CdcTableBackfillState>>,
326
327    /// Record the tables to truncate for each epoch of concurrent checkpoints.
328    pub(crate) truncate_tables: HashMap<u64, HashSet<u32>>,
329    /// Record the tables that have finished refresh for each epoch of concurrent checkpoints.
330    /// Used for materialized view refresh completion reporting.
331    pub(crate) refresh_finished_tables: HashMap<u64, HashSet<u32>>,
332
333    state_store: StateStoreImpl,
334
335    streaming_metrics: Arc<StreamingMetrics>,
336}
337
338impl PartialGraphManagedBarrierState {
339    pub(super) fn new(actor_manager: &StreamActorManager) -> Self {
340        Self::new_inner(
341            actor_manager.env.state_store(),
342            actor_manager.streaming_metrics.clone(),
343        )
344    }
345
346    fn new_inner(state_store: StateStoreImpl, streaming_metrics: Arc<StreamingMetrics>) -> Self {
347        Self {
348            epoch_barrier_state_map: Default::default(),
349            prev_barrier_table_ids: None,
350            mv_depended_subscriptions: Default::default(),
351            create_mview_progress: Default::default(),
352            load_finished_source_ids: Default::default(),
353            cdc_table_backfill_progress: Default::default(),
354            truncate_tables: Default::default(),
355            refresh_finished_tables: Default::default(),
356            state_store,
357            streaming_metrics,
358        }
359    }
360
361    #[cfg(test)]
362    pub(crate) fn for_test() -> Self {
363        Self::new_inner(
364            StateStoreImpl::for_test(),
365            Arc::new(StreamingMetrics::unused()),
366        )
367    }
368
369    pub(super) fn is_empty(&self) -> bool {
370        self.epoch_barrier_state_map.is_empty()
371    }
372}
373
374pub(crate) struct SuspendedDatabaseState {
375    pub(super) suspend_time: Instant,
376    inner: DatabaseManagedBarrierState,
377    failure: Option<(Option<ActorId>, StreamError)>,
378}
379
380impl SuspendedDatabaseState {
381    fn new(
382        state: DatabaseManagedBarrierState,
383        failure: Option<(Option<ActorId>, StreamError)>,
384        _completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>, /* discard the completing futures */
385    ) -> Self {
386        Self {
387            suspend_time: Instant::now(),
388            inner: state,
389            failure,
390        }
391    }
392
393    async fn reset(mut self) -> ResetDatabaseOutput {
394        let root_err = self.inner.try_find_root_actor_failure(self.failure).await;
395        self.inner.abort_and_wait_actors().await;
396        if let Some(hummock) = self.inner.actor_manager.env.state_store().as_hummock() {
397            hummock.clear_tables(self.inner.table_ids).await;
398        }
399        ResetDatabaseOutput { root_err }
400    }
401}
402
403pub(crate) struct ResettingDatabaseState {
404    join_handle: JoinHandle<ResetDatabaseOutput>,
405    reset_request_id: u32,
406}
407
408pub(crate) struct ResetDatabaseOutput {
409    pub(crate) root_err: Option<ScoredStreamError>,
410}
411
412pub(crate) enum DatabaseStatus {
413    ReceivedExchangeRequest(
414        Vec<(
415            UpDownActorIds,
416            oneshot::Sender<StreamResult<permit::Receiver>>,
417        )>,
418    ),
419    Running(DatabaseManagedBarrierState),
420    Suspended(SuspendedDatabaseState),
421    Resetting(ResettingDatabaseState),
422    /// temporary place holder
423    Unspecified,
424}
425
426impl DatabaseStatus {
427    pub(crate) async fn abort(&mut self) {
428        match self {
429            DatabaseStatus::ReceivedExchangeRequest(pending_requests) => {
430                for (_, sender) in pending_requests.drain(..) {
431                    let _ = sender.send(Err(anyhow!("database aborted").into()));
432                }
433            }
434            DatabaseStatus::Running(state) => {
435                state.abort_and_wait_actors().await;
436            }
437            DatabaseStatus::Suspended(SuspendedDatabaseState { inner: state, .. }) => {
438                state.abort_and_wait_actors().await;
439            }
440            DatabaseStatus::Resetting(state) => {
441                (&mut state.join_handle)
442                    .await
443                    .expect("failed to join reset database join handle");
444            }
445            DatabaseStatus::Unspecified => {
446                unreachable!()
447            }
448        }
449    }
450
451    pub(crate) fn state_for_request(&mut self) -> Option<&mut DatabaseManagedBarrierState> {
452        match self {
453            DatabaseStatus::ReceivedExchangeRequest(_) => {
454                unreachable!("should not handle request")
455            }
456            DatabaseStatus::Running(state) => Some(state),
457            DatabaseStatus::Suspended(_) => None,
458            DatabaseStatus::Resetting(_) => {
459                unreachable!("should not receive further request during cleaning")
460            }
461            DatabaseStatus::Unspecified => {
462                unreachable!()
463            }
464        }
465    }
466
467    pub(super) fn poll_next_event(
468        &mut self,
469        cx: &mut Context<'_>,
470    ) -> Poll<ManagedBarrierStateEvent> {
471        match self {
472            DatabaseStatus::ReceivedExchangeRequest(_) => Poll::Pending,
473            DatabaseStatus::Running(state) => state.poll_next_event(cx),
474            DatabaseStatus::Suspended(_) => Poll::Pending,
475            DatabaseStatus::Resetting(state) => state.join_handle.poll_unpin(cx).map(|result| {
476                let output = result.expect("should be able to join");
477                ManagedBarrierStateEvent::DatabaseReset(output, state.reset_request_id)
478            }),
479            DatabaseStatus::Unspecified => {
480                unreachable!()
481            }
482        }
483    }
484
485    pub(super) fn suspend(
486        &mut self,
487        failed_actor: Option<ActorId>,
488        err: StreamError,
489        completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>,
490    ) {
491        let state = must_match!(replace(self, DatabaseStatus::Unspecified), DatabaseStatus::Running(state) => state);
492        *self = DatabaseStatus::Suspended(SuspendedDatabaseState::new(
493            state,
494            Some((failed_actor, err)),
495            completing_futures,
496        ));
497    }
498
499    pub(super) fn start_reset(
500        &mut self,
501        database_id: DatabaseId,
502        completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>,
503        reset_request_id: u32,
504    ) {
505        let join_handle = match replace(self, DatabaseStatus::Unspecified) {
506            DatabaseStatus::ReceivedExchangeRequest(pending_requests) => {
507                for (_, sender) in pending_requests {
508                    let _ = sender.send(Err(anyhow!("database reset").into()));
509                }
510                tokio::spawn(async move { ResetDatabaseOutput { root_err: None } })
511            }
512            DatabaseStatus::Running(state) => {
513                assert_eq!(database_id, state.database_id);
514                info!(
515                    database_id = database_id.database_id,
516                    reset_request_id, "start database reset from Running"
517                );
518                tokio::spawn(SuspendedDatabaseState::new(state, None, completing_futures).reset())
519            }
520            DatabaseStatus::Suspended(state) => {
521                assert!(
522                    completing_futures.is_none(),
523                    "should have been clear when suspended"
524                );
525                assert_eq!(database_id, state.inner.database_id);
526                info!(
527                    database_id = database_id.database_id,
528                    reset_request_id,
529                    suspend_elapsed = ?state.suspend_time.elapsed(),
530                    "start database reset after suspended"
531                );
532                tokio::spawn(state.reset())
533            }
534            DatabaseStatus::Resetting(state) => {
535                let prev_request_id = state.reset_request_id;
536                info!(
537                    database_id = database_id.database_id,
538                    reset_request_id, prev_request_id, "receive duplicate reset request"
539                );
540                assert!(reset_request_id > prev_request_id);
541                state.join_handle
542            }
543            DatabaseStatus::Unspecified => {
544                unreachable!()
545            }
546        };
547        *self = DatabaseStatus::Resetting(ResettingDatabaseState {
548            join_handle,
549            reset_request_id,
550        });
551    }
552}
553
554pub(crate) struct ManagedBarrierState {
555    pub(crate) databases: HashMap<DatabaseId, DatabaseStatus>,
556}
557
558pub(super) enum ManagedBarrierStateEvent {
559    BarrierCollected {
560        partial_graph_id: PartialGraphId,
561        barrier: Barrier,
562    },
563    ActorError {
564        actor_id: ActorId,
565        err: StreamError,
566    },
567    DatabaseReset(ResetDatabaseOutput, u32),
568}
569
570impl ManagedBarrierState {
571    pub(super) fn new(
572        actor_manager: Arc<StreamActorManager>,
573        initial_partial_graphs: Vec<DatabaseInitialPartialGraph>,
574        term_id: String,
575    ) -> Self {
576        let mut databases = HashMap::new();
577        for database in initial_partial_graphs {
578            let database_id = DatabaseId::new(database.database_id);
579            assert!(!databases.contains_key(&database_id));
580            let state = DatabaseManagedBarrierState::new(
581                database_id,
582                term_id.clone(),
583                actor_manager.clone(),
584                database.graphs,
585            );
586            databases.insert(database_id, DatabaseStatus::Running(state));
587        }
588
589        Self { databases }
590    }
591
592    pub(super) fn next_event(
593        &mut self,
594    ) -> impl Future<Output = (DatabaseId, ManagedBarrierStateEvent)> + '_ {
595        poll_fn(|cx| {
596            for (database_id, database) in &mut self.databases {
597                if let Poll::Ready(event) = database.poll_next_event(cx) {
598                    return Poll::Ready((*database_id, event));
599                }
600            }
601            Poll::Pending
602        })
603    }
604}
605
606/// Per-database barrier state manager. Handles barriers for one specific database.
607/// Part of [`ManagedBarrierState`] in [`super::LocalBarrierWorker`].
608///
609/// See [`crate::task`] for architecture overview.
610pub(crate) struct DatabaseManagedBarrierState {
611    database_id: DatabaseId,
612    pub(crate) actor_states: HashMap<ActorId, InflightActorState>,
613    pub(super) actor_pending_new_output_requests:
614        HashMap<ActorId, Vec<(ActorId, NewOutputRequest)>>,
615
616    pub(crate) graph_states: HashMap<PartialGraphId, PartialGraphManagedBarrierState>,
617
618    table_ids: HashSet<TableId>,
619
620    actor_manager: Arc<StreamActorManager>,
621
622    pub(super) local_barrier_manager: LocalBarrierManager,
623
624    barrier_event_rx: UnboundedReceiver<LocalBarrierEvent>,
625    pub(super) actor_failure_rx: UnboundedReceiver<(ActorId, StreamError)>,
626}
627
628impl DatabaseManagedBarrierState {
629    /// Create a barrier manager state. This will be called only once.
630    pub(super) fn new(
631        database_id: DatabaseId,
632        term_id: String,
633        actor_manager: Arc<StreamActorManager>,
634        initial_partial_graphs: Vec<InitialPartialGraph>,
635    ) -> Self {
636        let (local_barrier_manager, barrier_event_rx, actor_failure_rx) =
637            LocalBarrierManager::new(database_id, term_id, actor_manager.env.clone());
638        Self {
639            database_id,
640            actor_states: Default::default(),
641            actor_pending_new_output_requests: Default::default(),
642            graph_states: initial_partial_graphs
643                .into_iter()
644                .map(|graph| {
645                    let mut state = PartialGraphManagedBarrierState::new(&actor_manager);
646                    state.add_subscriptions(graph.subscriptions);
647                    (PartialGraphId::new(graph.partial_graph_id), state)
648                })
649                .collect(),
650            table_ids: Default::default(),
651            actor_manager,
652            local_barrier_manager,
653            barrier_event_rx,
654            actor_failure_rx,
655        }
656    }
657
658    pub(super) fn to_debug_info(&self) -> ManagedBarrierStateDebugInfo<'_> {
659        ManagedBarrierStateDebugInfo {
660            running_actors: self.actor_states.keys().cloned().collect(),
661            graph_states: &self.graph_states,
662        }
663    }
664
665    async fn abort_and_wait_actors(&mut self) {
666        for (actor_id, state) in &self.actor_states {
667            tracing::debug!("force stopping actor {}", actor_id);
668            state.join_handle.abort();
669            if let Some(monitor_task_handle) = &state.monitor_task_handle {
670                monitor_task_handle.abort();
671            }
672        }
673
674        for (actor_id, state) in self.actor_states.drain() {
675            tracing::debug!("join actor {}", actor_id);
676            let result = state.join_handle.await;
677            assert!(result.is_ok() || result.unwrap_err().is_cancelled());
678        }
679    }
680}
681
682impl InflightActorState {
683    pub(super) fn register_barrier_sender(
684        &mut self,
685        tx: mpsc::UnboundedSender<Barrier>,
686    ) -> StreamResult<()> {
687        match &self.status {
688            InflightActorStatus::IssuedFirst(pending_barriers) => {
689                for barrier in pending_barriers {
690                    tx.send(barrier.clone()).map_err(|_| {
691                        StreamError::barrier_send(
692                            barrier.clone(),
693                            self.actor_id,
694                            "failed to send pending barriers to newly registered sender",
695                        )
696                    })?;
697                }
698                self.barrier_senders.push(tx);
699            }
700            InflightActorStatus::Running(_) => {
701                unreachable!("should not register barrier sender when entering Running status")
702            }
703        }
704        Ok(())
705    }
706}
707
708impl DatabaseManagedBarrierState {
709    pub(super) fn register_barrier_sender(
710        &mut self,
711        actor_id: ActorId,
712        tx: mpsc::UnboundedSender<Barrier>,
713    ) -> StreamResult<()> {
714        self.actor_states
715            .get_mut(&actor_id)
716            .expect("should exist")
717            .register_barrier_sender(tx)
718    }
719}
720
721impl PartialGraphManagedBarrierState {
722    pub(super) fn add_subscriptions(&mut self, subscriptions: Vec<SubscriptionUpstreamInfo>) {
723        for subscription_to_add in subscriptions {
724            if !self
725                .mv_depended_subscriptions
726                .entry(TableId::new(subscription_to_add.upstream_mv_table_id))
727                .or_default()
728                .insert(subscription_to_add.subscriber_id)
729            {
730                if cfg!(debug_assertions) {
731                    panic!("add an existing subscription: {:?}", subscription_to_add);
732                }
733                warn!(?subscription_to_add, "add an existing subscription");
734            }
735        }
736    }
737
738    pub(super) fn remove_subscriptions(&mut self, subscriptions: Vec<SubscriptionUpstreamInfo>) {
739        for subscription_to_remove in subscriptions {
740            let upstream_table_id = TableId::new(subscription_to_remove.upstream_mv_table_id);
741            let Some(subscribers) = self.mv_depended_subscriptions.get_mut(&upstream_table_id)
742            else {
743                if cfg!(debug_assertions) {
744                    panic!(
745                        "unable to find upstream mv table to remove: {:?}",
746                        subscription_to_remove
747                    );
748                }
749                warn!(
750                    ?subscription_to_remove,
751                    "unable to find upstream mv table to remove"
752                );
753                continue;
754            };
755            if !subscribers.remove(&subscription_to_remove.subscriber_id) {
756                if cfg!(debug_assertions) {
757                    panic!(
758                        "unable to find subscriber to remove: {:?}",
759                        subscription_to_remove
760                    );
761                }
762                warn!(
763                    ?subscription_to_remove,
764                    "unable to find subscriber to remove"
765                );
766            }
767            if subscribers.is_empty() {
768                self.mv_depended_subscriptions.remove(&upstream_table_id);
769            }
770        }
771    }
772}
773
774impl DatabaseManagedBarrierState {
775    pub(super) fn transform_to_issued(
776        &mut self,
777        barrier: &Barrier,
778        request: InjectBarrierRequest,
779    ) -> StreamResult<()> {
780        let partial_graph_id = PartialGraphId::new(request.partial_graph_id);
781        let actor_to_stop = barrier.all_stop_actors();
782        let is_stop_actor = |actor_id| {
783            actor_to_stop
784                .map(|actors| actors.contains(&actor_id))
785                .unwrap_or(false)
786        };
787        let graph_state = self
788            .graph_states
789            .get_mut(&partial_graph_id)
790            .expect("should exist");
791
792        graph_state.add_subscriptions(request.subscriptions_to_add);
793        graph_state.remove_subscriptions(request.subscriptions_to_remove);
794
795        let table_ids =
796            HashSet::from_iter(request.table_ids_to_sync.iter().cloned().map(TableId::new));
797        self.table_ids.extend(table_ids.iter().cloned());
798
799        graph_state.transform_to_issued(
800            barrier,
801            request.actor_ids_to_collect.iter().cloned(),
802            table_ids,
803        );
804
805        let mut new_actors = HashSet::new();
806        let subscriptions =
807            LazyCell::new(|| Arc::new(graph_state.mv_depended_subscriptions.clone()));
808        for (node, fragment_id, actor) in
809            request
810                .actors_to_build
811                .into_iter()
812                .flat_map(|fragment_actors| {
813                    let node = Arc::new(fragment_actors.node.unwrap());
814                    fragment_actors
815                        .actors
816                        .into_iter()
817                        .map(move |actor| (node.clone(), fragment_actors.fragment_id, actor))
818                })
819        {
820            let actor_id = actor.actor_id;
821            assert!(!is_stop_actor(actor_id));
822            assert!(new_actors.insert(actor_id));
823            assert!(request.actor_ids_to_collect.contains(&actor_id));
824            let (new_output_request_tx, new_output_request_rx) = unbounded_channel();
825            if let Some(pending_requests) = self.actor_pending_new_output_requests.remove(&actor_id)
826            {
827                for request in pending_requests {
828                    let _ = new_output_request_tx.send(request);
829                }
830            }
831            let (join_handle, monitor_join_handle) = self.actor_manager.spawn_actor(
832                actor,
833                fragment_id,
834                node,
835                (*subscriptions).clone(),
836                self.local_barrier_manager.clone(),
837                new_output_request_rx,
838            );
839            assert!(
840                self.actor_states
841                    .try_insert(
842                        actor_id,
843                        InflightActorState::start(
844                            actor_id,
845                            partial_graph_id,
846                            barrier,
847                            new_output_request_tx,
848                            join_handle,
849                            monitor_join_handle
850                        )
851                    )
852                    .is_ok()
853            );
854        }
855
856        // Spawn a trivial join handle to be compatible with the unit test. In the unit tests that involve local barrier manager,
857        // actors are spawned in the local test logic, but we assume that there is an entry for each spawned actor in ·actor_states`,
858        // so under cfg!(test) we add a dummy entry for each new actor.
859        if cfg!(test) {
860            for actor_id in &request.actor_ids_to_collect {
861                if !self.actor_states.contains_key(actor_id) {
862                    let (tx, rx) = unbounded_channel();
863                    let join_handle = self.actor_manager.runtime.spawn(async move {
864                        // The rx is spawned so that tx.send() will not fail.
865                        let _ = rx;
866                        pending().await
867                    });
868                    assert!(
869                        self.actor_states
870                            .try_insert(
871                                *actor_id,
872                                InflightActorState::start(
873                                    *actor_id,
874                                    partial_graph_id,
875                                    barrier,
876                                    tx,
877                                    join_handle,
878                                    None,
879                                )
880                            )
881                            .is_ok()
882                    );
883                    new_actors.insert(*actor_id);
884                }
885            }
886        }
887
888        // Note: it's important to issue barrier to actor after issuing to graph to ensure that
889        // we call `start_epoch` on the graph before the actors receive the barrier
890        for actor_id in &request.actor_ids_to_collect {
891            if new_actors.contains(actor_id) {
892                continue;
893            }
894            self.actor_states
895                .get_mut(actor_id)
896                .unwrap_or_else(|| {
897                    panic!(
898                        "should exist: {} {:?}",
899                        actor_id, request.actor_ids_to_collect
900                    );
901                })
902                .issue_barrier(partial_graph_id, barrier, is_stop_actor(*actor_id))?;
903        }
904
905        Ok(())
906    }
907
908    pub(super) fn new_actor_remote_output_request(
909        &mut self,
910        actor_id: ActorId,
911        upstream_actor_id: ActorId,
912        result_sender: oneshot::Sender<StreamResult<permit::Receiver>>,
913    ) {
914        let (tx, rx) = channel_from_config(self.local_barrier_manager.env.config());
915        self.new_actor_output_request(actor_id, upstream_actor_id, NewOutputRequest::Remote(tx));
916        let _ = result_sender.send(Ok(rx));
917    }
918
919    pub(super) fn new_actor_output_request(
920        &mut self,
921        actor_id: ActorId,
922        upstream_actor_id: ActorId,
923        request: NewOutputRequest,
924    ) {
925        if let Some(actor) = self.actor_states.get_mut(&upstream_actor_id) {
926            let _ = actor.new_output_request_tx.send((actor_id, request));
927        } else {
928            self.actor_pending_new_output_requests
929                .entry(upstream_actor_id)
930                .or_default()
931                .push((actor_id, request));
932        }
933    }
934
935    /// Handles [`LocalBarrierEvent`] from [`crate::task::barrier_manager::LocalBarrierManager`].
936    pub(super) fn poll_next_event(
937        &mut self,
938        cx: &mut Context<'_>,
939    ) -> Poll<ManagedBarrierStateEvent> {
940        if let Poll::Ready(option) = self.actor_failure_rx.poll_recv(cx) {
941            let (actor_id, err) = option.expect("non-empty when tx in local_barrier_manager");
942            return Poll::Ready(ManagedBarrierStateEvent::ActorError { actor_id, err });
943        }
944        // yield some pending collected epochs
945        for (partial_graph_id, graph_state) in &mut self.graph_states {
946            if let Some(barrier) = graph_state.may_have_collected_all() {
947                return Poll::Ready(ManagedBarrierStateEvent::BarrierCollected {
948                    partial_graph_id: *partial_graph_id,
949                    barrier,
950                });
951            }
952        }
953        while let Poll::Ready(event) = self.barrier_event_rx.poll_recv(cx) {
954            match event.expect("non-empty when tx in local_barrier_manager") {
955                LocalBarrierEvent::ReportActorCollected { actor_id, epoch } => {
956                    if let Some((partial_graph_id, barrier)) = self.collect(actor_id, epoch) {
957                        return Poll::Ready(ManagedBarrierStateEvent::BarrierCollected {
958                            partial_graph_id,
959                            barrier,
960                        });
961                    }
962                }
963                LocalBarrierEvent::ReportCreateProgress {
964                    epoch,
965                    actor,
966                    state,
967                } => {
968                    self.update_create_mview_progress(epoch, actor, state);
969                }
970                LocalBarrierEvent::ReportSourceLoadFinished {
971                    epoch,
972                    actor_id,
973                    table_id,
974                    associated_source_id,
975                } => {
976                    self.report_source_load_finished(
977                        epoch,
978                        actor_id,
979                        table_id,
980                        associated_source_id,
981                    );
982                }
983                LocalBarrierEvent::RefreshFinished {
984                    epoch,
985                    actor_id,
986                    table_id,
987                    staging_table_id,
988                } => {
989                    self.report_refresh_finished(epoch, actor_id, table_id, staging_table_id);
990                }
991                LocalBarrierEvent::RegisterBarrierSender {
992                    actor_id,
993                    barrier_sender,
994                } => {
995                    if let Err(err) = self.register_barrier_sender(actor_id, barrier_sender) {
996                        return Poll::Ready(ManagedBarrierStateEvent::ActorError { actor_id, err });
997                    }
998                }
999                LocalBarrierEvent::RegisterLocalUpstreamOutput {
1000                    actor_id,
1001                    upstream_actor_id,
1002                    tx,
1003                } => {
1004                    self.new_actor_output_request(
1005                        actor_id,
1006                        upstream_actor_id,
1007                        NewOutputRequest::Local(tx),
1008                    );
1009                }
1010                LocalBarrierEvent::ReportCdcTableBackfillProgress {
1011                    actor_id,
1012                    epoch,
1013                    state,
1014                } => {
1015                    self.update_cdc_table_backfill_progress(epoch, actor_id, state);
1016                }
1017            }
1018        }
1019
1020        debug_assert!(
1021            self.graph_states
1022                .values_mut()
1023                .all(|graph_state| graph_state.may_have_collected_all().is_none())
1024        );
1025        Poll::Pending
1026    }
1027}
1028
1029impl DatabaseManagedBarrierState {
1030    #[must_use]
1031    pub(super) fn collect(
1032        &mut self,
1033        actor_id: ActorId,
1034        epoch: EpochPair,
1035    ) -> Option<(PartialGraphId, Barrier)> {
1036        let (prev_partial_graph_id, is_finished) = self
1037            .actor_states
1038            .get_mut(&actor_id)
1039            .expect("should exist")
1040            .collect(epoch);
1041        if is_finished {
1042            let state = self.actor_states.remove(&actor_id).expect("should exist");
1043            if let Some(monitor_task_handle) = state.monitor_task_handle {
1044                monitor_task_handle.abort();
1045            }
1046        }
1047        let prev_graph_state = self
1048            .graph_states
1049            .get_mut(&prev_partial_graph_id)
1050            .expect("should exist");
1051        prev_graph_state.collect(actor_id, epoch);
1052        prev_graph_state
1053            .may_have_collected_all()
1054            .map(|barrier| (prev_partial_graph_id, barrier))
1055    }
1056
1057    #[allow(clippy::type_complexity)]
1058    pub(super) fn pop_barrier_to_complete(
1059        &mut self,
1060        partial_graph_id: PartialGraphId,
1061        prev_epoch: u64,
1062    ) -> BarrierToComplete {
1063        self.graph_states
1064            .get_mut(&partial_graph_id)
1065            .expect("should exist")
1066            .pop_barrier_to_complete(prev_epoch)
1067    }
1068
1069    /// Collect actor errors for a while and find the one that might be the root cause.
1070    ///
1071    /// Returns `None` if there's no actor error received.
1072    async fn try_find_root_actor_failure(
1073        &mut self,
1074        first_failure: Option<(Option<ActorId>, StreamError)>,
1075    ) -> Option<ScoredStreamError> {
1076        let mut later_errs = vec![];
1077        // fetch more actor errors within a timeout
1078        let _ = tokio::time::timeout(Duration::from_secs(3), async {
1079            let mut uncollected_actors: HashSet<_> = self.actor_states.keys().cloned().collect();
1080            if let Some((Some(failed_actor), _)) = &first_failure {
1081                uncollected_actors.remove(failed_actor);
1082            }
1083            while !uncollected_actors.is_empty()
1084                && let Some((actor_id, error)) = self.actor_failure_rx.recv().await
1085            {
1086                uncollected_actors.remove(&actor_id);
1087                later_errs.push(error);
1088            }
1089        })
1090        .await;
1091
1092        first_failure
1093            .into_iter()
1094            .map(|(_, err)| err)
1095            .chain(later_errs.into_iter())
1096            .map(|e| e.with_score())
1097            .max_by_key(|e| e.score)
1098    }
1099
1100    /// Report that a source has finished loading for a specific epoch
1101    pub(super) fn report_source_load_finished(
1102        &mut self,
1103        epoch: EpochPair,
1104        actor_id: ActorId,
1105        _table_id: u32,
1106        associated_source_id: u32,
1107    ) {
1108        // Find the correct partial graph state by matching the actor's partial graph id
1109        if let Some(actor_state) = self.actor_states.get(&actor_id)
1110            && let Some(partial_graph_id) = actor_state.inflight_barriers.get(&epoch.prev)
1111            && let Some(graph_state) = self.graph_states.get_mut(partial_graph_id)
1112        {
1113            graph_state
1114                .load_finished_source_ids
1115                .entry(epoch.curr)
1116                .or_default()
1117                .insert(associated_source_id);
1118        } else {
1119            warn!(
1120                ?epoch,
1121                actor_id, associated_source_id, "ignore source load finished"
1122            );
1123        }
1124    }
1125
1126    /// Report that a table has finished refreshing for a specific epoch
1127    pub(super) fn report_refresh_finished(
1128        &mut self,
1129        epoch: EpochPair,
1130        actor_id: ActorId,
1131        table_id: u32,
1132        staging_table_id: u32,
1133    ) {
1134        // Find the correct partial graph state by matching the actor's partial graph id
1135        let Some(actor_state) = self.actor_states.get(&actor_id) else {
1136            warn!(
1137                ?epoch,
1138                actor_id, table_id, "ignore refresh finished table: actor_state not found"
1139            );
1140            return;
1141        };
1142        let Some(partial_graph_id) = actor_state.inflight_barriers.get(&epoch.prev) else {
1143            let inflight_barriers = actor_state.inflight_barriers.keys().collect::<Vec<_>>();
1144            warn!(
1145                ?epoch,
1146                actor_id,
1147                table_id,
1148                ?inflight_barriers,
1149                "ignore refresh finished table: partial_graph_id not found in inflight_barriers"
1150            );
1151            return;
1152        };
1153        let Some(graph_state) = self.graph_states.get_mut(partial_graph_id) else {
1154            warn!(
1155                ?epoch,
1156                actor_id, table_id, "ignore refresh finished table: graph_state not found"
1157            );
1158            return;
1159        };
1160        graph_state
1161            .refresh_finished_tables
1162            .entry(epoch.curr)
1163            .or_default()
1164            .insert(table_id);
1165        graph_state
1166            .truncate_tables
1167            .entry(epoch.curr)
1168            .or_default()
1169            .insert(staging_table_id);
1170    }
1171}
1172
1173impl PartialGraphManagedBarrierState {
1174    /// This method is called when barrier state is modified in either `Issued` or `Stashed`
1175    /// to transform the state to `AllCollected` and start state store `sync` when the barrier
1176    /// has been collected from all actors for an `Issued` barrier.
1177    fn may_have_collected_all(&mut self) -> Option<Barrier> {
1178        for barrier_state in self.epoch_barrier_state_map.values_mut() {
1179            match &barrier_state.inner {
1180                ManagedBarrierStateInner::Issued(IssuedState {
1181                    remaining_actors, ..
1182                }) if remaining_actors.is_empty() => {}
1183                ManagedBarrierStateInner::AllCollected { .. } => {
1184                    continue;
1185                }
1186                ManagedBarrierStateInner::Issued(_) => {
1187                    break;
1188                }
1189            }
1190
1191            self.streaming_metrics.barrier_manager_progress.inc();
1192
1193            let create_mview_progress = self
1194                .create_mview_progress
1195                .remove(&barrier_state.barrier.epoch.curr)
1196                .unwrap_or_default()
1197                .into_iter()
1198                .map(|(actor, state)| state.to_pb(actor))
1199                .collect();
1200
1201            let load_finished_source_ids = self
1202                .load_finished_source_ids
1203                .remove(&barrier_state.barrier.epoch.curr)
1204                .unwrap_or_default()
1205                .into_iter()
1206                .collect();
1207
1208            let cdc_table_backfill_progress = self
1209                .cdc_table_backfill_progress
1210                .remove(&barrier_state.barrier.epoch.curr)
1211                .unwrap_or_default()
1212                .into_iter()
1213                .map(|(actor, state)| state.to_pb(actor, barrier_state.barrier.epoch.curr))
1214                .collect();
1215
1216            let truncate_tables = self
1217                .truncate_tables
1218                .remove(&barrier_state.barrier.epoch.curr)
1219                .unwrap_or_default()
1220                .into_iter()
1221                .collect();
1222
1223            let refresh_finished_tables = self
1224                .refresh_finished_tables
1225                .remove(&barrier_state.barrier.epoch.curr)
1226                .unwrap_or_default()
1227                .into_iter()
1228                .collect();
1229            let prev_state = replace(
1230                &mut barrier_state.inner,
1231                ManagedBarrierStateInner::AllCollected {
1232                    create_mview_progress,
1233                    load_finished_source_ids,
1234                    truncate_tables,
1235                    refresh_finished_tables,
1236                    cdc_table_backfill_progress,
1237                },
1238            );
1239
1240            must_match!(prev_state, ManagedBarrierStateInner::Issued(IssuedState {
1241                barrier_inflight_latency: timer,
1242                ..
1243            }) => {
1244                timer.observe_duration();
1245            });
1246
1247            return Some(barrier_state.barrier.clone());
1248        }
1249        None
1250    }
1251
1252    fn pop_barrier_to_complete(&mut self, prev_epoch: u64) -> BarrierToComplete {
1253        let (popped_prev_epoch, barrier_state) = self
1254            .epoch_barrier_state_map
1255            .pop_first()
1256            .expect("should exist");
1257
1258        assert_eq!(prev_epoch, popped_prev_epoch);
1259
1260        let (
1261            create_mview_progress,
1262            load_finished_source_ids,
1263            cdc_table_backfill_progress,
1264            truncate_tables,
1265            refresh_finished_tables,
1266        ) = must_match!(barrier_state.inner, ManagedBarrierStateInner::AllCollected {
1267            create_mview_progress,
1268            load_finished_source_ids,
1269            truncate_tables,
1270            refresh_finished_tables,
1271            cdc_table_backfill_progress,
1272        } => {
1273            (create_mview_progress, load_finished_source_ids, cdc_table_backfill_progress, truncate_tables, refresh_finished_tables)
1274        });
1275        BarrierToComplete {
1276            barrier: barrier_state.barrier,
1277            table_ids: barrier_state.table_ids,
1278            create_mview_progress,
1279            load_finished_source_ids,
1280            truncate_tables,
1281            refresh_finished_tables,
1282            cdc_table_backfill_progress,
1283        }
1284    }
1285}
1286
1287pub(crate) struct BarrierToComplete {
1288    pub barrier: Barrier,
1289    pub table_ids: Option<HashSet<TableId>>,
1290    pub create_mview_progress: Vec<PbCreateMviewProgress>,
1291    pub load_finished_source_ids: Vec<u32>,
1292    pub truncate_tables: Vec<u32>,
1293    pub refresh_finished_tables: Vec<u32>,
1294    pub cdc_table_backfill_progress: Vec<PbCdcTableBackfillProgress>,
1295}
1296
1297impl PartialGraphManagedBarrierState {
1298    /// Collect a `barrier` from the actor with `actor_id`.
1299    pub(super) fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) {
1300        tracing::debug!(
1301            target: "events::stream::barrier::manager::collect",
1302            ?epoch, actor_id, state = ?self.epoch_barrier_state_map,
1303            "collect_barrier",
1304        );
1305
1306        match self.epoch_barrier_state_map.get_mut(&epoch.prev) {
1307            None => {
1308                // If the barrier's state is stashed, this occurs exclusively in scenarios where the barrier has not been
1309                // injected by the barrier manager, or the barrier message is blocked at the `RemoteInput` side waiting for injection.
1310                // Given these conditions, it's inconceivable for an actor to attempt collect at this point.
1311                panic!(
1312                    "cannot collect new actor barrier {:?} at current state: None",
1313                    epoch,
1314                )
1315            }
1316            Some(&mut BarrierState {
1317                ref barrier,
1318                inner:
1319                    ManagedBarrierStateInner::Issued(IssuedState {
1320                        ref mut remaining_actors,
1321                        ..
1322                    }),
1323                ..
1324            }) => {
1325                let exist = remaining_actors.remove(&actor_id);
1326                assert!(
1327                    exist,
1328                    "the actor doesn't exist. actor_id: {:?}, curr_epoch: {:?}",
1329                    actor_id, epoch.curr
1330                );
1331                assert_eq!(barrier.epoch.curr, epoch.curr);
1332            }
1333            Some(BarrierState { inner, .. }) => {
1334                panic!(
1335                    "cannot collect new actor barrier {:?} at current state: {:?}",
1336                    epoch, inner
1337                )
1338            }
1339        }
1340    }
1341
1342    /// When the meta service issues a `send_barrier` request, call this function to transform to
1343    /// `Issued` and start to collect or to notify.
1344    pub(super) fn transform_to_issued(
1345        &mut self,
1346        barrier: &Barrier,
1347        actor_ids_to_collect: impl IntoIterator<Item = ActorId>,
1348        table_ids: HashSet<TableId>,
1349    ) {
1350        let timer = self
1351            .streaming_metrics
1352            .barrier_inflight_latency
1353            .start_timer();
1354
1355        if let Some(hummock) = self.state_store.as_hummock() {
1356            hummock.start_epoch(barrier.epoch.curr, table_ids.clone());
1357        }
1358
1359        let table_ids = match barrier.kind {
1360            BarrierKind::Unspecified => {
1361                unreachable!()
1362            }
1363            BarrierKind::Initial => {
1364                assert!(
1365                    self.prev_barrier_table_ids.is_none(),
1366                    "non empty table_ids at initial barrier: {:?}",
1367                    self.prev_barrier_table_ids
1368                );
1369                info!(epoch = ?barrier.epoch, "initialize at Initial barrier");
1370                self.prev_barrier_table_ids = Some((barrier.epoch, table_ids));
1371                None
1372            }
1373            BarrierKind::Barrier => {
1374                if let Some((prev_epoch, prev_table_ids)) = self.prev_barrier_table_ids.as_mut() {
1375                    assert_eq!(prev_epoch.curr, barrier.epoch.prev);
1376                    assert_eq!(prev_table_ids, &table_ids);
1377                    *prev_epoch = barrier.epoch;
1378                } else {
1379                    info!(epoch = ?barrier.epoch, "initialize at non-checkpoint barrier");
1380                    self.prev_barrier_table_ids = Some((barrier.epoch, table_ids));
1381                }
1382                None
1383            }
1384            BarrierKind::Checkpoint => Some(
1385                if let Some((prev_epoch, prev_table_ids)) = self
1386                    .prev_barrier_table_ids
1387                    .replace((barrier.epoch, table_ids))
1388                    && prev_epoch.curr == barrier.epoch.prev
1389                {
1390                    prev_table_ids
1391                } else {
1392                    debug!(epoch = ?barrier.epoch, "reinitialize at Checkpoint barrier");
1393                    HashSet::new()
1394                },
1395            ),
1396        };
1397
1398        if let Some(&mut BarrierState { ref inner, .. }) =
1399            self.epoch_barrier_state_map.get_mut(&barrier.epoch.prev)
1400        {
1401            {
1402                panic!(
1403                    "barrier epochs{:?} state has already been `Issued`. Current state: {:?}",
1404                    barrier.epoch, inner
1405                );
1406            }
1407        };
1408
1409        self.epoch_barrier_state_map.insert(
1410            barrier.epoch.prev,
1411            BarrierState {
1412                barrier: barrier.clone(),
1413                inner: ManagedBarrierStateInner::Issued(IssuedState {
1414                    remaining_actors: BTreeSet::from_iter(actor_ids_to_collect),
1415                    barrier_inflight_latency: timer,
1416                }),
1417                table_ids,
1418            },
1419        );
1420    }
1421
1422    #[cfg(test)]
1423    async fn pop_next_completed_epoch(&mut self) -> u64 {
1424        if let Some(barrier) = self.may_have_collected_all() {
1425            self.pop_barrier_to_complete(barrier.epoch.prev);
1426            return barrier.epoch.prev;
1427        }
1428        pending().await
1429    }
1430}
1431
1432#[cfg(test)]
1433mod tests {
1434    use std::collections::HashSet;
1435
1436    use risingwave_common::util::epoch::test_epoch;
1437
1438    use crate::executor::Barrier;
1439    use crate::task::barrier_worker::managed_state::PartialGraphManagedBarrierState;
1440
1441    #[tokio::test]
1442    async fn test_managed_state_add_actor() {
1443        let mut managed_barrier_state = PartialGraphManagedBarrierState::for_test();
1444        let barrier1 = Barrier::new_test_barrier(test_epoch(1));
1445        let barrier2 = Barrier::new_test_barrier(test_epoch(2));
1446        let barrier3 = Barrier::new_test_barrier(test_epoch(3));
1447        let actor_ids_to_collect1 = HashSet::from([1, 2]);
1448        let actor_ids_to_collect2 = HashSet::from([1, 2]);
1449        let actor_ids_to_collect3 = HashSet::from([1, 2, 3]);
1450        managed_barrier_state.transform_to_issued(&barrier1, actor_ids_to_collect1, HashSet::new());
1451        managed_barrier_state.transform_to_issued(&barrier2, actor_ids_to_collect2, HashSet::new());
1452        managed_barrier_state.transform_to_issued(&barrier3, actor_ids_to_collect3, HashSet::new());
1453        managed_barrier_state.collect(1, barrier1.epoch);
1454        managed_barrier_state.collect(2, barrier1.epoch);
1455        assert_eq!(
1456            managed_barrier_state.pop_next_completed_epoch().await,
1457            test_epoch(0)
1458        );
1459        assert_eq!(
1460            managed_barrier_state
1461                .epoch_barrier_state_map
1462                .first_key_value()
1463                .unwrap()
1464                .0,
1465            &test_epoch(1)
1466        );
1467        managed_barrier_state.collect(1, barrier2.epoch);
1468        managed_barrier_state.collect(1, barrier3.epoch);
1469        managed_barrier_state.collect(2, barrier2.epoch);
1470        assert_eq!(
1471            managed_barrier_state.pop_next_completed_epoch().await,
1472            test_epoch(1)
1473        );
1474        assert_eq!(
1475            managed_barrier_state
1476                .epoch_barrier_state_map
1477                .first_key_value()
1478                .unwrap()
1479                .0,
1480            &test_epoch(2)
1481        );
1482        managed_barrier_state.collect(2, barrier3.epoch);
1483        managed_barrier_state.collect(3, barrier3.epoch);
1484        assert_eq!(
1485            managed_barrier_state.pop_next_completed_epoch().await,
1486            test_epoch(2)
1487        );
1488        assert!(managed_barrier_state.epoch_barrier_state_map.is_empty());
1489    }
1490
1491    #[tokio::test]
1492    async fn test_managed_state_stop_actor() {
1493        let mut managed_barrier_state = PartialGraphManagedBarrierState::for_test();
1494        let barrier1 = Barrier::new_test_barrier(test_epoch(1));
1495        let barrier2 = Barrier::new_test_barrier(test_epoch(2));
1496        let barrier3 = Barrier::new_test_barrier(test_epoch(3));
1497        let actor_ids_to_collect1 = HashSet::from([1, 2, 3, 4]);
1498        let actor_ids_to_collect2 = HashSet::from([1, 2, 3]);
1499        let actor_ids_to_collect3 = HashSet::from([1, 2]);
1500        managed_barrier_state.transform_to_issued(&barrier1, actor_ids_to_collect1, HashSet::new());
1501        managed_barrier_state.transform_to_issued(&barrier2, actor_ids_to_collect2, HashSet::new());
1502        managed_barrier_state.transform_to_issued(&barrier3, actor_ids_to_collect3, HashSet::new());
1503
1504        managed_barrier_state.collect(1, barrier1.epoch);
1505        managed_barrier_state.collect(1, barrier2.epoch);
1506        managed_barrier_state.collect(1, barrier3.epoch);
1507        managed_barrier_state.collect(2, barrier1.epoch);
1508        managed_barrier_state.collect(2, barrier2.epoch);
1509        managed_barrier_state.collect(2, barrier3.epoch);
1510        assert_eq!(
1511            managed_barrier_state
1512                .epoch_barrier_state_map
1513                .first_key_value()
1514                .unwrap()
1515                .0,
1516            &0
1517        );
1518        managed_barrier_state.collect(3, barrier1.epoch);
1519        managed_barrier_state.collect(3, barrier2.epoch);
1520        assert_eq!(
1521            managed_barrier_state
1522                .epoch_barrier_state_map
1523                .first_key_value()
1524                .unwrap()
1525                .0,
1526            &0
1527        );
1528        managed_barrier_state.collect(4, barrier1.epoch);
1529        assert_eq!(
1530            managed_barrier_state.pop_next_completed_epoch().await,
1531            test_epoch(0)
1532        );
1533        assert_eq!(
1534            managed_barrier_state.pop_next_completed_epoch().await,
1535            test_epoch(1)
1536        );
1537        assert_eq!(
1538            managed_barrier_state.pop_next_completed_epoch().await,
1539            test_epoch(2)
1540        );
1541        assert!(managed_barrier_state.epoch_barrier_state_map.is_empty());
1542    }
1543}