risingwave_connector/connector_common/
mqtt_common.rs1use phf::{Set, phf_set};
16use rumqttc::tokio_rustls::rustls;
17use rumqttc::v5::mqttbytes::QoS;
18use rumqttc::v5::mqttbytes::v5::ConnectProperties;
19use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions};
20use serde_derive::Deserialize;
21use serde_with::{DisplayFromStr, serde_as};
22use strum_macros::{Display, EnumString};
23use with_options::WithOptions;
24
25use super::common::{load_certs, load_private_key};
26use crate::deserialize_bool_from_string;
27use crate::enforce_secret::EnforceSecret;
28use crate::error::ConnectorResult;
29
30#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
31#[strum(serialize_all = "snake_case")]
32#[allow(clippy::enum_variant_names)]
33pub enum QualityOfService {
34 AtLeastOnce,
35 AtMostOnce,
36 ExactlyOnce,
37}
38
39#[serde_as]
40#[derive(Deserialize, Debug, Clone, WithOptions)]
41pub struct MqttCommon {
42 pub url: String,
47
48 #[serde_as(as = "Option<DisplayFromStr>")]
51 pub qos: Option<QualityOfService>,
52
53 #[serde(rename = "username")]
55 pub user: Option<String>,
56
57 pub password: Option<String>,
59
60 pub client_prefix: Option<String>,
63
64 #[serde(default, deserialize_with = "deserialize_bool_from_string")]
71 pub clean_start: bool,
72
73 #[serde_as(as = "Option<DisplayFromStr>")]
75 pub inflight_messages: Option<usize>,
76
77 #[serde_as(as = "Option<DisplayFromStr>")]
79 pub max_packet_size: Option<u32>,
80
81 #[serde(rename = "tls.ca")]
83 pub ca: Option<String>,
84 #[serde(rename = "tls.client_cert")]
87 pub client_cert: Option<String>,
88
89 #[serde(rename = "tls.client_key")]
92 pub client_key: Option<String>,
93}
94
95impl EnforceSecret for MqttCommon {
96 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
97 "tls.client_cert",
98 "tls.client_key",
99 "password",
100 };
101}
102
103impl MqttCommon {
104 pub(crate) fn build_client(
105 &self,
106 actor_id: u32,
107 id: u64,
108 ) -> ConnectorResult<(AsyncClient, EventLoop)> {
109 let client_id = format!(
110 "{}_{}_{}",
111 self.client_prefix.as_deref().unwrap_or("risingwave"),
112 actor_id,
113 id
114 );
115
116 let mut url = url::Url::parse(&self.url)?;
117
118 let ssl = matches!(url.scheme(), "mqtts" | "ssl");
119
120 url.query_pairs_mut().append_pair("client_id", &client_id);
121
122 tracing::debug!("connecting mqtt using url: {}", url.as_str());
123
124 let mut options = MqttOptions::try_from(url)?;
125 options.set_keep_alive(std::time::Duration::from_secs(10));
126
127 options.set_clean_start(self.clean_start);
128
129 let mut connect_properties = ConnectProperties::new();
130 connect_properties.max_packet_size = self.max_packet_size;
131 options.set_connect_properties(connect_properties);
132
133 if ssl {
134 let tls_config = self.get_tls_config()?;
135 options.set_transport(rumqttc::Transport::tls_with_config(
136 rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)),
137 ));
138 }
139
140 if let Some(user) = &self.user {
141 options.set_credentials(user, self.password.as_deref().unwrap_or_default());
142 }
143
144 Ok(rumqttc::v5::AsyncClient::new(
145 options,
146 self.inflight_messages.unwrap_or(100),
147 ))
148 }
149
150 pub(crate) fn qos(&self) -> QoS {
151 self.qos
152 .as_ref()
153 .map(|qos| match qos {
154 QualityOfService::AtMostOnce => QoS::AtMostOnce,
155 QualityOfService::AtLeastOnce => QoS::AtLeastOnce,
156 QualityOfService::ExactlyOnce => QoS::ExactlyOnce,
157 })
158 .unwrap_or(QoS::AtMostOnce)
159 }
160
161 fn get_tls_config(&self) -> ConnectorResult<rustls::ClientConfig> {
162 let mut root_cert_store = rustls::RootCertStore::empty();
163 if let Some(ca) = &self.ca {
164 let certificates = load_certs(ca)?;
165 for cert in certificates {
166 root_cert_store.add(cert).unwrap();
167 }
168 } else {
169 for cert in
170 rustls_native_certs::load_native_certs().expect("could not load platform certs")
171 {
172 root_cert_store.add(cert).unwrap();
173 }
174 }
175
176 let builder = rustls::ClientConfig::builder().with_root_certificates(root_cert_store);
177
178 let tls_config = if let (Some(client_cert), Some(client_key)) =
179 (self.client_cert.as_ref(), self.client_key.as_ref())
180 {
181 let certs = load_certs(client_cert)?;
182 let key = load_private_key(client_key)?;
183
184 builder.with_client_auth_cert(certs, key)?
185 } else {
186 builder.with_no_client_auth()
187 };
188
189 Ok(tls_config)
190 }
191}