risingwave_connector/connector_common/
mqtt_common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use rumqttc::tokio_rustls::rustls;
16use rumqttc::v5::mqttbytes::QoS;
17use rumqttc::v5::mqttbytes::v5::ConnectProperties;
18use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions};
19use serde_derive::Deserialize;
20use serde_with::{DisplayFromStr, serde_as};
21use strum_macros::{Display, EnumString};
22use with_options::WithOptions;
23
24use super::common::{load_certs, load_private_key};
25use crate::deserialize_bool_from_string;
26use crate::error::ConnectorResult;
27
28#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
29#[strum(serialize_all = "snake_case")]
30#[allow(clippy::enum_variant_names)]
31pub enum QualityOfService {
32    AtLeastOnce,
33    AtMostOnce,
34    ExactlyOnce,
35}
36
37#[serde_as]
38#[derive(Deserialize, Debug, Clone, WithOptions)]
39pub struct MqttCommon {
40    /// The url of the broker to connect to. e.g. tcp://localhost.
41    /// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`,
42    /// to denote the protocol for establishing a connection with the broker.
43    /// `mqtts://`, `ssl://` will use the native certificates if no ca is specified
44    pub url: String,
45
46    /// The quality of service to use when publishing messages. Defaults to at_most_once.
47    /// Could be at_most_once, at_least_once or exactly_once
48    #[serde_as(as = "Option<DisplayFromStr>")]
49    pub qos: Option<QualityOfService>,
50
51    /// Username for the mqtt broker
52    #[serde(rename = "username")]
53    pub user: Option<String>,
54
55    /// Password for the mqtt broker
56    pub password: Option<String>,
57
58    /// Prefix for the mqtt client id.
59    /// The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave
60    pub client_prefix: Option<String>,
61
62    /// `clean_start = true` removes all the state from queues & instructs the broker
63    /// to clean all the client state when client disconnects.
64    ///
65    /// When set `false`, broker will hold the client state and performs pending
66    /// operations on the client when reconnection with same `client_id`
67    /// happens. Local queue state is also held to retransmit packets after reconnection.
68    #[serde(default, deserialize_with = "deserialize_bool_from_string")]
69    pub clean_start: bool,
70
71    /// The maximum number of inflight messages. Defaults to 100
72    #[serde_as(as = "Option<DisplayFromStr>")]
73    pub inflight_messages: Option<usize>,
74
75    /// The max size of messages received by the MQTT client
76    #[serde_as(as = "Option<DisplayFromStr>")]
77    pub max_packet_size: Option<u32>,
78
79    /// Path to CA certificate file for verifying the broker's key.
80    #[serde(rename = "tls.ca")]
81    pub ca: Option<String>,
82    /// Path to client's certificate file (PEM). Required for client authentication.
83    /// Can be a file path under fs:// or a string with the certificate content.
84    #[serde(rename = "tls.client_cert")]
85    pub client_cert: Option<String>,
86
87    /// Path to client's private key file (PEM). Required for client authentication.
88    /// Can be a file path under fs:// or a string with the private key content.
89    #[serde(rename = "tls.client_key")]
90    pub client_key: Option<String>,
91}
92
93impl MqttCommon {
94    pub(crate) fn build_client(
95        &self,
96        actor_id: u32,
97        id: u64,
98    ) -> ConnectorResult<(AsyncClient, EventLoop)> {
99        let client_id = format!(
100            "{}_{}_{}",
101            self.client_prefix.as_deref().unwrap_or("risingwave"),
102            actor_id,
103            id
104        );
105
106        let mut url = url::Url::parse(&self.url)?;
107
108        let ssl = matches!(url.scheme(), "mqtts" | "ssl");
109
110        url.query_pairs_mut().append_pair("client_id", &client_id);
111
112        tracing::debug!("connecting mqtt using url: {}", url.as_str());
113
114        let mut options = MqttOptions::try_from(url)?;
115        options.set_keep_alive(std::time::Duration::from_secs(10));
116
117        options.set_clean_start(self.clean_start);
118
119        let mut connect_properties = ConnectProperties::new();
120        connect_properties.max_packet_size = self.max_packet_size;
121        options.set_connect_properties(connect_properties);
122
123        if ssl {
124            let tls_config = self.get_tls_config()?;
125            options.set_transport(rumqttc::Transport::tls_with_config(
126                rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)),
127            ));
128        }
129
130        if let Some(user) = &self.user {
131            options.set_credentials(user, self.password.as_deref().unwrap_or_default());
132        }
133
134        Ok(rumqttc::v5::AsyncClient::new(
135            options,
136            self.inflight_messages.unwrap_or(100),
137        ))
138    }
139
140    pub(crate) fn qos(&self) -> QoS {
141        self.qos
142            .as_ref()
143            .map(|qos| match qos {
144                QualityOfService::AtMostOnce => QoS::AtMostOnce,
145                QualityOfService::AtLeastOnce => QoS::AtLeastOnce,
146                QualityOfService::ExactlyOnce => QoS::ExactlyOnce,
147            })
148            .unwrap_or(QoS::AtMostOnce)
149    }
150
151    fn get_tls_config(&self) -> ConnectorResult<rustls::ClientConfig> {
152        let mut root_cert_store = rustls::RootCertStore::empty();
153        if let Some(ca) = &self.ca {
154            let certificates = load_certs(ca)?;
155            for cert in certificates {
156                root_cert_store.add(cert).unwrap();
157            }
158        } else {
159            for cert in
160                rustls_native_certs::load_native_certs().expect("could not load platform certs")
161            {
162                root_cert_store.add(cert).unwrap();
163            }
164        }
165
166        let builder = rustls::ClientConfig::builder().with_root_certificates(root_cert_store);
167
168        let tls_config = if let (Some(client_cert), Some(client_key)) =
169            (self.client_cert.as_ref(), self.client_key.as_ref())
170        {
171            let certs = load_certs(client_cert)?;
172            let key = load_private_key(client_key)?;
173
174            builder.with_client_auth_cert(certs, key)?
175        } else {
176            builder.with_no_client_auth()
177        };
178
179        Ok(tls_config)
180    }
181}