risingwave_hummock_test/
mock_notification_client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use risingwave_common::util::addr::HostAddr;
19use risingwave_common_service::{Channel, NotificationClient, ObserverError};
20use risingwave_meta::controller::cluster::ClusterControllerRef;
21use risingwave_meta::hummock::{HummockManager, HummockManagerRef};
22use risingwave_meta::manager::{MessageStatus, MetaSrvEnv, NotificationManagerRef, WorkerKey};
23use risingwave_pb::backup_service::MetaBackupManifestId;
24use risingwave_pb::hummock::WriteLimits;
25use risingwave_pb::meta::{MetaSnapshot, SubscribeResponse, SubscribeType};
26use tokio::sync::mpsc::UnboundedReceiver;
27
28pub struct MockNotificationClient {
29    addr: HostAddr,
30    notification_manager: NotificationManagerRef,
31    hummock_manager: HummockManagerRef,
32}
33
34impl MockNotificationClient {
35    pub fn new(
36        addr: HostAddr,
37        notification_manager: NotificationManagerRef,
38        hummock_manager: HummockManagerRef,
39    ) -> Self {
40        Self {
41            addr,
42            notification_manager,
43            hummock_manager,
44        }
45    }
46}
47
48#[async_trait::async_trait]
49impl NotificationClient for MockNotificationClient {
50    type Channel = TestChannel<SubscribeResponse>;
51
52    async fn subscribe(
53        &self,
54        subscribe_type: SubscribeType,
55    ) -> Result<Self::Channel, ObserverError> {
56        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
57
58        let worker_key = WorkerKey(self.addr.to_protobuf());
59        self.notification_manager
60            .insert_sender(subscribe_type, worker_key.clone(), tx.clone());
61
62        let hummock_version = self.hummock_manager.get_current_version().await;
63        let meta_snapshot = MetaSnapshot {
64            hummock_version: Some(hummock_version.into()),
65            version: Some(Default::default()),
66            meta_backup_manifest_id: Some(MetaBackupManifestId { id: 0 }),
67            hummock_write_limits: Some(WriteLimits {
68                write_limits: HashMap::new(),
69            }),
70            ..Default::default()
71        };
72
73        self.notification_manager
74            .notify_snapshot(worker_key, subscribe_type, meta_snapshot);
75
76        Ok(TestChannel(rx))
77    }
78}
79
80pub async fn get_notification_client_for_test(
81    env: MetaSrvEnv,
82    hummock_manager_ref: Arc<HummockManager>,
83    cluster_controller_ref: ClusterControllerRef,
84    worker_id: i32,
85) -> MockNotificationClient {
86    let worker_node = cluster_controller_ref
87        .get_worker_by_id(worker_id)
88        .await
89        .unwrap()
90        .unwrap();
91
92    MockNotificationClient::new(
93        worker_node.get_host().unwrap().into(),
94        env.notification_manager_ref(),
95        hummock_manager_ref,
96    )
97}
98
99pub struct TestChannel<T>(UnboundedReceiver<Result<T, MessageStatus>>);
100
101#[async_trait::async_trait]
102impl<T: Send + 'static> Channel for TestChannel<T> {
103    type Item = T;
104
105    async fn message(&mut self) -> Result<Option<T>, MessageStatus> {
106        match self.0.recv().await {
107            None => Ok(None),
108            Some(result) => result.map(|r| Some(r)),
109        }
110    }
111}