1use core::fmt::Debug;
16use std::collections::BTreeMap;
17use std::sync::Arc;
18use std::sync::atomic::AtomicBool;
19
20use anyhow::{Context as _, anyhow};
21use risingwave_common::array::{Op, RowRef, StreamChunk};
22use risingwave_common::catalog::Schema;
23use risingwave_common::id::ActorId;
24use risingwave_common::row::Row;
25use risingwave_common::types::{DataType, ScalarRefImpl};
26use rumqttc::v5::ConnectionError;
27use rumqttc::v5::mqttbytes::QoS;
28use serde::Deserialize;
29use serde_with::serde_as;
30use thiserror_ext::AsReport;
31use with_options::WithOptions;
32
33use super::SinkWriterParam;
34use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc, SinkId};
35use super::encoder::{
36 DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
37 TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
38};
39use super::writer::AsyncTruncateSinkWriterExt;
40use crate::connector_common::MqttCommon;
41use crate::deserialize_bool_from_string;
42use crate::enforce_secret::EnforceSecret;
43use crate::sink::log_store::DeliveryFutureManagerAddFuture;
44use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
45use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
46
47pub const MQTT_SINK: &str = "mqtt";
48
49#[serde_as]
50#[derive(Clone, Debug, Deserialize, WithOptions)]
51pub struct MqttConfig {
52 #[serde(flatten)]
53 pub common: MqttCommon,
54
55 pub topic: Option<String>,
57
58 #[serde(default, deserialize_with = "deserialize_bool_from_string")]
60 pub retain: bool,
61
62 pub r#type: String,
64
65 #[serde(rename = "topic.field")]
67 pub topic_field: Option<String>,
68}
69
70impl EnforceSecret for MqttConfig {
71 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
72 MqttCommon::enforce_one(prop)
73 }
74}
75
76pub enum RowEncoderWrapper {
77 Json(JsonEncoder),
78 Proto(ProtoEncoder),
79}
80
81impl RowEncoder for RowEncoderWrapper {
82 type Output = Vec<u8>;
83
84 fn encode_cols(
85 &self,
86 row: impl Row,
87 col_indices: impl Iterator<Item = usize>,
88 ) -> Result<Self::Output> {
89 match self {
90 RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
91 RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
92 }
93 }
94
95 fn schema(&self) -> &Schema {
96 match self {
97 RowEncoderWrapper::Json(json) => json.schema(),
98 RowEncoderWrapper::Proto(proto) => proto.schema(),
99 }
100 }
101
102 fn col_indices(&self) -> Option<&[usize]> {
103 match self {
104 RowEncoderWrapper::Json(json) => json.col_indices(),
105 RowEncoderWrapper::Proto(proto) => proto.col_indices(),
106 }
107 }
108
109 fn encode(&self, row: impl Row) -> Result<Self::Output> {
110 match self {
111 RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
112 RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
113 }
114 }
115}
116
117#[derive(Clone, Debug)]
118pub struct MqttSink {
119 pub config: MqttConfig,
120 schema: Schema,
121 format_desc: SinkFormatDesc,
122 is_append_only: bool,
123 name: String,
124}
125
126impl EnforceSecret for MqttSink {
127 fn enforce_secret<'a>(
128 prop_iter: impl Iterator<Item = &'a str>,
129 ) -> crate::error::ConnectorResult<()> {
130 for prop in prop_iter {
131 MqttConfig::enforce_one(prop)?;
132 }
133 Ok(())
134 }
135}
136
137pub struct MqttSinkWriter {
139 pub config: MqttConfig,
140 payload_writer: MqttSinkPayloadWriter,
141 #[expect(dead_code)]
142 schema: Schema,
143 encoder: RowEncoderWrapper,
144 stopped: Arc<AtomicBool>,
145}
146
147impl MqttConfig {
149 pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
150 let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
151 .map_err(|e| SinkError::Config(anyhow!(e)))?;
152 if config.r#type != SINK_TYPE_APPEND_ONLY {
153 Err(SinkError::Config(anyhow!(
154 "MQTT sink only supports append-only mode"
155 )))
156 } else {
157 Ok(config)
158 }
159 }
160}
161
162impl TryFrom<SinkParam> for MqttSink {
163 type Error = SinkError;
164
165 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
166 let schema = param.schema();
167 let config = MqttConfig::from_btreemap(param.properties)?;
168 Ok(Self {
169 config,
170 schema,
171 name: param.sink_name,
172 format_desc: param
173 .format_desc
174 .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
175 is_append_only: param.sink_type.is_append_only(),
176 })
177 }
178}
179
180impl Sink for MqttSink {
181 type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
182
183 const SINK_NAME: &'static str = MQTT_SINK;
184
185 async fn validate(&self) -> Result<()> {
186 if !self.is_append_only {
187 return Err(SinkError::Mqtt(anyhow!(
188 "MQTT sink only supports append-only mode"
189 )));
190 }
191
192 if let Some(field) = &self.config.topic_field {
193 let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
194 } else if self.config.topic.is_none() {
195 return Err(SinkError::Config(anyhow!(
196 "either topic or topic.field must be set"
197 )));
198 }
199
200 let _client = (self.config.common.build_client(0.into(), 0))
201 .context("validate mqtt sink error")
202 .map_err(SinkError::Mqtt)?;
203
204 Ok(())
205 }
206
207 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
208 Ok(MqttSinkWriter::new(
209 self.config.clone(),
210 self.schema.clone(),
211 &self.format_desc,
212 &self.name,
213 writer_param.sink_id,
214 writer_param.actor_id,
215 )
216 .await?
217 .into_log_sinker(usize::MAX))
218 }
219}
220
221impl MqttSinkWriter {
222 pub async fn new(
223 config: MqttConfig,
224 schema: Schema,
225 format_desc: &SinkFormatDesc,
226 name: &str,
227 sink_id: SinkId,
228 actor_id: ActorId,
229 ) -> Result<Self> {
230 let mut topic_index_path = vec![];
231 if let Some(field) = &config.topic_field {
232 topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
233 }
234
235 let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
236 let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
237 let encoder = match format_desc.format {
238 SinkFormat::AppendOnly => match format_desc.encode {
239 SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
240 schema.clone(),
241 None,
242 DateHandlingMode::FromCe,
243 TimestampHandlingMode::Milli,
244 timestamptz_mode,
245 TimeHandlingMode::Milli,
246 jsonb_handling_mode,
247 )),
248 SinkEncode::Protobuf => {
249 let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
250 &format_desc.options,
251 config.topic.as_deref().unwrap_or(name),
252 None,
253 )
254 .await
255 .map_err(|e| SinkError::Config(anyhow!(e)))?;
256 let header = match sid {
257 None => ProtoHeader::None,
258 Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
259 };
260 RowEncoderWrapper::Proto(ProtoEncoder::new(
261 schema.clone(),
262 None,
263 descriptor,
264 header,
265 )?)
266 }
267 _ => {
268 return Err(SinkError::Config(anyhow!(
269 "mqtt sink encode unsupported: {:?}",
270 format_desc.encode,
271 )));
272 }
273 },
274 _ => {
275 return Err(SinkError::Config(anyhow!(
276 "MQTT sink only supports append-only mode"
277 )));
278 }
279 };
280 let qos = config.common.qos();
281
282 let (client, mut eventloop) = config
283 .common
284 .build_client(actor_id, sink_id.as_raw_id())
285 .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
286
287 let stopped = Arc::new(AtomicBool::new(false));
288 let stopped_clone = stopped.clone();
289 tokio::spawn(async move {
290 while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
291 match eventloop.poll().await {
292 Ok(_) => (),
293 Err(err) => match err {
294 ConnectionError::Timeout(_) => (),
295 ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
296 | ConnectionError::Io(err)
297 if err.kind() == std::io::ErrorKind::ConnectionAborted
298 || err.kind() == std::io::ErrorKind::ConnectionReset =>
299 {
300 continue;
301 }
302 err => {
303 tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
304 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
305 }
306 },
307 }
308 }
309 });
310
311 let payload_writer = MqttSinkPayloadWriter {
312 topic: config.topic.clone(),
313 client,
314 qos,
315 retain: config.retain,
316 topic_index_path,
317 };
318
319 Ok::<_, SinkError>(Self {
320 config: config.clone(),
321 payload_writer,
322 schema: schema.clone(),
323 stopped,
324 encoder,
325 })
326 }
327}
328
329impl AsyncTruncateSinkWriter for MqttSinkWriter {
330 async fn write_chunk<'a>(
331 &'a mut self,
332 chunk: StreamChunk,
333 _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
334 ) -> Result<()> {
335 self.payload_writer.write_chunk(chunk, &self.encoder).await
336 }
337}
338
339impl Drop for MqttSinkWriter {
340 fn drop(&mut self) {
341 self.stopped
342 .store(true, std::sync::atomic::Ordering::Relaxed);
343 }
344}
345
346struct MqttSinkPayloadWriter {
347 client: rumqttc::v5::AsyncClient,
349 topic: Option<String>,
350 qos: QoS,
351 retain: bool,
352 topic_index_path: Vec<usize>,
353}
354
355impl MqttSinkPayloadWriter {
356 async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
357 for (op, row) in chunk.rows() {
358 if op != Op::Insert {
359 continue;
360 }
361
362 let topic = match get_topic_from_index_path(
363 &self.topic_index_path,
364 self.topic.as_deref(),
365 &row,
366 ) {
367 Some(s) => s,
368 None => {
369 tracing::error!("topic field not found in row, skipping: {:?}", row);
370 return Ok(());
371 }
372 };
373
374 let v = encoder.encode(row)?;
375
376 self.client
377 .publish(topic, self.qos, self.retain, v)
378 .await
379 .context("mqtt sink error")
380 .map_err(SinkError::Mqtt)?;
381 }
382
383 Ok(())
384 }
385}
386
387fn get_topic_from_index_path<'s>(
388 path: &[usize],
389 default_topic: Option<&'s str>,
390 row: &'s RowRef<'s>,
391) -> Option<&'s str> {
392 if let Some(topic) = default_topic
393 && path.is_empty()
394 {
395 Some(topic)
396 } else {
397 let mut iter = path.iter();
398 let scalar = iter
399 .next()
400 .and_then(|pos| row.datum_at(*pos))
401 .and_then(|d| {
402 iter.try_fold(d, |d, pos| match d {
403 ScalarRefImpl::Struct(struct_ref) => {
404 struct_ref.iter_fields_ref().nth(*pos).flatten()
405 }
406 _ => None,
407 })
408 });
409 match scalar {
410 Some(ScalarRefImpl::Utf8(s)) => Some(s),
411 _ => {
412 if let Some(topic) = default_topic {
413 Some(topic)
414 } else {
415 None
416 }
417 }
418 }
419 }
420}
421
422fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
426 let mut iter = topic_field.split('.');
427 let mut path = vec![];
428 let dt =
429 iter.next()
430 .and_then(|field| {
431 schema
433 .fields()
434 .iter()
435 .enumerate()
436 .find(|(_, f)| f.name == field)
437 .map(|(pos, f)| {
438 path.push(pos);
439 &f.data_type
440 })
441 })
442 .and_then(|dt| {
443 iter.try_fold(dt, |dt, field| match dt {
445 DataType::Struct(st) => {
446 st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
447 |(pos, (_, dt))| {
448 path.push(pos);
449 dt
450 },
451 )
452 }
453 _ => None,
454 })
455 });
456
457 match dt {
458 Some(DataType::Varchar) => Ok(path),
459 Some(dt) => Err(SinkError::Config(anyhow!(
460 "topic field `{}` must be of type string but got {:?}",
461 topic_field,
462 dt
463 ))),
464 None => Err(SinkError::Config(anyhow!(
465 "topic field `{}` not found",
466 topic_field
467 ))),
468 }
469}
470
471#[cfg(test)]
472mod test {
473 use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
474 use risingwave_common::catalog::{Field, Schema};
475 use risingwave_common::types::{DataType, StructType};
476
477 use super::{get_topic_field_index_path, get_topic_from_index_path};
478
479 #[test]
480 fn test_single_field_extraction() {
481 let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
482 let path = get_topic_field_index_path(&schema, "topic").unwrap();
483 assert_eq!(path, vec![0]);
484
485 let chunk = DataChunk::from_pretty(
486 "T
487 test",
488 );
489
490 let row = RowRef::new(&chunk, 0);
491
492 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
493
494 let result = get_topic_field_index_path(&schema, "other_field");
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_nested_field_extraction() {
500 let schema = Schema::new(vec![Field::with_name(
501 DataType::Struct(StructType::new(vec![
502 ("field", DataType::Int32),
503 ("subtopic", DataType::Varchar),
504 ])),
505 "topic",
506 )]);
507 let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
508 assert_eq!(path, vec![0, 1]);
509
510 let chunk = DataChunk::from_pretty(
511 "<i,T>
512 (1,test)",
513 );
514
515 let row = RowRef::new(&chunk, 0);
516
517 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
518
519 let result = get_topic_field_index_path(&schema, "topic.other_field");
520 assert!(result.is_err());
521 }
522
523 #[test]
524 fn test_null_values_extraction() {
525 let path = vec![0];
526 let chunk = DataChunk::from_pretty(
527 "T
528 .",
529 );
530 let row = RowRef::new(&chunk, 0);
531 assert_eq!(
532 get_topic_from_index_path(&path, Some("default"), &row),
533 Some("default")
534 );
535 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
536
537 let path = vec![0, 1];
538 let chunk = DataChunk::from_pretty(
539 "<i,T>
540 (1,)",
541 );
542 let row = RowRef::new(&chunk, 0);
543 assert_eq!(
544 get_topic_from_index_path(&path, Some("default"), &row),
545 Some("default")
546 );
547 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
548 }
549
550 #[test]
551 fn test_multiple_levels() {
552 let schema = Schema::new(vec![
553 Field::with_name(
554 DataType::Struct(StructType::new(vec![
555 ("field", DataType::Int32),
556 (
557 "subtopic",
558 DataType::Struct(StructType::new(vec![
559 ("int_field", DataType::Int32),
560 ("boolean_field", DataType::Boolean),
561 ("string_field", DataType::Varchar),
562 ])),
563 ),
564 ])),
565 "topic",
566 ),
567 Field::with_name(DataType::Varchar, "other_field"),
568 ]);
569
570 let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
571 assert_eq!(path, vec![0, 1, 2]);
572
573 assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
574
575 assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
576
577 assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
578
579 let path = get_topic_field_index_path(&schema, "other_field").unwrap();
580 assert_eq!(path, vec![1]);
581
582 let chunk = DataChunk::from_pretty(
583 "<i,<T>> T
584 (1,(test)) other",
585 );
586
587 let row = RowRef::new(&chunk, 0);
588
589 assert_eq!(
591 get_topic_from_index_path(&[0, 1, 0], None, &row),
592 Some("test")
593 );
594
595 assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
597
598 assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
600 }
601}