1use core::fmt::Debug;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use anyhow::{Context as _, anyhow};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, ScalarRefImpl};
24use rumqttc::v5::ConnectionError;
25use rumqttc::v5::mqttbytes::QoS;
26use serde_derive::Deserialize;
27use serde_with::serde_as;
28use thiserror_ext::AsReport;
29use with_options::WithOptions;
30
31use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc};
32use super::encoder::{
33 DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
34 TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
35};
36use super::writer::AsyncTruncateSinkWriterExt;
37use super::{DummySinkCommitCoordinator, SinkWriterParam};
38use crate::connector_common::MqttCommon;
39use crate::deserialize_bool_from_string;
40use crate::sink::log_store::DeliveryFutureManagerAddFuture;
41use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
42use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
43
44pub const MQTT_SINK: &str = "mqtt";
45
46#[serde_as]
47#[derive(Clone, Debug, Deserialize, WithOptions)]
48pub struct MqttConfig {
49 #[serde(flatten)]
50 pub common: MqttCommon,
51
52 pub topic: Option<String>,
54
55 #[serde(default, deserialize_with = "deserialize_bool_from_string")]
57 pub retain: bool,
58
59 pub r#type: String,
61
62 #[serde(rename = "topic.field")]
64 pub topic_field: Option<String>,
65}
66
67pub enum RowEncoderWrapper {
68 Json(JsonEncoder),
69 Proto(ProtoEncoder),
70}
71
72impl RowEncoder for RowEncoderWrapper {
73 type Output = Vec<u8>;
74
75 fn encode_cols(
76 &self,
77 row: impl Row,
78 col_indices: impl Iterator<Item = usize>,
79 ) -> Result<Self::Output> {
80 match self {
81 RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
82 RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
83 }
84 }
85
86 fn schema(&self) -> &Schema {
87 match self {
88 RowEncoderWrapper::Json(json) => json.schema(),
89 RowEncoderWrapper::Proto(proto) => proto.schema(),
90 }
91 }
92
93 fn col_indices(&self) -> Option<&[usize]> {
94 match self {
95 RowEncoderWrapper::Json(json) => json.col_indices(),
96 RowEncoderWrapper::Proto(proto) => proto.col_indices(),
97 }
98 }
99
100 fn encode(&self, row: impl Row) -> Result<Self::Output> {
101 match self {
102 RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
103 RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
109pub struct MqttSink {
110 pub config: MqttConfig,
111 schema: Schema,
112 format_desc: SinkFormatDesc,
113 is_append_only: bool,
114 name: String,
115}
116
117pub struct MqttSinkWriter {
119 pub config: MqttConfig,
120 payload_writer: MqttSinkPayloadWriter,
121 #[expect(dead_code)]
122 schema: Schema,
123 encoder: RowEncoderWrapper,
124 stopped: Arc<AtomicBool>,
125}
126
127impl MqttConfig {
129 pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
130 let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
131 .map_err(|e| SinkError::Config(anyhow!(e)))?;
132 if config.r#type != SINK_TYPE_APPEND_ONLY {
133 Err(SinkError::Config(anyhow!(
134 "MQTT sink only supports append-only mode"
135 )))
136 } else {
137 Ok(config)
138 }
139 }
140}
141
142impl TryFrom<SinkParam> for MqttSink {
143 type Error = SinkError;
144
145 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
146 let schema = param.schema();
147 let config = MqttConfig::from_btreemap(param.properties)?;
148 Ok(Self {
149 config,
150 schema,
151 name: param.sink_name,
152 format_desc: param
153 .format_desc
154 .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
155 is_append_only: param.sink_type.is_append_only(),
156 })
157 }
158}
159
160impl Sink for MqttSink {
161 type Coordinator = DummySinkCommitCoordinator;
162 type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
163
164 const SINK_NAME: &'static str = MQTT_SINK;
165
166 async fn validate(&self) -> Result<()> {
167 if !self.is_append_only {
168 return Err(SinkError::Mqtt(anyhow!(
169 "MQTT sink only supports append-only mode"
170 )));
171 }
172
173 if let Some(field) = &self.config.topic_field {
174 let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
175 } else if self.config.topic.is_none() {
176 return Err(SinkError::Config(anyhow!(
177 "either topic or topic.field must be set"
178 )));
179 }
180
181 let _client = (self.config.common.build_client(0, 0))
182 .context("validate mqtt sink error")
183 .map_err(SinkError::Mqtt)?;
184
185 Ok(())
186 }
187
188 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
189 Ok(MqttSinkWriter::new(
190 self.config.clone(),
191 self.schema.clone(),
192 &self.format_desc,
193 &self.name,
194 writer_param.executor_id,
195 )
196 .await?
197 .into_log_sinker(usize::MAX))
198 }
199}
200
201impl MqttSinkWriter {
202 pub async fn new(
203 config: MqttConfig,
204 schema: Schema,
205 format_desc: &SinkFormatDesc,
206 name: &str,
207 id: u64,
208 ) -> Result<Self> {
209 let mut topic_index_path = vec![];
210 if let Some(field) = &config.topic_field {
211 topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
212 }
213
214 let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
215 let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
216 let encoder = match format_desc.format {
217 SinkFormat::AppendOnly => match format_desc.encode {
218 SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
219 schema.clone(),
220 None,
221 DateHandlingMode::FromCe,
222 TimestampHandlingMode::Milli,
223 timestamptz_mode,
224 TimeHandlingMode::Milli,
225 jsonb_handling_mode,
226 )),
227 SinkEncode::Protobuf => {
228 let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
229 &format_desc.options,
230 config.topic.as_deref().unwrap_or(name),
231 None,
232 )
233 .await
234 .map_err(|e| SinkError::Config(anyhow!(e)))?;
235 let header = match sid {
236 None => ProtoHeader::None,
237 Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
238 };
239 RowEncoderWrapper::Proto(ProtoEncoder::new(
240 schema.clone(),
241 None,
242 descriptor,
243 header,
244 )?)
245 }
246 _ => {
247 return Err(SinkError::Config(anyhow!(
248 "mqtt sink encode unsupported: {:?}",
249 format_desc.encode,
250 )));
251 }
252 },
253 _ => {
254 return Err(SinkError::Config(anyhow!(
255 "MQTT sink only supports append-only mode"
256 )));
257 }
258 };
259 let qos = config.common.qos();
260
261 let (client, mut eventloop) = config
262 .common
263 .build_client(0, id)
264 .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
265
266 let stopped = Arc::new(AtomicBool::new(false));
267 let stopped_clone = stopped.clone();
268 tokio::spawn(async move {
269 while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
270 match eventloop.poll().await {
271 Ok(_) => (),
272 Err(err) => match err {
273 ConnectionError::Timeout(_) => (),
274 ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
275 | ConnectionError::Io(err)
276 if err.kind() == std::io::ErrorKind::ConnectionAborted
277 || err.kind() == std::io::ErrorKind::ConnectionReset =>
278 {
279 continue;
280 }
281 err => {
282 tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
283 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
284 }
285 },
286 }
287 }
288 });
289
290 let payload_writer = MqttSinkPayloadWriter {
291 topic: config.topic.clone(),
292 client,
293 qos,
294 retain: config.retain,
295 topic_index_path,
296 };
297
298 Ok::<_, SinkError>(Self {
299 config: config.clone(),
300 payload_writer,
301 schema: schema.clone(),
302 stopped,
303 encoder,
304 })
305 }
306}
307
308impl AsyncTruncateSinkWriter for MqttSinkWriter {
309 async fn write_chunk<'a>(
310 &'a mut self,
311 chunk: StreamChunk,
312 _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
313 ) -> Result<()> {
314 self.payload_writer.write_chunk(chunk, &self.encoder).await
315 }
316}
317
318impl Drop for MqttSinkWriter {
319 fn drop(&mut self) {
320 self.stopped
321 .store(true, std::sync::atomic::Ordering::Relaxed);
322 }
323}
324
325struct MqttSinkPayloadWriter {
326 client: rumqttc::v5::AsyncClient,
328 topic: Option<String>,
329 qos: QoS,
330 retain: bool,
331 topic_index_path: Vec<usize>,
332}
333
334impl MqttSinkPayloadWriter {
335 async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
336 for (op, row) in chunk.rows() {
337 if op != Op::Insert {
338 continue;
339 }
340
341 let topic = match get_topic_from_index_path(
342 &self.topic_index_path,
343 self.topic.as_deref(),
344 &row,
345 ) {
346 Some(s) => s,
347 None => {
348 tracing::error!("topic field not found in row, skipping: {:?}", row);
349 return Ok(());
350 }
351 };
352
353 let v = encoder.encode(row)?;
354
355 self.client
356 .publish(topic, self.qos, self.retain, v)
357 .await
358 .context("mqtt sink error")
359 .map_err(SinkError::Mqtt)?;
360 }
361
362 Ok(())
363 }
364}
365
366fn get_topic_from_index_path<'s>(
367 path: &[usize],
368 default_topic: Option<&'s str>,
369 row: &'s RowRef<'s>,
370) -> Option<&'s str> {
371 if let Some(topic) = default_topic
372 && path.is_empty()
373 {
374 Some(topic)
375 } else {
376 let mut iter = path.iter();
377 let scalar = iter
378 .next()
379 .and_then(|pos| row.datum_at(*pos))
380 .and_then(|d| {
381 iter.try_fold(d, |d, pos| match d {
382 ScalarRefImpl::Struct(struct_ref) => {
383 struct_ref.iter_fields_ref().nth(*pos).flatten()
384 }
385 _ => None,
386 })
387 });
388 match scalar {
389 Some(ScalarRefImpl::Utf8(s)) => Some(s),
390 _ => {
391 if let Some(topic) = default_topic {
392 Some(topic)
393 } else {
394 None
395 }
396 }
397 }
398 }
399}
400
401fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
405 let mut iter = topic_field.split('.');
406 let mut path = vec![];
407 let dt =
408 iter.next()
409 .and_then(|field| {
410 schema
412 .fields()
413 .iter()
414 .enumerate()
415 .find(|(_, f)| f.name == field)
416 .map(|(pos, f)| {
417 path.push(pos);
418 &f.data_type
419 })
420 })
421 .and_then(|dt| {
422 iter.try_fold(dt, |dt, field| match dt {
424 DataType::Struct(st) => {
425 st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
426 |(pos, (_, dt))| {
427 path.push(pos);
428 dt
429 },
430 )
431 }
432 _ => None,
433 })
434 });
435
436 match dt {
437 Some(DataType::Varchar) => Ok(path),
438 Some(dt) => Err(SinkError::Config(anyhow!(
439 "topic field `{}` must be of type string but got {:?}",
440 topic_field,
441 dt
442 ))),
443 None => Err(SinkError::Config(anyhow!(
444 "topic field `{}` not found",
445 topic_field
446 ))),
447 }
448}
449
450#[cfg(test)]
451mod test {
452 use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
453 use risingwave_common::catalog::{Field, Schema};
454 use risingwave_common::types::{DataType, StructType};
455
456 use super::{get_topic_field_index_path, get_topic_from_index_path};
457
458 #[test]
459 fn test_single_field_extraction() {
460 let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
461 let path = get_topic_field_index_path(&schema, "topic").unwrap();
462 assert_eq!(path, vec![0]);
463
464 let chunk = DataChunk::from_pretty(
465 "T
466 test",
467 );
468
469 let row = RowRef::new(&chunk, 0);
470
471 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
472
473 let result = get_topic_field_index_path(&schema, "other_field");
474 assert!(result.is_err());
475 }
476
477 #[test]
478 fn test_nested_field_extraction() {
479 let schema = Schema::new(vec![Field::with_name(
480 DataType::Struct(StructType::new(vec![
481 ("field", DataType::Int32),
482 ("subtopic", DataType::Varchar),
483 ])),
484 "topic",
485 )]);
486 let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
487 assert_eq!(path, vec![0, 1]);
488
489 let chunk = DataChunk::from_pretty(
490 "<i,T>
491 (1,test)",
492 );
493
494 let row = RowRef::new(&chunk, 0);
495
496 assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
497
498 let result = get_topic_field_index_path(&schema, "topic.other_field");
499 assert!(result.is_err());
500 }
501
502 #[test]
503 fn test_null_values_extraction() {
504 let path = vec![0];
505 let chunk = DataChunk::from_pretty(
506 "T
507 .",
508 );
509 let row = RowRef::new(&chunk, 0);
510 assert_eq!(
511 get_topic_from_index_path(&path, Some("default"), &row),
512 Some("default")
513 );
514 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
515
516 let path = vec![0, 1];
517 let chunk = DataChunk::from_pretty(
518 "<i,T>
519 (1,)",
520 );
521 let row = RowRef::new(&chunk, 0);
522 assert_eq!(
523 get_topic_from_index_path(&path, Some("default"), &row),
524 Some("default")
525 );
526 assert_eq!(get_topic_from_index_path(&path, None, &row), None);
527 }
528
529 #[test]
530 fn test_multiple_levels() {
531 let schema = Schema::new(vec![
532 Field::with_name(
533 DataType::Struct(StructType::new(vec![
534 ("field", DataType::Int32),
535 (
536 "subtopic",
537 DataType::Struct(StructType::new(vec![
538 ("int_field", DataType::Int32),
539 ("boolean_field", DataType::Boolean),
540 ("string_field", DataType::Varchar),
541 ])),
542 ),
543 ])),
544 "topic",
545 ),
546 Field::with_name(DataType::Varchar, "other_field"),
547 ]);
548
549 let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
550 assert_eq!(path, vec![0, 1, 2]);
551
552 assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
553
554 assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
555
556 assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
557
558 let path = get_topic_field_index_path(&schema, "other_field").unwrap();
559 assert_eq!(path, vec![1]);
560
561 let chunk = DataChunk::from_pretty(
562 "<i,<T>> T
563 (1,(test)) other",
564 );
565
566 let row = RowRef::new(&chunk, 0);
567
568 assert_eq!(
570 get_topic_from_index_path(&[0, 1, 0], None, &row),
571 Some("test")
572 );
573
574 assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
576
577 assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
579 }
580}