risingwave_hummock_test/
mock_notification_client.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use risingwave_common::id::WorkerId;
19use risingwave_common::util::addr::HostAddr;
20use risingwave_common_service::{Channel, NotificationClient, ObserverError};
21use risingwave_meta::controller::cluster::ClusterControllerRef;
22use risingwave_meta::hummock::{HummockManager, HummockManagerRef};
23use risingwave_meta::manager::{MessageStatus, MetaSrvEnv, NotificationManagerRef, WorkerKey};
24use risingwave_pb::backup_service::MetaBackupManifestId;
25use risingwave_pb::hummock::WriteLimits;
26use risingwave_pb::meta::{MetaSnapshot, SubscribeResponse, SubscribeType};
27use tokio::sync::mpsc::UnboundedReceiver;
28
29pub struct MockNotificationClient {
30    addr: HostAddr,
31    notification_manager: NotificationManagerRef,
32    hummock_manager: HummockManagerRef,
33}
34
35impl MockNotificationClient {
36    pub fn new(
37        addr: HostAddr,
38        notification_manager: NotificationManagerRef,
39        hummock_manager: HummockManagerRef,
40    ) -> Self {
41        Self {
42            addr,
43            notification_manager,
44            hummock_manager,
45        }
46    }
47}
48
49#[async_trait::async_trait]
50impl NotificationClient for MockNotificationClient {
51    type Channel = TestChannel<SubscribeResponse>;
52
53    async fn subscribe(
54        &self,
55        subscribe_type: SubscribeType,
56    ) -> Result<Self::Channel, ObserverError> {
57        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
58
59        let worker_key = WorkerKey(self.addr.to_protobuf());
60        self.notification_manager
61            .insert_sender(subscribe_type, worker_key.clone(), tx.clone());
62
63        let hummock_version = self.hummock_manager.get_current_version().await;
64        let meta_snapshot = MetaSnapshot {
65            hummock_version: Some(hummock_version.into()),
66            version: Some(Default::default()),
67            meta_backup_manifest_id: Some(MetaBackupManifestId { id: 0 }),
68            hummock_write_limits: Some(WriteLimits {
69                write_limits: HashMap::new(),
70            }),
71            cluster_resource: Some(Default::default()),
72            ..Default::default()
73        };
74
75        self.notification_manager
76            .notify_snapshot(worker_key, subscribe_type, meta_snapshot);
77
78        Ok(TestChannel(rx))
79    }
80}
81
82pub async fn get_notification_client_for_test(
83    env: MetaSrvEnv,
84    hummock_manager_ref: Arc<HummockManager>,
85    cluster_controller_ref: ClusterControllerRef,
86    worker_id: WorkerId,
87) -> MockNotificationClient {
88    let worker_node = cluster_controller_ref
89        .get_worker_by_id(worker_id)
90        .await
91        .unwrap()
92        .unwrap();
93
94    MockNotificationClient::new(
95        worker_node.get_host().unwrap().into(),
96        env.notification_manager_ref(),
97        hummock_manager_ref,
98    )
99}
100
101pub struct TestChannel<T>(UnboundedReceiver<Result<T, MessageStatus>>);
102
103#[async_trait::async_trait]
104impl<T: Send + 'static> Channel for TestChannel<T> {
105    type Item = T;
106
107    async fn message(&mut self) -> Result<Option<T>, MessageStatus> {
108        match self.0.recv().await {
109            None => Ok(None),
110            Some(result) => result.map(|r| Some(r)),
111        }
112    }
113}