1use core::fmt::Debug;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use anyhow::{Context as _, anyhow};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, ScalarRefImpl};
24use rumqttc::v5::ConnectionError;
25use rumqttc::v5::mqttbytes::QoS;
26use serde_derive::Deserialize;
27use serde_with::serde_as;
28use thiserror_ext::AsReport;
29use with_options::WithOptions;
30
31use super::SinkWriterParam;
32use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc};
33use super::encoder::{
34 DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
35 TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
36};
37use super::writer::AsyncTruncateSinkWriterExt;
38use crate::connector_common::MqttCommon;
39use crate::deserialize_bool_from_string;
40use crate::enforce_secret::EnforceSecret;
41use crate::sink::log_store::DeliveryFutureManagerAddFuture;
42use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
43use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
44
45pub const MQTT_SINK: &str = "mqtt";
46
47#[serde_as]
48#[derive(Clone, Debug, Deserialize, WithOptions)]
49pub struct MqttConfig {
50 #[serde(flatten)]
51 pub common: MqttCommon,
52
53 pub topic: Option<String>,
55
56 #[serde(default, deserialize_with = "deserialize_bool_from_string")]
58 pub retain: bool,
59
60 pub r#type: String,
62
63 #[serde(rename = "topic.field")]
65 pub topic_field: Option<String>,
66}
67
68impl EnforceSecret for MqttConfig {
69 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
70 MqttCommon::enforce_one(prop)
71 }
72}
73
74pub enum RowEncoderWrapper {
75 Json(JsonEncoder),
76 Proto(ProtoEncoder),
77}
78
79impl RowEncoder for RowEncoderWrapper {
80 type Output = Vec<u8>;
81
82 fn encode_cols(
83 &self,
84 row: impl Row,
85 col_indices: impl Iterator<Item = usize>,
86 ) -> Result<Self::Output> {
87 match self {
88 RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
89 RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
90 }
91 }
92
93 fn schema(&self) -> &Schema {
94 match self {
95 RowEncoderWrapper::Json(json) => json.schema(),
96 RowEncoderWrapper::Proto(proto) => proto.schema(),
97 }
98 }
99
100 fn col_indices(&self) -> Option<&[usize]> {
101 match self {
102 RowEncoderWrapper::Json(json) => json.col_indices(),
103 RowEncoderWrapper::Proto(proto) => proto.col_indices(),
104 }
105 }
106
107 fn encode(&self, row: impl Row) -> Result<Self::Output> {
108 match self {
109 RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
110 RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
111 }
112 }
113}
114
115#[derive(Clone, Debug)]
116pub struct MqttSink {
117 pub config: MqttConfig,
118 schema: Schema,
119 format_desc: SinkFormatDesc,
120 is_append_only: bool,
121 name: String,
122}
123
124impl EnforceSecret for MqttSink {
125 fn enforce_secret<'a>(
126 prop_iter: impl Iterator<Item = &'a str>,
127 ) -> crate::error::ConnectorResult<()> {
128 for prop in prop_iter {
129 MqttConfig::enforce_one(prop)?;
130 }
131 Ok(())
132 }
133}
134
135pub struct MqttSinkWriter {
137 pub config: MqttConfig,
138 payload_writer: MqttSinkPayloadWriter,
139 #[expect(dead_code)]
140 schema: Schema,
141 encoder: RowEncoderWrapper,
142 stopped: Arc<AtomicBool>,
143}
144
145impl MqttConfig {
147 pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
148 let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
149 .map_err(|e| SinkError::Config(anyhow!(e)))?;
150 if config.r#type != SINK_TYPE_APPEND_ONLY {
151 Err(SinkError::Config(anyhow!(
152 "MQTT sink only supports append-only mode"
153 )))
154 } else {
155 Ok(config)
156 }
157 }
158}
159
160impl TryFrom<SinkParam> for MqttSink {
161 type Error = SinkError;
162
163 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
164 let schema = param.schema();
165 let config = MqttConfig::from_btreemap(param.properties)?;
166 Ok(Self {
167 config,
168 schema,
169 name: param.sink_name,
170 format_desc: param
171 .format_desc
172 .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
173 is_append_only: param.sink_type.is_append_only(),
174 })
175 }
176}
177
178impl Sink for MqttSink {
179 type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
180
181 const SINK_NAME: &'static str = MQTT_SINK;
182
183 async fn validate(&self) -> Result<()> {
184 if !self.is_append_only {
185 return Err(SinkError::Mqtt(anyhow!(
186 "MQTT sink only supports append-only mode"
187 )));
188 }
189
190 if let Some(field) = &self.config.topic_field {
191 let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
192 } else if self.config.topic.is_none() {
193 return Err(SinkError::Config(anyhow!(
194 "either topic or topic.field must be set"
195 )));
196 }
197
198 let _client = (self.config.common.build_client(0, 0))
199 .context("validate mqtt sink error")
200 .map_err(SinkError::Mqtt)?;
201
202 Ok(())
203 }
204
205 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
206 Ok(MqttSinkWriter::new(
207 self.config.clone(),
208 self.schema.clone(),
209 &self.format_desc,
210 &self.name,
211 writer_param.executor_id,
212 )
213 .await?
214 .into_log_sinker(usize::MAX))
215 }
216}
217
218impl MqttSinkWriter {
219 pub async fn new(
220 config: MqttConfig,
221 schema: Schema,
222 format_desc: &SinkFormatDesc,
223 name: &str,
224 id: u64,
225 ) -> Result<Self> {
226 let mut topic_index_path = vec![];
227 if let Some(field) = &config.topic_field {
228 topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
229 }
230
231 let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
232 let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
233 let encoder = match format_desc.format {
234 SinkFormat::AppendOnly => match format_desc.encode {
235 SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
236 schema.clone(),
237 None,
238 DateHandlingMode::FromCe,
239 TimestampHandlingMode::Milli,
240 timestamptz_mode,
241 TimeHandlingMode::Milli,
242 jsonb_handling_mode,
243 )),
244 SinkEncode::Protobuf => {
245 let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
246 &format_desc.options,
247 config.topic.as_deref().unwrap_or(name),
248 None,
249 )
250 .await
251 .map_err(|e| SinkError::Config(anyhow!(e)))?;
252 let header = match sid {
253 None => ProtoHeader::None,
254 Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
255 };
256 RowEncoderWrapper::Proto(ProtoEncoder::new(
257 schema.clone(),
258 None,
259 descriptor,
260 header,
261 )?)
262 }
263 _ => {
264 return Err(SinkError::Config(anyhow!(
265 "mqtt sink encode unsupported: {:?}",
266 format_desc.encode,
267 )));
268 }
269 },
270 _ => {
271 return Err(SinkError::Config(anyhow!(
272 "MQTT sink only supports append-only mode"
273 )));
274 }
275 };
276 let qos = config.common.qos();
277
278 let (client, mut eventloop) = config
279 .common
280 .build_client(0, id)
281 .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
282
283 let stopped = Arc::new(AtomicBool::new(false));
284 let stopped_clone = stopped.clone();
285 tokio::spawn(async move {
286 while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
287 match eventloop.poll().await {
288 Ok(_) => (),
289 Err(err) => match err {
290 ConnectionError::Timeout(_) => (),
291 ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
292 | ConnectionError::Io(err)
293 if err.kind() == std::io::ErrorKind::ConnectionAborted
294 || err.kind() == std::io::ErrorKind::ConnectionReset =>
295 {
296 continue;
297 }
298 err => {
299 tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
300 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
301 }
302 },
303 }
304 }
305 });
306
307 let payload_writer = MqttSinkPayloadWriter {
308 topic: config.topic.clone(),
309 client,
310 qos,
311 retain: config.retain,
312 topic_index_path,
313 };
314
315 Ok::<_, SinkError>(Self {
316 config: config.clone(),
317 payload_writer,
318 schema: schema.clone(),
319 stopped,
320 encoder,
321 })
322 }
323}
324
325impl AsyncTruncateSinkWriter for MqttSinkWriter {
326 async fn write_chunk<'a>(
327 &'a mut self,
328 chunk: StreamChunk,
329 _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
330 ) -> Result<()> {
331 self.payload_writer.write_chunk(chunk, &self.encoder).await
332 }
333}
334
335impl Drop for MqttSinkWriter {
336 fn drop(&mut self) {
337 self.stopped
338 .store(true, std::sync::atomic::Ordering::Relaxed);
339 }
340}
341
342struct MqttSinkPayloadWriter {
343 client: rumqttc::v5::AsyncClient,
345 topic: Option<String>,
346 qos: QoS,
347 retain: bool,
348 topic_index_path: Vec<usize>,
349}
350
351impl MqttSinkPayloadWriter {
352 async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
353 for (op, row) in chunk.rows() {
354 if op != Op::Insert {
355 continue;
356 }
357
358 let topic = match get_topic_from_index_path(
359 &self.topic_index_path,
360 self.topic.as_deref(),
361 &row,
362 ) {
363 Some(s) => s,
364 None => {
365 tracing::error!("topic field not found in row, skipping: {:?}", row);
366 return Ok(());
367 }
368 };
369
370 let v = encoder.encode(row)?;
371
372 self.client
373 .publish(topic, self.qos, self.retain, v)
374 .await
375 .context("mqtt sink error")
376 .map_err(SinkError::Mqtt)?;
377 }
378
379 Ok(())
380 }
381}
382
383fn get_topic_from_index_path<'s>(
384 path: &[usize],
385 default_topic: Option<&'s str>,
386 row: &'s RowRef<'s>,
387) -> Option<&'s str> {
388 if let Some(topic) = default_topic
389 && path.is_empty()
390 {
391 Some(topic)
392 } else {
393 let mut iter = path.iter();
394 let scalar = iter
395 .next()
396 .and_then(|pos| row.datum_at(*pos))
397 .and_then(|d| {
398 iter.try_fold(d, |d, pos| match d {
399 ScalarRefImpl::Struct(struct_ref) => {
400 struct_ref.iter_fields_ref().nth(*pos).flatten()
401 }
402 _ => None,
403 })
404 });
405 match scalar {
406 Some(ScalarRefImpl::Utf8(s)) => Some(s),
407 _ => {
408 if let Some(topic) = default_topic {
409 Some(topic)
410 } else {
411 None
412 }
413 }
414 }
415 }
416}
417
418fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
422 let mut iter = topic_field.split('.');
423 let mut path = vec![];
424 let dt =
425 iter.next()
426 .and_then(|field| {
427 schema
429 .fields()
430 .iter()
431 .enumerate()
432 .find(|(_, f)| f.name == field)
433 .map(|(pos, f)| {
434 path.push(pos);
435 &f.data_type
436 })
437 })
438 .and_then(|dt| {
439 iter.try_fold(dt, |dt, field| match dt {
441 DataType::Struct(st) => {
442 st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
443 |(pos, (_, dt))| {
444 path.push(pos);
445 dt
446 },
447 )
448 }
449 _ => None,
450 })
451 });
452
453 match dt {
454 Some(DataType::Varchar) => Ok(path),
455 Some(dt) => Err(SinkError::Config(anyhow!(
456 "topic field `{}` must be of type string but got {:?}",
457 topic_field,
458 dt
459 ))),
460 None => Err(SinkError::Config(anyhow!(
461 "topic field `{}` not found",
462 topic_field
463 ))),
464 }
465}
466
467#[cfg(test)]
468mod test {
469 use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
470 use risingwave_common::catalog::{Field, Schema};
471 use risingwave_common::types::{DataType, StructType};
472
473 use super::{get_topic_field_index_path, get_topic_from_index_path};
474
475 #[test]
476 fn test_single_field_extraction() {
477 let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
478 let path = get_topic_field_index_path(&schema, "topic").unwrap();
479 assert_eq!(path, vec![0]);
480
481 let chunk = DataChunk::from_pretty(
482 "T
483 test",
484 );
485
486 let row = RowRef::new(&chunk, 0);
487
488 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
489
490 let result = get_topic_field_index_path(&schema, "other_field");
491 assert!(result.is_err());
492 }
493
494 #[test]
495 fn test_nested_field_extraction() {
496 let schema = Schema::new(vec![Field::with_name(
497 DataType::Struct(StructType::new(vec![
498 ("field", DataType::Int32),
499 ("subtopic", DataType::Varchar),
500 ])),
501 "topic",
502 )]);
503 let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
504 assert_eq!(path, vec![0, 1]);
505
506 let chunk = DataChunk::from_pretty(
507 "<i,T>
508 (1,test)",
509 );
510
511 let row = RowRef::new(&chunk, 0);
512
513 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
514
515 let result = get_topic_field_index_path(&schema, "topic.other_field");
516 assert!(result.is_err());
517 }
518
519 #[test]
520 fn test_null_values_extraction() {
521 let path = vec![0];
522 let chunk = DataChunk::from_pretty(
523 "T
524 .",
525 );
526 let row = RowRef::new(&chunk, 0);
527 assert_eq!(
528 get_topic_from_index_path(&path, Some("default"), &row),
529 Some("default")
530 );
531 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
532
533 let path = vec![0, 1];
534 let chunk = DataChunk::from_pretty(
535 "<i,T>
536 (1,)",
537 );
538 let row = RowRef::new(&chunk, 0);
539 assert_eq!(
540 get_topic_from_index_path(&path, Some("default"), &row),
541 Some("default")
542 );
543 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
544 }
545
546 #[test]
547 fn test_multiple_levels() {
548 let schema = Schema::new(vec![
549 Field::with_name(
550 DataType::Struct(StructType::new(vec![
551 ("field", DataType::Int32),
552 (
553 "subtopic",
554 DataType::Struct(StructType::new(vec![
555 ("int_field", DataType::Int32),
556 ("boolean_field", DataType::Boolean),
557 ("string_field", DataType::Varchar),
558 ])),
559 ),
560 ])),
561 "topic",
562 ),
563 Field::with_name(DataType::Varchar, "other_field"),
564 ]);
565
566 let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
567 assert_eq!(path, vec![0, 1, 2]);
568
569 assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
570
571 assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
572
573 assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
574
575 let path = get_topic_field_index_path(&schema, "other_field").unwrap();
576 assert_eq!(path, vec![1]);
577
578 let chunk = DataChunk::from_pretty(
579 "<i,<T>> T
580 (1,(test)) other",
581 );
582
583 let row = RowRef::new(&chunk, 0);
584
585 assert_eq!(
587 get_topic_from_index_path(&[0, 1, 0], None, &row),
588 Some("test")
589 );
590
591 assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
593
594 assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
596 }
597}