risingwave_connector/connector_common/
mqtt_common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use phf::{Set, phf_set};
16use risingwave_common::id::ActorId;
17use rumqttc::tokio_rustls::rustls;
18use rumqttc::v5::mqttbytes::QoS;
19use rumqttc::v5::mqttbytes::v5::ConnectProperties;
20use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions};
21use serde::Deserialize;
22use serde_with::{DisplayFromStr, serde_as};
23use strum_macros::{Display, EnumString};
24use with_options::WithOptions;
25
26use super::common::{load_certs, load_private_key};
27use crate::deserialize_bool_from_string;
28use crate::enforce_secret::EnforceSecret;
29use crate::error::ConnectorResult;
30
31#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
32#[strum(serialize_all = "snake_case")]
33#[allow(clippy::enum_variant_names)]
34pub enum QualityOfService {
35    AtLeastOnce,
36    AtMostOnce,
37    ExactlyOnce,
38}
39
40#[serde_as]
41#[derive(Deserialize, Debug, Clone, WithOptions)]
42pub struct MqttCommon {
43    /// The url of the broker to connect to. e.g. <tcp://localhost>.
44    /// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`,
45    /// to denote the protocol for establishing a connection with the broker.
46    /// `mqtts://`, `ssl://` will use the native certificates if no ca is specified
47    pub url: String,
48
49    /// The quality of service to use when publishing messages. Defaults to `at_most_once`.
50    /// Could be `at_most_once`, `at_least_once` or `exactly_once`
51    #[serde_as(as = "Option<DisplayFromStr>")]
52    pub qos: Option<QualityOfService>,
53
54    /// Username for the mqtt broker
55    #[serde(rename = "username")]
56    pub user: Option<String>,
57
58    /// Password for the mqtt broker
59    pub password: Option<String>,
60
61    /// Prefix for the mqtt client id.
62    /// The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave
63    pub client_prefix: Option<String>,
64
65    /// `clean_start = true` removes all the state from queues & instructs the broker
66    /// to clean all the client state when client disconnects.
67    ///
68    /// When set `false`, broker will hold the client state and performs pending
69    /// operations on the client when reconnection with same `client_id`
70    /// happens. Local queue state is also held to retransmit packets after reconnection.
71    #[serde(default, deserialize_with = "deserialize_bool_from_string")]
72    pub clean_start: bool,
73
74    /// The maximum number of inflight messages. Defaults to 100
75    #[serde_as(as = "Option<DisplayFromStr>")]
76    pub inflight_messages: Option<usize>,
77
78    /// The max size of messages received by the MQTT client
79    #[serde_as(as = "Option<DisplayFromStr>")]
80    pub max_packet_size: Option<u32>,
81
82    /// Path to CA certificate file for verifying the broker's key.
83    #[serde(rename = "tls.ca")]
84    pub ca: Option<String>,
85    /// Path to client's certificate file (PEM). Required for client authentication.
86    /// Can be a file path under fs:// or a string with the certificate content.
87    #[serde(rename = "tls.client_cert")]
88    pub client_cert: Option<String>,
89
90    /// Path to client's private key file (PEM). Required for client authentication.
91    /// Can be a file path under fs:// or a string with the private key content.
92    #[serde(rename = "tls.client_key")]
93    pub client_key: Option<String>,
94}
95
96impl EnforceSecret for MqttCommon {
97    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
98        "tls.client_cert",
99        "tls.client_key",
100        "password",
101    };
102}
103
104impl MqttCommon {
105    pub(crate) fn build_client(
106        &self,
107        actor_id: ActorId,
108        id: u32,
109    ) -> ConnectorResult<(AsyncClient, EventLoop)> {
110        let client_id = format!(
111            "{}_{}_{}",
112            self.client_prefix.as_deref().unwrap_or("risingwave"),
113            actor_id,
114            id
115        );
116
117        let mut url = url::Url::parse(&self.url)?;
118
119        let ssl = matches!(url.scheme(), "mqtts" | "ssl");
120
121        url.query_pairs_mut().append_pair("client_id", &client_id);
122
123        tracing::debug!("connecting mqtt using url: {}", url.as_str());
124
125        let mut options = MqttOptions::try_from(url)?;
126        options.set_keep_alive(std::time::Duration::from_secs(10));
127
128        options.set_clean_start(self.clean_start);
129
130        let mut connect_properties = ConnectProperties::new();
131        connect_properties.max_packet_size = self.max_packet_size;
132        options.set_connect_properties(connect_properties);
133
134        if ssl {
135            let tls_config = self.get_tls_config()?;
136            options.set_transport(rumqttc::Transport::tls_with_config(
137                rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)),
138            ));
139        }
140
141        if let Some(user) = &self.user {
142            options.set_credentials(user, self.password.as_deref().unwrap_or_default());
143        }
144
145        Ok(rumqttc::v5::AsyncClient::new(
146            options,
147            self.inflight_messages.unwrap_or(100),
148        ))
149    }
150
151    pub(crate) fn qos(&self) -> QoS {
152        self.qos
153            .as_ref()
154            .map(|qos| match qos {
155                QualityOfService::AtMostOnce => QoS::AtMostOnce,
156                QualityOfService::AtLeastOnce => QoS::AtLeastOnce,
157                QualityOfService::ExactlyOnce => QoS::ExactlyOnce,
158            })
159            .unwrap_or(QoS::AtMostOnce)
160    }
161
162    fn get_tls_config(&self) -> ConnectorResult<rustls::ClientConfig> {
163        let mut root_cert_store = rustls::RootCertStore::empty();
164        if let Some(ca) = &self.ca {
165            let certificates = load_certs(ca)?;
166            for cert in certificates {
167                root_cert_store.add(cert).unwrap();
168            }
169        } else {
170            for cert in
171                rustls_native_certs::load_native_certs().expect("could not load platform certs")
172            {
173                root_cert_store.add(cert).unwrap();
174            }
175        }
176
177        let builder = rustls::ClientConfig::builder().with_root_certificates(root_cert_store);
178
179        let tls_config = if let (Some(client_cert), Some(client_key)) =
180            (self.client_cert.as_ref(), self.client_key.as_ref())
181        {
182            let certs = load_certs(client_cert)?;
183            let key = load_private_key(client_key)?;
184
185            builder.with_client_auth_cert(certs, key)?
186        } else {
187            builder.with_no_client_auth()
188        };
189
190        Ok(tls_config)
191    }
192}