risingwave_stream/task/barrier_worker/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use std::collections::hash_map::Entry;
15use std::collections::{HashMap, HashSet};
16use std::fmt::Display;
17use std::future::{pending, poll_fn};
18use std::sync::Arc;
19use std::task::Poll;
20
21use anyhow::anyhow;
22use await_tree::{InstrumentAwait, SpanExt};
23use futures::future::{BoxFuture, join_all};
24use futures::stream::{BoxStream, FuturesOrdered};
25use futures::{FutureExt, StreamExt, TryFutureExt};
26use itertools::Itertools;
27use risingwave_pb::stream_plan::barrier::BarrierKind;
28use risingwave_pb::stream_service::barrier_complete_response::{
29    PbCdcTableBackfillProgress, PbCreateMviewProgress, PbLocalSstableInfo,
30};
31use risingwave_rpc_client::error::{ToTonicStatus, TonicStatusWrapper};
32use risingwave_storage::store_impl::AsHummock;
33use thiserror_ext::AsReport;
34use tokio::select;
35use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
36use tokio::sync::oneshot;
37use tokio::task::JoinHandle;
38use tonic::{Code, Status};
39use tracing::warn;
40
41use self::managed_state::ManagedBarrierState;
42use crate::error::{ScoredStreamError, StreamError, StreamResult};
43#[cfg(test)]
44use crate::task::LocalBarrierManager;
45use crate::task::managed_state::BarrierToComplete;
46use crate::task::{
47    ActorId, AtomicU64Ref, PartialGraphId, StreamActorManager, StreamEnvironment, UpDownActorIds,
48};
49pub mod managed_state;
50#[cfg(test)]
51mod tests;
52
53use risingwave_hummock_sdk::table_stats::to_prost_table_stats_map;
54use risingwave_hummock_sdk::{LocalSstableInfo, SyncResult};
55use risingwave_pb::stream_service::streaming_control_stream_request::{
56    InitRequest, Request, ResetDatabaseRequest,
57};
58use risingwave_pb::stream_service::streaming_control_stream_response::{
59    InitResponse, ReportDatabaseFailureResponse, ResetDatabaseResponse, Response, ShutdownResponse,
60};
61use risingwave_pb::stream_service::{
62    BarrierCompleteResponse, InjectBarrierRequest, PbScoredError, StreamingControlStreamRequest,
63    StreamingControlStreamResponse, streaming_control_stream_response,
64};
65
66use crate::executor::Barrier;
67use crate::executor::exchange::permit::Receiver;
68use crate::executor::monitor::StreamingMetrics;
69use crate::task::barrier_worker::managed_state::{
70    DatabaseManagedBarrierState, DatabaseStatus, ManagedBarrierStateDebugInfo,
71    ManagedBarrierStateEvent, PartialGraphManagedBarrierState, ResetDatabaseOutput,
72};
73
74/// If enabled, all actors will be grouped in the same tracing span within one epoch.
75/// Note that this option will significantly increase the overhead of tracing.
76pub const ENABLE_BARRIER_AGGREGATION: bool = false;
77
78/// Collect result of some barrier on current compute node. Will be reported to the meta service in [`LocalBarrierWorker::on_epoch_completed`].
79#[derive(Debug)]
80pub struct BarrierCompleteResult {
81    /// The result returned from `sync` of `StateStore`.
82    pub sync_result: Option<SyncResult>,
83
84    /// The updated creation progress of materialized view after this barrier.
85    pub create_mview_progress: Vec<PbCreateMviewProgress>,
86
87    /// The source IDs that have finished listing data for refreshable batch sources.
88    pub list_finished_source_ids: Vec<u32>,
89
90    /// The source IDs that have finished loading data for refreshable batch sources.
91    pub load_finished_source_ids: Vec<u32>,
92
93    pub cdc_table_backfill_progress: Vec<PbCdcTableBackfillProgress>,
94
95    /// The table IDs that should be truncated.
96    pub truncate_tables: Vec<u32>,
97    /// The table IDs that have finished refresh.
98    pub refresh_finished_tables: Vec<u32>,
99}
100
101/// Lives in [`crate::task::barrier_worker::LocalBarrierWorker`],
102/// Communicates with `ControlStreamManager` in meta.
103/// Handles [`risingwave_pb::stream_service::streaming_control_stream_request::Request`].
104pub(super) struct ControlStreamHandle {
105    #[expect(clippy::type_complexity)]
106    pair: Option<(
107        UnboundedSender<Result<StreamingControlStreamResponse, Status>>,
108        BoxStream<'static, Result<StreamingControlStreamRequest, Status>>,
109    )>,
110}
111
112impl ControlStreamHandle {
113    fn empty() -> Self {
114        Self { pair: None }
115    }
116
117    pub(super) fn new(
118        sender: UnboundedSender<Result<StreamingControlStreamResponse, Status>>,
119        request_stream: BoxStream<'static, Result<StreamingControlStreamRequest, Status>>,
120    ) -> Self {
121        Self {
122            pair: Some((sender, request_stream)),
123        }
124    }
125
126    pub(super) fn connected(&self) -> bool {
127        self.pair.is_some()
128    }
129
130    fn reset_stream_with_err(&mut self, err: Status) {
131        if let Some((sender, _)) = self.pair.take() {
132            // Note: `TonicStatusWrapper` provides a better error report.
133            let err = TonicStatusWrapper::new(err);
134            warn!(error = %err.as_report(), "control stream reset with error");
135
136            let err = err.into_inner();
137            if sender.send(Err(err)).is_err() {
138                warn!("failed to notify reset of control stream");
139            }
140        }
141    }
142
143    /// Send `Shutdown` message to the control stream and wait for the stream to be closed
144    /// by the meta service.
145    async fn shutdown_stream(&mut self) {
146        if let Some((sender, _)) = self.pair.take() {
147            if sender
148                .send(Ok(StreamingControlStreamResponse {
149                    response: Some(streaming_control_stream_response::Response::Shutdown(
150                        ShutdownResponse::default(),
151                    )),
152                }))
153                .is_err()
154            {
155                warn!("failed to notify shutdown of control stream");
156            } else {
157                tracing::info!("waiting for meta service to close control stream...");
158
159                // Wait for the stream to be closed, to ensure that the `Shutdown` message has
160                // been acknowledged by the meta service for more precise error report.
161                //
162                // This is because the meta service will reset the control stream manager and
163                // drop the connection to us upon recovery. As a result, the receiver part of
164                // this sender will also be dropped, causing the stream to close.
165                sender.closed().await;
166            }
167        } else {
168            debug!("control stream has been reset, ignore shutdown");
169        }
170    }
171
172    pub(super) fn ack_reset_database(
173        &mut self,
174        database_id: DatabaseId,
175        root_err: Option<ScoredStreamError>,
176        reset_request_id: u32,
177    ) {
178        self.send_response(Response::ResetDatabase(ResetDatabaseResponse {
179            database_id: database_id.database_id,
180            root_err: root_err.map(|err| PbScoredError {
181                err_msg: err.error.to_report_string(),
182                score: err.score.0,
183            }),
184            reset_request_id,
185        }));
186    }
187
188    fn send_response(&mut self, response: streaming_control_stream_response::Response) {
189        if let Some((sender, _)) = self.pair.as_ref() {
190            if sender
191                .send(Ok(StreamingControlStreamResponse {
192                    response: Some(response),
193                }))
194                .is_err()
195            {
196                self.pair = None;
197                warn!("fail to send response. control stream reset");
198            }
199        } else {
200            debug!(?response, "control stream has been reset. ignore response");
201        }
202    }
203
204    async fn next_request(&mut self) -> StreamingControlStreamRequest {
205        if let Some((_, stream)) = &mut self.pair {
206            match stream.next().await {
207                Some(Ok(request)) => {
208                    return request;
209                }
210                Some(Err(e)) => self.reset_stream_with_err(
211                    anyhow!(TonicStatusWrapper::new(e)) // wrap the status to provide better error report
212                        .context("failed to get request")
213                        .to_status_unnamed(Code::Internal),
214                ),
215                None => self.reset_stream_with_err(Status::internal("end of stream")),
216            }
217        }
218        pending().await
219    }
220}
221
222/// Sent from [`crate::task::stream_manager::LocalStreamManager`] to [`crate::task::barrier_worker::LocalBarrierWorker::run`].
223///
224/// See [`crate::task`] for architecture overview.
225#[derive(strum_macros::Display)]
226pub(super) enum LocalActorOperation {
227    NewControlStream {
228        handle: ControlStreamHandle,
229        init_request: InitRequest,
230    },
231    TakeReceiver {
232        database_id: DatabaseId,
233        term_id: String,
234        ids: UpDownActorIds,
235        result_sender: oneshot::Sender<StreamResult<Receiver>>,
236    },
237    #[cfg(test)]
238    GetCurrentLocalBarrierManager(oneshot::Sender<LocalBarrierManager>),
239    #[cfg(test)]
240    TakePendingNewOutputRequest(ActorId, oneshot::Sender<Vec<(ActorId, NewOutputRequest)>>),
241    #[cfg(test)]
242    Flush(oneshot::Sender<()>),
243    InspectState {
244        result_sender: oneshot::Sender<String>,
245    },
246    Shutdown {
247        result_sender: oneshot::Sender<()>,
248    },
249}
250
251pub(super) struct LocalBarrierWorkerDebugInfo<'a> {
252    managed_barrier_state: HashMap<DatabaseId, (String, Option<ManagedBarrierStateDebugInfo<'a>>)>,
253    has_control_stream_connected: bool,
254}
255
256impl Display for LocalBarrierWorkerDebugInfo<'_> {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        writeln!(
259            f,
260            "\nhas_control_stream_connected: {}",
261            self.has_control_stream_connected
262        )?;
263
264        for (database_id, (status, managed_barrier_state)) in &self.managed_barrier_state {
265            writeln!(
266                f,
267                "database {} status: {} managed_barrier_state:\n{}",
268                database_id.database_id,
269                status,
270                managed_barrier_state
271                    .as_ref()
272                    .map(ToString::to_string)
273                    .unwrap_or_default()
274            )?;
275        }
276        Ok(())
277    }
278}
279
280/// [`LocalBarrierWorker`] manages barrier control flow.
281/// Specifically, [`LocalBarrierWorker`] serves barrier injection from meta server, sends the
282/// barriers to and collects them from all actors, and finally reports the progress.
283///
284/// Runs event loop in [`Self::run`]. Handles events sent by [`crate::task::LocalStreamManager`].
285///
286/// See [`crate::task`] for architecture overview.
287pub(super) struct LocalBarrierWorker {
288    /// Current barrier collection state.
289    pub(super) state: ManagedBarrierState,
290
291    /// Futures will be finished in the order of epoch in ascending order.
292    await_epoch_completed_futures: HashMap<DatabaseId, FuturesOrdered<AwaitEpochCompletedFuture>>,
293
294    control_stream_handle: ControlStreamHandle,
295
296    pub(super) actor_manager: Arc<StreamActorManager>,
297
298    pub(super) term_id: String,
299}
300
301impl LocalBarrierWorker {
302    pub(super) fn new(actor_manager: Arc<StreamActorManager>, term_id: String) -> Self {
303        Self {
304            state: Default::default(),
305            await_epoch_completed_futures: Default::default(),
306            control_stream_handle: ControlStreamHandle::empty(),
307            actor_manager,
308            term_id,
309        }
310    }
311
312    fn to_debug_info(&self) -> LocalBarrierWorkerDebugInfo<'_> {
313        LocalBarrierWorkerDebugInfo {
314            managed_barrier_state: self
315                .state
316                .databases
317                .iter()
318                .map(|(database_id, status)| {
319                    (*database_id, {
320                        match status {
321                            DatabaseStatus::ReceivedExchangeRequest(_) => {
322                                ("ReceivedExchangeRequest".to_owned(), None)
323                            }
324                            DatabaseStatus::Running(state) => {
325                                ("running".to_owned(), Some(state.to_debug_info()))
326                            }
327                            DatabaseStatus::Suspended(state) => {
328                                (format!("suspended: {:?}", state.suspend_time), None)
329                            }
330                            DatabaseStatus::Resetting(_) => ("resetting".to_owned(), None),
331                            DatabaseStatus::Unspecified => {
332                                unreachable!()
333                            }
334                        }
335                    })
336                })
337                .collect(),
338            has_control_stream_connected: self.control_stream_handle.connected(),
339        }
340    }
341
342    async fn next_completed_epoch(
343        futures: &mut HashMap<DatabaseId, FuturesOrdered<AwaitEpochCompletedFuture>>,
344    ) -> (
345        DatabaseId,
346        PartialGraphId,
347        Barrier,
348        StreamResult<BarrierCompleteResult>,
349    ) {
350        poll_fn(|cx| {
351            for (database_id, futures) in &mut *futures {
352                if let Poll::Ready(Some((partial_graph_id, barrier, result))) =
353                    futures.poll_next_unpin(cx)
354                {
355                    return Poll::Ready((*database_id, partial_graph_id, barrier, result));
356                }
357            }
358            Poll::Pending
359        })
360        .await
361    }
362
363    async fn run(mut self, mut actor_op_rx: UnboundedReceiver<LocalActorOperation>) {
364        loop {
365            select! {
366                biased;
367                (database_id, event) = self.state.next_event() => {
368                    match event {
369                        ManagedBarrierStateEvent::BarrierCollected{
370                            partial_graph_id,
371                            barrier,
372                        } => {
373                            // update await_epoch_completed_futures
374                            // handled below in next_completed_epoch
375                            self.complete_barrier(database_id, partial_graph_id, barrier.epoch.prev);
376                        }
377                        ManagedBarrierStateEvent::ActorError{
378                            actor_id,
379                            err,
380                        } => {
381                            self.on_database_failure(database_id, Some(actor_id), err, "recv actor failure");
382                        }
383                        ManagedBarrierStateEvent::DatabaseReset(output, reset_request_id) => {
384                            self.ack_database_reset(database_id, Some(output), reset_request_id);
385                        }
386                    }
387                }
388                (database_id, partial_graph_id, barrier, result) = Self::next_completed_epoch(&mut self.await_epoch_completed_futures) => {
389                    match result {
390                        Ok(result) => {
391                            self.on_epoch_completed(database_id, partial_graph_id, barrier.epoch.prev, result);
392                        }
393                        Err(err) => {
394                            // TODO: may only report as database failure instead of reset the stream
395                            // when the HummockUploader support partial recovery. Currently the HummockUploader
396                            // enter `Err` state and stop working until a global recovery to clear the uploader.
397                            self.control_stream_handle.reset_stream_with_err(Status::internal(format!("failed to complete epoch: {} {} {:?} {:?}", database_id, partial_graph_id.0, barrier.epoch, err.as_report())));
398                        }
399                    }
400                },
401                actor_op = actor_op_rx.recv() => {
402                    if let Some(actor_op) = actor_op {
403                        match actor_op {
404                            LocalActorOperation::NewControlStream { handle, init_request  } => {
405                                self.control_stream_handle.reset_stream_with_err(Status::internal("control stream has been reset to a new one"));
406                                self.reset(init_request).await;
407                                self.control_stream_handle = handle;
408                                self.control_stream_handle.send_response(streaming_control_stream_response::Response::Init(InitResponse {}));
409                            }
410                            LocalActorOperation::Shutdown { result_sender } => {
411                                if self.state.databases.values().any(|database| {
412                                    match database {
413                                        DatabaseStatus::Running(database) => {
414                                            !database.actor_states.is_empty()
415                                        }
416                                        DatabaseStatus::Suspended(_) | DatabaseStatus::Resetting(_) |
417                                            DatabaseStatus::ReceivedExchangeRequest(_) => {
418                                            false
419                                        }
420                                        DatabaseStatus::Unspecified => {
421                                            unreachable!()
422                                        }
423                                    }
424                                }) {
425                                    tracing::warn!(
426                                        "shutdown with running actors, scaling or migration will be triggered"
427                                    );
428                                }
429                                self.control_stream_handle.shutdown_stream().await;
430                                let _ = result_sender.send(());
431                            }
432                            actor_op => {
433                                self.handle_actor_op(actor_op);
434                            }
435                        }
436                    }
437                    else {
438                        break;
439                    }
440                },
441                request = self.control_stream_handle.next_request() => {
442                    let result = self.handle_streaming_control_request(request.request.expect("non empty"));
443                    if let Err((database_id, err)) = result {
444                        self.on_database_failure(database_id, None, err, "failed to inject barrier");
445                    }
446                },
447            }
448        }
449    }
450
451    fn handle_streaming_control_request(
452        &mut self,
453        request: Request,
454    ) -> Result<(), (DatabaseId, StreamError)> {
455        match request {
456            Request::InjectBarrier(req) => {
457                let database_id = DatabaseId::new(req.database_id);
458                let result: StreamResult<()> = try {
459                    let barrier = Barrier::from_protobuf(req.get_barrier().unwrap())?;
460                    self.send_barrier(&barrier, req)?;
461                };
462                result.map_err(|e| (database_id, e))?;
463                Ok(())
464            }
465            Request::RemovePartialGraph(req) => {
466                self.remove_partial_graphs(
467                    DatabaseId::new(req.database_id),
468                    req.partial_graph_ids.into_iter().map(PartialGraphId::new),
469                );
470                Ok(())
471            }
472            Request::CreatePartialGraph(req) => {
473                self.add_partial_graph(
474                    DatabaseId::new(req.database_id),
475                    PartialGraphId::new(req.partial_graph_id),
476                );
477                Ok(())
478            }
479            Request::ResetDatabase(req) => {
480                self.reset_database(req);
481                Ok(())
482            }
483            Request::Init(_) => {
484                unreachable!()
485            }
486        }
487    }
488
489    fn handle_actor_op(&mut self, actor_op: LocalActorOperation) {
490        match actor_op {
491            LocalActorOperation::NewControlStream { .. } | LocalActorOperation::Shutdown { .. } => {
492                unreachable!("event {actor_op} should be handled separately in async context")
493            }
494            LocalActorOperation::TakeReceiver {
495                database_id,
496                term_id,
497                ids,
498                result_sender,
499            } => {
500                let err = if self.term_id != term_id {
501                    {
502                        warn!(
503                            ?ids,
504                            term_id,
505                            current_term_id = self.term_id,
506                            "take receiver on unmatched term_id"
507                        );
508                        anyhow!(
509                            "take receiver {:?} on unmatched term_id {} to current term_id {}",
510                            ids,
511                            term_id,
512                            self.term_id
513                        )
514                    }
515                } else {
516                    match self.state.databases.entry(database_id) {
517                        Entry::Occupied(mut entry) => match entry.get_mut() {
518                            DatabaseStatus::ReceivedExchangeRequest(pending_requests) => {
519                                pending_requests.push((ids, result_sender));
520                                return;
521                            }
522                            DatabaseStatus::Running(database) => {
523                                let (upstream_actor_id, actor_id) = ids;
524                                database.new_actor_remote_output_request(
525                                    actor_id,
526                                    upstream_actor_id,
527                                    result_sender,
528                                );
529                                return;
530                            }
531                            DatabaseStatus::Suspended(_) => {
532                                anyhow!("database suspended")
533                            }
534                            DatabaseStatus::Resetting(_) => {
535                                anyhow!("database resetting")
536                            }
537                            DatabaseStatus::Unspecified => {
538                                unreachable!()
539                            }
540                        },
541                        Entry::Vacant(entry) => {
542                            entry.insert(DatabaseStatus::ReceivedExchangeRequest(vec![(
543                                ids,
544                                result_sender,
545                            )]));
546                            return;
547                        }
548                    }
549                };
550                let _ = result_sender.send(Err(err.into()));
551            }
552            #[cfg(test)]
553            LocalActorOperation::GetCurrentLocalBarrierManager(sender) => {
554                let database_status = self
555                    .state
556                    .databases
557                    .get(&crate::task::TEST_DATABASE_ID)
558                    .unwrap();
559                let database_state = risingwave_common::must_match!(database_status, DatabaseStatus::Running(database_state) => database_state);
560                let _ = sender.send(database_state.local_barrier_manager.clone());
561            }
562            #[cfg(test)]
563            LocalActorOperation::TakePendingNewOutputRequest(actor_id, sender) => {
564                let database_status = self
565                    .state
566                    .databases
567                    .get_mut(&crate::task::TEST_DATABASE_ID)
568                    .unwrap();
569
570                let database_state = risingwave_common::must_match!(database_status, DatabaseStatus::Running(database_state) => database_state);
571                assert!(!database_state.actor_states.contains_key(&actor_id));
572                let requests = database_state
573                    .actor_pending_new_output_requests
574                    .remove(&actor_id)
575                    .unwrap();
576                let _ = sender.send(requests);
577            }
578            #[cfg(test)]
579            LocalActorOperation::Flush(sender) => {
580                use futures::FutureExt;
581                while let Some(request) = self.control_stream_handle.next_request().now_or_never() {
582                    self.handle_streaming_control_request(
583                        request.request.expect("should not be empty"),
584                    )
585                    .unwrap();
586                }
587                while let Some((database_id, event)) = self.state.next_event().now_or_never() {
588                    match event {
589                        ManagedBarrierStateEvent::BarrierCollected {
590                            partial_graph_id,
591                            barrier,
592                        } => {
593                            self.complete_barrier(
594                                database_id,
595                                partial_graph_id,
596                                barrier.epoch.prev,
597                            );
598                        }
599                        ManagedBarrierStateEvent::ActorError { .. }
600                        | ManagedBarrierStateEvent::DatabaseReset(..) => {
601                            unreachable!()
602                        }
603                    }
604                }
605                sender.send(()).unwrap()
606            }
607            LocalActorOperation::InspectState { result_sender } => {
608                let debug_info = self.to_debug_info();
609                let _ = result_sender.send(debug_info.to_string());
610            }
611        }
612    }
613}
614
615mod await_epoch_completed_future {
616    use std::future::Future;
617
618    use futures::FutureExt;
619    use futures::future::BoxFuture;
620    use risingwave_hummock_sdk::SyncResult;
621    use risingwave_pb::stream_service::barrier_complete_response::{
622        PbCdcTableBackfillProgress, PbCreateMviewProgress,
623    };
624
625    use crate::error::StreamResult;
626    use crate::executor::Barrier;
627    use crate::task::{BarrierCompleteResult, PartialGraphId, await_tree_key};
628
629    pub(super) type AwaitEpochCompletedFuture = impl Future<Output = (PartialGraphId, Barrier, StreamResult<BarrierCompleteResult>)>
630        + 'static;
631
632    #[define_opaque(AwaitEpochCompletedFuture)]
633    #[expect(clippy::too_many_arguments)]
634    pub(super) fn instrument_complete_barrier_future(
635        partial_graph_id: PartialGraphId,
636        complete_barrier_future: Option<BoxFuture<'static, StreamResult<SyncResult>>>,
637        barrier: Barrier,
638        barrier_await_tree_reg: Option<&await_tree::Registry>,
639        create_mview_progress: Vec<PbCreateMviewProgress>,
640        list_finished_source_ids: Vec<u32>,
641        load_finished_source_ids: Vec<u32>,
642        cdc_table_backfill_progress: Vec<PbCdcTableBackfillProgress>,
643        truncate_tables: Vec<u32>,
644        refresh_finished_tables: Vec<u32>,
645    ) -> AwaitEpochCompletedFuture {
646        let prev_epoch = barrier.epoch.prev;
647        let future = async move {
648            if let Some(future) = complete_barrier_future {
649                let result = future.await;
650                result.map(Some)
651            } else {
652                Ok(None)
653            }
654        }
655        .map(move |result| {
656            (
657                partial_graph_id,
658                barrier,
659                result.map(|sync_result| BarrierCompleteResult {
660                    sync_result,
661                    create_mview_progress,
662                    list_finished_source_ids,
663                    load_finished_source_ids,
664                    cdc_table_backfill_progress,
665                    truncate_tables,
666                    refresh_finished_tables,
667                }),
668            )
669        });
670        if let Some(reg) = barrier_await_tree_reg {
671            reg.register(
672                await_tree_key::BarrierAwait { prev_epoch },
673                format!("SyncEpoch({})", prev_epoch),
674            )
675            .instrument(future)
676            .left_future()
677        } else {
678            future.right_future()
679        }
680    }
681}
682
683use await_epoch_completed_future::*;
684use risingwave_common::catalog::{DatabaseId, TableId};
685use risingwave_pb::hummock::vector_index_delta::PbVectorIndexAdds;
686use risingwave_storage::{StateStoreImpl, dispatch_state_store};
687
688use crate::executor::exchange::permit;
689
690fn sync_epoch(
691    state_store: &StateStoreImpl,
692    streaming_metrics: &StreamingMetrics,
693    prev_epoch: u64,
694    table_ids: HashSet<TableId>,
695) -> BoxFuture<'static, StreamResult<SyncResult>> {
696    let timer = streaming_metrics.barrier_sync_latency.start_timer();
697
698    let state_store = state_store.clone();
699    let future = async move {
700        dispatch_state_store!(state_store, hummock, {
701            hummock.sync(vec![(prev_epoch, table_ids)]).await
702        })
703    };
704
705    future
706        .instrument_await(await_tree::span!("sync_epoch (epoch {})", prev_epoch))
707        .inspect_ok(move |_| {
708            timer.observe_duration();
709        })
710        .map_err(move |e| {
711            tracing::error!(
712                prev_epoch,
713                error = %e.as_report(),
714                "Failed to sync state store",
715            );
716            e.into()
717        })
718        .boxed()
719}
720
721impl LocalBarrierWorker {
722    fn complete_barrier(
723        &mut self,
724        database_id: DatabaseId,
725        partial_graph_id: PartialGraphId,
726        prev_epoch: u64,
727    ) {
728        {
729            let Some(database_state) = self
730                .state
731                .databases
732                .get_mut(&database_id)
733                .expect("should exist")
734                .state_for_request()
735            else {
736                return;
737            };
738            let BarrierToComplete {
739                barrier,
740                table_ids,
741                create_mview_progress,
742                list_finished_source_ids,
743                load_finished_source_ids,
744                cdc_table_backfill_progress,
745                truncate_tables,
746                refresh_finished_tables,
747            } = database_state.pop_barrier_to_complete(partial_graph_id, prev_epoch);
748
749            let complete_barrier_future = match &barrier.kind {
750                BarrierKind::Unspecified => unreachable!(),
751                BarrierKind::Initial => {
752                    tracing::info!(
753                        epoch = prev_epoch,
754                        "ignore sealing data for the first barrier"
755                    );
756                    tracing::info!(?prev_epoch, "ignored syncing data for the first barrier");
757                    None
758                }
759                BarrierKind::Barrier => None,
760                BarrierKind::Checkpoint => Some(sync_epoch(
761                    &self.actor_manager.env.state_store(),
762                    &self.actor_manager.streaming_metrics,
763                    prev_epoch,
764                    table_ids.expect("should be Some on BarrierKind::Checkpoint"),
765                )),
766            };
767
768            self.await_epoch_completed_futures
769                .entry(database_id)
770                .or_default()
771                .push_back({
772                    instrument_complete_barrier_future(
773                        partial_graph_id,
774                        complete_barrier_future,
775                        barrier,
776                        self.actor_manager.await_tree_reg.as_ref(),
777                        create_mview_progress,
778                        list_finished_source_ids,
779                        load_finished_source_ids,
780                        cdc_table_backfill_progress,
781                        truncate_tables,
782                        refresh_finished_tables,
783                    )
784                });
785        }
786    }
787
788    fn on_epoch_completed(
789        &mut self,
790        database_id: DatabaseId,
791        partial_graph_id: PartialGraphId,
792        epoch: u64,
793        result: BarrierCompleteResult,
794    ) {
795        let BarrierCompleteResult {
796            create_mview_progress,
797            sync_result,
798            list_finished_source_ids,
799            load_finished_source_ids,
800            cdc_table_backfill_progress,
801            truncate_tables,
802            refresh_finished_tables,
803        } = result;
804
805        let (synced_sstables, table_watermarks, old_value_ssts, vector_index_adds) = sync_result
806            .map(|sync_result| {
807                (
808                    sync_result.uncommitted_ssts,
809                    sync_result.table_watermarks,
810                    sync_result.old_value_ssts,
811                    sync_result.vector_index_adds,
812                )
813            })
814            .unwrap_or_default();
815
816        let result = {
817            {
818                streaming_control_stream_response::Response::CompleteBarrier(
819                    BarrierCompleteResponse {
820                        request_id: "todo".to_owned(),
821                        partial_graph_id: partial_graph_id.into(),
822                        epoch,
823                        status: None,
824                        create_mview_progress,
825                        synced_sstables: synced_sstables
826                            .into_iter()
827                            .map(
828                                |LocalSstableInfo {
829                                     sst_info,
830                                     table_stats,
831                                     created_at,
832                                 }| PbLocalSstableInfo {
833                                    sst: Some(sst_info.into()),
834                                    table_stats_map: to_prost_table_stats_map(table_stats),
835                                    created_at,
836                                },
837                            )
838                            .collect_vec(),
839                        worker_id: self.actor_manager.env.worker_id(),
840                        table_watermarks: table_watermarks
841                            .into_iter()
842                            .map(|(key, value)| (key.table_id, value.into()))
843                            .collect(),
844                        old_value_sstables: old_value_ssts
845                            .into_iter()
846                            .map(|sst| sst.sst_info.into())
847                            .collect(),
848                        database_id: database_id.database_id,
849                        list_finished_source_ids,
850                        load_finished_source_ids,
851                        vector_index_adds: vector_index_adds
852                            .into_iter()
853                            .map(|(table_id, adds)| {
854                                (
855                                    table_id.table_id,
856                                    PbVectorIndexAdds {
857                                        adds: adds.into_iter().map(|add| add.into()).collect(),
858                                    },
859                                )
860                            })
861                            .collect(),
862                        cdc_table_backfill_progress,
863                        truncate_tables,
864                        refresh_finished_tables,
865                    },
866                )
867            }
868        };
869
870        self.control_stream_handle.send_response(result);
871    }
872
873    /// Broadcast a barrier to all senders. Save a receiver which will get notified when this
874    /// barrier is finished, in managed mode.
875    ///
876    /// Note that the error returned here is typically a [`StreamError::barrier_send`], which is not
877    /// the root cause of the failure. The caller should then call `try_find_root_failure`
878    /// to find the root cause.
879    fn send_barrier(
880        &mut self,
881        barrier: &Barrier,
882        request: InjectBarrierRequest,
883    ) -> StreamResult<()> {
884        debug!(
885            target: "events::stream::barrier::manager::send",
886            "send barrier {:?}, actor_ids_to_collect = {:?}",
887            barrier,
888            request.actor_ids_to_collect
889        );
890
891        let database_status = self
892            .state
893            .databases
894            .get_mut(&DatabaseId::new(request.database_id))
895            .expect("should exist");
896        if let Some(state) = database_status.state_for_request() {
897            state.transform_to_issued(barrier, request)?;
898        }
899        Ok(())
900    }
901
902    fn remove_partial_graphs(
903        &mut self,
904        database_id: DatabaseId,
905        partial_graph_ids: impl Iterator<Item = PartialGraphId>,
906    ) {
907        let Some(database_status) = self.state.databases.get_mut(&database_id) else {
908            warn!(
909                database_id = database_id.database_id,
910                "database to remove partial graph not exist"
911            );
912            return;
913        };
914        let Some(database_state) = database_status.state_for_request() else {
915            warn!(
916                database_id = database_id.database_id,
917                "ignore remove partial graph request on err database",
918            );
919            return;
920        };
921        for partial_graph_id in partial_graph_ids {
922            if let Some(graph) = database_state.graph_states.remove(&partial_graph_id) {
923                assert!(
924                    graph.is_empty(),
925                    "non empty graph to be removed: {}",
926                    &graph
927                );
928            } else {
929                warn!(
930                    partial_graph_id = partial_graph_id.0,
931                    "no partial graph to remove"
932                );
933            }
934        }
935    }
936
937    fn add_partial_graph(&mut self, database_id: DatabaseId, partial_graph_id: PartialGraphId) {
938        let status = match self.state.databases.entry(database_id) {
939            Entry::Occupied(entry) => {
940                let status = entry.into_mut();
941                if let DatabaseStatus::ReceivedExchangeRequest(pending_requests) = status {
942                    let mut database = DatabaseManagedBarrierState::new(
943                        database_id,
944                        self.term_id.clone(),
945                        self.actor_manager.clone(),
946                    );
947                    for ((upstream_actor_id, actor_id), result_sender) in pending_requests.drain(..)
948                    {
949                        database.new_actor_remote_output_request(
950                            actor_id,
951                            upstream_actor_id,
952                            result_sender,
953                        );
954                    }
955                    *status = DatabaseStatus::Running(database);
956                }
957
958                status
959            }
960            Entry::Vacant(entry) => {
961                entry.insert(DatabaseStatus::Running(DatabaseManagedBarrierState::new(
962                    database_id,
963                    self.term_id.clone(),
964                    self.actor_manager.clone(),
965                )))
966            }
967        };
968        if let Some(state) = status.state_for_request() {
969            assert!(
970                state
971                    .graph_states
972                    .insert(
973                        partial_graph_id,
974                        PartialGraphManagedBarrierState::new(&self.actor_manager)
975                    )
976                    .is_none()
977            );
978        }
979    }
980
981    fn reset_database(&mut self, req: ResetDatabaseRequest) {
982        let database_id = DatabaseId::new(req.database_id);
983        if let Some(database_status) = self.state.databases.get_mut(&database_id) {
984            database_status.start_reset(
985                database_id,
986                self.await_epoch_completed_futures.remove(&database_id),
987                req.reset_request_id,
988            );
989        } else {
990            self.ack_database_reset(database_id, None, req.reset_request_id);
991        }
992    }
993
994    fn ack_database_reset(
995        &mut self,
996        database_id: DatabaseId,
997        reset_output: Option<ResetDatabaseOutput>,
998        reset_request_id: u32,
999    ) {
1000        info!(
1001            database_id = database_id.database_id,
1002            "database reset successfully"
1003        );
1004        if let Some(reset_database) = self.state.databases.remove(&database_id) {
1005            match reset_database {
1006                DatabaseStatus::Resetting(_) => {}
1007                _ => {
1008                    unreachable!("must be resetting previously")
1009                }
1010            }
1011        }
1012        self.await_epoch_completed_futures.remove(&database_id);
1013        self.control_stream_handle.ack_reset_database(
1014            database_id,
1015            reset_output.and_then(|output| output.root_err),
1016            reset_request_id,
1017        );
1018    }
1019
1020    /// When some other failure happens (like failed to send barrier), the error is reported using
1021    /// this function. The control stream will be responded with a message to notify about the error,
1022    /// and the global barrier worker will later reset and rerun the database.
1023    fn on_database_failure(
1024        &mut self,
1025        database_id: DatabaseId,
1026        failed_actor: Option<ActorId>,
1027        err: StreamError,
1028        message: impl Into<String>,
1029    ) {
1030        let message = message.into();
1031        error!(database_id = database_id.database_id, ?failed_actor, message, err = ?err.as_report(), "suspend database on error");
1032        let completing_futures = self.await_epoch_completed_futures.remove(&database_id);
1033        self.state
1034            .databases
1035            .get_mut(&database_id)
1036            .expect("should exist")
1037            .suspend(failed_actor, err, completing_futures);
1038        self.control_stream_handle
1039            .send_response(Response::ReportDatabaseFailure(
1040                ReportDatabaseFailureResponse {
1041                    database_id: database_id.database_id,
1042                },
1043            ));
1044    }
1045
1046    /// Force stop all actors on this worker, and then drop their resources.
1047    async fn reset(&mut self, init_request: InitRequest) {
1048        join_all(
1049            self.state
1050                .databases
1051                .values_mut()
1052                .map(|database| database.abort()),
1053        )
1054        .await;
1055        if let Some(m) = self.actor_manager.await_tree_reg.as_ref() {
1056            m.clear();
1057        }
1058
1059        if let Some(hummock) = self.actor_manager.env.state_store().as_hummock() {
1060            hummock
1061                .clear_shared_buffer()
1062                .instrument_await("store_clear_shared_buffer".verbose())
1063                .await
1064        }
1065        self.actor_manager.env.dml_manager_ref().clear();
1066        *self = Self::new(self.actor_manager.clone(), init_request.term_id);
1067        self.actor_manager.env.client_pool().invalidate_all();
1068    }
1069
1070    /// Create a [`LocalBarrierWorker`] with managed mode.
1071    pub fn spawn(
1072        env: StreamEnvironment,
1073        streaming_metrics: Arc<StreamingMetrics>,
1074        await_tree_reg: Option<await_tree::Registry>,
1075        watermark_epoch: AtomicU64Ref,
1076        actor_op_rx: UnboundedReceiver<LocalActorOperation>,
1077    ) -> JoinHandle<()> {
1078        let runtime = {
1079            let mut builder = tokio::runtime::Builder::new_multi_thread();
1080            if let Some(worker_threads_num) = env.config().actor_runtime_worker_threads_num {
1081                builder.worker_threads(worker_threads_num);
1082            }
1083            builder
1084                .thread_name("rw-streaming")
1085                .enable_all()
1086                .build()
1087                .unwrap()
1088        };
1089
1090        let actor_manager = Arc::new(StreamActorManager {
1091            env,
1092            streaming_metrics,
1093            watermark_epoch,
1094            await_tree_reg,
1095            runtime: runtime.into(),
1096        });
1097        let worker = LocalBarrierWorker::new(actor_manager, "uninitialized".into());
1098        tokio::spawn(worker.run(actor_op_rx))
1099    }
1100}
1101
1102pub(super) struct EventSender<T>(pub(super) UnboundedSender<T>);
1103
1104impl<T> Clone for EventSender<T> {
1105    fn clone(&self) -> Self {
1106        Self(self.0.clone())
1107    }
1108}
1109
1110impl<T> EventSender<T> {
1111    pub(super) fn send_event(&self, event: T) {
1112        self.0.send(event).expect("should be able to send event")
1113    }
1114
1115    pub(super) async fn send_and_await<RSP>(
1116        &self,
1117        make_event: impl FnOnce(oneshot::Sender<RSP>) -> T,
1118    ) -> StreamResult<RSP> {
1119        let (tx, rx) = oneshot::channel();
1120        let event = make_event(tx);
1121        self.send_event(event);
1122        rx.await
1123            .map_err(|_| anyhow!("barrier manager maybe reset").into())
1124    }
1125}
1126
1127pub(crate) enum NewOutputRequest {
1128    Local(permit::Sender),
1129    Remote(permit::Sender),
1130}
1131
1132#[cfg(test)]
1133pub(crate) mod barrier_test_utils {
1134    use assert_matches::assert_matches;
1135    use futures::StreamExt;
1136    use risingwave_pb::stream_service::streaming_control_stream_request::{
1137        InitRequest, PbCreatePartialGraphRequest,
1138    };
1139    use risingwave_pb::stream_service::{
1140        InjectBarrierRequest, PbStreamingControlStreamRequest, StreamingControlStreamRequest,
1141        StreamingControlStreamResponse, streaming_control_stream_request,
1142        streaming_control_stream_response,
1143    };
1144    use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
1145    use tokio::sync::oneshot;
1146    use tokio_stream::wrappers::UnboundedReceiverStream;
1147    use tonic::Status;
1148
1149    use crate::executor::Barrier;
1150    use crate::task::barrier_worker::{ControlStreamHandle, EventSender, LocalActorOperation};
1151    use crate::task::{
1152        ActorId, LocalBarrierManager, NewOutputRequest, TEST_DATABASE_ID, TEST_PARTIAL_GRAPH_ID,
1153    };
1154
1155    pub(crate) struct LocalBarrierTestEnv {
1156        pub local_barrier_manager: LocalBarrierManager,
1157        pub(super) actor_op_tx: EventSender<LocalActorOperation>,
1158        pub request_tx: UnboundedSender<Result<StreamingControlStreamRequest, Status>>,
1159        pub response_rx: UnboundedReceiver<Result<StreamingControlStreamResponse, Status>>,
1160    }
1161
1162    impl LocalBarrierTestEnv {
1163        pub(crate) async fn for_test() -> Self {
1164            let actor_op_tx = LocalBarrierManager::spawn_for_test();
1165
1166            let (request_tx, request_rx) = unbounded_channel();
1167            let (response_tx, mut response_rx) = unbounded_channel();
1168
1169            request_tx
1170                .send(Ok(PbStreamingControlStreamRequest {
1171                    request: Some(
1172                        streaming_control_stream_request::Request::CreatePartialGraph(
1173                            PbCreatePartialGraphRequest {
1174                                partial_graph_id: TEST_PARTIAL_GRAPH_ID.into(),
1175                                database_id: TEST_DATABASE_ID.database_id,
1176                            },
1177                        ),
1178                    ),
1179                }))
1180                .unwrap();
1181
1182            actor_op_tx.send_event(LocalActorOperation::NewControlStream {
1183                handle: ControlStreamHandle::new(
1184                    response_tx,
1185                    UnboundedReceiverStream::new(request_rx).boxed(),
1186                ),
1187                init_request: InitRequest {
1188                    term_id: "for_test".into(),
1189                },
1190            });
1191
1192            assert_matches!(
1193                response_rx.recv().await.unwrap().unwrap().response.unwrap(),
1194                streaming_control_stream_response::Response::Init(_)
1195            );
1196
1197            let local_barrier_manager = actor_op_tx
1198                .send_and_await(LocalActorOperation::GetCurrentLocalBarrierManager)
1199                .await
1200                .unwrap();
1201
1202            Self {
1203                local_barrier_manager,
1204                actor_op_tx,
1205                request_tx,
1206                response_rx,
1207            }
1208        }
1209
1210        pub(crate) fn inject_barrier(
1211            &self,
1212            barrier: &Barrier,
1213            actor_to_collect: impl IntoIterator<Item = ActorId>,
1214        ) {
1215            self.request_tx
1216                .send(Ok(StreamingControlStreamRequest {
1217                    request: Some(streaming_control_stream_request::Request::InjectBarrier(
1218                        InjectBarrierRequest {
1219                            request_id: "".to_owned(),
1220                            barrier: Some(barrier.to_protobuf()),
1221                            database_id: TEST_DATABASE_ID.database_id,
1222                            actor_ids_to_collect: actor_to_collect.into_iter().collect(),
1223                            table_ids_to_sync: vec![],
1224                            partial_graph_id: TEST_PARTIAL_GRAPH_ID.into(),
1225                            actors_to_build: vec![],
1226                        },
1227                    )),
1228                }))
1229                .unwrap();
1230        }
1231
1232        pub(crate) async fn flush_all_events(&self) {
1233            Self::flush_all_events_impl(&self.actor_op_tx).await
1234        }
1235
1236        pub(super) async fn flush_all_events_impl(actor_op_tx: &EventSender<LocalActorOperation>) {
1237            let (tx, rx) = oneshot::channel();
1238            actor_op_tx.send_event(LocalActorOperation::Flush(tx));
1239            rx.await.unwrap()
1240        }
1241
1242        pub(crate) async fn take_pending_new_output_requests(
1243            &self,
1244            actor_id: ActorId,
1245        ) -> Vec<(ActorId, NewOutputRequest)> {
1246            self.actor_op_tx
1247                .send_and_await(|tx| LocalActorOperation::TakePendingNewOutputRequest(actor_id, tx))
1248                .await
1249                .unwrap()
1250        }
1251    }
1252}