risingwave_connector/source/mqtt/enumerator/
mod.rsuse std::collections::HashSet;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use async_trait::async_trait;
use risingwave_common::bail;
use rumqttc::v5::{ConnectionError, Event, Incoming};
use rumqttc::Outgoing;
use thiserror_ext::AsReport;
use tokio::sync::RwLock;
use super::source::MqttSplit;
use super::MqttProperties;
use crate::error::ConnectorResult;
use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
pub struct MqttSplitEnumerator {
#[expect(dead_code)]
topic: String,
#[expect(dead_code)]
client: rumqttc::v5::AsyncClient,
topics: Arc<RwLock<HashSet<String>>>,
connected: Arc<AtomicBool>,
stopped: Arc<AtomicBool>,
}
#[async_trait]
impl SplitEnumerator for MqttSplitEnumerator {
type Properties = MqttProperties;
type Split = MqttSplit;
async fn new(
properties: Self::Properties,
context: SourceEnumeratorContextRef,
) -> ConnectorResult<MqttSplitEnumerator> {
let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?;
let topic = properties.topic.clone();
let mut topics = HashSet::new();
if !topic.contains('#') && !topic.contains('+') {
topics.insert(topic.clone());
}
client
.subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce)
.await?;
let cloned_client = client.clone();
let topics = Arc::new(RwLock::new(topics));
let connected = Arc::new(AtomicBool::new(false));
let connected_clone = connected.clone();
let stopped = Arc::new(AtomicBool::new(false));
let stopped_clone = stopped.clone();
let topics_clone = topics.clone();
tokio::spawn(async move {
while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
match eventloop.poll().await {
Ok(Event::Outgoing(Outgoing::Subscribe(_))) => {
connected_clone.store(true, std::sync::atomic::Ordering::Relaxed);
}
Ok(Event::Incoming(Incoming::Publish(p))) => {
let topic = String::from_utf8_lossy(&p.topic).to_string();
let exist = {
let topics = topics_clone.read().await;
topics.contains(&topic)
};
if !exist {
let mut topics = topics_clone.write().await;
topics.insert(topic);
}
}
Ok(_) => {}
Err(err) => {
if let ConnectionError::Timeout(_) = err {
continue;
}
tracing::error!(
"Failed to subscribe to topic {}: {}",
topic,
err.as_report(),
);
connected_clone.store(false, std::sync::atomic::Ordering::Relaxed);
cloned_client
.subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce)
.await
.unwrap();
}
}
}
});
Ok(Self {
client,
topics,
topic: properties.topic,
connected,
stopped,
})
}
async fn list_splits(&mut self) -> ConnectorResult<Vec<MqttSplit>> {
if !self.connected.load(std::sync::atomic::Ordering::Relaxed) {
let start = std::time::Instant::now();
loop {
if self.connected.load(std::sync::atomic::Ordering::Relaxed) {
break;
}
if start.elapsed().as_secs() > 10 {
bail!("Failed to connect to mqtt broker");
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
}
let topics = self.topics.read().await;
Ok(topics.iter().cloned().map(MqttSplit::new).collect())
}
}
impl Drop for MqttSplitEnumerator {
fn drop(&mut self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
}