risingwave_connector/sink/
mqtt.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use core::fmt::Debug;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use anyhow::{Context as _, anyhow};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, ScalarRefImpl};
24use rumqttc::v5::ConnectionError;
25use rumqttc::v5::mqttbytes::QoS;
26use serde_derive::Deserialize;
27use serde_with::serde_as;
28use thiserror_ext::AsReport;
29use with_options::WithOptions;
30
31use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc};
32use super::encoder::{
33    DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
34    TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
35};
36use super::writer::AsyncTruncateSinkWriterExt;
37use super::{DummySinkCommitCoordinator, SinkWriterParam};
38use crate::connector_common::MqttCommon;
39use crate::deserialize_bool_from_string;
40use crate::sink::log_store::DeliveryFutureManagerAddFuture;
41use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
42use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
43
44pub const MQTT_SINK: &str = "mqtt";
45
46#[serde_as]
47#[derive(Clone, Debug, Deserialize, WithOptions)]
48pub struct MqttConfig {
49    #[serde(flatten)]
50    pub common: MqttCommon,
51
52    /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/#
53    pub topic: Option<String>,
54
55    /// Whether the message should be retained by the broker
56    #[serde(default, deserialize_with = "deserialize_bool_from_string")]
57    pub retain: bool,
58
59    // accept "append-only"
60    pub r#type: String,
61
62    // if set, will use a field value as the topic name, if topic is also set it will be used as a fallback
63    #[serde(rename = "topic.field")]
64    pub topic_field: Option<String>,
65}
66
67pub enum RowEncoderWrapper {
68    Json(JsonEncoder),
69    Proto(ProtoEncoder),
70}
71
72impl RowEncoder for RowEncoderWrapper {
73    type Output = Vec<u8>;
74
75    fn encode_cols(
76        &self,
77        row: impl Row,
78        col_indices: impl Iterator<Item = usize>,
79    ) -> Result<Self::Output> {
80        match self {
81            RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
82            RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
83        }
84    }
85
86    fn schema(&self) -> &Schema {
87        match self {
88            RowEncoderWrapper::Json(json) => json.schema(),
89            RowEncoderWrapper::Proto(proto) => proto.schema(),
90        }
91    }
92
93    fn col_indices(&self) -> Option<&[usize]> {
94        match self {
95            RowEncoderWrapper::Json(json) => json.col_indices(),
96            RowEncoderWrapper::Proto(proto) => proto.col_indices(),
97        }
98    }
99
100    fn encode(&self, row: impl Row) -> Result<Self::Output> {
101        match self {
102            RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
103            RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
104        }
105    }
106}
107
108#[derive(Clone, Debug)]
109pub struct MqttSink {
110    pub config: MqttConfig,
111    schema: Schema,
112    format_desc: SinkFormatDesc,
113    is_append_only: bool,
114    name: String,
115}
116
117// sink write
118pub struct MqttSinkWriter {
119    pub config: MqttConfig,
120    payload_writer: MqttSinkPayloadWriter,
121    #[expect(dead_code)]
122    schema: Schema,
123    encoder: RowEncoderWrapper,
124    stopped: Arc<AtomicBool>,
125}
126
127/// Basic data types for use with the mqtt interface
128impl MqttConfig {
129    pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
130        let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
131            .map_err(|e| SinkError::Config(anyhow!(e)))?;
132        if config.r#type != SINK_TYPE_APPEND_ONLY {
133            Err(SinkError::Config(anyhow!(
134                "MQTT sink only supports append-only mode"
135            )))
136        } else {
137            Ok(config)
138        }
139    }
140}
141
142impl TryFrom<SinkParam> for MqttSink {
143    type Error = SinkError;
144
145    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
146        let schema = param.schema();
147        let config = MqttConfig::from_btreemap(param.properties)?;
148        Ok(Self {
149            config,
150            schema,
151            name: param.sink_name,
152            format_desc: param
153                .format_desc
154                .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
155            is_append_only: param.sink_type.is_append_only(),
156        })
157    }
158}
159
160impl Sink for MqttSink {
161    type Coordinator = DummySinkCommitCoordinator;
162    type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
163
164    const SINK_NAME: &'static str = MQTT_SINK;
165
166    async fn validate(&self) -> Result<()> {
167        if !self.is_append_only {
168            return Err(SinkError::Mqtt(anyhow!(
169                "MQTT sink only supports append-only mode"
170            )));
171        }
172
173        if let Some(field) = &self.config.topic_field {
174            let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
175        } else if self.config.topic.is_none() {
176            return Err(SinkError::Config(anyhow!(
177                "either topic or topic.field must be set"
178            )));
179        }
180
181        let _client = (self.config.common.build_client(0, 0))
182            .context("validate mqtt sink error")
183            .map_err(SinkError::Mqtt)?;
184
185        Ok(())
186    }
187
188    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
189        Ok(MqttSinkWriter::new(
190            self.config.clone(),
191            self.schema.clone(),
192            &self.format_desc,
193            &self.name,
194            writer_param.executor_id,
195        )
196        .await?
197        .into_log_sinker(usize::MAX))
198    }
199}
200
201impl MqttSinkWriter {
202    pub async fn new(
203        config: MqttConfig,
204        schema: Schema,
205        format_desc: &SinkFormatDesc,
206        name: &str,
207        id: u64,
208    ) -> Result<Self> {
209        let mut topic_index_path = vec![];
210        if let Some(field) = &config.topic_field {
211            topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
212        }
213
214        let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
215        let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
216        let encoder = match format_desc.format {
217            SinkFormat::AppendOnly => match format_desc.encode {
218                SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
219                    schema.clone(),
220                    None,
221                    DateHandlingMode::FromCe,
222                    TimestampHandlingMode::Milli,
223                    timestamptz_mode,
224                    TimeHandlingMode::Milli,
225                    jsonb_handling_mode,
226                )),
227                SinkEncode::Protobuf => {
228                    let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
229                        &format_desc.options,
230                        config.topic.as_deref().unwrap_or(name),
231                        None,
232                    )
233                    .await
234                    .map_err(|e| SinkError::Config(anyhow!(e)))?;
235                    let header = match sid {
236                        None => ProtoHeader::None,
237                        Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
238                    };
239                    RowEncoderWrapper::Proto(ProtoEncoder::new(
240                        schema.clone(),
241                        None,
242                        descriptor,
243                        header,
244                    )?)
245                }
246                _ => {
247                    return Err(SinkError::Config(anyhow!(
248                        "mqtt sink encode unsupported: {:?}",
249                        format_desc.encode,
250                    )));
251                }
252            },
253            _ => {
254                return Err(SinkError::Config(anyhow!(
255                    "MQTT sink only supports append-only mode"
256                )));
257            }
258        };
259        let qos = config.common.qos();
260
261        let (client, mut eventloop) = config
262            .common
263            .build_client(0, id)
264            .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
265
266        let stopped = Arc::new(AtomicBool::new(false));
267        let stopped_clone = stopped.clone();
268        tokio::spawn(async move {
269            while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
270                match eventloop.poll().await {
271                    Ok(_) => (),
272                    Err(err) => match err {
273                        ConnectionError::Timeout(_) => (),
274                        ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
275                        | ConnectionError::Io(err)
276                            if err.kind() == std::io::ErrorKind::ConnectionAborted
277                                || err.kind() == std::io::ErrorKind::ConnectionReset =>
278                        {
279                            continue;
280                        }
281                        err => {
282                            tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
283                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
284                        }
285                    },
286                }
287            }
288        });
289
290        let payload_writer = MqttSinkPayloadWriter {
291            topic: config.topic.clone(),
292            client,
293            qos,
294            retain: config.retain,
295            topic_index_path,
296        };
297
298        Ok::<_, SinkError>(Self {
299            config: config.clone(),
300            payload_writer,
301            schema: schema.clone(),
302            stopped,
303            encoder,
304        })
305    }
306}
307
308impl AsyncTruncateSinkWriter for MqttSinkWriter {
309    async fn write_chunk<'a>(
310        &'a mut self,
311        chunk: StreamChunk,
312        _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
313    ) -> Result<()> {
314        self.payload_writer.write_chunk(chunk, &self.encoder).await
315    }
316}
317
318impl Drop for MqttSinkWriter {
319    fn drop(&mut self) {
320        self.stopped
321            .store(true, std::sync::atomic::Ordering::Relaxed);
322    }
323}
324
325struct MqttSinkPayloadWriter {
326    // connection to mqtt, one per executor
327    client: rumqttc::v5::AsyncClient,
328    topic: Option<String>,
329    qos: QoS,
330    retain: bool,
331    topic_index_path: Vec<usize>,
332}
333
334impl MqttSinkPayloadWriter {
335    async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
336        for (op, row) in chunk.rows() {
337            if op != Op::Insert {
338                continue;
339            }
340
341            let topic = match get_topic_from_index_path(
342                &self.topic_index_path,
343                self.topic.as_deref(),
344                &row,
345            ) {
346                Some(s) => s,
347                None => {
348                    tracing::error!("topic field not found in row, skipping: {:?}", row);
349                    return Ok(());
350                }
351            };
352
353            let v = encoder.encode(row)?;
354
355            self.client
356                .publish(topic, self.qos, self.retain, v)
357                .await
358                .context("mqtt sink error")
359                .map_err(SinkError::Mqtt)?;
360        }
361
362        Ok(())
363    }
364}
365
366fn get_topic_from_index_path<'s>(
367    path: &[usize],
368    default_topic: Option<&'s str>,
369    row: &'s RowRef<'s>,
370) -> Option<&'s str> {
371    if let Some(topic) = default_topic
372        && path.is_empty()
373    {
374        Some(topic)
375    } else {
376        let mut iter = path.iter();
377        let scalar = iter
378            .next()
379            .and_then(|pos| row.datum_at(*pos))
380            .and_then(|d| {
381                iter.try_fold(d, |d, pos| match d {
382                    ScalarRefImpl::Struct(struct_ref) => {
383                        struct_ref.iter_fields_ref().nth(*pos).flatten()
384                    }
385                    _ => None,
386                })
387            });
388        match scalar {
389            Some(ScalarRefImpl::Utf8(s)) => Some(s),
390            _ => {
391                if let Some(topic) = default_topic {
392                    Some(topic)
393                } else {
394                    None
395                }
396            }
397        }
398    }
399}
400
401// This function returns the index path to the topic field in the schema, validating that the field exists and is of type string
402// the returnent path can be used to extract the topic field from a row. The path is a list of indexes to be used to navigate the row
403// to the topic field.
404fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
405    let mut iter = topic_field.split('.');
406    let mut path = vec![];
407    let dt =
408        iter.next()
409            .and_then(|field| {
410                // Extract the field from the schema
411                schema
412                    .fields()
413                    .iter()
414                    .enumerate()
415                    .find(|(_, f)| f.name == field)
416                    .map(|(pos, f)| {
417                        path.push(pos);
418                        &f.data_type
419                    })
420            })
421            .and_then(|dt| {
422                // Iterate over the next fields to extract the fields from the nested structs
423                iter.try_fold(dt, |dt, field| match dt {
424                    DataType::Struct(st) => {
425                        st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
426                            |(pos, (_, dt))| {
427                                path.push(pos);
428                                dt
429                            },
430                        )
431                    }
432                    _ => None,
433                })
434            });
435
436    match dt {
437        Some(DataType::Varchar) => Ok(path),
438        Some(dt) => Err(SinkError::Config(anyhow!(
439            "topic field `{}` must be of type string but got {:?}",
440            topic_field,
441            dt
442        ))),
443        None => Err(SinkError::Config(anyhow!(
444            "topic field `{}`  not found",
445            topic_field
446        ))),
447    }
448}
449
450#[cfg(test)]
451mod test {
452    use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
453    use risingwave_common::catalog::{Field, Schema};
454    use risingwave_common::types::{DataType, StructType};
455
456    use super::{get_topic_field_index_path, get_topic_from_index_path};
457
458    #[test]
459    fn test_single_field_extraction() {
460        let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
461        let path = get_topic_field_index_path(&schema, "topic").unwrap();
462        assert_eq!(path, vec![0]);
463
464        let chunk = DataChunk::from_pretty(
465            "T
466            test",
467        );
468
469        let row = RowRef::new(&chunk, 0);
470
471        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
472
473        let result = get_topic_field_index_path(&schema, "other_field");
474        assert!(result.is_err());
475    }
476
477    #[test]
478    fn test_nested_field_extraction() {
479        let schema = Schema::new(vec![Field::with_name(
480            DataType::Struct(StructType::new(vec![
481                ("field", DataType::Int32),
482                ("subtopic", DataType::Varchar),
483            ])),
484            "topic",
485        )]);
486        let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
487        assert_eq!(path, vec![0, 1]);
488
489        let chunk = DataChunk::from_pretty(
490            "<i,T>
491            (1,test)",
492        );
493
494        let row = RowRef::new(&chunk, 0);
495
496        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
497
498        let result = get_topic_field_index_path(&schema, "topic.other_field");
499        assert!(result.is_err());
500    }
501
502    #[test]
503    fn test_null_values_extraction() {
504        let path = vec![0];
505        let chunk = DataChunk::from_pretty(
506            "T
507            .",
508        );
509        let row = RowRef::new(&chunk, 0);
510        assert_eq!(
511            get_topic_from_index_path(&path, Some("default"), &row),
512            Some("default")
513        );
514        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
515
516        let path = vec![0, 1];
517        let chunk = DataChunk::from_pretty(
518            "<i,T>
519            (1,)",
520        );
521        let row = RowRef::new(&chunk, 0);
522        assert_eq!(
523            get_topic_from_index_path(&path, Some("default"), &row),
524            Some("default")
525        );
526        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
527    }
528
529    #[test]
530    fn test_multiple_levels() {
531        let schema = Schema::new(vec![
532            Field::with_name(
533                DataType::Struct(StructType::new(vec![
534                    ("field", DataType::Int32),
535                    (
536                        "subtopic",
537                        DataType::Struct(StructType::new(vec![
538                            ("int_field", DataType::Int32),
539                            ("boolean_field", DataType::Boolean),
540                            ("string_field", DataType::Varchar),
541                        ])),
542                    ),
543                ])),
544                "topic",
545            ),
546            Field::with_name(DataType::Varchar, "other_field"),
547        ]);
548
549        let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
550        assert_eq!(path, vec![0, 1, 2]);
551
552        assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
553
554        assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
555
556        assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
557
558        let path = get_topic_field_index_path(&schema, "other_field").unwrap();
559        assert_eq!(path, vec![1]);
560
561        let chunk = DataChunk::from_pretty(
562            "<i,<T>> T
563            (1,(test)) other",
564        );
565
566        let row = RowRef::new(&chunk, 0);
567
568        // topic.subtopic.string_field
569        assert_eq!(
570            get_topic_from_index_path(&[0, 1, 0], None, &row),
571            Some("test")
572        );
573
574        // topic.field
575        assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
576
577        // other_field
578        assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
579    }
580}