risingwave_stream/task/barrier_worker/
managed_state.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
16use std::fmt::{Debug, Display, Formatter};
17use std::future::{Future, pending, poll_fn};
18use std::mem::replace;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21use std::time::{Duration, Instant};
22
23use anyhow::anyhow;
24use futures::future::BoxFuture;
25use futures::stream::{FuturesOrdered, FuturesUnordered};
26use futures::{FutureExt, StreamExt};
27use prometheus::HistogramTimer;
28use risingwave_common::catalog::TableId;
29use risingwave_common::id::SourceId;
30use risingwave_common::util::epoch::EpochPair;
31use risingwave_pb::stream_plan::barrier::BarrierKind;
32use risingwave_pb::stream_service::barrier_complete_response::{
33    PbCdcSourceOffsetUpdated, PbCdcTableBackfillProgress, PbCreateMviewProgress,
34    PbListFinishedSource, PbLoadFinishedSource,
35};
36use risingwave_storage::StateStoreImpl;
37use tokio::sync::mpsc;
38use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
39use tokio::task::JoinHandle;
40
41use crate::error::{StreamError, StreamResult};
42use crate::executor::Barrier;
43use crate::executor::monitor::StreamingMetrics;
44use crate::task::progress::BackfillState;
45use crate::task::{
46    ActorId, LocalBarrierEvent, LocalBarrierManager, NewOutputRequest, PartialGraphId,
47    StreamActorManager, UpDownActorIds,
48};
49
50struct IssuedState {
51    /// Actor ids remaining to be collected.
52    pub remaining_actors: BTreeSet<ActorId>,
53
54    pub barrier_inflight_latency: HistogramTimer,
55}
56
57impl Debug for IssuedState {
58    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("IssuedState")
60            .field("remaining_actors", &self.remaining_actors)
61            .finish()
62    }
63}
64
65/// The state machine of local barrier manager.
66#[derive(Debug)]
67enum ManagedBarrierStateInner {
68    /// Meta service has issued a `send_barrier` request. We're collecting barriers now.
69    Issued(IssuedState),
70
71    /// The barrier has been collected by all remaining actors
72    AllCollected {
73        create_mview_progress: Vec<PbCreateMviewProgress>,
74        list_finished_source_ids: Vec<PbListFinishedSource>,
75        load_finished_source_ids: Vec<PbLoadFinishedSource>,
76        cdc_table_backfill_progress: Vec<PbCdcTableBackfillProgress>,
77        cdc_source_offset_updated: Vec<PbCdcSourceOffsetUpdated>,
78        truncate_tables: Vec<TableId>,
79        refresh_finished_tables: Vec<TableId>,
80    },
81}
82
83#[derive(Debug)]
84struct BarrierState {
85    barrier: Barrier,
86    /// Only be `Some(_)` when `barrier.kind` is `Checkpoint`
87    table_ids: Option<HashSet<TableId>>,
88    inner: ManagedBarrierStateInner,
89}
90
91use risingwave_common::must_match;
92use risingwave_pb::id::FragmentId;
93use risingwave_pb::stream_service::InjectBarrierRequest;
94
95use crate::executor::exchange::permit;
96use crate::executor::exchange::permit::channel_from_config;
97use crate::task::barrier_worker::await_epoch_completed_future::AwaitEpochCompletedFuture;
98use crate::task::barrier_worker::{ScoredStreamError, TakeReceiverRequest};
99use crate::task::cdc_progress::CdcTableBackfillState;
100
101pub(super) struct ManagedBarrierStateDebugInfo<'a> {
102    running_actors: BTreeSet<ActorId>,
103    graph_state: &'a PartialGraphManagedBarrierState,
104}
105
106impl Display for ManagedBarrierStateDebugInfo<'_> {
107    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108        write!(f, "running_actors: ")?;
109        for actor_id in &self.running_actors {
110            write!(f, "{}, ", actor_id)?;
111        }
112        {
113            writeln!(f, "graph states")?;
114            write!(f, "{}", self.graph_state)?;
115        }
116        Ok(())
117    }
118}
119
120impl Display for &'_ PartialGraphManagedBarrierState {
121    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122        let mut prev_epoch = 0u64;
123        for (epoch, barrier_state) in &self.epoch_barrier_state_map {
124            write!(f, "> Epoch {}: ", epoch)?;
125            match &barrier_state.inner {
126                ManagedBarrierStateInner::Issued(state) => {
127                    write!(
128                        f,
129                        "Issued [{:?}]. Remaining actors: [",
130                        barrier_state.barrier.kind
131                    )?;
132                    let mut is_prev_epoch_issued = false;
133                    if prev_epoch != 0 {
134                        let bs = &self.epoch_barrier_state_map[&prev_epoch];
135                        if let ManagedBarrierStateInner::Issued(IssuedState {
136                            remaining_actors: remaining_actors_prev,
137                            ..
138                        }) = &bs.inner
139                        {
140                            // Only show the actors that are not in the previous epoch.
141                            is_prev_epoch_issued = true;
142                            let mut duplicates = 0usize;
143                            for actor_id in &state.remaining_actors {
144                                if !remaining_actors_prev.contains(actor_id) {
145                                    write!(f, "{}, ", actor_id)?;
146                                } else {
147                                    duplicates += 1;
148                                }
149                            }
150                            if duplicates > 0 {
151                                write!(f, "...and {} actors in prev epoch", duplicates)?;
152                            }
153                        }
154                    }
155                    if !is_prev_epoch_issued {
156                        for actor_id in &state.remaining_actors {
157                            write!(f, "{}, ", actor_id)?;
158                        }
159                    }
160                    write!(f, "]")?;
161                }
162                ManagedBarrierStateInner::AllCollected { .. } => {
163                    write!(f, "AllCollected")?;
164                }
165            }
166            prev_epoch = *epoch;
167            writeln!(f)?;
168        }
169
170        if !self.create_mview_progress.is_empty() {
171            writeln!(f, "Create MView Progress:")?;
172            for (epoch, progress) in &self.create_mview_progress {
173                write!(f, "> Epoch {}:", epoch)?;
174                for (actor_id, (_, state)) in progress {
175                    write!(f, ">> Actor {}: {}, ", actor_id, state)?;
176                }
177            }
178        }
179
180        Ok(())
181    }
182}
183
184enum InflightActorStatus {
185    /// The actor has been issued some barriers, but has not collected the first barrier
186    IssuedFirst(Vec<Barrier>),
187    /// The actor has been issued some barriers, and has collected the first barrier
188    Running(u64),
189}
190
191impl InflightActorStatus {
192    fn max_issued_epoch(&self) -> u64 {
193        match self {
194            InflightActorStatus::Running(epoch) => *epoch,
195            InflightActorStatus::IssuedFirst(issued_barriers) => {
196                issued_barriers.last().expect("non-empty").epoch.prev
197            }
198        }
199    }
200}
201
202pub(crate) struct InflightActorState {
203    actor_id: ActorId,
204    barrier_senders: Vec<mpsc::UnboundedSender<Barrier>>,
205    /// `prev_epoch`. `push_back` and `pop_front`
206    pub(in crate::task) inflight_barriers: VecDeque<u64>,
207    status: InflightActorStatus,
208    /// Whether the actor has been issued a stop barrier
209    is_stopping: bool,
210
211    new_output_request_tx: UnboundedSender<(ActorId, NewOutputRequest)>,
212    join_handle: JoinHandle<()>,
213    monitor_task_handle: Option<JoinHandle<()>>,
214}
215
216impl InflightActorState {
217    pub(super) fn start(
218        actor_id: ActorId,
219        initial_barrier: &Barrier,
220        new_output_request_tx: UnboundedSender<(ActorId, NewOutputRequest)>,
221        join_handle: JoinHandle<()>,
222        monitor_task_handle: Option<JoinHandle<()>>,
223    ) -> Self {
224        Self {
225            actor_id,
226            barrier_senders: vec![],
227            inflight_barriers: VecDeque::from_iter([initial_barrier.epoch.prev]),
228            status: InflightActorStatus::IssuedFirst(vec![initial_barrier.clone()]),
229            is_stopping: false,
230            new_output_request_tx,
231            join_handle,
232            monitor_task_handle,
233        }
234    }
235
236    pub(super) fn issue_barrier(&mut self, barrier: &Barrier, is_stop: bool) -> StreamResult<()> {
237        assert!(barrier.epoch.prev > self.status.max_issued_epoch());
238
239        for barrier_sender in &self.barrier_senders {
240            barrier_sender.send(barrier.clone()).map_err(|_| {
241                StreamError::barrier_send(
242                    barrier.clone(),
243                    self.actor_id,
244                    "failed to send to registered sender",
245                )
246            })?;
247        }
248
249        if let Some(prev_epoch) = self.inflight_barriers.back() {
250            assert!(*prev_epoch < barrier.epoch.prev);
251        }
252        self.inflight_barriers.push_back(barrier.epoch.prev);
253
254        match &mut self.status {
255            InflightActorStatus::IssuedFirst(pending_barriers) => {
256                pending_barriers.push(barrier.clone());
257            }
258            InflightActorStatus::Running(prev_epoch) => {
259                *prev_epoch = barrier.epoch.prev;
260            }
261        };
262
263        if is_stop {
264            assert!(!self.is_stopping, "stopped actor should not issue barrier");
265            self.is_stopping = true;
266        }
267        Ok(())
268    }
269
270    pub(super) fn collect(&mut self, epoch: EpochPair) -> bool {
271        let prev_epoch = self.inflight_barriers.pop_front().expect("should exist");
272        assert_eq!(prev_epoch, epoch.prev);
273        match &self.status {
274            InflightActorStatus::IssuedFirst(pending_barriers) => {
275                assert_eq!(
276                    prev_epoch,
277                    pending_barriers.first().expect("non-empty").epoch.prev
278                );
279                self.status = InflightActorStatus::Running(
280                    pending_barriers.last().expect("non-empty").epoch.prev,
281                );
282            }
283            InflightActorStatus::Running(_) => {}
284        }
285
286        self.inflight_barriers.is_empty() && self.is_stopping
287    }
288}
289
290/// Part of [`PartialGraphState`]
291pub(crate) struct PartialGraphManagedBarrierState {
292    /// Record barrier state for each epoch of concurrent checkpoints.
293    ///
294    /// The key is `prev_epoch`, and the first value is `curr_epoch`
295    epoch_barrier_state_map: BTreeMap<u64, BarrierState>,
296
297    prev_barrier_table_ids: Option<(EpochPair, HashSet<TableId>)>,
298
299    /// Record the progress updates of creating mviews for each epoch of concurrent checkpoints.
300    ///
301    /// The process of progress reporting is as follows:
302    /// 1. updated by [`crate::task::barrier_manager::CreateMviewProgressReporter::update`]
303    /// 2. converted to [`ManagedBarrierStateInner`] in [`Self::may_have_collected_all`]
304    /// 3. handled by [`Self::pop_barrier_to_complete`]
305    /// 4. put in [`crate::task::barrier_worker::BarrierCompleteResult`] and reported to meta.
306    pub(crate) create_mview_progress: HashMap<u64, HashMap<ActorId, (FragmentId, BackfillState)>>,
307
308    /// Record the source list finished reports for each epoch of concurrent checkpoints.
309    /// Used for refreshable batch source. The map key is epoch and the value is
310    /// a list of pb messages reported by actors.
311    pub(crate) list_finished_source_ids: HashMap<u64, Vec<PbListFinishedSource>>,
312
313    /// Record the source load finished reports for each epoch of concurrent checkpoints.
314    /// Used for refreshable batch source. The map key is epoch and the value is
315    /// a list of pb messages reported by actors.
316    pub(crate) load_finished_source_ids: HashMap<u64, Vec<PbLoadFinishedSource>>,
317
318    pub(crate) cdc_table_backfill_progress: HashMap<u64, HashMap<ActorId, CdcTableBackfillState>>,
319
320    /// Record CDC source offset updated reports for each epoch of concurrent checkpoints.
321    /// Used to track when CDC sources have updated their offset at least once.
322    pub(crate) cdc_source_offset_updated: HashMap<u64, Vec<PbCdcSourceOffsetUpdated>>,
323
324    /// Record the tables to truncate for each epoch of concurrent checkpoints.
325    pub(crate) truncate_tables: HashMap<u64, HashSet<TableId>>,
326    /// Record the tables that have finished refresh for each epoch of concurrent checkpoints.
327    /// Used for materialized view refresh completion reporting.
328    pub(crate) refresh_finished_tables: HashMap<u64, HashSet<TableId>>,
329
330    state_store: StateStoreImpl,
331
332    streaming_metrics: Arc<StreamingMetrics>,
333}
334
335impl PartialGraphManagedBarrierState {
336    pub(super) fn new(actor_manager: &StreamActorManager) -> Self {
337        Self::new_inner(
338            actor_manager.env.state_store(),
339            actor_manager.streaming_metrics.clone(),
340        )
341    }
342
343    fn new_inner(state_store: StateStoreImpl, streaming_metrics: Arc<StreamingMetrics>) -> Self {
344        Self {
345            epoch_barrier_state_map: Default::default(),
346            prev_barrier_table_ids: None,
347            create_mview_progress: Default::default(),
348            list_finished_source_ids: Default::default(),
349            load_finished_source_ids: Default::default(),
350            cdc_table_backfill_progress: Default::default(),
351            cdc_source_offset_updated: Default::default(),
352            truncate_tables: Default::default(),
353            refresh_finished_tables: Default::default(),
354            state_store,
355            streaming_metrics,
356        }
357    }
358
359    #[cfg(test)]
360    pub(crate) fn for_test() -> Self {
361        Self::new_inner(
362            StateStoreImpl::for_test(),
363            Arc::new(StreamingMetrics::unused()),
364        )
365    }
366
367    pub(super) fn is_empty(&self) -> bool {
368        self.epoch_barrier_state_map.is_empty()
369    }
370}
371
372pub(crate) struct SuspendedPartialGraphState {
373    pub(super) suspend_time: Instant,
374    inner: PartialGraphState,
375    failure: Option<(Option<ActorId>, StreamError)>,
376}
377
378impl SuspendedPartialGraphState {
379    fn new(
380        state: PartialGraphState,
381        failure: Option<(Option<ActorId>, StreamError)>,
382        _completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>, /* discard the completing futures */
383    ) -> Self {
384        Self {
385            suspend_time: Instant::now(),
386            inner: state,
387            failure,
388        }
389    }
390
391    async fn reset(mut self) -> ResetPartialGraphOutput {
392        let root_err = self.inner.try_find_root_actor_failure(self.failure).await;
393        self.inner.abort_and_wait_actors().await;
394        ResetPartialGraphOutput { root_err }
395    }
396}
397
398pub(crate) struct ResetPartialGraphOutput {
399    pub(crate) root_err: Option<ScoredStreamError>,
400}
401
402pub(in crate::task) enum PartialGraphStatus {
403    ReceivedExchangeRequest(Vec<(UpDownActorIds, TakeReceiverRequest)>),
404    Running(PartialGraphState),
405    Suspended(SuspendedPartialGraphState),
406    Resetting,
407    /// temporary place holder
408    Unspecified,
409}
410
411impl PartialGraphStatus {
412    pub(crate) async fn abort(&mut self) {
413        match self {
414            PartialGraphStatus::ReceivedExchangeRequest(pending_requests) => {
415                for (_, request) in pending_requests.drain(..) {
416                    if let TakeReceiverRequest::Remote(sender) = request {
417                        let _ = sender.send(Err(anyhow!("partial graph aborted").into()));
418                    }
419                }
420            }
421            PartialGraphStatus::Running(state) => {
422                state.abort_and_wait_actors().await;
423            }
424            PartialGraphStatus::Suspended(SuspendedPartialGraphState { inner: state, .. }) => {
425                state.abort_and_wait_actors().await;
426            }
427            PartialGraphStatus::Resetting => {}
428            PartialGraphStatus::Unspecified => {
429                unreachable!()
430            }
431        }
432    }
433
434    pub(crate) fn state_for_request(&mut self) -> Option<&mut PartialGraphState> {
435        match self {
436            PartialGraphStatus::ReceivedExchangeRequest(_) => {
437                unreachable!("should not handle request")
438            }
439            PartialGraphStatus::Running(state) => Some(state),
440            PartialGraphStatus::Suspended(_) => None,
441            PartialGraphStatus::Resetting => {
442                unreachable!("should not receive further request during cleaning")
443            }
444            PartialGraphStatus::Unspecified => {
445                unreachable!()
446            }
447        }
448    }
449
450    pub(super) fn poll_next_event(
451        &mut self,
452        cx: &mut Context<'_>,
453    ) -> Poll<ManagedBarrierStateEvent> {
454        match self {
455            PartialGraphStatus::ReceivedExchangeRequest(_) => Poll::Pending,
456            PartialGraphStatus::Running(state) => state.poll_next_event(cx),
457            PartialGraphStatus::Suspended(_) | PartialGraphStatus::Resetting => Poll::Pending,
458            PartialGraphStatus::Unspecified => {
459                unreachable!()
460            }
461        }
462    }
463
464    pub(super) fn suspend(
465        &mut self,
466        failed_actor: Option<ActorId>,
467        err: StreamError,
468        completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>,
469    ) {
470        let state = must_match!(replace(self, PartialGraphStatus::Unspecified), PartialGraphStatus::Running(state) => state);
471        *self = PartialGraphStatus::Suspended(SuspendedPartialGraphState::new(
472            state,
473            Some((failed_actor, err)),
474            completing_futures,
475        ));
476    }
477
478    pub(super) fn start_reset(
479        &mut self,
480        partial_graph_id: PartialGraphId,
481        completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>,
482        table_ids_to_clear: &mut HashSet<TableId>,
483    ) -> BoxFuture<'static, ResetPartialGraphOutput> {
484        match replace(self, PartialGraphStatus::Resetting) {
485            PartialGraphStatus::ReceivedExchangeRequest(pending_requests) => {
486                for (_, request) in pending_requests {
487                    if let TakeReceiverRequest::Remote(sender) = request {
488                        let _ = sender.send(Err(anyhow!("partial graph reset").into()));
489                    }
490                }
491                async move { ResetPartialGraphOutput { root_err: None } }.boxed()
492            }
493            PartialGraphStatus::Running(state) => {
494                assert_eq!(partial_graph_id, state.partial_graph_id);
495                info!(
496                    %partial_graph_id,
497                    "start partial graph reset from Running"
498                );
499                table_ids_to_clear.extend(state.table_ids.iter().copied());
500                SuspendedPartialGraphState::new(state, None, completing_futures)
501                    .reset()
502                    .boxed()
503            }
504            PartialGraphStatus::Suspended(state) => {
505                assert!(
506                    completing_futures.is_none(),
507                    "should have been clear when suspended"
508                );
509                assert_eq!(partial_graph_id, state.inner.partial_graph_id);
510                info!(
511                    %partial_graph_id,
512                    suspend_elapsed = ?state.suspend_time.elapsed(),
513                    "start partial graph reset after suspended"
514                );
515                table_ids_to_clear.extend(state.inner.table_ids.iter().copied());
516                state.reset().boxed()
517            }
518            PartialGraphStatus::Resetting => {
519                unreachable!("should not reset for twice");
520            }
521            PartialGraphStatus::Unspecified => {
522                unreachable!()
523            }
524        }
525    }
526}
527
528#[derive(Default)]
529pub(in crate::task) struct ManagedBarrierState {
530    pub(super) partial_graphs: HashMap<PartialGraphId, PartialGraphStatus>,
531    pub(super) resetting_graphs:
532        FuturesUnordered<JoinHandle<Vec<(PartialGraphId, ResetPartialGraphOutput)>>>,
533}
534
535pub(super) enum ManagedBarrierStateEvent {
536    BarrierCollected {
537        partial_graph_id: PartialGraphId,
538        barrier: Barrier,
539    },
540    ActorError {
541        partial_graph_id: PartialGraphId,
542        actor_id: ActorId,
543        err: StreamError,
544    },
545    PartialGraphsReset(Vec<(PartialGraphId, ResetPartialGraphOutput)>),
546    RegisterLocalUpstreamOutput {
547        actor_id: ActorId,
548        upstream_actor_id: ActorId,
549        upstream_partial_graph_id: PartialGraphId,
550        tx: permit::Sender,
551    },
552}
553
554impl ManagedBarrierState {
555    pub(super) fn next_event(&mut self) -> impl Future<Output = ManagedBarrierStateEvent> + '_ {
556        poll_fn(|cx| {
557            for graph in self.partial_graphs.values_mut() {
558                if let Poll::Ready(event) = graph.poll_next_event(cx) {
559                    return Poll::Ready(event);
560                }
561            }
562            if let Poll::Ready(Some(result)) = self.resetting_graphs.poll_next_unpin(cx) {
563                let outputs = result.expect("failed to join resetting future");
564                for (partial_graph_id, _) in &outputs {
565                    let PartialGraphStatus::Resetting = self
566                        .partial_graphs
567                        .remove(partial_graph_id)
568                        .expect("should exist")
569                    else {
570                        panic!("should be resetting")
571                    };
572                }
573                return Poll::Ready(ManagedBarrierStateEvent::PartialGraphsReset(outputs));
574            }
575            Poll::Pending
576        })
577    }
578}
579
580/// Per-partial-graph barrier state manager. Handles barriers for one specific partial graph.
581/// Part of [`ManagedBarrierState`] in [`super::LocalBarrierWorker`].
582///
583/// See [`crate::task`] for architecture overview.
584pub(crate) struct PartialGraphState {
585    partial_graph_id: PartialGraphId,
586    pub(crate) actor_states: HashMap<ActorId, InflightActorState>,
587    pub(super) actor_pending_new_output_requests:
588        HashMap<ActorId, Vec<(ActorId, NewOutputRequest)>>,
589
590    pub(crate) graph_state: PartialGraphManagedBarrierState,
591
592    table_ids: HashSet<TableId>,
593
594    actor_manager: Arc<StreamActorManager>,
595
596    pub(super) local_barrier_manager: LocalBarrierManager,
597
598    barrier_event_rx: UnboundedReceiver<LocalBarrierEvent>,
599    pub(super) actor_failure_rx: UnboundedReceiver<(ActorId, StreamError)>,
600}
601
602impl PartialGraphState {
603    /// Create a barrier manager state. This will be called only once.
604    pub(super) fn new(
605        partial_graph_id: PartialGraphId,
606        term_id: String,
607        actor_manager: Arc<StreamActorManager>,
608    ) -> Self {
609        let (local_barrier_manager, barrier_event_rx, actor_failure_rx) =
610            LocalBarrierManager::new(term_id, actor_manager.env.clone());
611        Self {
612            partial_graph_id,
613            actor_states: Default::default(),
614            actor_pending_new_output_requests: Default::default(),
615            graph_state: PartialGraphManagedBarrierState::new(&actor_manager),
616            table_ids: Default::default(),
617            actor_manager,
618            local_barrier_manager,
619            barrier_event_rx,
620            actor_failure_rx,
621        }
622    }
623
624    pub(super) fn to_debug_info(&self) -> ManagedBarrierStateDebugInfo<'_> {
625        ManagedBarrierStateDebugInfo {
626            running_actors: self.actor_states.keys().cloned().collect(),
627            graph_state: &self.graph_state,
628        }
629    }
630
631    async fn abort_and_wait_actors(&mut self) {
632        for (actor_id, state) in &self.actor_states {
633            tracing::debug!("force stopping actor {}", actor_id);
634            state.join_handle.abort();
635            if let Some(monitor_task_handle) = &state.monitor_task_handle {
636                monitor_task_handle.abort();
637            }
638        }
639
640        for (actor_id, state) in self.actor_states.drain() {
641            tracing::debug!("join actor {}", actor_id);
642            let result = state.join_handle.await;
643            assert!(result.is_ok() || result.unwrap_err().is_cancelled());
644        }
645    }
646}
647
648impl InflightActorState {
649    pub(super) fn register_barrier_sender(
650        &mut self,
651        tx: mpsc::UnboundedSender<Barrier>,
652    ) -> StreamResult<()> {
653        match &self.status {
654            InflightActorStatus::IssuedFirst(pending_barriers) => {
655                for barrier in pending_barriers {
656                    tx.send(barrier.clone()).map_err(|_| {
657                        StreamError::barrier_send(
658                            barrier.clone(),
659                            self.actor_id,
660                            "failed to send pending barriers to newly registered sender",
661                        )
662                    })?;
663                }
664                self.barrier_senders.push(tx);
665            }
666            InflightActorStatus::Running(_) => {
667                unreachable!("should not register barrier sender when entering Running status")
668            }
669        }
670        Ok(())
671    }
672}
673
674impl PartialGraphState {
675    pub(super) fn register_barrier_sender(
676        &mut self,
677        actor_id: ActorId,
678        tx: mpsc::UnboundedSender<Barrier>,
679    ) -> StreamResult<()> {
680        self.actor_states
681            .get_mut(&actor_id)
682            .expect("should exist")
683            .register_barrier_sender(tx)
684    }
685}
686
687impl PartialGraphState {
688    pub(super) fn transform_to_issued(
689        &mut self,
690        barrier: &Barrier,
691        request: InjectBarrierRequest,
692    ) -> StreamResult<()> {
693        assert_eq!(self.partial_graph_id, request.partial_graph_id);
694        let actor_to_stop = barrier.all_stop_actors();
695        let is_stop_actor = |actor_id| {
696            actor_to_stop
697                .map(|actors| actors.contains(&actor_id))
698                .unwrap_or(false)
699        };
700
701        let table_ids = HashSet::from_iter(request.table_ids_to_sync);
702        self.table_ids.extend(table_ids.iter().cloned());
703
704        self.graph_state.transform_to_issued(
705            barrier,
706            request.actor_ids_to_collect.iter().copied(),
707            table_ids,
708        );
709
710        let mut new_actors = HashSet::new();
711        for (node, fragment_id, actor) in
712            request
713                .actors_to_build
714                .into_iter()
715                .flat_map(|fragment_actors| {
716                    let node = Arc::new(fragment_actors.node.unwrap());
717                    fragment_actors
718                        .actors
719                        .into_iter()
720                        .map(move |actor| (node.clone(), fragment_actors.fragment_id, actor))
721                })
722        {
723            let actor_id = actor.actor_id;
724            assert!(!is_stop_actor(actor_id));
725            assert!(new_actors.insert(actor_id));
726            assert!(request.actor_ids_to_collect.contains(&actor_id));
727            let (new_output_request_tx, new_output_request_rx) = unbounded_channel();
728            if let Some(pending_requests) = self.actor_pending_new_output_requests.remove(&actor_id)
729            {
730                for request in pending_requests {
731                    let _ = new_output_request_tx.send(request);
732                }
733            }
734            let (join_handle, monitor_join_handle) = self.actor_manager.spawn_actor(
735                actor,
736                fragment_id,
737                node,
738                self.local_barrier_manager.clone(),
739                new_output_request_rx,
740            );
741            assert!(
742                self.actor_states
743                    .try_insert(
744                        actor_id,
745                        InflightActorState::start(
746                            actor_id,
747                            barrier,
748                            new_output_request_tx,
749                            join_handle,
750                            monitor_join_handle
751                        )
752                    )
753                    .is_ok()
754            );
755        }
756
757        // Spawn a trivial join handle to be compatible with the unit test. In the unit tests that involve local barrier manager,
758        // actors are spawned in the local test logic, but we assume that there is an entry for each spawned actor in ·actor_states`,
759        // so under cfg!(test) we add a dummy entry for each new actor.
760        if cfg!(test) {
761            for &actor_id in &request.actor_ids_to_collect {
762                if !self.actor_states.contains_key(&actor_id) {
763                    let (tx, rx) = unbounded_channel();
764                    let join_handle = self.actor_manager.runtime.spawn(async move {
765                        // The rx is spawned so that tx.send() will not fail.
766                        let _ = rx;
767                        pending().await
768                    });
769                    assert!(
770                        self.actor_states
771                            .try_insert(
772                                actor_id,
773                                InflightActorState::start(actor_id, barrier, tx, join_handle, None,)
774                            )
775                            .is_ok()
776                    );
777                    new_actors.insert(actor_id);
778                }
779            }
780        }
781
782        // Note: it's important to issue barrier to actor after issuing to graph to ensure that
783        // we call `start_epoch` on the graph before the actors receive the barrier
784        for &actor_id in &request.actor_ids_to_collect {
785            if new_actors.contains(&actor_id) {
786                continue;
787            }
788            self.actor_states
789                .get_mut(&actor_id)
790                .unwrap_or_else(|| {
791                    panic!(
792                        "should exist: {} {:?}",
793                        actor_id, request.actor_ids_to_collect
794                    );
795                })
796                .issue_barrier(barrier, is_stop_actor(actor_id))?;
797        }
798
799        Ok(())
800    }
801
802    pub(super) fn new_actor_output_request(
803        &mut self,
804        actor_id: ActorId,
805        upstream_actor_id: ActorId,
806        request: TakeReceiverRequest,
807    ) {
808        let request = match request {
809            TakeReceiverRequest::Remote(result_sender) => {
810                let (tx, rx) = channel_from_config(self.local_barrier_manager.env.global_config());
811                let _ = result_sender.send(Ok(rx));
812                NewOutputRequest::Remote(tx)
813            }
814            TakeReceiverRequest::Local(tx) => NewOutputRequest::Local(tx),
815        };
816        if let Some(actor) = self.actor_states.get_mut(&upstream_actor_id) {
817            let _ = actor.new_output_request_tx.send((actor_id, request));
818        } else {
819            self.actor_pending_new_output_requests
820                .entry(upstream_actor_id)
821                .or_default()
822                .push((actor_id, request));
823        }
824    }
825
826    /// Handles [`LocalBarrierEvent`] from [`crate::task::barrier_manager::LocalBarrierManager`].
827    pub(super) fn poll_next_event(
828        &mut self,
829        cx: &mut Context<'_>,
830    ) -> Poll<ManagedBarrierStateEvent> {
831        if let Poll::Ready(option) = self.actor_failure_rx.poll_recv(cx) {
832            let (actor_id, err) = option.expect("non-empty when tx in local_barrier_manager");
833            return Poll::Ready(ManagedBarrierStateEvent::ActorError {
834                actor_id,
835                err,
836                partial_graph_id: self.partial_graph_id,
837            });
838        }
839        // yield some pending collected epochs
840        {
841            if let Some(barrier) = self.graph_state.may_have_collected_all() {
842                return Poll::Ready(ManagedBarrierStateEvent::BarrierCollected {
843                    barrier,
844                    partial_graph_id: self.partial_graph_id,
845                });
846            }
847        }
848        while let Poll::Ready(event) = self.barrier_event_rx.poll_recv(cx) {
849            match event.expect("non-empty when tx in local_barrier_manager") {
850                LocalBarrierEvent::ReportActorCollected { actor_id, epoch } => {
851                    if let Some(barrier) = self.collect(actor_id, epoch) {
852                        return Poll::Ready(ManagedBarrierStateEvent::BarrierCollected {
853                            barrier,
854                            partial_graph_id: self.partial_graph_id,
855                        });
856                    }
857                }
858                LocalBarrierEvent::ReportCreateProgress {
859                    epoch,
860                    fragment_id,
861                    actor,
862                    state,
863                } => {
864                    self.update_create_mview_progress(epoch, fragment_id, actor, state);
865                }
866                LocalBarrierEvent::ReportSourceListFinished {
867                    epoch,
868                    actor_id,
869                    table_id,
870                    associated_source_id,
871                } => {
872                    self.report_source_list_finished(
873                        epoch,
874                        actor_id,
875                        table_id,
876                        associated_source_id,
877                    );
878                }
879                LocalBarrierEvent::ReportSourceLoadFinished {
880                    epoch,
881                    actor_id,
882                    table_id,
883                    associated_source_id,
884                } => {
885                    self.report_source_load_finished(
886                        epoch,
887                        actor_id,
888                        table_id,
889                        associated_source_id,
890                    );
891                }
892                LocalBarrierEvent::RefreshFinished {
893                    epoch,
894                    actor_id,
895                    table_id,
896                    staging_table_id,
897                } => {
898                    self.report_refresh_finished(epoch, actor_id, table_id, staging_table_id);
899                }
900                LocalBarrierEvent::RegisterBarrierSender {
901                    actor_id,
902                    barrier_sender,
903                } => {
904                    if let Err(err) = self.register_barrier_sender(actor_id, barrier_sender) {
905                        return Poll::Ready(ManagedBarrierStateEvent::ActorError {
906                            actor_id,
907                            err,
908                            partial_graph_id: self.partial_graph_id,
909                        });
910                    }
911                }
912                LocalBarrierEvent::RegisterLocalUpstreamOutput {
913                    actor_id,
914                    upstream_actor_id,
915                    upstream_partial_graph_id,
916                    tx,
917                } => {
918                    return Poll::Ready(ManagedBarrierStateEvent::RegisterLocalUpstreamOutput {
919                        actor_id,
920                        upstream_actor_id,
921                        upstream_partial_graph_id,
922                        tx,
923                    });
924                }
925                LocalBarrierEvent::ReportCdcTableBackfillProgress {
926                    actor_id,
927                    epoch,
928                    state,
929                } => {
930                    self.update_cdc_table_backfill_progress(epoch, actor_id, state);
931                }
932                LocalBarrierEvent::ReportCdcSourceOffsetUpdated {
933                    epoch,
934                    actor_id,
935                    source_id,
936                } => {
937                    self.report_cdc_source_offset_updated(epoch, actor_id, source_id);
938                }
939            }
940        }
941
942        debug_assert!(self.graph_state.may_have_collected_all().is_none());
943        Poll::Pending
944    }
945}
946
947impl PartialGraphState {
948    #[must_use]
949    pub(super) fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) -> Option<Barrier> {
950        let is_finished = self
951            .actor_states
952            .get_mut(&actor_id)
953            .expect("should exist")
954            .collect(epoch);
955        if is_finished {
956            let state = self.actor_states.remove(&actor_id).expect("should exist");
957            if let Some(monitor_task_handle) = state.monitor_task_handle {
958                monitor_task_handle.abort();
959            }
960        }
961        self.graph_state.collect(actor_id, epoch);
962        self.graph_state.may_have_collected_all()
963    }
964
965    pub(super) fn pop_barrier_to_complete(&mut self, prev_epoch: u64) -> BarrierToComplete {
966        self.graph_state.pop_barrier_to_complete(prev_epoch)
967    }
968
969    /// Collect actor errors for a while and find the one that might be the root cause.
970    ///
971    /// Returns `None` if there's no actor error received.
972    async fn try_find_root_actor_failure(
973        &mut self,
974        first_failure: Option<(Option<ActorId>, StreamError)>,
975    ) -> Option<ScoredStreamError> {
976        let mut later_errs = vec![];
977        // fetch more actor errors within a timeout
978        let _ = tokio::time::timeout(Duration::from_secs(3), async {
979            let mut uncollected_actors: HashSet<_> = self.actor_states.keys().cloned().collect();
980            if let Some((Some(failed_actor), _)) = &first_failure {
981                uncollected_actors.remove(failed_actor);
982            }
983            while !uncollected_actors.is_empty()
984                && let Some((actor_id, error)) = self.actor_failure_rx.recv().await
985            {
986                uncollected_actors.remove(&actor_id);
987                later_errs.push(error);
988            }
989        })
990        .await;
991
992        first_failure
993            .into_iter()
994            .map(|(_, err)| err)
995            .chain(later_errs.into_iter())
996            .map(|e| e.with_score())
997            .max_by_key(|e| e.score)
998    }
999
1000    /// Report that a source has finished listing for a specific epoch
1001    pub(super) fn report_source_list_finished(
1002        &mut self,
1003        epoch: EpochPair,
1004        actor_id: ActorId,
1005        table_id: TableId,
1006        associated_source_id: SourceId,
1007    ) {
1008        // Find the correct partial graph state by matching the actor's partial graph id
1009        if let Some(actor_state) = self.actor_states.get(&actor_id)
1010            && actor_state.inflight_barriers.contains(&epoch.prev)
1011        {
1012            self.graph_state
1013                .list_finished_source_ids
1014                .entry(epoch.curr)
1015                .or_default()
1016                .push(PbListFinishedSource {
1017                    reporter_actor_id: actor_id,
1018                    table_id,
1019                    associated_source_id,
1020                });
1021        } else {
1022            warn!(
1023                ?epoch,
1024                %actor_id, %table_id, %associated_source_id, "ignore source list finished"
1025            );
1026        }
1027    }
1028
1029    /// Report that a source has finished loading for a specific epoch
1030    pub(super) fn report_source_load_finished(
1031        &mut self,
1032        epoch: EpochPair,
1033        actor_id: ActorId,
1034        table_id: TableId,
1035        associated_source_id: SourceId,
1036    ) {
1037        // Find the correct partial graph state by matching the actor's partial graph id
1038        if let Some(actor_state) = self.actor_states.get(&actor_id)
1039            && actor_state.inflight_barriers.contains(&epoch.prev)
1040        {
1041            self.graph_state
1042                .load_finished_source_ids
1043                .entry(epoch.curr)
1044                .or_default()
1045                .push(PbLoadFinishedSource {
1046                    reporter_actor_id: actor_id,
1047                    table_id,
1048                    associated_source_id,
1049                });
1050        } else {
1051            warn!(
1052                ?epoch,
1053                %actor_id, %table_id, %associated_source_id, "ignore source load finished"
1054            );
1055        }
1056    }
1057
1058    /// Report that a CDC source has updated its offset at least once
1059    pub(super) fn report_cdc_source_offset_updated(
1060        &mut self,
1061        epoch: EpochPair,
1062        actor_id: ActorId,
1063        source_id: SourceId,
1064    ) {
1065        if let Some(actor_state) = self.actor_states.get(&actor_id)
1066            && actor_state.inflight_barriers.contains(&epoch.prev)
1067        {
1068            self.graph_state
1069                .cdc_source_offset_updated
1070                .entry(epoch.curr)
1071                .or_default()
1072                .push(PbCdcSourceOffsetUpdated {
1073                    reporter_actor_id: actor_id.as_raw_id(),
1074                    source_id: source_id.as_raw_id(),
1075                });
1076        } else {
1077            warn!(
1078                ?epoch,
1079                %actor_id, %source_id, "ignore cdc source offset updated"
1080            );
1081        }
1082    }
1083
1084    /// Report that a table has finished refreshing for a specific epoch
1085    pub(super) fn report_refresh_finished(
1086        &mut self,
1087        epoch: EpochPair,
1088        actor_id: ActorId,
1089        table_id: TableId,
1090        staging_table_id: TableId,
1091    ) {
1092        // Find the correct partial graph state by matching the actor's partial graph id
1093        let Some(actor_state) = self.actor_states.get(&actor_id) else {
1094            warn!(
1095                ?epoch,
1096                %actor_id, %table_id, "ignore refresh finished table: actor_state not found"
1097            );
1098            return;
1099        };
1100        if !actor_state.inflight_barriers.contains(&epoch.prev) {
1101            warn!(
1102                ?epoch,
1103                %actor_id,
1104                %table_id,
1105                inflight_barriers = ?actor_state.inflight_barriers,
1106                "ignore refresh finished table: partial_graph_id not found in inflight_barriers"
1107            );
1108            return;
1109        };
1110        self.graph_state
1111            .refresh_finished_tables
1112            .entry(epoch.curr)
1113            .or_default()
1114            .insert(table_id);
1115        self.graph_state
1116            .truncate_tables
1117            .entry(epoch.curr)
1118            .or_default()
1119            .insert(staging_table_id);
1120    }
1121}
1122
1123impl PartialGraphManagedBarrierState {
1124    /// This method is called when barrier state is modified in either `Issued` or `Stashed`
1125    /// to transform the state to `AllCollected` and start state store `sync` when the barrier
1126    /// has been collected from all actors for an `Issued` barrier.
1127    fn may_have_collected_all(&mut self) -> Option<Barrier> {
1128        for barrier_state in self.epoch_barrier_state_map.values_mut() {
1129            match &barrier_state.inner {
1130                ManagedBarrierStateInner::Issued(IssuedState {
1131                    remaining_actors, ..
1132                }) if remaining_actors.is_empty() => {}
1133                ManagedBarrierStateInner::AllCollected { .. } => {
1134                    continue;
1135                }
1136                ManagedBarrierStateInner::Issued(_) => {
1137                    break;
1138                }
1139            }
1140
1141            self.streaming_metrics.barrier_manager_progress.inc();
1142
1143            let create_mview_progress = self
1144                .create_mview_progress
1145                .remove(&barrier_state.barrier.epoch.curr)
1146                .unwrap_or_default()
1147                .into_iter()
1148                .map(|(actor, (fragment_id, state))| state.to_pb(fragment_id, actor))
1149                .collect();
1150
1151            let list_finished_source_ids = self
1152                .list_finished_source_ids
1153                .remove(&barrier_state.barrier.epoch.curr)
1154                .unwrap_or_default();
1155
1156            let load_finished_source_ids = self
1157                .load_finished_source_ids
1158                .remove(&barrier_state.barrier.epoch.curr)
1159                .unwrap_or_default();
1160
1161            let cdc_table_backfill_progress = self
1162                .cdc_table_backfill_progress
1163                .remove(&barrier_state.barrier.epoch.curr)
1164                .unwrap_or_default()
1165                .into_iter()
1166                .map(|(actor, state)| state.to_pb(actor, barrier_state.barrier.epoch.curr))
1167                .collect();
1168
1169            let cdc_source_offset_updated = self
1170                .cdc_source_offset_updated
1171                .remove(&barrier_state.barrier.epoch.curr)
1172                .unwrap_or_default();
1173
1174            let truncate_tables = self
1175                .truncate_tables
1176                .remove(&barrier_state.barrier.epoch.curr)
1177                .unwrap_or_default()
1178                .into_iter()
1179                .collect();
1180
1181            let refresh_finished_tables = self
1182                .refresh_finished_tables
1183                .remove(&barrier_state.barrier.epoch.curr)
1184                .unwrap_or_default()
1185                .into_iter()
1186                .collect();
1187            let prev_state = replace(
1188                &mut barrier_state.inner,
1189                ManagedBarrierStateInner::AllCollected {
1190                    create_mview_progress,
1191                    list_finished_source_ids,
1192                    load_finished_source_ids,
1193                    truncate_tables,
1194                    refresh_finished_tables,
1195                    cdc_table_backfill_progress,
1196                    cdc_source_offset_updated,
1197                },
1198            );
1199
1200            must_match!(prev_state, ManagedBarrierStateInner::Issued(IssuedState {
1201                barrier_inflight_latency: timer,
1202                ..
1203            }) => {
1204                timer.observe_duration();
1205            });
1206
1207            return Some(barrier_state.barrier.clone());
1208        }
1209        None
1210    }
1211
1212    fn pop_barrier_to_complete(&mut self, prev_epoch: u64) -> BarrierToComplete {
1213        let (popped_prev_epoch, barrier_state) = self
1214            .epoch_barrier_state_map
1215            .pop_first()
1216            .expect("should exist");
1217
1218        assert_eq!(prev_epoch, popped_prev_epoch);
1219
1220        let (
1221            create_mview_progress,
1222            list_finished_source_ids,
1223            load_finished_source_ids,
1224            cdc_table_backfill_progress,
1225            cdc_source_offset_updated,
1226            truncate_tables,
1227            refresh_finished_tables,
1228        ) = must_match!(barrier_state.inner, ManagedBarrierStateInner::AllCollected {
1229            create_mview_progress,
1230            list_finished_source_ids,
1231            load_finished_source_ids,
1232            truncate_tables,
1233            refresh_finished_tables,
1234            cdc_table_backfill_progress,
1235            cdc_source_offset_updated,
1236        } => {
1237            (create_mview_progress, list_finished_source_ids, load_finished_source_ids, cdc_table_backfill_progress, cdc_source_offset_updated, truncate_tables, refresh_finished_tables)
1238        });
1239        BarrierToComplete {
1240            barrier: barrier_state.barrier,
1241            table_ids: barrier_state.table_ids,
1242            create_mview_progress,
1243            list_finished_source_ids,
1244            load_finished_source_ids,
1245            truncate_tables,
1246            refresh_finished_tables,
1247            cdc_table_backfill_progress,
1248            cdc_source_offset_updated,
1249        }
1250    }
1251}
1252
1253pub(crate) struct BarrierToComplete {
1254    pub barrier: Barrier,
1255    pub table_ids: Option<HashSet<TableId>>,
1256    pub create_mview_progress: Vec<PbCreateMviewProgress>,
1257    pub list_finished_source_ids: Vec<PbListFinishedSource>,
1258    pub load_finished_source_ids: Vec<PbLoadFinishedSource>,
1259    pub truncate_tables: Vec<TableId>,
1260    pub refresh_finished_tables: Vec<TableId>,
1261    pub cdc_table_backfill_progress: Vec<PbCdcTableBackfillProgress>,
1262    pub cdc_source_offset_updated: Vec<PbCdcSourceOffsetUpdated>,
1263}
1264
1265impl PartialGraphManagedBarrierState {
1266    /// Collect a `barrier` from the actor with `actor_id`.
1267    pub(super) fn collect(&mut self, actor_id: impl Into<ActorId>, epoch: EpochPair) {
1268        let actor_id = actor_id.into();
1269        tracing::debug!(
1270            target: "events::stream::barrier::manager::collect",
1271            ?epoch, %actor_id, state = ?self.epoch_barrier_state_map,
1272            "collect_barrier",
1273        );
1274
1275        match self.epoch_barrier_state_map.get_mut(&epoch.prev) {
1276            None => {
1277                // If the barrier's state is stashed, this occurs exclusively in scenarios where the barrier has not been
1278                // injected by the barrier manager, or the barrier message is blocked at the `RemoteInput` side waiting for injection.
1279                // Given these conditions, it's inconceivable for an actor to attempt collect at this point.
1280                panic!(
1281                    "cannot collect new actor barrier {:?} at current state: None",
1282                    epoch,
1283                )
1284            }
1285            Some(&mut BarrierState {
1286                ref barrier,
1287                inner:
1288                    ManagedBarrierStateInner::Issued(IssuedState {
1289                        ref mut remaining_actors,
1290                        ..
1291                    }),
1292                ..
1293            }) => {
1294                let exist = remaining_actors.remove(&actor_id);
1295                assert!(
1296                    exist,
1297                    "the actor doesn't exist. actor_id: {:?}, curr_epoch: {:?}",
1298                    actor_id, epoch.curr
1299                );
1300                assert_eq!(barrier.epoch.curr, epoch.curr);
1301            }
1302            Some(BarrierState { inner, .. }) => {
1303                panic!(
1304                    "cannot collect new actor barrier {:?} at current state: {:?}",
1305                    epoch, inner
1306                )
1307            }
1308        }
1309    }
1310
1311    /// When the meta service issues a `send_barrier` request, call this function to transform to
1312    /// `Issued` and start to collect or to notify.
1313    pub(super) fn transform_to_issued(
1314        &mut self,
1315        barrier: &Barrier,
1316        actor_ids_to_collect: impl IntoIterator<Item = ActorId>,
1317        table_ids: HashSet<TableId>,
1318    ) {
1319        let timer = self
1320            .streaming_metrics
1321            .barrier_inflight_latency
1322            .start_timer();
1323
1324        if let Some(hummock) = self.state_store.as_hummock() {
1325            hummock.start_epoch(barrier.epoch.curr, table_ids.clone());
1326        }
1327
1328        let table_ids = match barrier.kind {
1329            BarrierKind::Unspecified => {
1330                unreachable!()
1331            }
1332            BarrierKind::Initial => {
1333                assert!(
1334                    self.prev_barrier_table_ids.is_none(),
1335                    "non empty table_ids at initial barrier: {:?}",
1336                    self.prev_barrier_table_ids
1337                );
1338                info!(epoch = ?barrier.epoch, "initialize at Initial barrier");
1339                self.prev_barrier_table_ids = Some((barrier.epoch, table_ids));
1340                None
1341            }
1342            BarrierKind::Barrier => {
1343                if let Some((prev_epoch, prev_table_ids)) = self.prev_barrier_table_ids.as_mut() {
1344                    assert_eq!(prev_epoch.curr, barrier.epoch.prev);
1345                    assert_eq!(prev_table_ids, &table_ids);
1346                    *prev_epoch = barrier.epoch;
1347                } else {
1348                    info!(epoch = ?barrier.epoch, "initialize at non-checkpoint barrier");
1349                    self.prev_barrier_table_ids = Some((barrier.epoch, table_ids));
1350                }
1351                None
1352            }
1353            BarrierKind::Checkpoint => Some(
1354                if let Some((prev_epoch, prev_table_ids)) = self
1355                    .prev_barrier_table_ids
1356                    .replace((barrier.epoch, table_ids))
1357                    && prev_epoch.curr == barrier.epoch.prev
1358                {
1359                    prev_table_ids
1360                } else {
1361                    debug!(epoch = ?barrier.epoch, "reinitialize at Checkpoint barrier");
1362                    HashSet::new()
1363                },
1364            ),
1365        };
1366
1367        if let Some(&mut BarrierState { ref inner, .. }) =
1368            self.epoch_barrier_state_map.get_mut(&barrier.epoch.prev)
1369        {
1370            {
1371                panic!(
1372                    "barrier epochs{:?} state has already been `Issued`. Current state: {:?}",
1373                    barrier.epoch, inner
1374                );
1375            }
1376        };
1377
1378        self.epoch_barrier_state_map.insert(
1379            barrier.epoch.prev,
1380            BarrierState {
1381                barrier: barrier.clone(),
1382                inner: ManagedBarrierStateInner::Issued(IssuedState {
1383                    remaining_actors: BTreeSet::from_iter(actor_ids_to_collect),
1384                    barrier_inflight_latency: timer,
1385                }),
1386                table_ids,
1387            },
1388        );
1389    }
1390
1391    #[cfg(test)]
1392    async fn pop_next_completed_epoch(&mut self) -> u64 {
1393        if let Some(barrier) = self.may_have_collected_all() {
1394            self.pop_barrier_to_complete(barrier.epoch.prev);
1395            return barrier.epoch.prev;
1396        }
1397        pending().await
1398    }
1399}
1400
1401#[cfg(test)]
1402mod tests {
1403    use std::collections::HashSet;
1404
1405    use risingwave_common::util::epoch::test_epoch;
1406
1407    use crate::executor::Barrier;
1408    use crate::task::barrier_worker::managed_state::PartialGraphManagedBarrierState;
1409
1410    #[tokio::test]
1411    async fn test_managed_state_add_actor() {
1412        let mut managed_barrier_state = PartialGraphManagedBarrierState::for_test();
1413        let barrier1 = Barrier::new_test_barrier(test_epoch(1));
1414        let barrier2 = Barrier::new_test_barrier(test_epoch(2));
1415        let barrier3 = Barrier::new_test_barrier(test_epoch(3));
1416        let actor_ids_to_collect1 = HashSet::from([1.into(), 2.into()]);
1417        let actor_ids_to_collect2 = HashSet::from([1.into(), 2.into()]);
1418        let actor_ids_to_collect3 = HashSet::from([1.into(), 2.into(), 3.into()]);
1419        managed_barrier_state.transform_to_issued(&barrier1, actor_ids_to_collect1, HashSet::new());
1420        managed_barrier_state.transform_to_issued(&barrier2, actor_ids_to_collect2, HashSet::new());
1421        managed_barrier_state.transform_to_issued(&barrier3, actor_ids_to_collect3, HashSet::new());
1422        managed_barrier_state.collect(1, barrier1.epoch);
1423        managed_barrier_state.collect(2, barrier1.epoch);
1424        assert_eq!(
1425            managed_barrier_state.pop_next_completed_epoch().await,
1426            test_epoch(0)
1427        );
1428        assert_eq!(
1429            managed_barrier_state
1430                .epoch_barrier_state_map
1431                .first_key_value()
1432                .unwrap()
1433                .0,
1434            &test_epoch(1)
1435        );
1436        managed_barrier_state.collect(1, barrier2.epoch);
1437        managed_barrier_state.collect(1, barrier3.epoch);
1438        managed_barrier_state.collect(2, barrier2.epoch);
1439        assert_eq!(
1440            managed_barrier_state.pop_next_completed_epoch().await,
1441            test_epoch(1)
1442        );
1443        assert_eq!(
1444            managed_barrier_state
1445                .epoch_barrier_state_map
1446                .first_key_value()
1447                .unwrap()
1448                .0,
1449            &test_epoch(2)
1450        );
1451        managed_barrier_state.collect(2, barrier3.epoch);
1452        managed_barrier_state.collect(3, barrier3.epoch);
1453        assert_eq!(
1454            managed_barrier_state.pop_next_completed_epoch().await,
1455            test_epoch(2)
1456        );
1457        assert!(managed_barrier_state.epoch_barrier_state_map.is_empty());
1458    }
1459
1460    #[tokio::test]
1461    async fn test_managed_state_stop_actor() {
1462        let mut managed_barrier_state = PartialGraphManagedBarrierState::for_test();
1463        let barrier1 = Barrier::new_test_barrier(test_epoch(1));
1464        let barrier2 = Barrier::new_test_barrier(test_epoch(2));
1465        let barrier3 = Barrier::new_test_barrier(test_epoch(3));
1466        let actor_ids_to_collect1 = HashSet::from([1.into(), 2.into(), 3.into(), 4.into()]);
1467        let actor_ids_to_collect2 = HashSet::from([1.into(), 2.into(), 3.into()]);
1468        let actor_ids_to_collect3 = HashSet::from([1.into(), 2.into()]);
1469        managed_barrier_state.transform_to_issued(&barrier1, actor_ids_to_collect1, HashSet::new());
1470        managed_barrier_state.transform_to_issued(&barrier2, actor_ids_to_collect2, HashSet::new());
1471        managed_barrier_state.transform_to_issued(&barrier3, actor_ids_to_collect3, HashSet::new());
1472
1473        managed_barrier_state.collect(1, barrier1.epoch);
1474        managed_barrier_state.collect(1, barrier2.epoch);
1475        managed_barrier_state.collect(1, barrier3.epoch);
1476        managed_barrier_state.collect(2, barrier1.epoch);
1477        managed_barrier_state.collect(2, barrier2.epoch);
1478        managed_barrier_state.collect(2, barrier3.epoch);
1479        assert_eq!(
1480            managed_barrier_state
1481                .epoch_barrier_state_map
1482                .first_key_value()
1483                .unwrap()
1484                .0,
1485            &0
1486        );
1487        managed_barrier_state.collect(3, barrier1.epoch);
1488        managed_barrier_state.collect(3, barrier2.epoch);
1489        assert_eq!(
1490            managed_barrier_state
1491                .epoch_barrier_state_map
1492                .first_key_value()
1493                .unwrap()
1494                .0,
1495            &0
1496        );
1497        managed_barrier_state.collect(4, barrier1.epoch);
1498        assert_eq!(
1499            managed_barrier_state.pop_next_completed_epoch().await,
1500            test_epoch(0)
1501        );
1502        assert_eq!(
1503            managed_barrier_state.pop_next_completed_epoch().await,
1504            test_epoch(1)
1505        );
1506        assert_eq!(
1507            managed_barrier_state.pop_next_completed_epoch().await,
1508            test_epoch(2)
1509        );
1510        assert!(managed_barrier_state.epoch_barrier_state_map.is_empty());
1511    }
1512}