risingwave_connector/source/mqtt/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::{AtomicBool, Ordering};
16use std::sync::{Arc, LazyLock, Weak};
17
18use async_trait::async_trait;
19use moka::future::Cache as MokaCache;
20use moka::ops::compute::Op;
21use risingwave_common::bail;
22use rumqttc::v5::{AsyncClient, ConnectionError, Event, EventLoop, Incoming};
23use thiserror_ext::AsReport;
24
25use super::MqttProperties;
26use super::source::MqttSplit;
27use crate::error::{ConnectorError, ConnectorResult};
28use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
29
30/// Consumer client is shared, and the cache doesn't manage the lifecycle, so we store `Weak` and no eviction.
31static SHARED_MQTT_CLIENT: LazyLock<MokaCache<String, Weak<MqttConnectionCheck>>> =
32    LazyLock::new(|| moka::future::Cache::builder().build());
33
34pub struct MqttSplitEnumerator {
35    topic: String,
36    broker: String,
37    connection_check: Arc<MqttConnectionCheck>,
38}
39
40struct MqttConnectionCheck {
41    #[expect(dead_code)]
42    client: AsyncClient,
43    connected: Arc<AtomicBool>,
44    stopped: Arc<AtomicBool>,
45}
46
47impl MqttConnectionCheck {
48    fn new(client: AsyncClient, event_loop: EventLoop, topic: String) -> Self {
49        let this = Self {
50            client,
51            connected: Arc::new(AtomicBool::new(false)),
52            stopped: Arc::new(AtomicBool::new(false)),
53        };
54        this.spawn_client_loop(event_loop, topic);
55        this
56    }
57
58    fn is_connected(&self) -> bool {
59        self.connected.load(Ordering::Relaxed)
60    }
61
62    fn spawn_client_loop(&self, mut event_loop: EventLoop, topic: String) {
63        let connected_clone = self.connected.clone();
64        let stopped_clone = self.stopped.clone();
65        tokio::spawn(async move {
66            while !stopped_clone.load(Ordering::Relaxed) {
67                match event_loop.poll().await {
68                    Ok(Event::Incoming(Incoming::ConnAck(_))) => {
69                        // Atomic operation that sets connected to true if it is currently false
70                        connected_clone
71                            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
72                            .ok();
73                    }
74                    Ok(_)
75                    | Err(ConnectionError::Timeout(_))
76                    | Err(ConnectionError::RequestsDone) => {}
77                    Err(err) => {
78                        tracing::error!(
79                            "Failed to fetch splits to topic {}: {}",
80                            topic,
81                            err.as_report(),
82                        );
83                        tokio::time::sleep(std::time::Duration::from_millis(500)).await
84                    }
85                }
86            }
87        });
88    }
89}
90
91impl Drop for MqttConnectionCheck {
92    fn drop(&mut self) {
93        self.stopped
94            .store(true, std::sync::atomic::Ordering::Relaxed);
95    }
96}
97
98#[async_trait]
99impl SplitEnumerator for MqttSplitEnumerator {
100    type Properties = MqttProperties;
101    type Split = MqttSplit;
102
103    async fn new(
104        properties: Self::Properties,
105        context: SourceEnumeratorContextRef,
106    ) -> ConnectorResult<MqttSplitEnumerator> {
107        let broker_url = properties.common.url.clone();
108        let mut connection_check: Option<Arc<MqttConnectionCheck>> = None;
109
110        SHARED_MQTT_CLIENT
111            .entry_by_ref(&properties.common.url)
112            .and_try_compute_with::<_, _, ConnectorError>(|entry| async {
113                if let Some(cached) = entry.and_then(|e| e.into_value().upgrade()) {
114                    // return if the client is already built
115                    tracing::debug!("reuse existing mqtt client for {}", broker_url);
116                    connection_check = Some(cached);
117                    Ok(Op::Nop)
118                } else {
119                    tracing::debug!("build new mqtt client for {}", broker_url);
120                    let (new_client, event_loop) =
121                        properties.common.build_client(context.info.source_id, 0)?;
122                    let new_connection_check = Arc::new(MqttConnectionCheck::new(
123                        new_client,
124                        event_loop,
125                        properties.topic.clone(),
126                    ));
127                    connection_check = Some(new_connection_check.clone());
128                    Ok(Op::Put(Arc::downgrade(&new_connection_check)))
129                }
130            })
131            .await?;
132
133        // connection_check always gets created if it doesn't exist
134        let connection_check = connection_check.unwrap();
135
136        Ok(Self {
137            topic: properties.topic,
138            broker: broker_url,
139            connection_check,
140        })
141    }
142
143    async fn list_splits(&mut self) -> ConnectorResult<Vec<MqttSplit>> {
144        if !self.connection_check.is_connected() {
145            let start = std::time::Instant::now();
146            loop {
147                if self.connection_check.is_connected() {
148                    break;
149                };
150                if start.elapsed().as_secs() > 10 {
151                    bail!("Failed to connect to mqtt broker");
152                }
153
154                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
155            }
156        }
157        tracing::debug!("found new splits {} for broker {}", self.topic, self.broker);
158        Ok(vec![MqttSplit::new(self.topic.clone())])
159    }
160}