1use core::fmt::Debug;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use anyhow::{Context as _, anyhow};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::id::ActorId;
23use risingwave_common::row::Row;
24use risingwave_common::types::{DataType, ScalarRefImpl};
25use rumqttc::v5::ConnectionError;
26use rumqttc::v5::mqttbytes::QoS;
27use serde::Deserialize;
28use serde_with::serde_as;
29use thiserror_ext::AsReport;
30use with_options::WithOptions;
31
32use super::SinkWriterParam;
33use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc, SinkId};
34use super::encoder::{
35 DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
36 TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
37};
38use super::writer::AsyncTruncateSinkWriterExt;
39use crate::connector_common::MqttCommon;
40use crate::deserialize_bool_from_string;
41use crate::enforce_secret::EnforceSecret;
42use crate::sink::log_store::DeliveryFutureManagerAddFuture;
43use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
44use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
45
46pub const MQTT_SINK: &str = "mqtt";
47
48#[serde_as]
49#[derive(Clone, Debug, Deserialize, WithOptions)]
50pub struct MqttConfig {
51 #[serde(flatten)]
52 pub common: MqttCommon,
53
54 pub topic: Option<String>,
56
57 #[serde(default, deserialize_with = "deserialize_bool_from_string")]
59 pub retain: bool,
60
61 pub r#type: String,
63
64 #[serde(rename = "topic.field")]
66 pub topic_field: Option<String>,
67}
68
69impl EnforceSecret for MqttConfig {
70 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
71 MqttCommon::enforce_one(prop)
72 }
73}
74
75pub enum RowEncoderWrapper {
76 Json(JsonEncoder),
77 Proto(ProtoEncoder),
78}
79
80impl RowEncoder for RowEncoderWrapper {
81 type Output = Vec<u8>;
82
83 fn encode_cols(
84 &self,
85 row: impl Row,
86 col_indices: impl Iterator<Item = usize>,
87 ) -> Result<Self::Output> {
88 match self {
89 RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
90 RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
91 }
92 }
93
94 fn schema(&self) -> &Schema {
95 match self {
96 RowEncoderWrapper::Json(json) => json.schema(),
97 RowEncoderWrapper::Proto(proto) => proto.schema(),
98 }
99 }
100
101 fn col_indices(&self) -> Option<&[usize]> {
102 match self {
103 RowEncoderWrapper::Json(json) => json.col_indices(),
104 RowEncoderWrapper::Proto(proto) => proto.col_indices(),
105 }
106 }
107
108 fn encode(&self, row: impl Row) -> Result<Self::Output> {
109 match self {
110 RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
111 RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
112 }
113 }
114}
115
116#[derive(Clone, Debug)]
117pub struct MqttSink {
118 pub config: MqttConfig,
119 schema: Schema,
120 format_desc: SinkFormatDesc,
121 is_append_only: bool,
122 name: String,
123}
124
125impl EnforceSecret for MqttSink {
126 fn enforce_secret<'a>(
127 prop_iter: impl Iterator<Item = &'a str>,
128 ) -> crate::error::ConnectorResult<()> {
129 for prop in prop_iter {
130 MqttConfig::enforce_one(prop)?;
131 }
132 Ok(())
133 }
134}
135
136pub struct MqttSinkWriter {
138 pub config: MqttConfig,
139 payload_writer: MqttSinkPayloadWriter,
140 #[expect(dead_code)]
141 schema: Schema,
142 encoder: RowEncoderWrapper,
143 stopped: Arc<AtomicBool>,
144}
145
146impl MqttConfig {
148 pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
149 let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
150 .map_err(|e| SinkError::Config(anyhow!(e)))?;
151 if config.r#type != SINK_TYPE_APPEND_ONLY {
152 Err(SinkError::Config(anyhow!(
153 "MQTT sink only supports append-only mode"
154 )))
155 } else {
156 Ok(config)
157 }
158 }
159}
160
161impl TryFrom<SinkParam> for MqttSink {
162 type Error = SinkError;
163
164 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
165 let schema = param.schema();
166 let config = MqttConfig::from_btreemap(param.properties)?;
167 Ok(Self {
168 config,
169 schema,
170 name: param.sink_name,
171 format_desc: param
172 .format_desc
173 .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
174 is_append_only: param.sink_type.is_append_only(),
175 })
176 }
177}
178
179impl Sink for MqttSink {
180 type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
181
182 const SINK_NAME: &'static str = MQTT_SINK;
183
184 async fn validate(&self) -> Result<()> {
185 if !self.is_append_only {
186 return Err(SinkError::Mqtt(anyhow!(
187 "MQTT sink only supports append-only mode"
188 )));
189 }
190
191 if let Some(field) = &self.config.topic_field {
192 let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
193 } else if self.config.topic.is_none() {
194 return Err(SinkError::Config(anyhow!(
195 "either topic or topic.field must be set"
196 )));
197 }
198
199 let _client = (self.config.common.build_client(0.into(), 0))
200 .context("validate mqtt sink error")
201 .map_err(SinkError::Mqtt)?;
202
203 Ok(())
204 }
205
206 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
207 Ok(MqttSinkWriter::new(
208 self.config.clone(),
209 self.schema.clone(),
210 &self.format_desc,
211 &self.name,
212 writer_param.sink_id,
213 writer_param.actor_id,
214 )
215 .await?
216 .into_log_sinker(usize::MAX))
217 }
218}
219
220impl MqttSinkWriter {
221 pub async fn new(
222 config: MqttConfig,
223 schema: Schema,
224 format_desc: &SinkFormatDesc,
225 name: &str,
226 sink_id: SinkId,
227 actor_id: ActorId,
228 ) -> Result<Self> {
229 let mut topic_index_path = vec![];
230 if let Some(field) = &config.topic_field {
231 topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
232 }
233
234 let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
235 let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
236 let encoder = match format_desc.format {
237 SinkFormat::AppendOnly => match format_desc.encode {
238 SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
239 schema.clone(),
240 None,
241 DateHandlingMode::FromCe,
242 TimestampHandlingMode::Milli,
243 timestamptz_mode,
244 TimeHandlingMode::Milli,
245 jsonb_handling_mode,
246 )),
247 SinkEncode::Protobuf => {
248 let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
249 &format_desc.options,
250 config.topic.as_deref().unwrap_or(name),
251 None,
252 )
253 .await
254 .map_err(|e| SinkError::Config(anyhow!(e)))?;
255 let header = match sid {
256 None => ProtoHeader::None,
257 Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
258 };
259 RowEncoderWrapper::Proto(ProtoEncoder::new(
260 schema.clone(),
261 None,
262 descriptor,
263 header,
264 )?)
265 }
266 _ => {
267 return Err(SinkError::Config(anyhow!(
268 "mqtt sink encode unsupported: {:?}",
269 format_desc.encode,
270 )));
271 }
272 },
273 _ => {
274 return Err(SinkError::Config(anyhow!(
275 "MQTT sink only supports append-only mode"
276 )));
277 }
278 };
279 let qos = config.common.qos();
280
281 let (client, mut eventloop) = config
282 .common
283 .build_client(actor_id, sink_id.as_raw_id())
284 .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
285
286 let stopped = Arc::new(AtomicBool::new(false));
287 let stopped_clone = stopped.clone();
288 tokio::spawn(async move {
289 while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
290 match eventloop.poll().await {
291 Ok(_) => (),
292 Err(err) => match err {
293 ConnectionError::Timeout(_) => (),
294 ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
295 | ConnectionError::Io(err)
296 if err.kind() == std::io::ErrorKind::ConnectionAborted
297 || err.kind() == std::io::ErrorKind::ConnectionReset =>
298 {
299 continue;
300 }
301 err => {
302 tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
303 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
304 }
305 },
306 }
307 }
308 });
309
310 let payload_writer = MqttSinkPayloadWriter {
311 topic: config.topic.clone(),
312 client,
313 qos,
314 retain: config.retain,
315 topic_index_path,
316 };
317
318 Ok::<_, SinkError>(Self {
319 config: config.clone(),
320 payload_writer,
321 schema: schema.clone(),
322 stopped,
323 encoder,
324 })
325 }
326}
327
328impl AsyncTruncateSinkWriter for MqttSinkWriter {
329 async fn write_chunk<'a>(
330 &'a mut self,
331 chunk: StreamChunk,
332 _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
333 ) -> Result<()> {
334 self.payload_writer.write_chunk(chunk, &self.encoder).await
335 }
336}
337
338impl Drop for MqttSinkWriter {
339 fn drop(&mut self) {
340 self.stopped
341 .store(true, std::sync::atomic::Ordering::Relaxed);
342 }
343}
344
345struct MqttSinkPayloadWriter {
346 client: rumqttc::v5::AsyncClient,
348 topic: Option<String>,
349 qos: QoS,
350 retain: bool,
351 topic_index_path: Vec<usize>,
352}
353
354impl MqttSinkPayloadWriter {
355 async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
356 for (op, row) in chunk.rows() {
357 if op != Op::Insert {
358 continue;
359 }
360
361 let topic = match get_topic_from_index_path(
362 &self.topic_index_path,
363 self.topic.as_deref(),
364 &row,
365 ) {
366 Some(s) => s,
367 None => {
368 tracing::error!("topic field not found in row, skipping: {:?}", row);
369 return Ok(());
370 }
371 };
372
373 let v = encoder.encode(row)?;
374
375 self.client
376 .publish(topic, self.qos, self.retain, v)
377 .await
378 .context("mqtt sink error")
379 .map_err(SinkError::Mqtt)?;
380 }
381
382 Ok(())
383 }
384}
385
386fn get_topic_from_index_path<'s>(
387 path: &[usize],
388 default_topic: Option<&'s str>,
389 row: &'s RowRef<'s>,
390) -> Option<&'s str> {
391 if let Some(topic) = default_topic
392 && path.is_empty()
393 {
394 Some(topic)
395 } else {
396 let mut iter = path.iter();
397 let scalar = iter
398 .next()
399 .and_then(|pos| row.datum_at(*pos))
400 .and_then(|d| {
401 iter.try_fold(d, |d, pos| match d {
402 ScalarRefImpl::Struct(struct_ref) => {
403 struct_ref.iter_fields_ref().nth(*pos).flatten()
404 }
405 _ => None,
406 })
407 });
408 match scalar {
409 Some(ScalarRefImpl::Utf8(s)) => Some(s),
410 _ => {
411 if let Some(topic) = default_topic {
412 Some(topic)
413 } else {
414 None
415 }
416 }
417 }
418 }
419}
420
421fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
425 let mut iter = topic_field.split('.');
426 let mut path = vec![];
427 let dt =
428 iter.next()
429 .and_then(|field| {
430 schema
432 .fields()
433 .iter()
434 .enumerate()
435 .find(|(_, f)| f.name == field)
436 .map(|(pos, f)| {
437 path.push(pos);
438 &f.data_type
439 })
440 })
441 .and_then(|dt| {
442 iter.try_fold(dt, |dt, field| match dt {
444 DataType::Struct(st) => {
445 st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
446 |(pos, (_, dt))| {
447 path.push(pos);
448 dt
449 },
450 )
451 }
452 _ => None,
453 })
454 });
455
456 match dt {
457 Some(DataType::Varchar) => Ok(path),
458 Some(dt) => Err(SinkError::Config(anyhow!(
459 "topic field `{}` must be of type string but got {:?}",
460 topic_field,
461 dt
462 ))),
463 None => Err(SinkError::Config(anyhow!(
464 "topic field `{}` not found",
465 topic_field
466 ))),
467 }
468}
469
470#[cfg(test)]
471mod test {
472 use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
473 use risingwave_common::catalog::{Field, Schema};
474 use risingwave_common::types::{DataType, StructType};
475
476 use super::{get_topic_field_index_path, get_topic_from_index_path};
477
478 #[test]
479 fn test_single_field_extraction() {
480 let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
481 let path = get_topic_field_index_path(&schema, "topic").unwrap();
482 assert_eq!(path, vec![0]);
483
484 let chunk = DataChunk::from_pretty(
485 "T
486 test",
487 );
488
489 let row = RowRef::new(&chunk, 0);
490
491 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
492
493 let result = get_topic_field_index_path(&schema, "other_field");
494 assert!(result.is_err());
495 }
496
497 #[test]
498 fn test_nested_field_extraction() {
499 let schema = Schema::new(vec![Field::with_name(
500 DataType::Struct(StructType::new(vec![
501 ("field", DataType::Int32),
502 ("subtopic", DataType::Varchar),
503 ])),
504 "topic",
505 )]);
506 let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
507 assert_eq!(path, vec![0, 1]);
508
509 let chunk = DataChunk::from_pretty(
510 "<i,T>
511 (1,test)",
512 );
513
514 let row = RowRef::new(&chunk, 0);
515
516 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
517
518 let result = get_topic_field_index_path(&schema, "topic.other_field");
519 assert!(result.is_err());
520 }
521
522 #[test]
523 fn test_null_values_extraction() {
524 let path = vec![0];
525 let chunk = DataChunk::from_pretty(
526 "T
527 .",
528 );
529 let row = RowRef::new(&chunk, 0);
530 assert_eq!(
531 get_topic_from_index_path(&path, Some("default"), &row),
532 Some("default")
533 );
534 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
535
536 let path = vec![0, 1];
537 let chunk = DataChunk::from_pretty(
538 "<i,T>
539 (1,)",
540 );
541 let row = RowRef::new(&chunk, 0);
542 assert_eq!(
543 get_topic_from_index_path(&path, Some("default"), &row),
544 Some("default")
545 );
546 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
547 }
548
549 #[test]
550 fn test_multiple_levels() {
551 let schema = Schema::new(vec![
552 Field::with_name(
553 DataType::Struct(StructType::new(vec![
554 ("field", DataType::Int32),
555 (
556 "subtopic",
557 DataType::Struct(StructType::new(vec![
558 ("int_field", DataType::Int32),
559 ("boolean_field", DataType::Boolean),
560 ("string_field", DataType::Varchar),
561 ])),
562 ),
563 ])),
564 "topic",
565 ),
566 Field::with_name(DataType::Varchar, "other_field"),
567 ]);
568
569 let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
570 assert_eq!(path, vec![0, 1, 2]);
571
572 assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
573
574 assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
575
576 assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
577
578 let path = get_topic_field_index_path(&schema, "other_field").unwrap();
579 assert_eq!(path, vec![1]);
580
581 let chunk = DataChunk::from_pretty(
582 "<i,<T>> T
583 (1,(test)) other",
584 );
585
586 let row = RowRef::new(&chunk, 0);
587
588 assert_eq!(
590 get_topic_from_index_path(&[0, 1, 0], None, &row),
591 Some("test")
592 );
593
594 assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
596
597 assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
599 }
600}