risingwave_connector/source/mqtt/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::sync::atomic::AtomicBool;
17
18use async_trait::async_trait;
19use risingwave_common::bail;
20use rumqttc::v5::{ConnectionError, Event, Incoming};
21use thiserror_ext::AsReport;
22
23use super::MqttProperties;
24use super::source::MqttSplit;
25use crate::error::ConnectorResult;
26use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
27
28pub struct MqttSplitEnumerator {
29    topic: String,
30    #[expect(dead_code)]
31    client: rumqttc::v5::AsyncClient,
32    connected: Arc<AtomicBool>,
33    stopped: Arc<AtomicBool>,
34}
35
36#[async_trait]
37impl SplitEnumerator for MqttSplitEnumerator {
38    type Properties = MqttProperties;
39    type Split = MqttSplit;
40
41    async fn new(
42        properties: Self::Properties,
43        context: SourceEnumeratorContextRef,
44    ) -> ConnectorResult<MqttSplitEnumerator> {
45        let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?;
46        let topic = properties.topic.clone();
47
48        let connected = Arc::new(AtomicBool::new(false));
49        let connected_clone = connected.clone();
50
51        let stopped = Arc::new(AtomicBool::new(false));
52        let stopped_clone = stopped.clone();
53
54        tokio::spawn(async move {
55            while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
56                match eventloop.poll().await {
57                    Ok(Event::Incoming(Incoming::ConnAck(_))) => {
58                        connected_clone.store(true, std::sync::atomic::Ordering::Relaxed);
59                    }
60                    Ok(_)
61                    | Err(ConnectionError::Timeout(_))
62                    | Err(ConnectionError::RequestsDone) => {}
63                    Err(err) => {
64                        tracing::error!(
65                            "Failed to fetch splits to topic {}: {}",
66                            topic,
67                            err.as_report(),
68                        );
69                        tokio::time::sleep(std::time::Duration::from_millis(500)).await
70                    }
71                }
72            }
73        });
74        Ok(Self {
75            topic: properties.topic,
76            client,
77            connected,
78            stopped,
79        })
80    }
81
82    async fn list_splits(&mut self) -> ConnectorResult<Vec<MqttSplit>> {
83        if !self.connected.load(std::sync::atomic::Ordering::Relaxed) {
84            let start = std::time::Instant::now();
85            loop {
86                if self.connected.load(std::sync::atomic::Ordering::Relaxed) {
87                    break;
88                }
89
90                if start.elapsed().as_secs() > 10 {
91                    bail!("Failed to connect to mqtt broker");
92                }
93
94                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
95            }
96        }
97        Ok(vec![MqttSplit::new(self.topic.clone())])
98    }
99}
100
101impl Drop for MqttSplitEnumerator {
102    fn drop(&mut self) {
103        self.stopped
104            .store(true, std::sync::atomic::Ordering::Relaxed);
105    }
106}