risingwave_connector/connector_common/
mqtt_common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use phf::{Set, phf_set};
16use rumqttc::tokio_rustls::rustls;
17use rumqttc::v5::mqttbytes::QoS;
18use rumqttc::v5::mqttbytes::v5::ConnectProperties;
19use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions};
20use serde_derive::Deserialize;
21use serde_with::{DisplayFromStr, serde_as};
22use strum_macros::{Display, EnumString};
23use with_options::WithOptions;
24
25use super::common::{load_certs, load_private_key};
26use crate::deserialize_bool_from_string;
27use crate::enforce_secret::EnforceSecret;
28use crate::error::ConnectorResult;
29
30#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
31#[strum(serialize_all = "snake_case")]
32#[allow(clippy::enum_variant_names)]
33pub enum QualityOfService {
34    AtLeastOnce,
35    AtMostOnce,
36    ExactlyOnce,
37}
38
39#[serde_as]
40#[derive(Deserialize, Debug, Clone, WithOptions)]
41pub struct MqttCommon {
42    /// The url of the broker to connect to. e.g. tcp://localhost.
43    /// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`,
44    /// to denote the protocol for establishing a connection with the broker.
45    /// `mqtts://`, `ssl://` will use the native certificates if no ca is specified
46    pub url: String,
47
48    /// The quality of service to use when publishing messages. Defaults to at_most_once.
49    /// Could be at_most_once, at_least_once or exactly_once
50    #[serde_as(as = "Option<DisplayFromStr>")]
51    pub qos: Option<QualityOfService>,
52
53    /// Username for the mqtt broker
54    #[serde(rename = "username")]
55    pub user: Option<String>,
56
57    /// Password for the mqtt broker
58    pub password: Option<String>,
59
60    /// Prefix for the mqtt client id.
61    /// The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave
62    pub client_prefix: Option<String>,
63
64    /// `clean_start = true` removes all the state from queues & instructs the broker
65    /// to clean all the client state when client disconnects.
66    ///
67    /// When set `false`, broker will hold the client state and performs pending
68    /// operations on the client when reconnection with same `client_id`
69    /// happens. Local queue state is also held to retransmit packets after reconnection.
70    #[serde(default, deserialize_with = "deserialize_bool_from_string")]
71    pub clean_start: bool,
72
73    /// The maximum number of inflight messages. Defaults to 100
74    #[serde_as(as = "Option<DisplayFromStr>")]
75    pub inflight_messages: Option<usize>,
76
77    /// The max size of messages received by the MQTT client
78    #[serde_as(as = "Option<DisplayFromStr>")]
79    pub max_packet_size: Option<u32>,
80
81    /// Path to CA certificate file for verifying the broker's key.
82    #[serde(rename = "tls.ca")]
83    pub ca: Option<String>,
84    /// Path to client's certificate file (PEM). Required for client authentication.
85    /// Can be a file path under fs:// or a string with the certificate content.
86    #[serde(rename = "tls.client_cert")]
87    pub client_cert: Option<String>,
88
89    /// Path to client's private key file (PEM). Required for client authentication.
90    /// Can be a file path under fs:// or a string with the private key content.
91    #[serde(rename = "tls.client_key")]
92    pub client_key: Option<String>,
93}
94
95impl EnforceSecret for MqttCommon {
96    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
97        "tls.client_cert",
98        "tls.client_key",
99        "password",
100    };
101}
102
103impl MqttCommon {
104    pub(crate) fn build_client(
105        &self,
106        actor_id: u32,
107        id: u64,
108    ) -> ConnectorResult<(AsyncClient, EventLoop)> {
109        let client_id = format!(
110            "{}_{}_{}",
111            self.client_prefix.as_deref().unwrap_or("risingwave"),
112            actor_id,
113            id
114        );
115
116        let mut url = url::Url::parse(&self.url)?;
117
118        let ssl = matches!(url.scheme(), "mqtts" | "ssl");
119
120        url.query_pairs_mut().append_pair("client_id", &client_id);
121
122        tracing::debug!("connecting mqtt using url: {}", url.as_str());
123
124        let mut options = MqttOptions::try_from(url)?;
125        options.set_keep_alive(std::time::Duration::from_secs(10));
126
127        options.set_clean_start(self.clean_start);
128
129        let mut connect_properties = ConnectProperties::new();
130        connect_properties.max_packet_size = self.max_packet_size;
131        options.set_connect_properties(connect_properties);
132
133        if ssl {
134            let tls_config = self.get_tls_config()?;
135            options.set_transport(rumqttc::Transport::tls_with_config(
136                rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)),
137            ));
138        }
139
140        if let Some(user) = &self.user {
141            options.set_credentials(user, self.password.as_deref().unwrap_or_default());
142        }
143
144        Ok(rumqttc::v5::AsyncClient::new(
145            options,
146            self.inflight_messages.unwrap_or(100),
147        ))
148    }
149
150    pub(crate) fn qos(&self) -> QoS {
151        self.qos
152            .as_ref()
153            .map(|qos| match qos {
154                QualityOfService::AtMostOnce => QoS::AtMostOnce,
155                QualityOfService::AtLeastOnce => QoS::AtLeastOnce,
156                QualityOfService::ExactlyOnce => QoS::ExactlyOnce,
157            })
158            .unwrap_or(QoS::AtMostOnce)
159    }
160
161    fn get_tls_config(&self) -> ConnectorResult<rustls::ClientConfig> {
162        let mut root_cert_store = rustls::RootCertStore::empty();
163        if let Some(ca) = &self.ca {
164            let certificates = load_certs(ca)?;
165            for cert in certificates {
166                root_cert_store.add(cert).unwrap();
167            }
168        } else {
169            for cert in
170                rustls_native_certs::load_native_certs().expect("could not load platform certs")
171            {
172                root_cert_store.add(cert).unwrap();
173            }
174        }
175
176        let builder = rustls::ClientConfig::builder().with_root_certificates(root_cert_store);
177
178        let tls_config = if let (Some(client_cert), Some(client_key)) =
179            (self.client_cert.as_ref(), self.client_key.as_ref())
180        {
181            let certs = load_certs(client_cert)?;
182            let key = load_private_key(client_key)?;
183
184            builder.with_client_auth_cert(certs, key)?
185        } else {
186            builder.with_no_client_auth()
187        };
188
189        Ok(tls_config)
190    }
191}