1use core::fmt::Debug;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use anyhow::{Context as _, anyhow};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, ScalarRefImpl};
24use rumqttc::v5::ConnectionError;
25use rumqttc::v5::mqttbytes::QoS;
26use serde_derive::Deserialize;
27use serde_with::serde_as;
28use thiserror_ext::AsReport;
29use with_options::WithOptions;
30
31use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc};
32use super::encoder::{
33 DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
34 TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
35};
36use super::writer::AsyncTruncateSinkWriterExt;
37use super::{DummySinkCommitCoordinator, SinkWriterParam};
38use crate::connector_common::MqttCommon;
39use crate::deserialize_bool_from_string;
40use crate::enforce_secret::EnforceSecret;
41use crate::sink::log_store::DeliveryFutureManagerAddFuture;
42use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
43use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
44
45pub const MQTT_SINK: &str = "mqtt";
46
47#[serde_as]
48#[derive(Clone, Debug, Deserialize, WithOptions)]
49pub struct MqttConfig {
50 #[serde(flatten)]
51 pub common: MqttCommon,
52
53 pub topic: Option<String>,
55
56 #[serde(default, deserialize_with = "deserialize_bool_from_string")]
58 pub retain: bool,
59
60 pub r#type: String,
62
63 #[serde(rename = "topic.field")]
65 pub topic_field: Option<String>,
66}
67
68impl EnforceSecret for MqttConfig {
69 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
70 MqttCommon::enforce_one(prop)
71 }
72}
73
74pub enum RowEncoderWrapper {
75 Json(JsonEncoder),
76 Proto(ProtoEncoder),
77}
78
79impl RowEncoder for RowEncoderWrapper {
80 type Output = Vec<u8>;
81
82 fn encode_cols(
83 &self,
84 row: impl Row,
85 col_indices: impl Iterator<Item = usize>,
86 ) -> Result<Self::Output> {
87 match self {
88 RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
89 RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
90 }
91 }
92
93 fn schema(&self) -> &Schema {
94 match self {
95 RowEncoderWrapper::Json(json) => json.schema(),
96 RowEncoderWrapper::Proto(proto) => proto.schema(),
97 }
98 }
99
100 fn col_indices(&self) -> Option<&[usize]> {
101 match self {
102 RowEncoderWrapper::Json(json) => json.col_indices(),
103 RowEncoderWrapper::Proto(proto) => proto.col_indices(),
104 }
105 }
106
107 fn encode(&self, row: impl Row) -> Result<Self::Output> {
108 match self {
109 RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
110 RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
111 }
112 }
113}
114
115#[derive(Clone, Debug)]
116pub struct MqttSink {
117 pub config: MqttConfig,
118 schema: Schema,
119 format_desc: SinkFormatDesc,
120 is_append_only: bool,
121 name: String,
122}
123
124impl EnforceSecret for MqttSink {
125 fn enforce_secret<'a>(
126 prop_iter: impl Iterator<Item = &'a str>,
127 ) -> crate::error::ConnectorResult<()> {
128 for prop in prop_iter {
129 MqttConfig::enforce_one(prop)?;
130 }
131 Ok(())
132 }
133}
134
135pub struct MqttSinkWriter {
137 pub config: MqttConfig,
138 payload_writer: MqttSinkPayloadWriter,
139 #[expect(dead_code)]
140 schema: Schema,
141 encoder: RowEncoderWrapper,
142 stopped: Arc<AtomicBool>,
143}
144
145impl MqttConfig {
147 pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
148 let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
149 .map_err(|e| SinkError::Config(anyhow!(e)))?;
150 if config.r#type != SINK_TYPE_APPEND_ONLY {
151 Err(SinkError::Config(anyhow!(
152 "MQTT sink only supports append-only mode"
153 )))
154 } else {
155 Ok(config)
156 }
157 }
158}
159
160impl TryFrom<SinkParam> for MqttSink {
161 type Error = SinkError;
162
163 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
164 let schema = param.schema();
165 let config = MqttConfig::from_btreemap(param.properties)?;
166 Ok(Self {
167 config,
168 schema,
169 name: param.sink_name,
170 format_desc: param
171 .format_desc
172 .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
173 is_append_only: param.sink_type.is_append_only(),
174 })
175 }
176}
177
178impl Sink for MqttSink {
179 type Coordinator = DummySinkCommitCoordinator;
180 type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
181
182 const SINK_NAME: &'static str = MQTT_SINK;
183
184 async fn validate(&self) -> Result<()> {
185 if !self.is_append_only {
186 return Err(SinkError::Mqtt(anyhow!(
187 "MQTT sink only supports append-only mode"
188 )));
189 }
190
191 if let Some(field) = &self.config.topic_field {
192 let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
193 } else if self.config.topic.is_none() {
194 return Err(SinkError::Config(anyhow!(
195 "either topic or topic.field must be set"
196 )));
197 }
198
199 let _client = (self.config.common.build_client(0, 0))
200 .context("validate mqtt sink error")
201 .map_err(SinkError::Mqtt)?;
202
203 Ok(())
204 }
205
206 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
207 Ok(MqttSinkWriter::new(
208 self.config.clone(),
209 self.schema.clone(),
210 &self.format_desc,
211 &self.name,
212 writer_param.executor_id,
213 )
214 .await?
215 .into_log_sinker(usize::MAX))
216 }
217}
218
219impl MqttSinkWriter {
220 pub async fn new(
221 config: MqttConfig,
222 schema: Schema,
223 format_desc: &SinkFormatDesc,
224 name: &str,
225 id: u64,
226 ) -> Result<Self> {
227 let mut topic_index_path = vec![];
228 if let Some(field) = &config.topic_field {
229 topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
230 }
231
232 let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
233 let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
234 let encoder = match format_desc.format {
235 SinkFormat::AppendOnly => match format_desc.encode {
236 SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
237 schema.clone(),
238 None,
239 DateHandlingMode::FromCe,
240 TimestampHandlingMode::Milli,
241 timestamptz_mode,
242 TimeHandlingMode::Milli,
243 jsonb_handling_mode,
244 )),
245 SinkEncode::Protobuf => {
246 let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
247 &format_desc.options,
248 config.topic.as_deref().unwrap_or(name),
249 None,
250 )
251 .await
252 .map_err(|e| SinkError::Config(anyhow!(e)))?;
253 let header = match sid {
254 None => ProtoHeader::None,
255 Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
256 };
257 RowEncoderWrapper::Proto(ProtoEncoder::new(
258 schema.clone(),
259 None,
260 descriptor,
261 header,
262 )?)
263 }
264 _ => {
265 return Err(SinkError::Config(anyhow!(
266 "mqtt sink encode unsupported: {:?}",
267 format_desc.encode,
268 )));
269 }
270 },
271 _ => {
272 return Err(SinkError::Config(anyhow!(
273 "MQTT sink only supports append-only mode"
274 )));
275 }
276 };
277 let qos = config.common.qos();
278
279 let (client, mut eventloop) = config
280 .common
281 .build_client(0, id)
282 .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
283
284 let stopped = Arc::new(AtomicBool::new(false));
285 let stopped_clone = stopped.clone();
286 tokio::spawn(async move {
287 while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
288 match eventloop.poll().await {
289 Ok(_) => (),
290 Err(err) => match err {
291 ConnectionError::Timeout(_) => (),
292 ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
293 | ConnectionError::Io(err)
294 if err.kind() == std::io::ErrorKind::ConnectionAborted
295 || err.kind() == std::io::ErrorKind::ConnectionReset =>
296 {
297 continue;
298 }
299 err => {
300 tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
301 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
302 }
303 },
304 }
305 }
306 });
307
308 let payload_writer = MqttSinkPayloadWriter {
309 topic: config.topic.clone(),
310 client,
311 qos,
312 retain: config.retain,
313 topic_index_path,
314 };
315
316 Ok::<_, SinkError>(Self {
317 config: config.clone(),
318 payload_writer,
319 schema: schema.clone(),
320 stopped,
321 encoder,
322 })
323 }
324}
325
326impl AsyncTruncateSinkWriter for MqttSinkWriter {
327 async fn write_chunk<'a>(
328 &'a mut self,
329 chunk: StreamChunk,
330 _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
331 ) -> Result<()> {
332 self.payload_writer.write_chunk(chunk, &self.encoder).await
333 }
334}
335
336impl Drop for MqttSinkWriter {
337 fn drop(&mut self) {
338 self.stopped
339 .store(true, std::sync::atomic::Ordering::Relaxed);
340 }
341}
342
343struct MqttSinkPayloadWriter {
344 client: rumqttc::v5::AsyncClient,
346 topic: Option<String>,
347 qos: QoS,
348 retain: bool,
349 topic_index_path: Vec<usize>,
350}
351
352impl MqttSinkPayloadWriter {
353 async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
354 for (op, row) in chunk.rows() {
355 if op != Op::Insert {
356 continue;
357 }
358
359 let topic = match get_topic_from_index_path(
360 &self.topic_index_path,
361 self.topic.as_deref(),
362 &row,
363 ) {
364 Some(s) => s,
365 None => {
366 tracing::error!("topic field not found in row, skipping: {:?}", row);
367 return Ok(());
368 }
369 };
370
371 let v = encoder.encode(row)?;
372
373 self.client
374 .publish(topic, self.qos, self.retain, v)
375 .await
376 .context("mqtt sink error")
377 .map_err(SinkError::Mqtt)?;
378 }
379
380 Ok(())
381 }
382}
383
384fn get_topic_from_index_path<'s>(
385 path: &[usize],
386 default_topic: Option<&'s str>,
387 row: &'s RowRef<'s>,
388) -> Option<&'s str> {
389 if let Some(topic) = default_topic
390 && path.is_empty()
391 {
392 Some(topic)
393 } else {
394 let mut iter = path.iter();
395 let scalar = iter
396 .next()
397 .and_then(|pos| row.datum_at(*pos))
398 .and_then(|d| {
399 iter.try_fold(d, |d, pos| match d {
400 ScalarRefImpl::Struct(struct_ref) => {
401 struct_ref.iter_fields_ref().nth(*pos).flatten()
402 }
403 _ => None,
404 })
405 });
406 match scalar {
407 Some(ScalarRefImpl::Utf8(s)) => Some(s),
408 _ => {
409 if let Some(topic) = default_topic {
410 Some(topic)
411 } else {
412 None
413 }
414 }
415 }
416 }
417}
418
419fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
423 let mut iter = topic_field.split('.');
424 let mut path = vec![];
425 let dt =
426 iter.next()
427 .and_then(|field| {
428 schema
430 .fields()
431 .iter()
432 .enumerate()
433 .find(|(_, f)| f.name == field)
434 .map(|(pos, f)| {
435 path.push(pos);
436 &f.data_type
437 })
438 })
439 .and_then(|dt| {
440 iter.try_fold(dt, |dt, field| match dt {
442 DataType::Struct(st) => {
443 st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
444 |(pos, (_, dt))| {
445 path.push(pos);
446 dt
447 },
448 )
449 }
450 _ => None,
451 })
452 });
453
454 match dt {
455 Some(DataType::Varchar) => Ok(path),
456 Some(dt) => Err(SinkError::Config(anyhow!(
457 "topic field `{}` must be of type string but got {:?}",
458 topic_field,
459 dt
460 ))),
461 None => Err(SinkError::Config(anyhow!(
462 "topic field `{}` not found",
463 topic_field
464 ))),
465 }
466}
467
468#[cfg(test)]
469mod test {
470 use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
471 use risingwave_common::catalog::{Field, Schema};
472 use risingwave_common::types::{DataType, StructType};
473
474 use super::{get_topic_field_index_path, get_topic_from_index_path};
475
476 #[test]
477 fn test_single_field_extraction() {
478 let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
479 let path = get_topic_field_index_path(&schema, "topic").unwrap();
480 assert_eq!(path, vec![0]);
481
482 let chunk = DataChunk::from_pretty(
483 "T
484 test",
485 );
486
487 let row = RowRef::new(&chunk, 0);
488
489 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
490
491 let result = get_topic_field_index_path(&schema, "other_field");
492 assert!(result.is_err());
493 }
494
495 #[test]
496 fn test_nested_field_extraction() {
497 let schema = Schema::new(vec![Field::with_name(
498 DataType::Struct(StructType::new(vec![
499 ("field", DataType::Int32),
500 ("subtopic", DataType::Varchar),
501 ])),
502 "topic",
503 )]);
504 let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
505 assert_eq!(path, vec![0, 1]);
506
507 let chunk = DataChunk::from_pretty(
508 "<i,T>
509 (1,test)",
510 );
511
512 let row = RowRef::new(&chunk, 0);
513
514 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
515
516 let result = get_topic_field_index_path(&schema, "topic.other_field");
517 assert!(result.is_err());
518 }
519
520 #[test]
521 fn test_null_values_extraction() {
522 let path = vec![0];
523 let chunk = DataChunk::from_pretty(
524 "T
525 .",
526 );
527 let row = RowRef::new(&chunk, 0);
528 assert_eq!(
529 get_topic_from_index_path(&path, Some("default"), &row),
530 Some("default")
531 );
532 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
533
534 let path = vec![0, 1];
535 let chunk = DataChunk::from_pretty(
536 "<i,T>
537 (1,)",
538 );
539 let row = RowRef::new(&chunk, 0);
540 assert_eq!(
541 get_topic_from_index_path(&path, Some("default"), &row),
542 Some("default")
543 );
544 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
545 }
546
547 #[test]
548 fn test_multiple_levels() {
549 let schema = Schema::new(vec![
550 Field::with_name(
551 DataType::Struct(StructType::new(vec![
552 ("field", DataType::Int32),
553 (
554 "subtopic",
555 DataType::Struct(StructType::new(vec![
556 ("int_field", DataType::Int32),
557 ("boolean_field", DataType::Boolean),
558 ("string_field", DataType::Varchar),
559 ])),
560 ),
561 ])),
562 "topic",
563 ),
564 Field::with_name(DataType::Varchar, "other_field"),
565 ]);
566
567 let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
568 assert_eq!(path, vec![0, 1, 2]);
569
570 assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
571
572 assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
573
574 assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
575
576 let path = get_topic_field_index_path(&schema, "other_field").unwrap();
577 assert_eq!(path, vec![1]);
578
579 let chunk = DataChunk::from_pretty(
580 "<i,<T>> T
581 (1,(test)) other",
582 );
583
584 let row = RowRef::new(&chunk, 0);
585
586 assert_eq!(
588 get_topic_from_index_path(&[0, 1, 0], None, &row),
589 Some("test")
590 );
591
592 assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
594
595 assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
597 }
598}