risingwave_stream/task/barrier_manager/
managed_state.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::LazyCell;
16use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
17use std::fmt::{Debug, Display, Formatter};
18use std::future::{Future, pending, poll_fn};
19use std::mem::replace;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use std::time::Instant;
23
24use anyhow::anyhow;
25use futures::FutureExt;
26use futures::stream::FuturesOrdered;
27use prometheus::HistogramTimer;
28use risingwave_common::catalog::{DatabaseId, TableId};
29use risingwave_common::util::epoch::EpochPair;
30use risingwave_pb::stream_plan::barrier::BarrierKind;
31use risingwave_storage::StateStoreImpl;
32use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
33use tokio::sync::{mpsc, oneshot};
34use tokio::task::JoinHandle;
35
36use super::progress::BackfillState;
37use crate::error::{StreamError, StreamResult};
38use crate::executor::Barrier;
39use crate::executor::monitor::StreamingMetrics;
40use crate::task::{
41    ActorId, LocalBarrierManager, NewOutputRequest, PartialGraphId, StreamActorManager,
42    UpDownActorIds,
43};
44
45struct IssuedState {
46    /// Actor ids remaining to be collected.
47    pub remaining_actors: BTreeSet<ActorId>,
48
49    pub barrier_inflight_latency: HistogramTimer,
50}
51
52impl Debug for IssuedState {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("IssuedState")
55            .field("remaining_actors", &self.remaining_actors)
56            .finish()
57    }
58}
59
60/// The state machine of local barrier manager.
61#[derive(Debug)]
62enum ManagedBarrierStateInner {
63    /// Meta service has issued a `send_barrier` request. We're collecting barriers now.
64    Issued(IssuedState),
65
66    /// The barrier has been collected by all remaining actors
67    AllCollected(Vec<PbCreateMviewProgress>),
68}
69
70#[derive(Debug)]
71struct BarrierState {
72    barrier: Barrier,
73    /// Only be `Some(_)` when `barrier.kind` is `Checkpoint`
74    table_ids: Option<HashSet<TableId>>,
75    inner: ManagedBarrierStateInner,
76}
77
78use risingwave_common::must_match;
79use risingwave_pb::stream_plan::SubscriptionUpstreamInfo;
80use risingwave_pb::stream_service::InjectBarrierRequest;
81use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress;
82use risingwave_pb::stream_service::streaming_control_stream_request::{
83    DatabaseInitialPartialGraph, InitialPartialGraph,
84};
85
86use crate::executor::exchange::permit;
87use crate::executor::exchange::permit::channel_from_config;
88use crate::task::barrier_manager::await_epoch_completed_future::AwaitEpochCompletedFuture;
89use crate::task::barrier_manager::{LocalBarrierEvent, ScoredStreamError};
90
91pub(super) struct ManagedBarrierStateDebugInfo<'a> {
92    running_actors: BTreeSet<ActorId>,
93    graph_states: &'a HashMap<PartialGraphId, PartialGraphManagedBarrierState>,
94}
95
96impl Display for ManagedBarrierStateDebugInfo<'_> {
97    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
98        write!(f, "running_actors: ")?;
99        for actor_id in &self.running_actors {
100            write!(f, "{}, ", actor_id)?;
101        }
102        for (partial_graph_id, graph_states) in self.graph_states {
103            writeln!(f, "--- Partial Group {}", partial_graph_id.0)?;
104            write!(f, "{}", graph_states)?;
105        }
106        Ok(())
107    }
108}
109
110impl Display for &'_ PartialGraphManagedBarrierState {
111    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
112        let mut prev_epoch = 0u64;
113        for (epoch, barrier_state) in &self.epoch_barrier_state_map {
114            write!(f, "> Epoch {}: ", epoch)?;
115            match &barrier_state.inner {
116                ManagedBarrierStateInner::Issued(state) => {
117                    write!(
118                        f,
119                        "Issued [{:?}]. Remaining actors: [",
120                        barrier_state.barrier.kind
121                    )?;
122                    let mut is_prev_epoch_issued = false;
123                    if prev_epoch != 0 {
124                        let bs = &self.epoch_barrier_state_map[&prev_epoch];
125                        if let ManagedBarrierStateInner::Issued(IssuedState {
126                            remaining_actors: remaining_actors_prev,
127                            ..
128                        }) = &bs.inner
129                        {
130                            // Only show the actors that are not in the previous epoch.
131                            is_prev_epoch_issued = true;
132                            let mut duplicates = 0usize;
133                            for actor_id in &state.remaining_actors {
134                                if !remaining_actors_prev.contains(actor_id) {
135                                    write!(f, "{}, ", actor_id)?;
136                                } else {
137                                    duplicates += 1;
138                                }
139                            }
140                            if duplicates > 0 {
141                                write!(f, "...and {} actors in prev epoch", duplicates)?;
142                            }
143                        }
144                    }
145                    if !is_prev_epoch_issued {
146                        for actor_id in &state.remaining_actors {
147                            write!(f, "{}, ", actor_id)?;
148                        }
149                    }
150                    write!(f, "]")?;
151                }
152                ManagedBarrierStateInner::AllCollected(_) => {
153                    write!(f, "AllCollected")?;
154                }
155            }
156            prev_epoch = *epoch;
157            writeln!(f)?;
158        }
159
160        if !self.create_mview_progress.is_empty() {
161            writeln!(f, "Create MView Progress:")?;
162            for (epoch, progress) in &self.create_mview_progress {
163                write!(f, "> Epoch {}:", epoch)?;
164                for (actor_id, state) in progress {
165                    write!(f, ">> Actor {}: {}, ", actor_id, state)?;
166                }
167            }
168        }
169
170        Ok(())
171    }
172}
173
174enum InflightActorStatus {
175    /// The actor has been issued some barriers, but has not collected the first barrier
176    IssuedFirst(Vec<Barrier>),
177    /// The actor has been issued some barriers, and has collected the first barrier
178    Running(u64),
179}
180
181impl InflightActorStatus {
182    fn max_issued_epoch(&self) -> u64 {
183        match self {
184            InflightActorStatus::Running(epoch) => *epoch,
185            InflightActorStatus::IssuedFirst(issued_barriers) => {
186                issued_barriers.last().expect("non-empty").epoch.prev
187            }
188        }
189    }
190}
191
192pub(crate) struct InflightActorState {
193    actor_id: ActorId,
194    barrier_senders: Vec<mpsc::UnboundedSender<Barrier>>,
195    /// `prev_epoch` -> partial graph id
196    pub(super) inflight_barriers: BTreeMap<u64, PartialGraphId>,
197    status: InflightActorStatus,
198    /// Whether the actor has been issued a stop barrier
199    is_stopping: bool,
200
201    new_output_request_tx: UnboundedSender<(ActorId, NewOutputRequest)>,
202    join_handle: JoinHandle<()>,
203    monitor_task_handle: Option<JoinHandle<()>>,
204}
205
206impl InflightActorState {
207    pub(super) fn start(
208        actor_id: ActorId,
209        initial_partial_graph_id: PartialGraphId,
210        initial_barrier: &Barrier,
211        new_output_request_tx: UnboundedSender<(ActorId, NewOutputRequest)>,
212        join_handle: JoinHandle<()>,
213        monitor_task_handle: Option<JoinHandle<()>>,
214    ) -> Self {
215        Self {
216            actor_id,
217            barrier_senders: vec![],
218            inflight_barriers: BTreeMap::from_iter([(
219                initial_barrier.epoch.prev,
220                initial_partial_graph_id,
221            )]),
222            status: InflightActorStatus::IssuedFirst(vec![initial_barrier.clone()]),
223            is_stopping: false,
224            new_output_request_tx,
225            join_handle,
226            monitor_task_handle,
227        }
228    }
229
230    pub(super) fn issue_barrier(
231        &mut self,
232        partial_graph_id: PartialGraphId,
233        barrier: &Barrier,
234        is_stop: bool,
235    ) -> StreamResult<()> {
236        assert!(barrier.epoch.prev > self.status.max_issued_epoch());
237
238        for barrier_sender in &self.barrier_senders {
239            barrier_sender.send(barrier.clone()).map_err(|_| {
240                StreamError::barrier_send(
241                    barrier.clone(),
242                    self.actor_id,
243                    "failed to send to registered sender",
244                )
245            })?;
246        }
247
248        assert!(
249            self.inflight_barriers
250                .insert(barrier.epoch.prev, partial_graph_id)
251                .is_none()
252        );
253
254        match &mut self.status {
255            InflightActorStatus::IssuedFirst(pending_barriers) => {
256                pending_barriers.push(barrier.clone());
257            }
258            InflightActorStatus::Running(prev_epoch) => {
259                *prev_epoch = barrier.epoch.prev;
260            }
261        };
262
263        if is_stop {
264            assert!(!self.is_stopping, "stopped actor should not issue barrier");
265            self.is_stopping = true;
266        }
267        Ok(())
268    }
269
270    pub(super) fn collect(&mut self, epoch: EpochPair) -> (PartialGraphId, bool) {
271        let (prev_epoch, prev_partial_graph_id) =
272            self.inflight_barriers.pop_first().expect("should exist");
273        assert_eq!(prev_epoch, epoch.prev);
274        match &self.status {
275            InflightActorStatus::IssuedFirst(pending_barriers) => {
276                assert_eq!(
277                    prev_epoch,
278                    pending_barriers.first().expect("non-empty").epoch.prev
279                );
280                self.status = InflightActorStatus::Running(
281                    pending_barriers.last().expect("non-empty").epoch.prev,
282                );
283            }
284            InflightActorStatus::Running(_) => {}
285        }
286        (
287            prev_partial_graph_id,
288            self.inflight_barriers.is_empty() && self.is_stopping,
289        )
290    }
291}
292
293pub(super) struct PartialGraphManagedBarrierState {
294    /// Record barrier state for each epoch of concurrent checkpoints.
295    ///
296    /// The key is `prev_epoch`, and the first value is `curr_epoch`
297    epoch_barrier_state_map: BTreeMap<u64, BarrierState>,
298
299    prev_barrier_table_ids: Option<(EpochPair, HashSet<TableId>)>,
300
301    mv_depended_subscriptions: HashMap<TableId, HashSet<u32>>,
302
303    /// Record the progress updates of creating mviews for each epoch of concurrent checkpoints.
304    ///
305    /// This is updated by [`super::CreateMviewProgressReporter::update`] and will be reported to meta
306    /// in [`crate::task::barrier_manager::BarrierCompleteResult`].
307    pub(super) create_mview_progress: HashMap<u64, HashMap<ActorId, BackfillState>>,
308
309    state_store: StateStoreImpl,
310
311    streaming_metrics: Arc<StreamingMetrics>,
312}
313
314impl PartialGraphManagedBarrierState {
315    pub(super) fn new(actor_manager: &StreamActorManager) -> Self {
316        Self::new_inner(
317            actor_manager.env.state_store(),
318            actor_manager.streaming_metrics.clone(),
319        )
320    }
321
322    fn new_inner(state_store: StateStoreImpl, streaming_metrics: Arc<StreamingMetrics>) -> Self {
323        Self {
324            epoch_barrier_state_map: Default::default(),
325            prev_barrier_table_ids: None,
326            mv_depended_subscriptions: Default::default(),
327            create_mview_progress: Default::default(),
328            state_store,
329            streaming_metrics,
330        }
331    }
332
333    #[cfg(test)]
334    pub(crate) fn for_test() -> Self {
335        Self::new_inner(
336            StateStoreImpl::for_test(),
337            Arc::new(StreamingMetrics::unused()),
338        )
339    }
340
341    pub(super) fn is_empty(&self) -> bool {
342        self.epoch_barrier_state_map.is_empty()
343    }
344}
345
346pub(crate) struct SuspendedDatabaseState {
347    pub(super) suspend_time: Instant,
348    inner: DatabaseManagedBarrierState,
349    failure: Option<(Option<ActorId>, StreamError)>,
350}
351
352impl SuspendedDatabaseState {
353    fn new(
354        state: DatabaseManagedBarrierState,
355        failure: Option<(Option<ActorId>, StreamError)>,
356        _completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>, /* discard the completing futures */
357    ) -> Self {
358        Self {
359            suspend_time: Instant::now(),
360            inner: state,
361            failure,
362        }
363    }
364
365    async fn reset(mut self) -> ResetDatabaseOutput {
366        let root_err = self.inner.try_find_root_actor_failure(self.failure).await;
367        self.inner.abort_and_wait_actors().await;
368        if let Some(hummock) = self.inner.actor_manager.env.state_store().as_hummock() {
369            hummock.clear_tables(self.inner.table_ids).await;
370        }
371        ResetDatabaseOutput { root_err }
372    }
373}
374
375pub(crate) struct ResettingDatabaseState {
376    join_handle: JoinHandle<ResetDatabaseOutput>,
377    reset_request_id: u32,
378}
379
380pub(crate) struct ResetDatabaseOutput {
381    pub(crate) root_err: Option<ScoredStreamError>,
382}
383
384pub(crate) enum DatabaseStatus {
385    ReceivedExchangeRequest(
386        Vec<(
387            UpDownActorIds,
388            oneshot::Sender<StreamResult<permit::Receiver>>,
389        )>,
390    ),
391    Running(DatabaseManagedBarrierState),
392    Suspended(SuspendedDatabaseState),
393    Resetting(ResettingDatabaseState),
394    /// temporary place holder
395    Unspecified,
396}
397
398impl DatabaseStatus {
399    pub(crate) async fn abort(&mut self) {
400        match self {
401            DatabaseStatus::ReceivedExchangeRequest(pending_requests) => {
402                for (_, sender) in pending_requests.drain(..) {
403                    let _ = sender.send(Err(anyhow!("database aborted").into()));
404                }
405            }
406            DatabaseStatus::Running(state) => {
407                state.abort_and_wait_actors().await;
408            }
409            DatabaseStatus::Suspended(SuspendedDatabaseState { inner: state, .. }) => {
410                state.abort_and_wait_actors().await;
411            }
412            DatabaseStatus::Resetting(state) => {
413                (&mut state.join_handle)
414                    .await
415                    .expect("failed to join reset database join handle");
416            }
417            DatabaseStatus::Unspecified => {
418                unreachable!()
419            }
420        }
421    }
422
423    pub(crate) fn state_for_request(&mut self) -> Option<&mut DatabaseManagedBarrierState> {
424        match self {
425            DatabaseStatus::ReceivedExchangeRequest(_) => {
426                unreachable!("should not handle request")
427            }
428            DatabaseStatus::Running(state) => Some(state),
429            DatabaseStatus::Suspended(_) => None,
430            DatabaseStatus::Resetting(_) => {
431                unreachable!("should not receive further request during cleaning")
432            }
433            DatabaseStatus::Unspecified => {
434                unreachable!()
435            }
436        }
437    }
438
439    pub(super) fn poll_next_event(
440        &mut self,
441        cx: &mut Context<'_>,
442    ) -> Poll<ManagedBarrierStateEvent> {
443        match self {
444            DatabaseStatus::ReceivedExchangeRequest(_) => Poll::Pending,
445            DatabaseStatus::Running(state) => state.poll_next_event(cx),
446            DatabaseStatus::Suspended(_) => Poll::Pending,
447            DatabaseStatus::Resetting(state) => state.join_handle.poll_unpin(cx).map(|result| {
448                let output = result.expect("should be able to join");
449                ManagedBarrierStateEvent::DatabaseReset(output, state.reset_request_id)
450            }),
451            DatabaseStatus::Unspecified => {
452                unreachable!()
453            }
454        }
455    }
456
457    pub(super) fn suspend(
458        &mut self,
459        failed_actor: Option<ActorId>,
460        err: StreamError,
461        completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>,
462    ) {
463        let state = must_match!(replace(self, DatabaseStatus::Unspecified), DatabaseStatus::Running(state) => state);
464        *self = DatabaseStatus::Suspended(SuspendedDatabaseState::new(
465            state,
466            Some((failed_actor, err)),
467            completing_futures,
468        ));
469    }
470
471    pub(super) fn start_reset(
472        &mut self,
473        database_id: DatabaseId,
474        completing_futures: Option<FuturesOrdered<AwaitEpochCompletedFuture>>,
475        reset_request_id: u32,
476    ) {
477        let join_handle = match replace(self, DatabaseStatus::Unspecified) {
478            DatabaseStatus::ReceivedExchangeRequest(pending_requests) => {
479                for (_, sender) in pending_requests {
480                    let _ = sender.send(Err(anyhow!("database reset").into()));
481                }
482                tokio::spawn(async move { ResetDatabaseOutput { root_err: None } })
483            }
484            DatabaseStatus::Running(state) => {
485                assert_eq!(database_id, state.database_id);
486                info!(
487                    database_id = database_id.database_id,
488                    reset_request_id, "start database reset from Running"
489                );
490                tokio::spawn(SuspendedDatabaseState::new(state, None, completing_futures).reset())
491            }
492            DatabaseStatus::Suspended(state) => {
493                assert!(
494                    completing_futures.is_none(),
495                    "should have been clear when suspended"
496                );
497                assert_eq!(database_id, state.inner.database_id);
498                info!(
499                    database_id = database_id.database_id,
500                    reset_request_id,
501                    suspend_elapsed = ?state.suspend_time.elapsed(),
502                    "start database reset after suspended"
503                );
504                tokio::spawn(state.reset())
505            }
506            DatabaseStatus::Resetting(state) => {
507                let prev_request_id = state.reset_request_id;
508                info!(
509                    database_id = database_id.database_id,
510                    reset_request_id, prev_request_id, "receive duplicate reset request"
511                );
512                assert!(reset_request_id > prev_request_id);
513                state.join_handle
514            }
515            DatabaseStatus::Unspecified => {
516                unreachable!()
517            }
518        };
519        *self = DatabaseStatus::Resetting(ResettingDatabaseState {
520            join_handle,
521            reset_request_id,
522        });
523    }
524}
525
526pub(crate) struct ManagedBarrierState {
527    pub(crate) databases: HashMap<DatabaseId, DatabaseStatus>,
528}
529
530pub(super) enum ManagedBarrierStateEvent {
531    BarrierCollected {
532        partial_graph_id: PartialGraphId,
533        barrier: Barrier,
534    },
535    ActorError {
536        actor_id: ActorId,
537        err: StreamError,
538    },
539    DatabaseReset(ResetDatabaseOutput, u32),
540}
541
542impl ManagedBarrierState {
543    pub(super) fn new(
544        actor_manager: Arc<StreamActorManager>,
545        initial_partial_graphs: Vec<DatabaseInitialPartialGraph>,
546        term_id: String,
547    ) -> Self {
548        let mut databases = HashMap::new();
549        for database in initial_partial_graphs {
550            let database_id = DatabaseId::new(database.database_id);
551            assert!(!databases.contains_key(&database_id));
552            let state = DatabaseManagedBarrierState::new(
553                database_id,
554                term_id.clone(),
555                actor_manager.clone(),
556                database.graphs,
557            );
558            databases.insert(database_id, DatabaseStatus::Running(state));
559        }
560
561        Self { databases }
562    }
563
564    pub(super) fn next_event(
565        &mut self,
566    ) -> impl Future<Output = (DatabaseId, ManagedBarrierStateEvent)> + '_ {
567        poll_fn(|cx| {
568            for (database_id, database) in &mut self.databases {
569                if let Poll::Ready(event) = database.poll_next_event(cx) {
570                    return Poll::Ready((*database_id, event));
571                }
572            }
573            Poll::Pending
574        })
575    }
576}
577
578pub(crate) struct DatabaseManagedBarrierState {
579    database_id: DatabaseId,
580    pub(super) actor_states: HashMap<ActorId, InflightActorState>,
581    pub(super) actor_pending_new_output_requests:
582        HashMap<ActorId, Vec<(ActorId, NewOutputRequest)>>,
583
584    pub(super) graph_states: HashMap<PartialGraphId, PartialGraphManagedBarrierState>,
585
586    table_ids: HashSet<TableId>,
587
588    actor_manager: Arc<StreamActorManager>,
589
590    pub(super) local_barrier_manager: LocalBarrierManager,
591
592    barrier_event_rx: UnboundedReceiver<LocalBarrierEvent>,
593    pub(super) actor_failure_rx: UnboundedReceiver<(ActorId, StreamError)>,
594}
595
596impl DatabaseManagedBarrierState {
597    /// Create a barrier manager state. This will be called only once.
598    pub(super) fn new(
599        database_id: DatabaseId,
600        term_id: String,
601        actor_manager: Arc<StreamActorManager>,
602        initial_partial_graphs: Vec<InitialPartialGraph>,
603    ) -> Self {
604        let (local_barrier_manager, barrier_event_rx, actor_failure_rx) =
605            LocalBarrierManager::new(database_id, term_id, actor_manager.env.clone());
606        Self {
607            database_id,
608            actor_states: Default::default(),
609            actor_pending_new_output_requests: Default::default(),
610            graph_states: initial_partial_graphs
611                .into_iter()
612                .map(|graph| {
613                    let mut state = PartialGraphManagedBarrierState::new(&actor_manager);
614                    state.add_subscriptions(graph.subscriptions);
615                    (PartialGraphId::new(graph.partial_graph_id), state)
616                })
617                .collect(),
618            table_ids: Default::default(),
619            actor_manager,
620            local_barrier_manager,
621            barrier_event_rx,
622            actor_failure_rx,
623        }
624    }
625
626    pub(super) fn to_debug_info(&self) -> ManagedBarrierStateDebugInfo<'_> {
627        ManagedBarrierStateDebugInfo {
628            running_actors: self.actor_states.keys().cloned().collect(),
629            graph_states: &self.graph_states,
630        }
631    }
632
633    async fn abort_and_wait_actors(&mut self) {
634        for (actor_id, state) in &self.actor_states {
635            tracing::debug!("force stopping actor {}", actor_id);
636            state.join_handle.abort();
637            if let Some(monitor_task_handle) = &state.monitor_task_handle {
638                monitor_task_handle.abort();
639            }
640        }
641
642        for (actor_id, state) in self.actor_states.drain() {
643            tracing::debug!("join actor {}", actor_id);
644            let result = state.join_handle.await;
645            assert!(result.is_ok() || result.unwrap_err().is_cancelled());
646        }
647    }
648}
649
650impl InflightActorState {
651    pub(super) fn register_barrier_sender(
652        &mut self,
653        tx: mpsc::UnboundedSender<Barrier>,
654    ) -> StreamResult<()> {
655        match &self.status {
656            InflightActorStatus::IssuedFirst(pending_barriers) => {
657                for barrier in pending_barriers {
658                    tx.send(barrier.clone()).map_err(|_| {
659                        StreamError::barrier_send(
660                            barrier.clone(),
661                            self.actor_id,
662                            "failed to send pending barriers to newly registered sender",
663                        )
664                    })?;
665                }
666                self.barrier_senders.push(tx);
667            }
668            InflightActorStatus::Running(_) => {
669                unreachable!("should not register barrier sender when entering Running status")
670            }
671        }
672        Ok(())
673    }
674}
675
676impl DatabaseManagedBarrierState {
677    pub(super) fn register_barrier_sender(
678        &mut self,
679        actor_id: ActorId,
680        tx: mpsc::UnboundedSender<Barrier>,
681    ) -> StreamResult<()> {
682        self.actor_states
683            .get_mut(&actor_id)
684            .expect("should exist")
685            .register_barrier_sender(tx)
686    }
687}
688
689impl PartialGraphManagedBarrierState {
690    pub(super) fn add_subscriptions(&mut self, subscriptions: Vec<SubscriptionUpstreamInfo>) {
691        for subscription_to_add in subscriptions {
692            if !self
693                .mv_depended_subscriptions
694                .entry(TableId::new(subscription_to_add.upstream_mv_table_id))
695                .or_default()
696                .insert(subscription_to_add.subscriber_id)
697            {
698                if cfg!(debug_assertions) {
699                    panic!("add an existing subscription: {:?}", subscription_to_add);
700                }
701                warn!(?subscription_to_add, "add an existing subscription");
702            }
703        }
704    }
705
706    pub(super) fn remove_subscriptions(&mut self, subscriptions: Vec<SubscriptionUpstreamInfo>) {
707        for subscription_to_remove in subscriptions {
708            let upstream_table_id = TableId::new(subscription_to_remove.upstream_mv_table_id);
709            let Some(subscribers) = self.mv_depended_subscriptions.get_mut(&upstream_table_id)
710            else {
711                if cfg!(debug_assertions) {
712                    panic!(
713                        "unable to find upstream mv table to remove: {:?}",
714                        subscription_to_remove
715                    );
716                }
717                warn!(
718                    ?subscription_to_remove,
719                    "unable to find upstream mv table to remove"
720                );
721                continue;
722            };
723            if !subscribers.remove(&subscription_to_remove.subscriber_id) {
724                if cfg!(debug_assertions) {
725                    panic!(
726                        "unable to find subscriber to remove: {:?}",
727                        subscription_to_remove
728                    );
729                }
730                warn!(
731                    ?subscription_to_remove,
732                    "unable to find subscriber to remove"
733                );
734            }
735            if subscribers.is_empty() {
736                self.mv_depended_subscriptions.remove(&upstream_table_id);
737            }
738        }
739    }
740}
741
742impl DatabaseManagedBarrierState {
743    pub(super) fn transform_to_issued(
744        &mut self,
745        barrier: &Barrier,
746        request: InjectBarrierRequest,
747    ) -> StreamResult<()> {
748        let partial_graph_id = PartialGraphId::new(request.partial_graph_id);
749        let actor_to_stop = barrier.all_stop_actors();
750        let is_stop_actor = |actor_id| {
751            actor_to_stop
752                .map(|actors| actors.contains(&actor_id))
753                .unwrap_or(false)
754        };
755        let graph_state = self
756            .graph_states
757            .get_mut(&partial_graph_id)
758            .expect("should exist");
759
760        graph_state.add_subscriptions(request.subscriptions_to_add);
761        graph_state.remove_subscriptions(request.subscriptions_to_remove);
762
763        let table_ids =
764            HashSet::from_iter(request.table_ids_to_sync.iter().cloned().map(TableId::new));
765        self.table_ids.extend(table_ids.iter().cloned());
766
767        graph_state.transform_to_issued(
768            barrier,
769            request.actor_ids_to_collect.iter().cloned(),
770            table_ids,
771        );
772
773        let mut new_actors = HashSet::new();
774        let subscriptions =
775            LazyCell::new(|| Arc::new(graph_state.mv_depended_subscriptions.clone()));
776        for (node, fragment_id, actor) in
777            request
778                .actors_to_build
779                .into_iter()
780                .flat_map(|fragment_actors| {
781                    let node = Arc::new(fragment_actors.node.unwrap());
782                    fragment_actors
783                        .actors
784                        .into_iter()
785                        .map(move |actor| (node.clone(), fragment_actors.fragment_id, actor))
786                })
787        {
788            let actor_id = actor.actor_id;
789            assert!(!is_stop_actor(actor_id));
790            assert!(new_actors.insert(actor_id));
791            assert!(request.actor_ids_to_collect.contains(&actor_id));
792            let (new_output_request_tx, new_output_request_rx) = unbounded_channel();
793            if let Some(pending_requests) = self.actor_pending_new_output_requests.remove(&actor_id)
794            {
795                for request in pending_requests {
796                    let _ = new_output_request_tx.send(request);
797                }
798            }
799            let (join_handle, monitor_join_handle) = self.actor_manager.spawn_actor(
800                actor,
801                fragment_id,
802                node,
803                (*subscriptions).clone(),
804                self.local_barrier_manager.clone(),
805                new_output_request_rx,
806            );
807            assert!(
808                self.actor_states
809                    .try_insert(
810                        actor_id,
811                        InflightActorState::start(
812                            actor_id,
813                            partial_graph_id,
814                            barrier,
815                            new_output_request_tx,
816                            join_handle,
817                            monitor_join_handle
818                        )
819                    )
820                    .is_ok()
821            );
822        }
823
824        // Spawn a trivial join handle to be compatible with the unit test. In the unit tests that involve local barrier manager,
825        // actors are spawned in the local test logic, but we assume that there is an entry for each spawned actor in ·actor_states`,
826        // so under cfg!(test) we add a dummy entry for each new actor.
827        if cfg!(test) {
828            for actor_id in &request.actor_ids_to_collect {
829                if !self.actor_states.contains_key(actor_id) {
830                    let (tx, rx) = unbounded_channel();
831                    let join_handle = self.actor_manager.runtime.spawn(async move {
832                        // The rx is spawned so that tx.send() will not fail.
833                        let _ = rx;
834                        pending().await
835                    });
836                    assert!(
837                        self.actor_states
838                            .try_insert(
839                                *actor_id,
840                                InflightActorState::start(
841                                    *actor_id,
842                                    partial_graph_id,
843                                    barrier,
844                                    tx,
845                                    join_handle,
846                                    None,
847                                )
848                            )
849                            .is_ok()
850                    );
851                    new_actors.insert(*actor_id);
852                }
853            }
854        }
855
856        // Note: it's important to issue barrier to actor after issuing to graph to ensure that
857        // we call `start_epoch` on the graph before the actors receive the barrier
858        for actor_id in &request.actor_ids_to_collect {
859            if new_actors.contains(actor_id) {
860                continue;
861            }
862            self.actor_states
863                .get_mut(actor_id)
864                .unwrap_or_else(|| {
865                    panic!(
866                        "should exist: {} {:?}",
867                        actor_id, request.actor_ids_to_collect
868                    );
869                })
870                .issue_barrier(partial_graph_id, barrier, is_stop_actor(*actor_id))?;
871        }
872
873        Ok(())
874    }
875
876    pub(super) fn new_actor_remote_output_request(
877        &mut self,
878        actor_id: ActorId,
879        upstream_actor_id: ActorId,
880        result_sender: oneshot::Sender<StreamResult<permit::Receiver>>,
881    ) {
882        let (tx, rx) = channel_from_config(self.local_barrier_manager.env.config());
883        self.new_actor_output_request(actor_id, upstream_actor_id, NewOutputRequest::Remote(tx));
884        let _ = result_sender.send(Ok(rx));
885    }
886
887    pub(super) fn new_actor_output_request(
888        &mut self,
889        actor_id: ActorId,
890        upstream_actor_id: ActorId,
891        request: NewOutputRequest,
892    ) {
893        if let Some(actor) = self.actor_states.get_mut(&upstream_actor_id) {
894            let _ = actor.new_output_request_tx.send((actor_id, request));
895        } else {
896            self.actor_pending_new_output_requests
897                .entry(upstream_actor_id)
898                .or_default()
899                .push((actor_id, request));
900        }
901    }
902
903    pub(super) fn poll_next_event(
904        &mut self,
905        cx: &mut Context<'_>,
906    ) -> Poll<ManagedBarrierStateEvent> {
907        if let Poll::Ready(option) = self.actor_failure_rx.poll_recv(cx) {
908            let (actor_id, err) = option.expect("non-empty when tx in local_barrier_manager");
909            return Poll::Ready(ManagedBarrierStateEvent::ActorError { actor_id, err });
910        }
911        // yield some pending collected epochs
912        for (partial_graph_id, graph_state) in &mut self.graph_states {
913            if let Some(barrier) = graph_state.may_have_collected_all() {
914                return Poll::Ready(ManagedBarrierStateEvent::BarrierCollected {
915                    partial_graph_id: *partial_graph_id,
916                    barrier,
917                });
918            }
919        }
920        while let Poll::Ready(event) = self.barrier_event_rx.poll_recv(cx) {
921            match event.expect("non-empty when tx in local_barrier_manager") {
922                LocalBarrierEvent::ReportActorCollected { actor_id, epoch } => {
923                    if let Some((partial_graph_id, barrier)) = self.collect(actor_id, epoch) {
924                        return Poll::Ready(ManagedBarrierStateEvent::BarrierCollected {
925                            partial_graph_id,
926                            barrier,
927                        });
928                    }
929                }
930                LocalBarrierEvent::ReportCreateProgress {
931                    epoch,
932                    actor,
933                    state,
934                } => {
935                    self.update_create_mview_progress(epoch, actor, state);
936                }
937                LocalBarrierEvent::RegisterBarrierSender {
938                    actor_id,
939                    barrier_sender,
940                } => {
941                    if let Err(err) = self.register_barrier_sender(actor_id, barrier_sender) {
942                        return Poll::Ready(ManagedBarrierStateEvent::ActorError { actor_id, err });
943                    }
944                }
945                LocalBarrierEvent::RegisterLocalUpstreamOutput {
946                    actor_id,
947                    upstream_actor_id,
948                    tx,
949                } => {
950                    self.new_actor_output_request(
951                        actor_id,
952                        upstream_actor_id,
953                        NewOutputRequest::Local(tx),
954                    );
955                }
956            }
957        }
958
959        debug_assert!(
960            self.graph_states
961                .values_mut()
962                .all(|graph_state| graph_state.may_have_collected_all().is_none())
963        );
964        Poll::Pending
965    }
966}
967
968impl DatabaseManagedBarrierState {
969    #[must_use]
970    pub(super) fn collect(
971        &mut self,
972        actor_id: ActorId,
973        epoch: EpochPair,
974    ) -> Option<(PartialGraphId, Barrier)> {
975        let (prev_partial_graph_id, is_finished) = self
976            .actor_states
977            .get_mut(&actor_id)
978            .expect("should exist")
979            .collect(epoch);
980        if is_finished {
981            let state = self.actor_states.remove(&actor_id).expect("should exist");
982            if let Some(monitor_task_handle) = state.monitor_task_handle {
983                monitor_task_handle.abort();
984            }
985        }
986        let prev_graph_state = self
987            .graph_states
988            .get_mut(&prev_partial_graph_id)
989            .expect("should exist");
990        prev_graph_state.collect(actor_id, epoch);
991        prev_graph_state
992            .may_have_collected_all()
993            .map(|barrier| (prev_partial_graph_id, barrier))
994    }
995
996    pub(super) fn pop_barrier_to_complete(
997        &mut self,
998        partial_graph_id: PartialGraphId,
999        prev_epoch: u64,
1000    ) -> (
1001        Barrier,
1002        Option<HashSet<TableId>>,
1003        Vec<PbCreateMviewProgress>,
1004    ) {
1005        self.graph_states
1006            .get_mut(&partial_graph_id)
1007            .expect("should exist")
1008            .pop_barrier_to_complete(prev_epoch)
1009    }
1010}
1011
1012impl PartialGraphManagedBarrierState {
1013    /// This method is called when barrier state is modified in either `Issued` or `Stashed`
1014    /// to transform the state to `AllCollected` and start state store `sync` when the barrier
1015    /// has been collected from all actors for an `Issued` barrier.
1016    fn may_have_collected_all(&mut self) -> Option<Barrier> {
1017        for barrier_state in self.epoch_barrier_state_map.values_mut() {
1018            match &barrier_state.inner {
1019                ManagedBarrierStateInner::Issued(IssuedState {
1020                    remaining_actors, ..
1021                }) if remaining_actors.is_empty() => {}
1022                ManagedBarrierStateInner::AllCollected(_) => {
1023                    continue;
1024                }
1025                ManagedBarrierStateInner::Issued(_) => {
1026                    break;
1027                }
1028            }
1029
1030            self.streaming_metrics.barrier_manager_progress.inc();
1031
1032            let create_mview_progress = self
1033                .create_mview_progress
1034                .remove(&barrier_state.barrier.epoch.curr)
1035                .unwrap_or_default()
1036                .into_iter()
1037                .map(|(actor, state)| state.to_pb(actor))
1038                .collect();
1039
1040            let prev_state = replace(
1041                &mut barrier_state.inner,
1042                ManagedBarrierStateInner::AllCollected(create_mview_progress),
1043            );
1044
1045            must_match!(prev_state, ManagedBarrierStateInner::Issued(IssuedState {
1046                barrier_inflight_latency: timer,
1047                ..
1048            }) => {
1049                timer.observe_duration();
1050            });
1051
1052            return Some(barrier_state.barrier.clone());
1053        }
1054        None
1055    }
1056
1057    fn pop_barrier_to_complete(
1058        &mut self,
1059        prev_epoch: u64,
1060    ) -> (
1061        Barrier,
1062        Option<HashSet<TableId>>,
1063        Vec<PbCreateMviewProgress>,
1064    ) {
1065        let (popped_prev_epoch, barrier_state) = self
1066            .epoch_barrier_state_map
1067            .pop_first()
1068            .expect("should exist");
1069
1070        assert_eq!(prev_epoch, popped_prev_epoch);
1071
1072        let create_mview_progress = must_match!(barrier_state.inner, ManagedBarrierStateInner::AllCollected(create_mview_progress) => {
1073            create_mview_progress
1074        });
1075        (
1076            barrier_state.barrier,
1077            barrier_state.table_ids,
1078            create_mview_progress,
1079        )
1080    }
1081}
1082
1083impl PartialGraphManagedBarrierState {
1084    /// Collect a `barrier` from the actor with `actor_id`.
1085    pub(super) fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) {
1086        tracing::debug!(
1087            target: "events::stream::barrier::manager::collect",
1088            ?epoch, actor_id, state = ?self.epoch_barrier_state_map,
1089            "collect_barrier",
1090        );
1091
1092        match self.epoch_barrier_state_map.get_mut(&epoch.prev) {
1093            None => {
1094                // If the barrier's state is stashed, this occurs exclusively in scenarios where the barrier has not been
1095                // injected by the barrier manager, or the barrier message is blocked at the `RemoteInput` side waiting for injection.
1096                // Given these conditions, it's inconceivable for an actor to attempt collect at this point.
1097                panic!(
1098                    "cannot collect new actor barrier {:?} at current state: None",
1099                    epoch,
1100                )
1101            }
1102            Some(&mut BarrierState {
1103                ref barrier,
1104                inner:
1105                    ManagedBarrierStateInner::Issued(IssuedState {
1106                        ref mut remaining_actors,
1107                        ..
1108                    }),
1109                ..
1110            }) => {
1111                let exist = remaining_actors.remove(&actor_id);
1112                assert!(
1113                    exist,
1114                    "the actor doesn't exist. actor_id: {:?}, curr_epoch: {:?}",
1115                    actor_id, epoch.curr
1116                );
1117                assert_eq!(barrier.epoch.curr, epoch.curr);
1118            }
1119            Some(BarrierState { inner, .. }) => {
1120                panic!(
1121                    "cannot collect new actor barrier {:?} at current state: {:?}",
1122                    epoch, inner
1123                )
1124            }
1125        }
1126    }
1127
1128    /// When the meta service issues a `send_barrier` request, call this function to transform to
1129    /// `Issued` and start to collect or to notify.
1130    pub(super) fn transform_to_issued(
1131        &mut self,
1132        barrier: &Barrier,
1133        actor_ids_to_collect: impl IntoIterator<Item = ActorId>,
1134        table_ids: HashSet<TableId>,
1135    ) {
1136        let timer = self
1137            .streaming_metrics
1138            .barrier_inflight_latency
1139            .start_timer();
1140
1141        if let Some(hummock) = self.state_store.as_hummock() {
1142            hummock.start_epoch(barrier.epoch.curr, table_ids.clone());
1143        }
1144
1145        let table_ids = match barrier.kind {
1146            BarrierKind::Unspecified => {
1147                unreachable!()
1148            }
1149            BarrierKind::Initial => {
1150                assert!(
1151                    self.prev_barrier_table_ids.is_none(),
1152                    "non empty table_ids at initial barrier: {:?}",
1153                    self.prev_barrier_table_ids
1154                );
1155                info!(epoch = ?barrier.epoch, "initialize at Initial barrier");
1156                self.prev_barrier_table_ids = Some((barrier.epoch, table_ids));
1157                None
1158            }
1159            BarrierKind::Barrier => {
1160                if let Some((prev_epoch, prev_table_ids)) = self.prev_barrier_table_ids.as_mut() {
1161                    assert_eq!(prev_epoch.curr, barrier.epoch.prev);
1162                    assert_eq!(prev_table_ids, &table_ids);
1163                    *prev_epoch = barrier.epoch;
1164                } else {
1165                    info!(epoch = ?barrier.epoch, "initialize at non-checkpoint barrier");
1166                    self.prev_barrier_table_ids = Some((barrier.epoch, table_ids));
1167                }
1168                None
1169            }
1170            BarrierKind::Checkpoint => Some(
1171                if let Some((prev_epoch, prev_table_ids)) = self
1172                    .prev_barrier_table_ids
1173                    .replace((barrier.epoch, table_ids))
1174                    && prev_epoch.curr == barrier.epoch.prev
1175                {
1176                    prev_table_ids
1177                } else {
1178                    debug!(epoch = ?barrier.epoch, "reinitialize at Checkpoint barrier");
1179                    HashSet::new()
1180                },
1181            ),
1182        };
1183
1184        if let Some(&mut BarrierState { ref inner, .. }) =
1185            self.epoch_barrier_state_map.get_mut(&barrier.epoch.prev)
1186        {
1187            {
1188                panic!(
1189                    "barrier epochs{:?} state has already been `Issued`. Current state: {:?}",
1190                    barrier.epoch, inner
1191                );
1192            }
1193        };
1194
1195        self.epoch_barrier_state_map.insert(
1196            barrier.epoch.prev,
1197            BarrierState {
1198                barrier: barrier.clone(),
1199                inner: ManagedBarrierStateInner::Issued(IssuedState {
1200                    remaining_actors: BTreeSet::from_iter(actor_ids_to_collect),
1201                    barrier_inflight_latency: timer,
1202                }),
1203                table_ids,
1204            },
1205        );
1206    }
1207
1208    #[cfg(test)]
1209    async fn pop_next_completed_epoch(&mut self) -> u64 {
1210        if let Some(barrier) = self.may_have_collected_all() {
1211            self.pop_barrier_to_complete(barrier.epoch.prev);
1212            return barrier.epoch.prev;
1213        }
1214        pending().await
1215    }
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220    use std::collections::HashSet;
1221
1222    use risingwave_common::util::epoch::test_epoch;
1223
1224    use crate::executor::Barrier;
1225    use crate::task::barrier_manager::managed_state::PartialGraphManagedBarrierState;
1226
1227    #[tokio::test]
1228    async fn test_managed_state_add_actor() {
1229        let mut managed_barrier_state = PartialGraphManagedBarrierState::for_test();
1230        let barrier1 = Barrier::new_test_barrier(test_epoch(1));
1231        let barrier2 = Barrier::new_test_barrier(test_epoch(2));
1232        let barrier3 = Barrier::new_test_barrier(test_epoch(3));
1233        let actor_ids_to_collect1 = HashSet::from([1, 2]);
1234        let actor_ids_to_collect2 = HashSet::from([1, 2]);
1235        let actor_ids_to_collect3 = HashSet::from([1, 2, 3]);
1236        managed_barrier_state.transform_to_issued(&barrier1, actor_ids_to_collect1, HashSet::new());
1237        managed_barrier_state.transform_to_issued(&barrier2, actor_ids_to_collect2, HashSet::new());
1238        managed_barrier_state.transform_to_issued(&barrier3, actor_ids_to_collect3, HashSet::new());
1239        managed_barrier_state.collect(1, barrier1.epoch);
1240        managed_barrier_state.collect(2, barrier1.epoch);
1241        assert_eq!(
1242            managed_barrier_state.pop_next_completed_epoch().await,
1243            test_epoch(0)
1244        );
1245        assert_eq!(
1246            managed_barrier_state
1247                .epoch_barrier_state_map
1248                .first_key_value()
1249                .unwrap()
1250                .0,
1251            &test_epoch(1)
1252        );
1253        managed_barrier_state.collect(1, barrier2.epoch);
1254        managed_barrier_state.collect(1, barrier3.epoch);
1255        managed_barrier_state.collect(2, barrier2.epoch);
1256        assert_eq!(
1257            managed_barrier_state.pop_next_completed_epoch().await,
1258            test_epoch(1)
1259        );
1260        assert_eq!(
1261            managed_barrier_state
1262                .epoch_barrier_state_map
1263                .first_key_value()
1264                .unwrap()
1265                .0,
1266            &test_epoch(2)
1267        );
1268        managed_barrier_state.collect(2, barrier3.epoch);
1269        managed_barrier_state.collect(3, barrier3.epoch);
1270        assert_eq!(
1271            managed_barrier_state.pop_next_completed_epoch().await,
1272            test_epoch(2)
1273        );
1274        assert!(managed_barrier_state.epoch_barrier_state_map.is_empty());
1275    }
1276
1277    #[tokio::test]
1278    async fn test_managed_state_stop_actor() {
1279        let mut managed_barrier_state = PartialGraphManagedBarrierState::for_test();
1280        let barrier1 = Barrier::new_test_barrier(test_epoch(1));
1281        let barrier2 = Barrier::new_test_barrier(test_epoch(2));
1282        let barrier3 = Barrier::new_test_barrier(test_epoch(3));
1283        let actor_ids_to_collect1 = HashSet::from([1, 2, 3, 4]);
1284        let actor_ids_to_collect2 = HashSet::from([1, 2, 3]);
1285        let actor_ids_to_collect3 = HashSet::from([1, 2]);
1286        managed_barrier_state.transform_to_issued(&barrier1, actor_ids_to_collect1, HashSet::new());
1287        managed_barrier_state.transform_to_issued(&barrier2, actor_ids_to_collect2, HashSet::new());
1288        managed_barrier_state.transform_to_issued(&barrier3, actor_ids_to_collect3, HashSet::new());
1289
1290        managed_barrier_state.collect(1, barrier1.epoch);
1291        managed_barrier_state.collect(1, barrier2.epoch);
1292        managed_barrier_state.collect(1, barrier3.epoch);
1293        managed_barrier_state.collect(2, barrier1.epoch);
1294        managed_barrier_state.collect(2, barrier2.epoch);
1295        managed_barrier_state.collect(2, barrier3.epoch);
1296        assert_eq!(
1297            managed_barrier_state
1298                .epoch_barrier_state_map
1299                .first_key_value()
1300                .unwrap()
1301                .0,
1302            &0
1303        );
1304        managed_barrier_state.collect(3, barrier1.epoch);
1305        managed_barrier_state.collect(3, barrier2.epoch);
1306        assert_eq!(
1307            managed_barrier_state
1308                .epoch_barrier_state_map
1309                .first_key_value()
1310                .unwrap()
1311                .0,
1312            &0
1313        );
1314        managed_barrier_state.collect(4, barrier1.epoch);
1315        assert_eq!(
1316            managed_barrier_state.pop_next_completed_epoch().await,
1317            test_epoch(0)
1318        );
1319        assert_eq!(
1320            managed_barrier_state.pop_next_completed_epoch().await,
1321            test_epoch(1)
1322        );
1323        assert_eq!(
1324            managed_barrier_state.pop_next_completed_epoch().await,
1325            test_epoch(2)
1326        );
1327        assert!(managed_barrier_state.epoch_barrier_state_map.is_empty());
1328    }
1329}