risingwave_hummock_test/
mock_notification_client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use risingwave_common::util::addr::HostAddr;
19use risingwave_common_service::{Channel, NotificationClient, ObserverError};
20use risingwave_meta::controller::cluster::ClusterControllerRef;
21use risingwave_meta::hummock::{HummockManager, HummockManagerRef};
22use risingwave_meta::manager::{MessageStatus, MetaSrvEnv, NotificationManagerRef, WorkerKey};
23use risingwave_pb::backup_service::MetaBackupManifestId;
24use risingwave_pb::hummock::WriteLimits;
25use risingwave_pb::meta::{MetaSnapshot, SubscribeResponse, SubscribeType};
26use tokio::sync::mpsc::UnboundedReceiver;
27
28pub struct MockNotificationClient {
29    addr: HostAddr,
30    notification_manager: NotificationManagerRef,
31    hummock_manager: HummockManagerRef,
32}
33
34impl MockNotificationClient {
35    pub fn new(
36        addr: HostAddr,
37        notification_manager: NotificationManagerRef,
38        hummock_manager: HummockManagerRef,
39    ) -> Self {
40        Self {
41            addr,
42            notification_manager,
43            hummock_manager,
44        }
45    }
46}
47
48#[async_trait::async_trait]
49impl NotificationClient for MockNotificationClient {
50    type Channel = TestChannel<SubscribeResponse>;
51
52    async fn subscribe(
53        &self,
54        subscribe_type: SubscribeType,
55    ) -> Result<Self::Channel, ObserverError> {
56        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
57
58        let worker_key = WorkerKey(self.addr.to_protobuf());
59        self.notification_manager
60            .insert_sender(subscribe_type, worker_key.clone(), tx.clone())
61            .await;
62
63        let hummock_version = self.hummock_manager.get_current_version().await;
64        let meta_snapshot = MetaSnapshot {
65            hummock_version: Some(hummock_version.into()),
66            version: Some(Default::default()),
67            meta_backup_manifest_id: Some(MetaBackupManifestId { id: 0 }),
68            hummock_write_limits: Some(WriteLimits {
69                write_limits: HashMap::new(),
70            }),
71            ..Default::default()
72        };
73
74        self.notification_manager
75            .notify_snapshot(worker_key, subscribe_type, meta_snapshot);
76
77        Ok(TestChannel(rx))
78    }
79}
80
81pub async fn get_notification_client_for_test(
82    env: MetaSrvEnv,
83    hummock_manager_ref: Arc<HummockManager>,
84    cluster_controller_ref: ClusterControllerRef,
85    worker_id: i32,
86) -> MockNotificationClient {
87    let worker_node = cluster_controller_ref
88        .get_worker_by_id(worker_id)
89        .await
90        .unwrap()
91        .unwrap();
92
93    MockNotificationClient::new(
94        worker_node.get_host().unwrap().into(),
95        env.notification_manager_ref(),
96        hummock_manager_ref,
97    )
98}
99
100pub struct TestChannel<T>(UnboundedReceiver<Result<T, MessageStatus>>);
101
102#[async_trait::async_trait]
103impl<T: Send + 'static> Channel for TestChannel<T> {
104    type Item = T;
105
106    async fn message(&mut self) -> Result<Option<T>, MessageStatus> {
107        match self.0.recv().await {
108            None => Ok(None),
109            Some(result) => result.map(|r| Some(r)),
110        }
111    }
112}