risingwave_connector/source/mqtt/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use async_trait::async_trait;
20use risingwave_common::bail;
21use rumqttc::Outgoing;
22use rumqttc::v5::{ConnectionError, Event, Incoming};
23use thiserror_ext::AsReport;
24use tokio::sync::RwLock;
25
26use super::MqttProperties;
27use super::source::MqttSplit;
28use crate::error::ConnectorResult;
29use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
30
31pub struct MqttSplitEnumerator {
32    #[expect(dead_code)]
33    topic: String,
34    #[expect(dead_code)]
35    client: rumqttc::v5::AsyncClient,
36    topics: Arc<RwLock<HashSet<String>>>,
37    connected: Arc<AtomicBool>,
38    stopped: Arc<AtomicBool>,
39}
40
41#[async_trait]
42impl SplitEnumerator for MqttSplitEnumerator {
43    type Properties = MqttProperties;
44    type Split = MqttSplit;
45
46    async fn new(
47        properties: Self::Properties,
48        context: SourceEnumeratorContextRef,
49    ) -> ConnectorResult<MqttSplitEnumerator> {
50        let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?;
51
52        let topic = properties.topic.clone();
53        let mut topics = HashSet::new();
54        if !topic.contains('#') && !topic.contains('+') {
55            topics.insert(topic.clone());
56        }
57
58        client
59            .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce)
60            .await?;
61
62        let topics = Arc::new(RwLock::new(topics));
63
64        let connected = Arc::new(AtomicBool::new(false));
65        let connected_clone = connected.clone();
66
67        let stopped = Arc::new(AtomicBool::new(false));
68        let stopped_clone = stopped.clone();
69
70        let topics_clone = topics.clone();
71        tokio::spawn(async move {
72            while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
73                match eventloop.poll().await {
74                    Ok(Event::Outgoing(Outgoing::Subscribe(_))) => {
75                        connected_clone.store(true, std::sync::atomic::Ordering::Relaxed);
76                    }
77                    Ok(Event::Incoming(Incoming::Publish(p))) => {
78                        let topic = String::from_utf8_lossy(&p.topic).to_string();
79                        let exist = {
80                            let topics = topics_clone.read().await;
81                            topics.contains(&topic)
82                        };
83
84                        if !exist {
85                            let mut topics = topics_clone.write().await;
86                            topics.insert(topic);
87                        }
88                    }
89                    Ok(_)
90                    | Err(ConnectionError::Timeout(_))
91                    | Err(ConnectionError::RequestsDone) => {}
92                    Err(err) => {
93                        tracing::error!(
94                            "Failed to fetch splits to topic {}: {}",
95                            topic,
96                            err.as_report(),
97                        );
98                        tokio::time::sleep(std::time::Duration::from_millis(500)).await
99                    }
100                }
101            }
102        });
103
104        Ok(Self {
105            client,
106            topics,
107            topic: properties.topic,
108            connected,
109            stopped,
110        })
111    }
112
113    async fn list_splits(&mut self) -> ConnectorResult<Vec<MqttSplit>> {
114        if !self.connected.load(std::sync::atomic::Ordering::Relaxed) {
115            let start = std::time::Instant::now();
116            loop {
117                if self.connected.load(std::sync::atomic::Ordering::Relaxed) {
118                    break;
119                }
120
121                if start.elapsed().as_secs() > 10 {
122                    bail!("Failed to connect to mqtt broker");
123                }
124
125                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
126            }
127        }
128
129        let topics = self.topics.read().await;
130        Ok(topics.iter().cloned().map(MqttSplit::new).collect())
131    }
132}
133
134impl Drop for MqttSplitEnumerator {
135    fn drop(&mut self) {
136        self.stopped
137            .store(true, std::sync::atomic::Ordering::Relaxed);
138    }
139}