risingwave_connector/source/mqtt/enumerator/
mod.rs1use std::collections::HashSet;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use async_trait::async_trait;
20use risingwave_common::bail;
21use rumqttc::Outgoing;
22use rumqttc::v5::{ConnectionError, Event, Incoming};
23use thiserror_ext::AsReport;
24use tokio::sync::RwLock;
25
26use super::MqttProperties;
27use super::source::MqttSplit;
28use crate::error::ConnectorResult;
29use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
30
31pub struct MqttSplitEnumerator {
32 #[expect(dead_code)]
33 topic: String,
34 #[expect(dead_code)]
35 client: rumqttc::v5::AsyncClient,
36 topics: Arc<RwLock<HashSet<String>>>,
37 connected: Arc<AtomicBool>,
38 stopped: Arc<AtomicBool>,
39}
40
41#[async_trait]
42impl SplitEnumerator for MqttSplitEnumerator {
43 type Properties = MqttProperties;
44 type Split = MqttSplit;
45
46 async fn new(
47 properties: Self::Properties,
48 context: SourceEnumeratorContextRef,
49 ) -> ConnectorResult<MqttSplitEnumerator> {
50 let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?;
51
52 let topic = properties.topic.clone();
53 let mut topics = HashSet::new();
54 if !topic.contains('#') && !topic.contains('+') {
55 topics.insert(topic.clone());
56 }
57
58 client
59 .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce)
60 .await?;
61
62 let topics = Arc::new(RwLock::new(topics));
63
64 let connected = Arc::new(AtomicBool::new(false));
65 let connected_clone = connected.clone();
66
67 let stopped = Arc::new(AtomicBool::new(false));
68 let stopped_clone = stopped.clone();
69
70 let topics_clone = topics.clone();
71 tokio::spawn(async move {
72 while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
73 match eventloop.poll().await {
74 Ok(Event::Outgoing(Outgoing::Subscribe(_))) => {
75 connected_clone.store(true, std::sync::atomic::Ordering::Relaxed);
76 }
77 Ok(Event::Incoming(Incoming::Publish(p))) => {
78 let topic = String::from_utf8_lossy(&p.topic).to_string();
79 let exist = {
80 let topics = topics_clone.read().await;
81 topics.contains(&topic)
82 };
83
84 if !exist {
85 let mut topics = topics_clone.write().await;
86 topics.insert(topic);
87 }
88 }
89 Ok(_)
90 | Err(ConnectionError::Timeout(_))
91 | Err(ConnectionError::RequestsDone) => {}
92 Err(err) => {
93 tracing::error!(
94 "Failed to fetch splits to topic {}: {}",
95 topic,
96 err.as_report(),
97 );
98 tokio::time::sleep(std::time::Duration::from_millis(500)).await
99 }
100 }
101 }
102 });
103
104 Ok(Self {
105 client,
106 topics,
107 topic: properties.topic,
108 connected,
109 stopped,
110 })
111 }
112
113 async fn list_splits(&mut self) -> ConnectorResult<Vec<MqttSplit>> {
114 if !self.connected.load(std::sync::atomic::Ordering::Relaxed) {
115 let start = std::time::Instant::now();
116 loop {
117 if self.connected.load(std::sync::atomic::Ordering::Relaxed) {
118 break;
119 }
120
121 if start.elapsed().as_secs() > 10 {
122 bail!("Failed to connect to mqtt broker");
123 }
124
125 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
126 }
127 }
128
129 let topics = self.topics.read().await;
130 Ok(topics.iter().cloned().map(MqttSplit::new).collect())
131 }
132}
133
134impl Drop for MqttSplitEnumerator {
135 fn drop(&mut self) {
136 self.stopped
137 .store(true, std::sync::atomic::Ordering::Relaxed);
138 }
139}