risingwave_connector/source/mqtt/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use async_trait::async_trait;
20use risingwave_common::bail;
21use rumqttc::Outgoing;
22use rumqttc::v5::{ConnectionError, Event, Incoming};
23use thiserror_ext::AsReport;
24use tokio::sync::RwLock;
25
26use super::MqttProperties;
27use super::source::MqttSplit;
28use crate::error::ConnectorResult;
29use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
30
31pub struct MqttSplitEnumerator {
32    #[expect(dead_code)]
33    topic: String,
34    #[expect(dead_code)]
35    client: rumqttc::v5::AsyncClient,
36    topics: Arc<RwLock<HashSet<String>>>,
37    connected: Arc<AtomicBool>,
38    stopped: Arc<AtomicBool>,
39}
40
41#[async_trait]
42impl SplitEnumerator for MqttSplitEnumerator {
43    type Properties = MqttProperties;
44    type Split = MqttSplit;
45
46    async fn new(
47        properties: Self::Properties,
48        context: SourceEnumeratorContextRef,
49    ) -> ConnectorResult<MqttSplitEnumerator> {
50        let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?;
51
52        let topic = properties.topic.clone();
53        let mut topics = HashSet::new();
54        if !topic.contains('#') && !topic.contains('+') {
55            topics.insert(topic.clone());
56        }
57
58        client
59            .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce)
60            .await?;
61
62        let cloned_client = client.clone();
63
64        let topics = Arc::new(RwLock::new(topics));
65
66        let connected = Arc::new(AtomicBool::new(false));
67        let connected_clone = connected.clone();
68
69        let stopped = Arc::new(AtomicBool::new(false));
70        let stopped_clone = stopped.clone();
71
72        let topics_clone = topics.clone();
73        tokio::spawn(async move {
74            while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
75                match eventloop.poll().await {
76                    Ok(Event::Outgoing(Outgoing::Subscribe(_))) => {
77                        connected_clone.store(true, std::sync::atomic::Ordering::Relaxed);
78                    }
79                    Ok(Event::Incoming(Incoming::Publish(p))) => {
80                        let topic = String::from_utf8_lossy(&p.topic).to_string();
81                        let exist = {
82                            let topics = topics_clone.read().await;
83                            topics.contains(&topic)
84                        };
85
86                        if !exist {
87                            let mut topics = topics_clone.write().await;
88                            topics.insert(topic);
89                        }
90                    }
91                    Ok(_) => {}
92                    Err(err) => {
93                        if let ConnectionError::Timeout(_) = err {
94                            continue;
95                        }
96                        tracing::error!(
97                            "Failed to subscribe to topic {}: {}",
98                            topic,
99                            err.as_report(),
100                        );
101                        connected_clone.store(false, std::sync::atomic::Ordering::Relaxed);
102                        cloned_client
103                            .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce)
104                            .await
105                            .unwrap();
106                    }
107                }
108            }
109        });
110
111        Ok(Self {
112            client,
113            topics,
114            topic: properties.topic,
115            connected,
116            stopped,
117        })
118    }
119
120    async fn list_splits(&mut self) -> ConnectorResult<Vec<MqttSplit>> {
121        if !self.connected.load(std::sync::atomic::Ordering::Relaxed) {
122            let start = std::time::Instant::now();
123            loop {
124                if self.connected.load(std::sync::atomic::Ordering::Relaxed) {
125                    break;
126                }
127
128                if start.elapsed().as_secs() > 10 {
129                    bail!("Failed to connect to mqtt broker");
130                }
131
132                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
133            }
134        }
135
136        let topics = self.topics.read().await;
137        Ok(topics.iter().cloned().map(MqttSplit::new).collect())
138    }
139}
140
141impl Drop for MqttSplitEnumerator {
142    fn drop(&mut self) {
143        self.stopped
144            .store(true, std::sync::atomic::Ordering::Relaxed);
145    }
146}