1use std::cmp;
16use std::cmp::Ordering;
17use std::collections::{HashMap, HashSet, VecDeque};
18use std::ops::Add;
19use std::sync::Arc;
20use std::time::{Duration, SystemTime};
21
22use itertools::Itertools;
23use risingwave_common::RW_VERSION;
24use risingwave_common::hash::WorkerSlotId;
25use risingwave_common::util::addr::HostAddr;
26use risingwave_common::util::resource_util::cpu::total_cpu_available;
27use risingwave_common::util::resource_util::memory::system_memory_available_bytes;
28use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
29use risingwave_license::LicenseManager;
30use risingwave_meta_model::prelude::{Worker, WorkerProperty};
31use risingwave_meta_model::worker::{WorkerStatus, WorkerType};
32use risingwave_meta_model::{TransactionId, WorkerId, worker, worker_property};
33use risingwave_pb::common::worker_node::{
34 PbProperty, PbProperty as AddNodeProperty, PbResource, PbState,
35};
36use risingwave_pb::common::{HostAddress, PbHostAddress, PbWorkerNode, PbWorkerType, WorkerNode};
37use risingwave_pb::meta::subscribe_response::{Info, Operation};
38use risingwave_pb::meta::update_worker_node_schedulability_request::Schedulability;
39use sea_orm::ActiveValue::Set;
40use sea_orm::prelude::Expr;
41use sea_orm::{
42 ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, QuerySelect,
43 TransactionTrait,
44};
45use thiserror_ext::AsReport;
46use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
47use tokio::sync::oneshot::Sender;
48use tokio::sync::{RwLock, RwLockReadGuard};
49use tokio::task::JoinHandle;
50
51use crate::controller::utils::filter_workers_by_resource_group;
52use crate::manager::{LocalNotification, META_NODE_ID, MetaSrvEnv, WorkerKey};
53use crate::model::ClusterId;
54use crate::{MetaError, MetaResult};
55
56pub type ClusterControllerRef = Arc<ClusterController>;
57
58pub struct ClusterController {
59 env: MetaSrvEnv,
60 max_heartbeat_interval: Duration,
61 inner: RwLock<ClusterControllerInner>,
62 started_at: u64,
64}
65
66struct WorkerInfo(
67 worker::Model,
68 Option<worker_property::Model>,
69 WorkerExtraInfo,
70);
71
72impl From<WorkerInfo> for PbWorkerNode {
73 fn from(info: WorkerInfo) -> Self {
74 Self {
75 id: info.0.worker_id as _,
76 r#type: PbWorkerType::from(info.0.worker_type) as _,
77 host: Some(PbHostAddress {
78 host: info.0.host,
79 port: info.0.port,
80 }),
81 state: PbState::from(info.0.status) as _,
82 property: info.1.as_ref().map(|p| PbProperty {
83 is_streaming: p.is_streaming,
84 is_serving: p.is_serving,
85 is_unschedulable: p.is_unschedulable,
86 internal_rpc_host_addr: p.internal_rpc_host_addr.clone().unwrap_or_default(),
87 resource_group: p.resource_group.clone(),
88 parallelism: info.1.as_ref().map(|p| p.parallelism).unwrap_or_default() as u32,
89 }),
90 transactional_id: info.0.transaction_id.map(|id| id as _),
91 resource: Some(info.2.resource),
92 started_at: info.2.started_at,
93 }
94 }
95}
96
97impl ClusterController {
98 pub async fn new(env: MetaSrvEnv, max_heartbeat_interval: Duration) -> MetaResult<Self> {
99 let inner = ClusterControllerInner::new(
100 env.meta_store_ref().conn.clone(),
101 env.opts.disable_automatic_parallelism_control,
102 )
103 .await?;
104 Ok(Self {
105 env,
106 max_heartbeat_interval,
107 inner: RwLock::new(inner),
108 started_at: timestamp_now_sec(),
109 })
110 }
111
112 pub async fn get_inner_read_guard(&self) -> RwLockReadGuard<'_, ClusterControllerInner> {
115 self.inner.read().await
116 }
117
118 pub async fn count_worker_by_type(&self) -> MetaResult<HashMap<WorkerType, i64>> {
119 self.inner.read().await.count_worker_by_type().await
120 }
121
122 pub async fn compute_node_total_cpu_count(&self) -> usize {
123 self.inner.read().await.compute_node_total_cpu_count()
124 }
125
126 async fn update_compute_node_total_cpu_count(&self) -> MetaResult<()> {
127 let total_cpu_cores = self.compute_node_total_cpu_count().await;
128
129 LicenseManager::get().update_cpu_core_count(total_cpu_cores);
131 self.env.notification_manager().notify_all_without_version(
133 Operation::Update, Info::ComputeNodeTotalCpuCount(total_cpu_cores as _),
135 );
136
137 Ok(())
138 }
139
140 pub async fn add_worker(
145 &self,
146 r#type: PbWorkerType,
147 host_address: HostAddress,
148 property: AddNodeProperty,
149 resource: PbResource,
150 ) -> MetaResult<WorkerId> {
151 let worker_id = self
152 .inner
153 .write()
154 .await
155 .add_worker(
156 r#type,
157 host_address,
158 property,
159 resource,
160 self.max_heartbeat_interval,
161 )
162 .await?;
163
164 if r#type == PbWorkerType::ComputeNode {
165 self.update_compute_node_total_cpu_count().await?;
166 }
167
168 Ok(worker_id)
169 }
170
171 pub async fn activate_worker(&self, worker_id: WorkerId) -> MetaResult<()> {
172 let inner = self.inner.write().await;
173 let worker = inner.activate_worker(worker_id).await?;
174
175 if worker.r#type() == PbWorkerType::ComputeNode || worker.r#type() == PbWorkerType::Frontend
178 {
179 self.env
180 .notification_manager()
181 .notify_frontend(Operation::Add, Info::Node(worker.clone()))
182 .await;
183 }
184 self.env
185 .notification_manager()
186 .notify_local_subscribers(LocalNotification::WorkerNodeActivated(worker));
187
188 Ok(())
189 }
190
191 pub async fn delete_worker(&self, host_address: HostAddress) -> MetaResult<WorkerNode> {
192 let worker = self.inner.write().await.delete_worker(host_address).await?;
193
194 if worker.r#type() == PbWorkerType::ComputeNode || worker.r#type() == PbWorkerType::Frontend
195 {
196 self.env
197 .notification_manager()
198 .notify_frontend(Operation::Delete, Info::Node(worker.clone()))
199 .await;
200 if worker.r#type() == PbWorkerType::ComputeNode {
201 self.update_compute_node_total_cpu_count().await?;
202 }
203 }
204
205 self.env
209 .notification_manager()
210 .notify_local_subscribers(LocalNotification::WorkerNodeDeleted(worker.clone()));
211
212 Ok(worker)
213 }
214
215 pub async fn update_schedulability(
216 &self,
217 worker_ids: Vec<WorkerId>,
218 schedulability: Schedulability,
219 ) -> MetaResult<()> {
220 self.inner
221 .write()
222 .await
223 .update_schedulability(worker_ids, schedulability)
224 .await
225 }
226
227 pub async fn heartbeat(&self, worker_id: WorkerId) -> MetaResult<()> {
229 tracing::trace!(target: "events::meta::server_heartbeat", worker_id = worker_id, "receive heartbeat");
230 self.inner
231 .write()
232 .await
233 .heartbeat(worker_id, self.max_heartbeat_interval)
234 }
235
236 pub fn start_heartbeat_checker(
237 cluster_controller: ClusterControllerRef,
238 check_interval: Duration,
239 ) -> (JoinHandle<()>, Sender<()>) {
240 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
241 let join_handle = tokio::spawn(async move {
242 let mut min_interval = tokio::time::interval(check_interval);
243 min_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
244 loop {
245 tokio::select! {
246 _ = min_interval.tick() => {},
248 _ = &mut shutdown_rx => {
250 tracing::info!("Heartbeat checker is stopped");
251 return;
252 }
253 }
254
255 let mut inner = cluster_controller.inner.write().await;
256 for worker in inner
258 .worker_extra_info
259 .values_mut()
260 .filter(|worker| worker.expire_at.is_none())
261 {
262 worker.update_ttl(cluster_controller.max_heartbeat_interval);
263 }
264
265 let now = timestamp_now_sec();
267 let worker_to_delete = inner
268 .worker_extra_info
269 .iter()
270 .filter(|(_, info)| info.expire_at.unwrap() < now)
271 .map(|(id, _)| *id)
272 .collect_vec();
273
274 let worker_infos = match Worker::find()
276 .select_only()
277 .column(worker::Column::WorkerId)
278 .column(worker::Column::WorkerType)
279 .column(worker::Column::Host)
280 .column(worker::Column::Port)
281 .filter(worker::Column::WorkerId.is_in(worker_to_delete.clone()))
282 .into_tuple::<(WorkerId, WorkerType, String, i32)>()
283 .all(&inner.db)
284 .await
285 {
286 Ok(keys) => keys,
287 Err(err) => {
288 tracing::warn!(error = %err.as_report(), "Failed to load expire worker info from db");
289 continue;
290 }
291 };
292 drop(inner);
293
294 for (worker_id, worker_type, host, port) in worker_infos {
295 let host_addr = PbHostAddress { host, port };
296 match cluster_controller.delete_worker(host_addr.clone()).await {
297 Ok(_) => {
298 tracing::warn!(
299 worker_id,
300 ?host_addr,
301 %now,
302 "Deleted expired worker"
303 );
304 match worker_type {
305 WorkerType::Frontend
306 | WorkerType::ComputeNode
307 | WorkerType::Compactor
308 | WorkerType::RiseCtl => cluster_controller
309 .env
310 .notification_manager()
311 .delete_sender(worker_type.into(), WorkerKey(host_addr)),
312 _ => {}
313 };
314 }
315 Err(err) => {
316 tracing::warn!(error = %err.as_report(), "Failed to delete expire worker from db");
317 }
318 }
319 }
320 }
321 });
322
323 (join_handle, shutdown_tx)
324 }
325
326 pub async fn list_workers(
331 &self,
332 worker_type: Option<WorkerType>,
333 worker_status: Option<WorkerStatus>,
334 ) -> MetaResult<Vec<PbWorkerNode>> {
335 let mut workers = vec![];
336 if worker_type.is_none() {
338 workers.push(meta_node_info(
339 &self.env.opts.advertise_addr,
340 Some(self.started_at),
341 ));
342 }
343 workers.extend(
344 self.inner
345 .read()
346 .await
347 .list_workers(worker_type, worker_status)
348 .await?,
349 );
350 Ok(workers)
351 }
352
353 pub(crate) async fn subscribe_active_streaming_compute_nodes(
354 &self,
355 ) -> MetaResult<(Vec<WorkerNode>, UnboundedReceiver<LocalNotification>)> {
356 let inner = self.inner.read().await;
357 let worker_nodes = inner.list_active_streaming_workers().await?;
358 let (tx, rx) = unbounded_channel();
359
360 self.env.notification_manager().insert_local_sender(tx);
362 drop(inner);
363 Ok((worker_nodes, rx))
364 }
365
366 pub async fn list_active_streaming_workers(&self) -> MetaResult<Vec<PbWorkerNode>> {
369 self.inner
370 .read()
371 .await
372 .list_active_streaming_workers()
373 .await
374 }
375
376 pub async fn list_active_worker_slots(&self) -> MetaResult<Vec<WorkerSlotId>> {
377 self.inner.read().await.list_active_worker_slots().await
378 }
379
380 pub async fn list_active_serving_workers(&self) -> MetaResult<Vec<PbWorkerNode>> {
383 self.inner.read().await.list_active_serving_workers().await
384 }
385
386 pub async fn get_streaming_cluster_info(&self) -> MetaResult<StreamingClusterInfo> {
388 self.inner.read().await.get_streaming_cluster_info().await
389 }
390
391 pub async fn get_worker_by_id(&self, worker_id: WorkerId) -> MetaResult<Option<PbWorkerNode>> {
392 self.inner.read().await.get_worker_by_id(worker_id).await
393 }
394
395 pub async fn get_worker_info_by_id(&self, worker_id: WorkerId) -> Option<WorkerExtraInfo> {
396 self.inner
397 .read()
398 .await
399 .get_worker_extra_info_by_id(worker_id)
400 }
401
402 pub fn cluster_id(&self) -> &ClusterId {
403 self.env.cluster_id()
404 }
405
406 pub fn meta_store_endpoint(&self) -> String {
407 self.env.meta_store_ref().endpoint.clone()
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct StreamingClusterInfo {
414 pub worker_nodes: HashMap<u32, WorkerNode>,
416
417 pub schedulable_workers: HashSet<u32>,
419
420 pub unschedulable_workers: HashSet<u32>,
422}
423
424impl StreamingClusterInfo {
426 pub fn parallelism(&self, resource_group: &str) -> usize {
427 let available_worker_ids =
428 filter_workers_by_resource_group(&self.worker_nodes, resource_group);
429
430 self.worker_nodes
431 .values()
432 .filter(|worker| available_worker_ids.contains(&(worker.id as WorkerId)))
433 .map(|worker| worker.compute_node_parallelism())
434 .sum()
435 }
436
437 pub fn filter_schedulable_workers_by_resource_group(
438 &self,
439 resource_group: &str,
440 ) -> HashMap<u32, WorkerNode> {
441 let worker_ids = filter_workers_by_resource_group(&self.worker_nodes, resource_group);
442 self.worker_nodes
443 .iter()
444 .filter(|(id, _)| worker_ids.contains(&(**id as WorkerId)))
445 .map(|(id, worker)| (*id, worker.clone()))
446 .collect()
447 }
448}
449
450#[derive(Default, Clone)]
451pub struct WorkerExtraInfo {
452 expire_at: Option<u64>,
456 started_at: Option<u64>,
457 resource: PbResource,
458 r#type: PbWorkerType,
459}
460
461impl WorkerExtraInfo {
462 fn update_ttl(&mut self, ttl: Duration) {
463 let expire = cmp::max(
464 self.expire_at.unwrap_or_default(),
465 SystemTime::now()
466 .add(ttl)
467 .duration_since(SystemTime::UNIX_EPOCH)
468 .expect("Clock may have gone backwards")
469 .as_secs(),
470 );
471 self.expire_at = Some(expire);
472 }
473
474 fn update_started_at(&mut self) {
475 self.started_at = Some(timestamp_now_sec());
476 }
477}
478
479fn timestamp_now_sec() -> u64 {
480 SystemTime::now()
481 .duration_since(SystemTime::UNIX_EPOCH)
482 .expect("Clock may have gone backwards")
483 .as_secs()
484}
485
486fn meta_node_info(host: &str, started_at: Option<u64>) -> PbWorkerNode {
487 PbWorkerNode {
488 id: META_NODE_ID,
489 r#type: PbWorkerType::Meta.into(),
490 host: HostAddr::try_from(host)
491 .as_ref()
492 .map(HostAddr::to_protobuf)
493 .ok(),
494 state: PbState::Running as _,
495 property: None,
496 transactional_id: None,
497 resource: Some(risingwave_pb::common::worker_node::Resource {
498 rw_version: RW_VERSION.to_owned(),
499 total_memory_bytes: system_memory_available_bytes() as _,
500 total_cpu_cores: total_cpu_available() as _,
501 }),
502 started_at,
503 }
504}
505
506pub struct ClusterControllerInner {
507 db: DatabaseConnection,
508 available_transactional_ids: VecDeque<TransactionId>,
510 worker_extra_info: HashMap<WorkerId, WorkerExtraInfo>,
511 disable_automatic_parallelism_control: bool,
512}
513
514impl ClusterControllerInner {
515 pub const MAX_WORKER_REUSABLE_ID_BITS: usize = 10;
516 pub const MAX_WORKER_REUSABLE_ID_COUNT: usize = 1 << Self::MAX_WORKER_REUSABLE_ID_BITS;
517
518 pub async fn new(
519 db: DatabaseConnection,
520 disable_automatic_parallelism_control: bool,
521 ) -> MetaResult<Self> {
522 let workers: Vec<(WorkerId, Option<TransactionId>)> = Worker::find()
523 .select_only()
524 .column(worker::Column::WorkerId)
525 .column(worker::Column::TransactionId)
526 .into_tuple()
527 .all(&db)
528 .await?;
529 let inuse_txn_ids: HashSet<_> = workers
530 .iter()
531 .cloned()
532 .filter_map(|(_, txn_id)| txn_id)
533 .collect();
534 let available_transactional_ids = (0..Self::MAX_WORKER_REUSABLE_ID_COUNT as TransactionId)
535 .filter(|id| !inuse_txn_ids.contains(id))
536 .collect();
537
538 let worker_extra_info = workers
539 .into_iter()
540 .map(|(w, _)| (w, WorkerExtraInfo::default()))
541 .collect();
542
543 Ok(Self {
544 db,
545 available_transactional_ids,
546 worker_extra_info,
547 disable_automatic_parallelism_control,
548 })
549 }
550
551 pub async fn count_worker_by_type(&self) -> MetaResult<HashMap<WorkerType, i64>> {
552 let workers: Vec<(WorkerType, i64)> = Worker::find()
553 .select_only()
554 .column(worker::Column::WorkerType)
555 .column_as(worker::Column::WorkerId.count(), "count")
556 .group_by(worker::Column::WorkerType)
557 .into_tuple()
558 .all(&self.db)
559 .await?;
560
561 Ok(workers.into_iter().collect())
562 }
563
564 pub fn update_worker_ttl(&mut self, worker_id: WorkerId, ttl: Duration) -> MetaResult<()> {
565 if let Some(info) = self.worker_extra_info.get_mut(&worker_id) {
566 let expire = cmp::max(
567 info.expire_at.unwrap_or_default(),
568 SystemTime::now()
569 .add(ttl)
570 .duration_since(SystemTime::UNIX_EPOCH)
571 .expect("Clock may have gone backwards")
572 .as_secs(),
573 );
574 info.expire_at = Some(expire);
575 Ok(())
576 } else {
577 Err(MetaError::invalid_worker(worker_id, "worker not found"))
578 }
579 }
580
581 fn update_resource_and_started_at(
582 &mut self,
583 worker_id: WorkerId,
584 resource: PbResource,
585 ) -> MetaResult<()> {
586 if let Some(info) = self.worker_extra_info.get_mut(&worker_id) {
587 info.resource = resource;
588 info.update_started_at();
589 Ok(())
590 } else {
591 Err(MetaError::invalid_worker(worker_id, "worker not found"))
592 }
593 }
594
595 fn get_extra_info_checked(&self, worker_id: WorkerId) -> MetaResult<WorkerExtraInfo> {
596 self.worker_extra_info
597 .get(&worker_id)
598 .cloned()
599 .ok_or_else(|| MetaError::invalid_worker(worker_id, "worker not found"))
600 }
601
602 fn apply_transaction_id(&self, r#type: PbWorkerType) -> MetaResult<Option<TransactionId>> {
603 match (self.available_transactional_ids.front(), r#type) {
604 (None, _) => Err(MetaError::unavailable("no available reusable machine id")),
605 (Some(id), PbWorkerType::ComputeNode | PbWorkerType::Frontend) => Ok(Some(*id)),
607 _ => Ok(None),
608 }
609 }
610
611 fn compute_node_total_cpu_count(&self) -> usize {
612 self.worker_extra_info
613 .values()
614 .filter(|info| info.r#type == PbWorkerType::ComputeNode)
615 .map(|info| info.resource.total_cpu_cores as usize)
616 .sum()
617 }
618
619 pub async fn add_worker(
620 &mut self,
621 r#type: PbWorkerType,
622 host_address: HostAddress,
623 add_property: AddNodeProperty,
624 resource: PbResource,
625 ttl: Duration,
626 ) -> MetaResult<WorkerId> {
627 let txn = self.db.begin().await?;
628
629 let worker = Worker::find()
630 .filter(
631 worker::Column::Host
632 .eq(host_address.host.clone())
633 .and(worker::Column::Port.eq(host_address.port)),
634 )
635 .find_also_related(WorkerProperty)
636 .one(&txn)
637 .await?;
638 if let Some((worker, property)) = worker {
640 assert_eq!(worker.worker_type, r#type.into());
641 return if worker.worker_type == WorkerType::ComputeNode {
642 let property = property.unwrap();
643 let mut current_parallelism = property.parallelism as usize;
644 let new_parallelism = add_property.parallelism as usize;
645 match new_parallelism.cmp(¤t_parallelism) {
646 Ordering::Less => {
647 if !self.disable_automatic_parallelism_control {
648 tracing::info!(
650 "worker {} parallelism reduced from {} to {}",
651 worker.worker_id,
652 current_parallelism,
653 new_parallelism
654 );
655 current_parallelism = new_parallelism;
656 } else {
657 tracing::warn!(
660 "worker {} parallelism is less than current, current is {}, but received {}",
661 worker.worker_id,
662 current_parallelism,
663 new_parallelism
664 );
665 }
666 }
667 Ordering::Greater => {
668 tracing::info!(
669 "worker {} parallelism updated from {} to {}",
670 worker.worker_id,
671 current_parallelism,
672 new_parallelism
673 );
674 current_parallelism = new_parallelism;
675 }
676 Ordering::Equal => {}
677 }
678 let mut property: worker_property::ActiveModel = property.into();
679
680 property.is_streaming = Set(add_property.is_streaming);
682 property.is_serving = Set(add_property.is_serving);
683 property.parallelism = Set(current_parallelism as _);
684 property.resource_group =
685 Set(Some(add_property.resource_group.unwrap_or_else(|| {
686 tracing::warn!(
687 "resource_group is not set for worker {}, fallback to `default`",
688 worker.worker_id
689 );
690 DEFAULT_RESOURCE_GROUP.to_owned()
691 })));
692
693 WorkerProperty::update(property).exec(&txn).await?;
694 txn.commit().await?;
695 self.update_worker_ttl(worker.worker_id, ttl)?;
696 self.update_resource_and_started_at(worker.worker_id, resource)?;
697 Ok(worker.worker_id)
698 } else if worker.worker_type == WorkerType::Frontend && property.is_none() {
699 let worker_property = worker_property::ActiveModel {
700 worker_id: Set(worker.worker_id),
701 parallelism: Set(add_property
702 .parallelism
703 .try_into()
704 .expect("invalid parallelism")),
705 is_streaming: Set(add_property.is_streaming),
706 is_serving: Set(add_property.is_serving),
707 is_unschedulable: Set(add_property.is_unschedulable),
708 internal_rpc_host_addr: Set(Some(add_property.internal_rpc_host_addr)),
709 resource_group: Set(None),
710 };
711 WorkerProperty::insert(worker_property).exec(&txn).await?;
712 txn.commit().await?;
713 self.update_worker_ttl(worker.worker_id, ttl)?;
714 self.update_resource_and_started_at(worker.worker_id, resource)?;
715 Ok(worker.worker_id)
716 } else {
717 self.update_worker_ttl(worker.worker_id, ttl)?;
718 self.update_resource_and_started_at(worker.worker_id, resource)?;
719 Ok(worker.worker_id)
720 };
721 }
722 let txn_id = self.apply_transaction_id(r#type)?;
723
724 let worker = worker::ActiveModel {
725 worker_id: Default::default(),
726 worker_type: Set(r#type.into()),
727 host: Set(host_address.host),
728 port: Set(host_address.port),
729 status: Set(WorkerStatus::Starting),
730 transaction_id: Set(txn_id),
731 };
732 let insert_res = Worker::insert(worker).exec(&txn).await?;
733 let worker_id = insert_res.last_insert_id as WorkerId;
734 if r#type == PbWorkerType::ComputeNode || r#type == PbWorkerType::Frontend {
735 let property = worker_property::ActiveModel {
736 worker_id: Set(worker_id),
737 parallelism: Set(add_property
738 .parallelism
739 .try_into()
740 .expect("invalid parallelism")),
741 is_streaming: Set(add_property.is_streaming),
742 is_serving: Set(add_property.is_serving),
743 is_unschedulable: Set(add_property.is_unschedulable),
744 internal_rpc_host_addr: Set(Some(add_property.internal_rpc_host_addr)),
745 resource_group: if r#type == PbWorkerType::ComputeNode {
746 Set(add_property.resource_group.clone())
747 } else {
748 Set(None)
749 },
750 };
751 WorkerProperty::insert(property).exec(&txn).await?;
752 }
753
754 txn.commit().await?;
755 if let Some(txn_id) = txn_id {
756 self.available_transactional_ids.retain(|id| *id != txn_id);
757 }
758 let extra_info = WorkerExtraInfo {
759 started_at: Some(timestamp_now_sec()),
760 expire_at: None,
761 resource,
762 r#type,
763 };
764 self.worker_extra_info.insert(worker_id, extra_info);
765
766 Ok(worker_id)
767 }
768
769 pub async fn activate_worker(&self, worker_id: WorkerId) -> MetaResult<PbWorkerNode> {
770 let worker = worker::ActiveModel {
771 worker_id: Set(worker_id),
772 status: Set(WorkerStatus::Running),
773 ..Default::default()
774 };
775
776 let worker = worker.update(&self.db).await?;
777 let worker_property = WorkerProperty::find_by_id(worker.worker_id)
778 .one(&self.db)
779 .await?;
780 let extra_info = self.get_extra_info_checked(worker_id)?;
781 Ok(WorkerInfo(worker, worker_property, extra_info).into())
782 }
783
784 pub async fn update_schedulability(
785 &self,
786 worker_ids: Vec<WorkerId>,
787 schedulability: Schedulability,
788 ) -> MetaResult<()> {
789 let is_unschedulable = schedulability == Schedulability::Unschedulable;
790 WorkerProperty::update_many()
791 .col_expr(
792 worker_property::Column::IsUnschedulable,
793 Expr::value(is_unschedulable),
794 )
795 .filter(worker_property::Column::WorkerId.is_in(worker_ids))
796 .exec(&self.db)
797 .await?;
798
799 Ok(())
800 }
801
802 pub async fn delete_worker(&mut self, host_addr: HostAddress) -> MetaResult<PbWorkerNode> {
803 let worker = Worker::find()
804 .filter(
805 worker::Column::Host
806 .eq(host_addr.host)
807 .and(worker::Column::Port.eq(host_addr.port)),
808 )
809 .find_also_related(WorkerProperty)
810 .one(&self.db)
811 .await?;
812 let Some((worker, property)) = worker else {
813 return Err(MetaError::invalid_parameter("worker not found!"));
814 };
815
816 let res = Worker::delete_by_id(worker.worker_id)
817 .exec(&self.db)
818 .await?;
819 if res.rows_affected == 0 {
820 return Err(MetaError::invalid_parameter("worker not found!"));
821 }
822
823 let extra_info = self.worker_extra_info.remove(&worker.worker_id).unwrap();
824 if let Some(txn_id) = &worker.transaction_id {
825 self.available_transactional_ids.push_back(*txn_id);
826 }
827 let worker: PbWorkerNode = WorkerInfo(worker, property, extra_info).into();
828
829 Ok(worker)
830 }
831
832 pub fn heartbeat(&mut self, worker_id: WorkerId, ttl: Duration) -> MetaResult<()> {
833 if let Some(worker_info) = self.worker_extra_info.get_mut(&worker_id) {
834 worker_info.update_ttl(ttl);
835 Ok(())
836 } else {
837 Err(MetaError::invalid_worker(worker_id, "worker not found"))
838 }
839 }
840
841 pub async fn list_workers(
842 &self,
843 worker_type: Option<WorkerType>,
844 worker_status: Option<WorkerStatus>,
845 ) -> MetaResult<Vec<PbWorkerNode>> {
846 let mut find = Worker::find();
847 if let Some(worker_type) = worker_type {
848 find = find.filter(worker::Column::WorkerType.eq(worker_type));
849 }
850 if let Some(worker_status) = worker_status {
851 find = find.filter(worker::Column::Status.eq(worker_status));
852 }
853 let workers = find.find_also_related(WorkerProperty).all(&self.db).await?;
854 Ok(workers
855 .into_iter()
856 .map(|(worker, property)| {
857 let extra_info = self.get_extra_info_checked(worker.worker_id).unwrap();
858 WorkerInfo(worker, property, extra_info).into()
859 })
860 .collect_vec())
861 }
862
863 pub async fn list_active_streaming_workers(&self) -> MetaResult<Vec<PbWorkerNode>> {
864 let workers = Worker::find()
865 .filter(
866 worker::Column::WorkerType
867 .eq(WorkerType::ComputeNode)
868 .and(worker::Column::Status.eq(WorkerStatus::Running)),
869 )
870 .inner_join(WorkerProperty)
871 .select_also(WorkerProperty)
872 .filter(worker_property::Column::IsStreaming.eq(true))
873 .all(&self.db)
874 .await?;
875
876 Ok(workers
877 .into_iter()
878 .map(|(worker, property)| {
879 let extra_info = self.get_extra_info_checked(worker.worker_id).unwrap();
880 WorkerInfo(worker, property, extra_info).into()
881 })
882 .collect_vec())
883 }
884
885 pub async fn list_active_worker_slots(&self) -> MetaResult<Vec<WorkerSlotId>> {
886 let worker_parallelisms: Vec<(WorkerId, i32)> = WorkerProperty::find()
887 .select_only()
888 .column(worker_property::Column::WorkerId)
889 .column(worker_property::Column::Parallelism)
890 .inner_join(Worker)
891 .filter(worker::Column::Status.eq(WorkerStatus::Running))
892 .into_tuple()
893 .all(&self.db)
894 .await?;
895 Ok(worker_parallelisms
896 .into_iter()
897 .flat_map(|(worker_id, parallelism)| {
898 (0..parallelism).map(move |idx| WorkerSlotId::new(worker_id as u32, idx as usize))
899 })
900 .collect_vec())
901 }
902
903 pub async fn list_active_serving_workers(&self) -> MetaResult<Vec<PbWorkerNode>> {
904 let workers = Worker::find()
905 .filter(
906 worker::Column::WorkerType
907 .eq(WorkerType::ComputeNode)
908 .and(worker::Column::Status.eq(WorkerStatus::Running)),
909 )
910 .inner_join(WorkerProperty)
911 .select_also(WorkerProperty)
912 .filter(worker_property::Column::IsServing.eq(true))
913 .all(&self.db)
914 .await?;
915
916 Ok(workers
917 .into_iter()
918 .map(|(worker, property)| {
919 let extra_info = self.get_extra_info_checked(worker.worker_id).unwrap();
920 WorkerInfo(worker, property, extra_info).into()
921 })
922 .collect_vec())
923 }
924
925 pub async fn get_streaming_cluster_info(&self) -> MetaResult<StreamingClusterInfo> {
926 let mut streaming_workers = self.list_active_streaming_workers().await?;
927
928 let unschedulable_workers: HashSet<_> = streaming_workers
929 .extract_if(.., |worker| {
930 worker.property.as_ref().is_some_and(|p| p.is_unschedulable)
931 })
932 .map(|w| w.id)
933 .collect();
934
935 let schedulable_workers = streaming_workers
936 .iter()
937 .map(|worker| worker.id)
938 .filter(|id| !unschedulable_workers.contains(id))
939 .collect();
940
941 let active_workers: HashMap<_, _> =
942 streaming_workers.into_iter().map(|w| (w.id, w)).collect();
943
944 Ok(StreamingClusterInfo {
945 worker_nodes: active_workers,
946 schedulable_workers,
947 unschedulable_workers,
948 })
949 }
950
951 pub async fn get_worker_by_id(&self, worker_id: WorkerId) -> MetaResult<Option<PbWorkerNode>> {
952 let worker = Worker::find_by_id(worker_id)
953 .find_also_related(WorkerProperty)
954 .one(&self.db)
955 .await?;
956 if worker.is_none() {
957 return Ok(None);
958 }
959 let extra_info = self.get_extra_info_checked(worker_id)?;
960 Ok(worker.map(|(w, p)| WorkerInfo(w, p, extra_info).into()))
961 }
962
963 pub fn get_worker_extra_info_by_id(&self, worker_id: WorkerId) -> Option<WorkerExtraInfo> {
964 self.worker_extra_info.get(&worker_id).cloned()
965 }
966}
967
968#[cfg(test)]
969mod tests {
970 use super::*;
971
972 fn mock_worker_hosts_for_test(count: usize) -> Vec<HostAddress> {
973 (0..count)
974 .map(|i| HostAddress {
975 host: "localhost".to_owned(),
976 port: 5000 + i as i32,
977 })
978 .collect_vec()
979 }
980
981 #[tokio::test]
982 async fn test_cluster_controller() -> MetaResult<()> {
983 let env = MetaSrvEnv::for_test().await;
984 let cluster_ctl = ClusterController::new(env, Duration::from_secs(1)).await?;
985
986 let parallelism_num = 4_usize;
987 let worker_count = 5_usize;
988 let property = AddNodeProperty {
989 parallelism: parallelism_num as _,
990 is_streaming: true,
991 is_serving: true,
992 is_unschedulable: false,
993 ..Default::default()
994 };
995 let hosts = mock_worker_hosts_for_test(worker_count);
996 let mut worker_ids = vec![];
997 for host in &hosts {
998 worker_ids.push(
999 cluster_ctl
1000 .add_worker(
1001 PbWorkerType::ComputeNode,
1002 host.clone(),
1003 property.clone(),
1004 PbResource::default(),
1005 )
1006 .await?,
1007 );
1008 }
1009
1010 assert_eq!(cluster_ctl.list_active_worker_slots().await?.len(), 0);
1012
1013 for id in &worker_ids {
1014 cluster_ctl.activate_worker(*id).await?;
1015 }
1016 let worker_cnt_map = cluster_ctl.count_worker_by_type().await?;
1017 assert_eq!(
1018 *worker_cnt_map.get(&WorkerType::ComputeNode).unwrap() as usize,
1019 worker_count
1020 );
1021 assert_eq!(
1022 cluster_ctl.list_active_streaming_workers().await?.len(),
1023 worker_count
1024 );
1025 assert_eq!(
1026 cluster_ctl.list_active_serving_workers().await?.len(),
1027 worker_count
1028 );
1029 assert_eq!(
1030 cluster_ctl.list_active_worker_slots().await?.len(),
1031 parallelism_num * worker_count
1032 );
1033
1034 let mut new_property = property.clone();
1036 new_property.parallelism = (parallelism_num * 2) as _;
1037 new_property.is_serving = false;
1038 cluster_ctl
1039 .add_worker(
1040 PbWorkerType::ComputeNode,
1041 hosts[0].clone(),
1042 new_property,
1043 PbResource::default(),
1044 )
1045 .await?;
1046
1047 assert_eq!(
1048 cluster_ctl.list_active_streaming_workers().await?.len(),
1049 worker_count
1050 );
1051 assert_eq!(
1052 cluster_ctl.list_active_serving_workers().await?.len(),
1053 worker_count - 1
1054 );
1055 let worker_slots = cluster_ctl.list_active_worker_slots().await?;
1056 assert!(worker_slots.iter().all_unique());
1057 assert_eq!(worker_slots.len(), parallelism_num * (worker_count + 1));
1058
1059 for host in hosts {
1061 cluster_ctl.delete_worker(host).await?;
1062 }
1063 assert_eq!(cluster_ctl.list_active_streaming_workers().await?.len(), 0);
1064 assert_eq!(cluster_ctl.list_active_serving_workers().await?.len(), 0);
1065 assert_eq!(cluster_ctl.list_active_worker_slots().await?.len(), 0);
1066
1067 Ok(())
1068 }
1069
1070 #[tokio::test]
1071 async fn test_update_schedulability() -> MetaResult<()> {
1072 let env = MetaSrvEnv::for_test().await;
1073 let cluster_ctl = ClusterController::new(env, Duration::from_secs(1)).await?;
1074
1075 let host = HostAddress {
1076 host: "localhost".to_owned(),
1077 port: 5001,
1078 };
1079 let mut property = AddNodeProperty {
1080 is_streaming: true,
1081 is_serving: true,
1082 is_unschedulable: false,
1083 parallelism: 4,
1084 ..Default::default()
1085 };
1086 let worker_id = cluster_ctl
1087 .add_worker(
1088 PbWorkerType::ComputeNode,
1089 host.clone(),
1090 property.clone(),
1091 PbResource::default(),
1092 )
1093 .await?;
1094
1095 cluster_ctl.activate_worker(worker_id).await?;
1096 cluster_ctl
1097 .update_schedulability(vec![worker_id], Schedulability::Unschedulable)
1098 .await?;
1099
1100 let workers = cluster_ctl.list_active_streaming_workers().await?;
1101 assert_eq!(workers.len(), 1);
1102 assert!(workers[0].property.as_ref().unwrap().is_unschedulable);
1103
1104 property.is_unschedulable = false;
1106 property.is_serving = false;
1107 let new_worker_id = cluster_ctl
1108 .add_worker(
1109 PbWorkerType::ComputeNode,
1110 host.clone(),
1111 property,
1112 PbResource::default(),
1113 )
1114 .await?;
1115 assert_eq!(worker_id, new_worker_id);
1116
1117 let workers = cluster_ctl.list_active_streaming_workers().await?;
1118 assert_eq!(workers.len(), 1);
1119 assert!(workers[0].property.as_ref().unwrap().is_unschedulable);
1120
1121 cluster_ctl.delete_worker(host).await?;
1122
1123 Ok(())
1124 }
1125}