risingwave_connector/connector_common/
mqtt_common.rs1use rumqttc::tokio_rustls::rustls;
16use rumqttc::v5::mqttbytes::QoS;
17use rumqttc::v5::mqttbytes::v5::ConnectProperties;
18use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions};
19use serde_derive::Deserialize;
20use serde_with::{DisplayFromStr, serde_as};
21use strum_macros::{Display, EnumString};
22use with_options::WithOptions;
23
24use super::common::{load_certs, load_private_key};
25use crate::deserialize_bool_from_string;
26use crate::error::ConnectorResult;
27
28#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
29#[strum(serialize_all = "snake_case")]
30#[allow(clippy::enum_variant_names)]
31pub enum QualityOfService {
32 AtLeastOnce,
33 AtMostOnce,
34 ExactlyOnce,
35}
36
37#[serde_as]
38#[derive(Deserialize, Debug, Clone, WithOptions)]
39pub struct MqttCommon {
40 pub url: String,
45
46 #[serde_as(as = "Option<DisplayFromStr>")]
49 pub qos: Option<QualityOfService>,
50
51 #[serde(rename = "username")]
53 pub user: Option<String>,
54
55 pub password: Option<String>,
57
58 pub client_prefix: Option<String>,
61
62 #[serde(default, deserialize_with = "deserialize_bool_from_string")]
69 pub clean_start: bool,
70
71 #[serde_as(as = "Option<DisplayFromStr>")]
73 pub inflight_messages: Option<usize>,
74
75 #[serde_as(as = "Option<DisplayFromStr>")]
77 pub max_packet_size: Option<u32>,
78
79 #[serde(rename = "tls.ca")]
81 pub ca: Option<String>,
82 #[serde(rename = "tls.client_cert")]
85 pub client_cert: Option<String>,
86
87 #[serde(rename = "tls.client_key")]
90 pub client_key: Option<String>,
91}
92
93impl MqttCommon {
94 pub(crate) fn build_client(
95 &self,
96 actor_id: u32,
97 id: u64,
98 ) -> ConnectorResult<(AsyncClient, EventLoop)> {
99 let client_id = format!(
100 "{}_{}_{}",
101 self.client_prefix.as_deref().unwrap_or("risingwave"),
102 actor_id,
103 id
104 );
105
106 let mut url = url::Url::parse(&self.url)?;
107
108 let ssl = matches!(url.scheme(), "mqtts" | "ssl");
109
110 url.query_pairs_mut().append_pair("client_id", &client_id);
111
112 tracing::debug!("connecting mqtt using url: {}", url.as_str());
113
114 let mut options = MqttOptions::try_from(url)?;
115 options.set_keep_alive(std::time::Duration::from_secs(10));
116
117 options.set_clean_start(self.clean_start);
118
119 let mut connect_properties = ConnectProperties::new();
120 connect_properties.max_packet_size = self.max_packet_size;
121 options.set_connect_properties(connect_properties);
122
123 if ssl {
124 let tls_config = self.get_tls_config()?;
125 options.set_transport(rumqttc::Transport::tls_with_config(
126 rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)),
127 ));
128 }
129
130 if let Some(user) = &self.user {
131 options.set_credentials(user, self.password.as_deref().unwrap_or_default());
132 }
133
134 Ok(rumqttc::v5::AsyncClient::new(
135 options,
136 self.inflight_messages.unwrap_or(100),
137 ))
138 }
139
140 pub(crate) fn qos(&self) -> QoS {
141 self.qos
142 .as_ref()
143 .map(|qos| match qos {
144 QualityOfService::AtMostOnce => QoS::AtMostOnce,
145 QualityOfService::AtLeastOnce => QoS::AtLeastOnce,
146 QualityOfService::ExactlyOnce => QoS::ExactlyOnce,
147 })
148 .unwrap_or(QoS::AtMostOnce)
149 }
150
151 fn get_tls_config(&self) -> ConnectorResult<rustls::ClientConfig> {
152 let mut root_cert_store = rustls::RootCertStore::empty();
153 if let Some(ca) = &self.ca {
154 let certificates = load_certs(ca)?;
155 for cert in certificates {
156 root_cert_store.add(cert).unwrap();
157 }
158 } else {
159 for cert in
160 rustls_native_certs::load_native_certs().expect("could not load platform certs")
161 {
162 root_cert_store.add(cert).unwrap();
163 }
164 }
165
166 let builder = rustls::ClientConfig::builder().with_root_certificates(root_cert_store);
167
168 let tls_config = if let (Some(client_cert), Some(client_key)) =
169 (self.client_cert.as_ref(), self.client_key.as_ref())
170 {
171 let certs = load_certs(client_cert)?;
172 let key = load_private_key(client_key)?;
173
174 builder.with_client_auth_cert(certs, key)?
175 } else {
176 builder.with_no_client_auth()
177 };
178
179 Ok(tls_config)
180 }
181}