1use std::collections::hash_map::Entry;
16use std::collections::{HashMap, HashSet, VecDeque};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use assert_matches::assert_matches;
21use await_tree::InstrumentAwait;
22use itertools::Itertools;
23use parking_lot::Mutex;
24use prometheus::HistogramTimer;
25use risingwave_common::catalog::{DatabaseId, TableId};
26use risingwave_common::id::JobId;
27use risingwave_common::metrics::LabelGuardedHistogram;
28use risingwave_hummock_sdk::HummockVersionId;
29use risingwave_pb::catalog::Database;
30use rw_futures_util::pending_on_none;
31use tokio::select;
32use tokio::sync::{oneshot, watch};
33use tokio::time::{Duration, Instant};
34use tokio_stream::wrappers::IntervalStream;
35use tokio_stream::{StreamExt, StreamMap};
36use tracing::{info, warn};
37
38use super::notifier::Notifier;
39use super::{Command, Scheduled};
40use crate::barrier::context::GlobalBarrierWorkerContext;
41use crate::hummock::HummockManagerRef;
42use crate::rpc::metrics::{GLOBAL_META_METRICS, MetaMetrics};
43use crate::{MetaError, MetaResult};
44
45pub(super) struct NewBarrier {
46 pub database_id: DatabaseId,
47 pub command: Option<(Command, Vec<Notifier>)>,
48 pub span: tracing::Span,
49 pub checkpoint: bool,
50}
51
52struct Inner {
57 queue: Mutex<ScheduledQueue>,
58
59 changed_tx: watch::Sender<()>,
61
62 metrics: Arc<MetaMetrics>,
64}
65
66#[derive(Debug)]
67enum QueueStatus {
68 Ready,
70 Blocked(String),
72}
73
74impl QueueStatus {
75 fn is_blocked(&self) -> bool {
76 matches!(self, Self::Blocked(_))
77 }
78}
79
80struct ScheduledQueueItem {
81 command: Command,
82 notifiers: Vec<Notifier>,
83 send_latency_timer: HistogramTimer,
84 span: tracing::Span,
85}
86
87struct StatusQueue<T> {
88 queue: T,
89 status: QueueStatus,
90}
91
92struct DatabaseQueue {
93 inner: VecDeque<ScheduledQueueItem>,
94 send_latency: LabelGuardedHistogram,
95}
96
97type DatabaseScheduledQueue = StatusQueue<DatabaseQueue>;
98type ScheduledQueue = StatusQueue<HashMap<DatabaseId, DatabaseScheduledQueue>>;
99
100impl DatabaseScheduledQueue {
101 fn new(database_id: DatabaseId, metrics: &MetaMetrics, status: QueueStatus) -> Self {
102 Self {
103 queue: DatabaseQueue {
104 inner: Default::default(),
105 send_latency: metrics
106 .barrier_send_latency
107 .with_guarded_label_values(&[database_id.to_string().as_str()]),
108 },
109 status,
110 }
111 }
112}
113
114impl<T> StatusQueue<T> {
115 fn mark_blocked(&mut self, reason: String) {
116 self.status = QueueStatus::Blocked(reason);
117 }
118
119 fn mark_ready(&mut self) -> bool {
120 let prev_blocked = self.status.is_blocked();
121 self.status = QueueStatus::Ready;
122 prev_blocked
123 }
124
125 fn validate_item(&mut self, command: &Command) -> MetaResult<()> {
126 if let QueueStatus::Blocked(reason) = &self.status
132 && !matches!(
133 command,
134 Command::DropStreamingJobs { .. } | Command::DropSubscription { .. }
135 )
136 {
137 return Err(MetaError::unavailable(reason));
138 }
139 Ok(())
140 }
141}
142
143fn tracing_span() -> tracing::Span {
144 if tracing::Span::current().is_none() {
145 tracing::Span::none()
146 } else {
147 tracing::info_span!(
148 "barrier",
149 checkpoint = tracing::field::Empty,
150 epoch = tracing::field::Empty
151 )
152 }
153}
154
155#[derive(Clone)]
158pub struct BarrierScheduler {
159 inner: Arc<Inner>,
160
161 hummock_manager: HummockManagerRef,
163}
164
165impl BarrierScheduler {
166 pub fn new_pair(
169 hummock_manager: HummockManagerRef,
170 metrics: Arc<MetaMetrics>,
171 ) -> (Self, ScheduledBarriers) {
172 let inner = Arc::new(Inner {
173 queue: Mutex::new(ScheduledQueue {
174 queue: Default::default(),
175 status: QueueStatus::Ready,
176 }),
177 changed_tx: watch::channel(()).0,
178 metrics,
179 });
180
181 (
182 Self {
183 inner: inner.clone(),
184 hummock_manager,
185 },
186 ScheduledBarriers { inner },
187 )
188 }
189
190 fn push(
192 &self,
193 database_id: DatabaseId,
194 scheduleds: impl IntoIterator<Item = (Command, Notifier)>,
195 ) -> MetaResult<()> {
196 let mut queue = self.inner.queue.lock();
197 let scheduleds = scheduleds.into_iter().collect_vec();
198 scheduleds
199 .iter()
200 .try_for_each(|(command, _)| queue.validate_item(command))?;
201 let queue = queue.queue.entry(database_id).or_insert_with(|| {
202 DatabaseScheduledQueue::new(database_id, &self.inner.metrics, QueueStatus::Ready)
203 });
204 scheduleds
205 .iter()
206 .try_for_each(|(command, _)| queue.validate_item(command))?;
207 for (command, notifier) in scheduleds {
208 queue.queue.inner.push_back(ScheduledQueueItem {
209 command,
210 notifiers: vec![notifier],
211 send_latency_timer: queue.queue.send_latency.start_timer(),
212 span: tracing_span(),
213 });
214 if queue.queue.inner.len() == 1 {
215 self.inner.changed_tx.send(()).ok();
216 }
217 }
218 Ok(())
219 }
220
221 pub fn try_cancel_scheduled_create(&self, database_id: DatabaseId, job_id: JobId) -> bool {
223 let queue = &mut self.inner.queue.lock();
224 let Some(queue) = queue.queue.get_mut(&database_id) else {
225 return false;
226 };
227
228 if let Some(idx) = queue.queue.inner.iter().position(|scheduled| {
229 if let Command::CreateStreamingJob { info, .. } = &scheduled.command
230 && info.stream_job_fragments.stream_job_id() == job_id
231 {
232 true
233 } else {
234 false
235 }
236 }) {
237 queue.queue.inner.remove(idx).unwrap();
238 true
239 } else {
240 false
241 }
242 }
243
244 #[await_tree::instrument("run_commands({})", commands.iter().join(", "))]
251 async fn run_multiple_commands(
252 &self,
253 database_id: DatabaseId,
254 commands: Vec<Command>,
255 ) -> MetaResult<()> {
256 let mut contexts = Vec::with_capacity(commands.len());
257 let mut scheduleds = Vec::with_capacity(commands.len());
258
259 for command in commands {
260 let (started_tx, started_rx) = oneshot::channel();
261 let (collect_tx, collect_rx) = oneshot::channel();
262
263 contexts.push((started_rx, collect_rx));
264 scheduleds.push((
265 command,
266 Notifier {
267 started: Some(started_tx),
268 collected: Some(collect_tx),
269 },
270 ));
271 }
272
273 self.push(database_id, scheduleds)?;
274
275 for (injected_rx, collect_rx) in contexts {
276 tracing::trace!("waiting for injected_rx");
278 injected_rx
279 .instrument_await("wait_injected")
280 .await
281 .ok()
282 .context("failed to inject barrier")??;
283
284 tracing::trace!("waiting for collect_rx");
285 collect_rx
287 .instrument_await("wait_collected")
288 .await
289 .ok()
290 .context("failed to collect barrier")??;
291 }
292
293 Ok(())
294 }
295
296 pub async fn run_command(&self, database_id: DatabaseId, command: Command) -> MetaResult<()> {
300 tracing::trace!("run_command: {:?}", command);
301 let ret = self.run_multiple_commands(database_id, vec![command]).await;
302 tracing::trace!("run_command finished");
303 ret
304 }
305
306 pub fn run_command_no_wait(&self, database_id: DatabaseId, command: Command) -> MetaResult<()> {
308 tracing::trace!("run_command_no_wait: {:?}", command);
309 self.push(database_id, vec![(command, Notifier::default())])
310 }
311
312 pub async fn flush(&self, database_id: DatabaseId) -> MetaResult<HummockVersionId> {
314 let start = Instant::now();
315
316 tracing::debug!("start barrier flush");
317 self.run_multiple_commands(database_id, vec![Command::Flush])
318 .await?;
319
320 let elapsed = Instant::now().duration_since(start);
321 tracing::debug!("barrier flushed in {:?}", elapsed);
322
323 let version_id = self.hummock_manager.get_version_id().await;
324 Ok(version_id)
325 }
326}
327
328pub struct ScheduledBarriers {
330 inner: Arc<Inner>,
331}
332
333#[derive(Debug)]
335pub struct DatabaseBarrierState {
336 barrier_interval: Option<Duration>,
337 checkpoint_frequency: Option<u64>,
338 num_uncheckpointed_barrier: u64,
340}
341
342impl DatabaseBarrierState {
343 fn new(barrier_interval_ms: Option<u32>, checkpoint_frequency: Option<u64>) -> Self {
344 Self {
345 barrier_interval: barrier_interval_ms.map(|ms| Duration::from_millis(ms as u64)),
346 checkpoint_frequency,
347 num_uncheckpointed_barrier: 0,
348 }
349 }
350}
351
352#[derive(Default, Debug)]
354pub struct PeriodicBarriers {
355 sys_barrier_interval: Duration,
357 sys_checkpoint_frequency: u64,
358 databases: HashMap<DatabaseId, DatabaseBarrierState>,
360 timer_streams: StreamMap<DatabaseId, IntervalStream>,
363 force_checkpoint_databases: HashSet<DatabaseId>,
364}
365
366impl PeriodicBarriers {
367 pub(super) fn new(
368 sys_barrier_interval: Duration,
369 sys_checkpoint_frequency: u64,
370 database_infos: Vec<Database>,
371 ) -> Self {
372 let mut databases = HashMap::with_capacity(database_infos.len());
373 let mut timer_streams = StreamMap::with_capacity(database_infos.len());
374 database_infos.into_iter().for_each(|database| {
375 let database_id: DatabaseId = database.id;
376 let barrier_interval_ms = database.barrier_interval_ms;
377 let checkpoint_frequency = database.checkpoint_frequency;
378 databases.insert(
379 database_id,
380 DatabaseBarrierState::new(barrier_interval_ms, checkpoint_frequency),
381 );
382 let duration = if let Some(ms) = barrier_interval_ms {
383 Duration::from_millis(ms as u64)
384 } else {
385 sys_barrier_interval
386 };
387
388 let interval_stream = Self::new_interval_stream(duration, &database_id);
390 timer_streams.insert(database_id, interval_stream);
391 });
392 Self {
393 sys_barrier_interval,
394 sys_checkpoint_frequency,
395 databases,
396 timer_streams,
397 force_checkpoint_databases: Default::default(),
398 }
399 }
400
401 fn new_interval_stream(duration: Duration, database_id: &DatabaseId) -> IntervalStream {
403 GLOBAL_META_METRICS
404 .barrier_interval_by_database
405 .with_label_values(&[&database_id.to_string()])
406 .set(duration.as_millis_f64());
407 let mut interval = tokio::time::interval(duration);
408 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
409 IntervalStream::new(interval)
410 }
411
412 pub(super) fn set_sys_barrier_interval(&mut self, duration: Duration) {
414 if self.sys_barrier_interval == duration {
415 return;
416 }
417 self.sys_barrier_interval = duration;
418 for (db_id, db_state) in &mut self.databases {
420 if db_state.barrier_interval.is_none() {
421 let interval_stream = Self::new_interval_stream(duration, db_id);
422 self.timer_streams.insert(*db_id, interval_stream);
423 }
424 }
425 }
426
427 pub fn set_sys_checkpoint_frequency(&mut self, frequency: u64) {
429 if self.sys_checkpoint_frequency == frequency {
430 return;
431 }
432 self.sys_checkpoint_frequency = frequency;
433 for db_state in self.databases.values_mut() {
435 if db_state.checkpoint_frequency.is_none() {
436 db_state.num_uncheckpointed_barrier = 0;
437 }
438 }
439 }
440
441 pub(super) fn update_database_barrier(
442 &mut self,
443 database_id: DatabaseId,
444 barrier_interval_ms: Option<u32>,
445 checkpoint_frequency: Option<u64>,
446 ) {
447 match self.databases.entry(database_id) {
448 Entry::Occupied(mut entry) => {
449 let db_state = entry.get_mut();
450 db_state.barrier_interval =
451 barrier_interval_ms.map(|ms| Duration::from_millis(ms as u64));
452 db_state.checkpoint_frequency = checkpoint_frequency;
453 db_state.num_uncheckpointed_barrier = 0;
455 }
456 Entry::Vacant(entry) => {
457 entry.insert(DatabaseBarrierState::new(
458 barrier_interval_ms,
459 checkpoint_frequency,
460 ));
461 }
462 }
463
464 let duration = if let Some(ms) = barrier_interval_ms {
466 Duration::from_millis(ms as u64)
467 } else {
468 self.sys_barrier_interval
469 };
470
471 let interval_stream = Self::new_interval_stream(duration, &database_id);
472 self.timer_streams.insert(database_id, interval_stream);
473 }
474
475 pub fn force_checkpoint_in_next_barrier(&mut self, database_id: DatabaseId) {
477 if self.databases.contains_key(&database_id) {
478 self.force_checkpoint_databases.insert(database_id);
479 } else {
480 warn!(
481 ?database_id,
482 "force checkpoint in next barrier for non-existing database"
483 );
484 }
485 }
486
487 fn reset_database_timer(&mut self, database_id: DatabaseId) {
488 assert!(
490 self.databases.contains_key(&database_id),
491 "database {} not found in scheduled barriers",
492 database_id
493 );
494 assert!(
495 self.timer_streams.contains_key(&database_id),
496 "timer stream for database {} not found in scheduled barriers",
497 database_id
498 );
499 for (db_id, timer_stream) in self.timer_streams.iter_mut() {
501 if *db_id == database_id {
502 timer_stream.as_mut().reset();
503 }
504 }
505 }
506
507 #[await_tree::instrument]
508 pub(super) async fn next_barrier(
509 &mut self,
510 context: &impl GlobalBarrierWorkerContext,
511 ) -> NewBarrier {
512 let force_checkpoint_database = self.force_checkpoint_databases.drain().next();
513 let new_barrier = if let Some(database_id) = force_checkpoint_database {
514 self.reset_database_timer(database_id);
515 NewBarrier {
516 database_id,
517 command: None,
518 span: tracing_span(),
519 checkpoint: true,
520 }
521 } else {
522 select! {
523 biased;
524 scheduled = context.next_scheduled() => {
525 let database_id = scheduled.database_id;
526 self.reset_database_timer(database_id);
527 let checkpoint = scheduled.command.need_checkpoint() || self.try_get_checkpoint(database_id);
528 NewBarrier {
529 database_id: scheduled.database_id,
530 command: Some((scheduled.command, scheduled.notifiers)),
531 span: scheduled.span,
532 checkpoint,
533 }
534 },
535 (database_id, _instant) = pending_on_none(self.timer_streams.next()) => {
538 let checkpoint = self.try_get_checkpoint(database_id);
539 NewBarrier {
540 database_id,
541 command: None,
542 span: tracing_span(),
543 checkpoint,
544 }
545 }
546 }
547 };
548 self.update_num_uncheckpointed_barrier(new_barrier.database_id, new_barrier.checkpoint);
549
550 new_barrier
551 }
552
553 fn try_get_checkpoint(&self, database_id: DatabaseId) -> bool {
555 let db_state = self.databases.get(&database_id).unwrap();
556 let checkpoint_frequency = db_state
557 .checkpoint_frequency
558 .unwrap_or(self.sys_checkpoint_frequency);
559 db_state.num_uncheckpointed_barrier + 1 >= checkpoint_frequency
560 }
561
562 fn update_num_uncheckpointed_barrier(&mut self, database_id: DatabaseId, checkpoint: bool) {
564 let db_state = self.databases.get_mut(&database_id).unwrap();
565 if checkpoint {
566 db_state.num_uncheckpointed_barrier = 0;
567 } else {
568 db_state.num_uncheckpointed_barrier += 1;
569 }
570 }
571}
572
573impl ScheduledBarriers {
574 pub(super) async fn next_scheduled(&self) -> Scheduled {
575 'outer: loop {
576 let mut rx = self.inner.changed_tx.subscribe();
577 {
578 let mut queue = self.inner.queue.lock();
579 if queue.status.is_blocked() {
580 continue;
581 }
582 for (database_id, queue) in &mut queue.queue {
583 if queue.status.is_blocked() {
584 continue;
585 }
586 if let Some(item) = queue.queue.inner.pop_front() {
587 item.send_latency_timer.observe_duration();
588 break 'outer Scheduled {
589 database_id: *database_id,
590 command: item.command,
591 notifiers: item.notifiers,
592 span: item.span,
593 };
594 }
595 }
596 }
597 rx.changed().await.unwrap();
598 }
599 }
600}
601
602pub(super) enum MarkReadyOptions {
603 Database(DatabaseId),
604 Global {
605 blocked_databases: HashSet<DatabaseId>,
606 },
607}
608
609pub(super) struct PreApplyDropCancel {
610 pub streaming_job_ids: Vec<JobId>,
611 pub dropped_state_table_ids: Vec<TableId>,
612}
613
614impl ScheduledBarriers {
615 pub(super) fn pre_apply_drop_cancel(
617 &self,
618 database_id: Option<DatabaseId>,
619 ) -> PreApplyDropCancel {
620 self.pre_apply_drop_cancel_scheduled(database_id)
621 }
622
623 pub(super) fn abort_and_mark_blocked(
626 &self,
627 database_id: Option<DatabaseId>,
628 reason: impl Into<String>,
629 ) {
630 let mut queue = self.inner.queue.lock();
631 fn database_blocked_reason(database_id: DatabaseId, reason: &String) -> String {
632 format!("database {} unavailable {}", database_id, reason)
633 }
634 fn mark_blocked_and_notify_failed(
635 database_id: DatabaseId,
636 queue: &mut DatabaseScheduledQueue,
637 reason: &String,
638 ) {
639 let reason = database_blocked_reason(database_id, reason);
640 let err: MetaError = anyhow!("{}", reason).into();
641 queue.mark_blocked(reason);
642 while let Some(ScheduledQueueItem { notifiers, .. }) = queue.queue.inner.pop_front() {
643 notifiers
644 .into_iter()
645 .for_each(|notify| notify.notify_collection_failed(err.clone()))
646 }
647 }
648 if let Some(database_id) = database_id {
649 let reason = reason.into();
650 match queue.queue.entry(database_id) {
651 Entry::Occupied(entry) => {
652 let queue = entry.into_mut();
653 if queue.status.is_blocked() {
654 if cfg!(debug_assertions) {
655 panic!("database {} marked as blocked twice", database_id);
656 } else {
657 warn!(?database_id, "database marked as blocked twice");
658 }
659 }
660 info!(?database_id, "database marked as blocked");
661 mark_blocked_and_notify_failed(database_id, queue, &reason);
662 }
663 Entry::Vacant(entry) => {
664 entry.insert(DatabaseScheduledQueue::new(
665 database_id,
666 &self.inner.metrics,
667 QueueStatus::Blocked(database_blocked_reason(database_id, &reason)),
668 ));
669 }
670 }
671 } else {
672 let reason = reason.into();
673 if queue.status.is_blocked() {
674 if cfg!(debug_assertions) {
675 panic!("cluster marked as blocked twice");
676 } else {
677 warn!("cluster marked as blocked twice");
678 }
679 }
680 info!("cluster marked as blocked");
681 queue.mark_blocked(reason.clone());
682 for (database_id, queue) in &mut queue.queue {
683 mark_blocked_and_notify_failed(*database_id, queue, &reason);
684 }
685 }
686 }
687
688 pub(super) fn mark_ready(&self, options: MarkReadyOptions) {
690 let mut queue = self.inner.queue.lock();
691 let queue = &mut *queue;
692 match options {
693 MarkReadyOptions::Database(database_id) => {
694 info!(?database_id, "database marked as ready");
695 let database_queue = queue.queue.entry(database_id).or_insert_with(|| {
696 DatabaseScheduledQueue::new(
697 database_id,
698 &self.inner.metrics,
699 QueueStatus::Ready,
700 )
701 });
702 if !database_queue.status.is_blocked() {
703 if cfg!(debug_assertions) {
704 panic!("database {} marked as ready twice", database_id);
705 } else {
706 warn!(?database_id, "database marked as ready twice");
707 }
708 }
709 if database_queue.mark_ready()
710 && !queue.status.is_blocked()
711 && !database_queue.queue.inner.is_empty()
712 {
713 self.inner.changed_tx.send(()).ok();
714 }
715 }
716 MarkReadyOptions::Global { blocked_databases } => {
717 if !queue.status.is_blocked() {
718 if cfg!(debug_assertions) {
719 panic!("cluster marked as ready twice");
720 } else {
721 warn!("cluster marked as ready twice");
722 }
723 }
724 info!(?blocked_databases, "cluster marked as ready");
725 let prev_blocked = queue.mark_ready();
726 for database_id in &blocked_databases {
727 queue.queue.entry(*database_id).or_insert_with(|| {
728 DatabaseScheduledQueue::new(
729 *database_id,
730 &self.inner.metrics,
731 QueueStatus::Blocked(format!(
732 "database {} failed to recover in global recovery",
733 database_id
734 )),
735 )
736 });
737 }
738 for (database_id, queue) in &mut queue.queue {
739 if !blocked_databases.contains(database_id) {
740 queue.mark_ready();
741 }
742 }
743 if prev_blocked
744 && queue
745 .queue
746 .values()
747 .any(|database_queue| !database_queue.queue.inner.is_empty())
748 {
749 self.inner.changed_tx.send(()).ok();
750 }
751 }
752 }
753 }
754
755 pub(super) fn pre_apply_drop_cancel_scheduled(
758 &self,
759 database_id: Option<DatabaseId>,
760 ) -> PreApplyDropCancel {
761 let mut queue = self.inner.queue.lock();
762 let mut drop_cancel = PreApplyDropCancel {
763 streaming_job_ids: vec![],
764 dropped_state_table_ids: vec![],
765 };
766
767 let mut pre_apply_drop_cancel = |queue: &mut DatabaseScheduledQueue| {
768 while let Some(ScheduledQueueItem {
769 notifiers, command, ..
770 }) = queue.queue.inner.pop_front()
771 {
772 match command {
773 Command::DropStreamingJobs {
774 streaming_job_ids,
775 unregistered_state_table_ids,
776 ..
777 } => {
778 drop_cancel.streaming_job_ids.extend(streaming_job_ids);
779 drop_cancel
780 .dropped_state_table_ids
781 .extend(unregistered_state_table_ids);
782 }
783 Command::DropSubscription { .. } => {}
784 _ => {
785 unreachable!("only drop and cancel streaming jobs should be buffered");
786 }
787 }
788 notifiers.into_iter().for_each(|notify| {
789 notify.notify_collected();
790 });
791 }
792 };
793
794 if let Some(database_id) = database_id {
795 assert_matches!(queue.status, QueueStatus::Ready);
796 if let Some(queue) = queue.queue.get_mut(&database_id) {
797 assert_matches!(queue.status, QueueStatus::Blocked(_));
798 pre_apply_drop_cancel(queue);
799 }
800 } else {
801 assert_matches!(queue.status, QueueStatus::Blocked(_));
802 for queue in queue.queue.values_mut() {
803 pre_apply_drop_cancel(queue);
804 }
805 }
806
807 drop_cancel
808 }
809}
810
811#[cfg(test)]
812mod tests {
813 use futures::FutureExt;
814
815 use super::*;
816
817 fn create_test_database(
818 id: u32,
819 barrier_interval_ms: Option<u32>,
820 checkpoint_frequency: Option<u64>,
821 ) -> Database {
822 Database {
823 id: id.into(),
824 name: format!("test_db_{}", id),
825 barrier_interval_ms,
826 checkpoint_frequency,
827 ..Default::default()
828 }
829 }
830
831 struct MockGlobalBarrierWorkerContext {
833 scheduled_rx: tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<Scheduled>>,
834 }
835
836 impl MockGlobalBarrierWorkerContext {
837 fn new() -> (Self, tokio::sync::mpsc::UnboundedSender<Scheduled>) {
838 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
839 (
840 Self {
841 scheduled_rx: tokio::sync::Mutex::new(rx),
842 },
843 tx,
844 )
845 }
846 }
847
848 impl GlobalBarrierWorkerContext for MockGlobalBarrierWorkerContext {
849 async fn next_scheduled(&self) -> Scheduled {
850 self.scheduled_rx.lock().await.recv().await.unwrap()
851 }
852
853 async fn commit_epoch(
854 &self,
855 _commit_info: crate::hummock::CommitEpochInfo,
856 ) -> MetaResult<risingwave_pb::hummock::HummockVersionStats> {
857 unimplemented!()
858 }
859
860 fn abort_and_mark_blocked(
861 &self,
862 _database_id: Option<DatabaseId>,
863 _recovery_reason: crate::barrier::RecoveryReason,
864 ) {
865 unimplemented!()
866 }
867
868 fn mark_ready(&self, _options: MarkReadyOptions) {
869 unimplemented!()
870 }
871
872 async fn post_collect_command(
873 &self,
874 _command: crate::barrier::command::PostCollectCommand,
875 ) -> MetaResult<()> {
876 unimplemented!()
877 }
878
879 async fn notify_creating_job_failed(&self, _database_id: Option<DatabaseId>, _err: String) {
880 unimplemented!()
881 }
882
883 async fn finish_creating_job(
884 &self,
885 _job: crate::barrier::progress::TrackingJob,
886 ) -> MetaResult<()> {
887 unimplemented!()
888 }
889
890 async fn new_control_stream(
891 &self,
892 _node: &risingwave_pb::common::WorkerNode,
893 _init_request: &risingwave_pb::stream_service::streaming_control_stream_request::PbInitRequest,
894 ) -> MetaResult<risingwave_rpc_client::StreamingControlHandle> {
895 unimplemented!()
896 }
897
898 async fn reload_runtime_info(
899 &self,
900 ) -> MetaResult<crate::barrier::BarrierWorkerRuntimeInfoSnapshot> {
901 unimplemented!()
902 }
903
904 async fn reload_database_runtime_info(
905 &self,
906 _database_id: DatabaseId,
907 ) -> MetaResult<crate::barrier::DatabaseRuntimeInfoSnapshot> {
908 unimplemented!()
909 }
910
911 async fn handle_list_finished_source_ids(
912 &self,
913 _list_finished_source_ids: Vec<
914 risingwave_pb::stream_service::barrier_complete_response::PbListFinishedSource,
915 >,
916 ) -> MetaResult<()> {
917 unimplemented!()
918 }
919
920 async fn handle_load_finished_source_ids(
921 &self,
922 _load_finished_source_ids: Vec<
923 risingwave_pb::stream_service::barrier_complete_response::PbLoadFinishedSource,
924 >,
925 ) -> MetaResult<()> {
926 unimplemented!()
927 }
928
929 async fn finish_cdc_table_backfill(&self, _job_id: JobId) -> MetaResult<()> {
930 unimplemented!()
931 }
932
933 async fn handle_refresh_finished_table_ids(
934 &self,
935 _refresh_finished_table_ids: Vec<JobId>,
936 ) -> MetaResult<()> {
937 unimplemented!()
938 }
939 }
940
941 #[tokio::test(start_paused = true)]
942 async fn test_next_barrier_with_different_intervals() {
943 let databases = vec![
945 create_test_database(1, Some(50), Some(2)), create_test_database(2, Some(100), Some(3)), create_test_database(3, None, Some(5)), ];
949
950 let mut periodic = PeriodicBarriers::new(
951 Duration::from_millis(200), 10, databases,
954 );
955
956 let (context, _tx) = MockGlobalBarrierWorkerContext::new();
957
958 for _ in 0..3 {
960 let barrier = periodic.next_barrier(&context).await;
961 assert!(barrier.command.is_none()); assert!(!barrier.checkpoint); }
964
965 let start_time = Instant::now();
968 let barrier = periodic.next_barrier(&context).await;
969 let mut elapsed = start_time.elapsed();
970
971 assert_eq!(barrier.database_id, DatabaseId::from(1));
973 assert!(barrier.command.is_none()); assert!(barrier.checkpoint); assert_eq!(
977 elapsed,
978 Duration::from_millis(50),
979 "Elapsed time exceeded: {:?}",
980 elapsed
981 );
982
983 let db1_id = DatabaseId::from(1);
985 let db1_state = periodic.databases.get_mut(&db1_id).unwrap();
986 assert_eq!(db1_state.num_uncheckpointed_barrier, 0); for _ in 0..2 {
990 let barrier = periodic.next_barrier(&context).await;
991 assert!(barrier.command.is_none()); assert!(!barrier.checkpoint); }
994
995 elapsed = start_time.elapsed();
996
997 assert_eq!(
998 elapsed,
999 Duration::from_millis(100),
1000 "Elapsed time exceeded: {:?}",
1001 elapsed
1002 );
1003 }
1004
1005 #[tokio::test]
1006 async fn test_next_barrier_with_scheduled_command() {
1007 let databases = vec![
1008 create_test_database(1, Some(1000), Some(2)), ];
1010
1011 let mut periodic = PeriodicBarriers::new(Duration::from_millis(1000), 10, databases);
1012
1013 let (context, tx) = MockGlobalBarrierWorkerContext::new();
1014
1015 periodic.next_barrier(&context).await;
1017
1018 let scheduled_command = Scheduled {
1020 database_id: DatabaseId::from(1),
1021 command: Command::Flush,
1022 notifiers: vec![],
1023 span: tracing::Span::none(),
1024 };
1025
1026 let tx_clone = tx.clone();
1028 tokio::spawn(async move {
1029 tokio::time::sleep(Duration::from_millis(10)).await;
1030 tx_clone.send(scheduled_command).unwrap();
1031 });
1032
1033 let barrier = periodic.next_barrier(&context).await;
1034
1035 assert!(barrier.command.is_some());
1037 assert_eq!(barrier.database_id, DatabaseId::from(1));
1038
1039 if let Some((command, _)) = barrier.command {
1040 assert!(matches!(command, Command::Flush));
1041 }
1042 }
1043
1044 #[tokio::test(start_paused = true)]
1045 async fn test_next_barrier_multiple_databases_timing() {
1046 let databases = vec![
1047 create_test_database(1, Some(30), Some(10)), create_test_database(2, Some(100), Some(10)), ];
1050
1051 let mut periodic = PeriodicBarriers::new(Duration::from_millis(500), 10, databases);
1052
1053 let (context, _tx) = MockGlobalBarrierWorkerContext::new();
1054
1055 for _ in 0..2 {
1057 periodic.next_barrier(&context).await;
1058 }
1059
1060 let mut barrier_counts = HashMap::new();
1061
1062 let mut barriers = Vec::new();
1064 for _ in 0..5 {
1065 let barrier = periodic.next_barrier(&context).await;
1066 barriers.push(barrier);
1067 }
1068
1069 for barrier in barriers {
1071 *barrier_counts.entry(barrier.database_id).or_insert(0) += 1;
1072 }
1073
1074 let db1_count = barrier_counts.get(&DatabaseId::from(1)).unwrap_or(&0);
1076 let db2_count = barrier_counts.get(&DatabaseId::from(2)).unwrap_or(&0);
1077
1078 assert_eq!(*db1_count, 4);
1080 assert_eq!(*db2_count, 1);
1081 }
1082
1083 #[tokio::test]
1084 async fn test_next_barrier_force_checkpoint() {
1085 let databases = vec![create_test_database(1, Some(100), Some(10))];
1086
1087 let mut periodic = PeriodicBarriers::new(Duration::from_millis(100), 10, databases);
1088
1089 let (context, _tx) = MockGlobalBarrierWorkerContext::new();
1090
1091 periodic.force_checkpoint_in_next_barrier(DatabaseId::from(1));
1093
1094 let barrier = periodic.next_barrier(&context).now_or_never().unwrap();
1095
1096 assert!(barrier.checkpoint);
1098 assert_eq!(barrier.database_id, DatabaseId::from(1));
1099 assert!(barrier.command.is_none());
1100 }
1101
1102 #[tokio::test]
1103 async fn test_next_barrier_checkpoint_frequency() {
1104 let databases = vec![create_test_database(1, Some(50), Some(2))]; let mut periodic = PeriodicBarriers::new(Duration::from_millis(50), 10, databases);
1107
1108 let (context, _tx) = MockGlobalBarrierWorkerContext::new();
1109
1110 let barrier1 = periodic.next_barrier(&context).await;
1112 assert!(!barrier1.checkpoint);
1113
1114 let barrier2 = periodic.next_barrier(&context).await;
1116 assert!(barrier2.checkpoint);
1117
1118 let barrier3 = periodic.next_barrier(&context).await;
1120 assert!(!barrier3.checkpoint);
1121 }
1122
1123 #[tokio::test]
1124 async fn test_update_database_barrier() {
1125 let databases = vec![create_test_database(1, Some(1000), Some(10))];
1126
1127 let mut periodic = PeriodicBarriers::new(Duration::from_millis(500), 20, databases);
1128
1129 let database_id = DatabaseId::new(1);
1130
1131 periodic.update_database_barrier(database_id, Some(2000), Some(15));
1133
1134 let db_state = periodic.databases.get(&database_id).unwrap();
1135 assert_eq!(db_state.barrier_interval, Some(Duration::from_millis(2000)));
1136 assert_eq!(db_state.checkpoint_frequency, Some(15));
1137 assert_eq!(db_state.num_uncheckpointed_barrier, 0);
1138 assert!(!periodic.force_checkpoint_databases.contains(&database_id));
1139
1140 periodic.update_database_barrier(DatabaseId::from(2), None, None);
1142
1143 assert!(periodic.databases.contains_key(&DatabaseId::from(2)));
1144 let db2_state = periodic.databases.get(&DatabaseId::from(2)).unwrap();
1145 assert_eq!(db2_state.barrier_interval, None);
1146 assert_eq!(db2_state.checkpoint_frequency, None);
1147 }
1148}