risingwave_connector/source/mqtt/enumerator/
mod.rs1use std::collections::HashSet;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use async_trait::async_trait;
20use risingwave_common::bail;
21use rumqttc::Outgoing;
22use rumqttc::v5::{ConnectionError, Event, Incoming};
23use thiserror_ext::AsReport;
24use tokio::sync::RwLock;
25
26use super::MqttProperties;
27use super::source::MqttSplit;
28use crate::error::ConnectorResult;
29use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
30
31pub struct MqttSplitEnumerator {
32 #[expect(dead_code)]
33 topic: String,
34 #[expect(dead_code)]
35 client: rumqttc::v5::AsyncClient,
36 topics: Arc<RwLock<HashSet<String>>>,
37 connected: Arc<AtomicBool>,
38 stopped: Arc<AtomicBool>,
39}
40
41#[async_trait]
42impl SplitEnumerator for MqttSplitEnumerator {
43 type Properties = MqttProperties;
44 type Split = MqttSplit;
45
46 async fn new(
47 properties: Self::Properties,
48 context: SourceEnumeratorContextRef,
49 ) -> ConnectorResult<MqttSplitEnumerator> {
50 let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?;
51
52 let topic = properties.topic.clone();
53 let mut topics = HashSet::new();
54 if !topic.contains('#') && !topic.contains('+') {
55 topics.insert(topic.clone());
56 }
57
58 client
59 .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce)
60 .await?;
61
62 let cloned_client = client.clone();
63
64 let topics = Arc::new(RwLock::new(topics));
65
66 let connected = Arc::new(AtomicBool::new(false));
67 let connected_clone = connected.clone();
68
69 let stopped = Arc::new(AtomicBool::new(false));
70 let stopped_clone = stopped.clone();
71
72 let topics_clone = topics.clone();
73 tokio::spawn(async move {
74 while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
75 match eventloop.poll().await {
76 Ok(Event::Outgoing(Outgoing::Subscribe(_))) => {
77 connected_clone.store(true, std::sync::atomic::Ordering::Relaxed);
78 }
79 Ok(Event::Incoming(Incoming::Publish(p))) => {
80 let topic = String::from_utf8_lossy(&p.topic).to_string();
81 let exist = {
82 let topics = topics_clone.read().await;
83 topics.contains(&topic)
84 };
85
86 if !exist {
87 let mut topics = topics_clone.write().await;
88 topics.insert(topic);
89 }
90 }
91 Ok(_) => {}
92 Err(err) => {
93 if let ConnectionError::Timeout(_) = err {
94 continue;
95 }
96 tracing::error!(
97 "Failed to subscribe to topic {}: {}",
98 topic,
99 err.as_report(),
100 );
101 connected_clone.store(false, std::sync::atomic::Ordering::Relaxed);
102 cloned_client
103 .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce)
104 .await
105 .unwrap();
106 }
107 }
108 }
109 });
110
111 Ok(Self {
112 client,
113 topics,
114 topic: properties.topic,
115 connected,
116 stopped,
117 })
118 }
119
120 async fn list_splits(&mut self) -> ConnectorResult<Vec<MqttSplit>> {
121 if !self.connected.load(std::sync::atomic::Ordering::Relaxed) {
122 let start = std::time::Instant::now();
123 loop {
124 if self.connected.load(std::sync::atomic::Ordering::Relaxed) {
125 break;
126 }
127
128 if start.elapsed().as_secs() > 10 {
129 bail!("Failed to connect to mqtt broker");
130 }
131
132 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
133 }
134 }
135
136 let topics = self.topics.read().await;
137 Ok(topics.iter().cloned().map(MqttSplit::new).collect())
138 }
139}
140
141impl Drop for MqttSplitEnumerator {
142 fn drop(&mut self) {
143 self.stopped
144 .store(true, std::sync::atomic::Ordering::Relaxed);
145 }
146}