risingwave_connector/sink/
mqtt.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use core::fmt::Debug;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use anyhow::{Context as _, anyhow};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::id::ActorId;
23use risingwave_common::row::Row;
24use risingwave_common::types::{DataType, ScalarRefImpl};
25use rumqttc::v5::ConnectionError;
26use rumqttc::v5::mqttbytes::QoS;
27use serde::Deserialize;
28use serde_with::serde_as;
29use thiserror_ext::AsReport;
30use with_options::WithOptions;
31
32use super::SinkWriterParam;
33use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc, SinkId};
34use super::encoder::{
35    DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
36    TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
37};
38use super::writer::AsyncTruncateSinkWriterExt;
39use crate::connector_common::MqttCommon;
40use crate::deserialize_bool_from_string;
41use crate::enforce_secret::EnforceSecret;
42use crate::sink::log_store::DeliveryFutureManagerAddFuture;
43use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
44use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
45
46pub const MQTT_SINK: &str = "mqtt";
47
48#[serde_as]
49#[derive(Clone, Debug, Deserialize, WithOptions)]
50pub struct MqttConfig {
51    #[serde(flatten)]
52    pub common: MqttCommon,
53
54    /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/#
55    pub topic: Option<String>,
56
57    /// Whether the message should be retained by the broker
58    #[serde(default, deserialize_with = "deserialize_bool_from_string")]
59    pub retain: bool,
60
61    // accept "append-only"
62    pub r#type: String,
63
64    // if set, will use a field value as the topic name, if topic is also set it will be used as a fallback
65    #[serde(rename = "topic.field")]
66    pub topic_field: Option<String>,
67}
68
69impl EnforceSecret for MqttConfig {
70    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
71        MqttCommon::enforce_one(prop)
72    }
73}
74
75pub enum RowEncoderWrapper {
76    Json(JsonEncoder),
77    Proto(ProtoEncoder),
78}
79
80impl RowEncoder for RowEncoderWrapper {
81    type Output = Vec<u8>;
82
83    fn encode_cols(
84        &self,
85        row: impl Row,
86        col_indices: impl Iterator<Item = usize>,
87    ) -> Result<Self::Output> {
88        match self {
89            RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
90            RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
91        }
92    }
93
94    fn schema(&self) -> &Schema {
95        match self {
96            RowEncoderWrapper::Json(json) => json.schema(),
97            RowEncoderWrapper::Proto(proto) => proto.schema(),
98        }
99    }
100
101    fn col_indices(&self) -> Option<&[usize]> {
102        match self {
103            RowEncoderWrapper::Json(json) => json.col_indices(),
104            RowEncoderWrapper::Proto(proto) => proto.col_indices(),
105        }
106    }
107
108    fn encode(&self, row: impl Row) -> Result<Self::Output> {
109        match self {
110            RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
111            RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
112        }
113    }
114}
115
116#[derive(Clone, Debug)]
117pub struct MqttSink {
118    pub config: MqttConfig,
119    schema: Schema,
120    format_desc: SinkFormatDesc,
121    is_append_only: bool,
122    name: String,
123}
124
125impl EnforceSecret for MqttSink {
126    fn enforce_secret<'a>(
127        prop_iter: impl Iterator<Item = &'a str>,
128    ) -> crate::error::ConnectorResult<()> {
129        for prop in prop_iter {
130            MqttConfig::enforce_one(prop)?;
131        }
132        Ok(())
133    }
134}
135
136// sink write
137pub struct MqttSinkWriter {
138    pub config: MqttConfig,
139    payload_writer: MqttSinkPayloadWriter,
140    #[expect(dead_code)]
141    schema: Schema,
142    encoder: RowEncoderWrapper,
143    stopped: Arc<AtomicBool>,
144}
145
146/// Basic data types for use with the mqtt interface
147impl MqttConfig {
148    pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
149        let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
150            .map_err(|e| SinkError::Config(anyhow!(e)))?;
151        if config.r#type != SINK_TYPE_APPEND_ONLY {
152            Err(SinkError::Config(anyhow!(
153                "MQTT sink only supports append-only mode"
154            )))
155        } else {
156            Ok(config)
157        }
158    }
159}
160
161impl TryFrom<SinkParam> for MqttSink {
162    type Error = SinkError;
163
164    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
165        let schema = param.schema();
166        let config = MqttConfig::from_btreemap(param.properties)?;
167        Ok(Self {
168            config,
169            schema,
170            name: param.sink_name,
171            format_desc: param
172                .format_desc
173                .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
174            is_append_only: param.sink_type.is_append_only(),
175        })
176    }
177}
178
179impl Sink for MqttSink {
180    type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
181
182    const SINK_NAME: &'static str = MQTT_SINK;
183
184    async fn validate(&self) -> Result<()> {
185        if !self.is_append_only {
186            return Err(SinkError::Mqtt(anyhow!(
187                "MQTT sink only supports append-only mode"
188            )));
189        }
190
191        if let Some(field) = &self.config.topic_field {
192            let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
193        } else if self.config.topic.is_none() {
194            return Err(SinkError::Config(anyhow!(
195                "either topic or topic.field must be set"
196            )));
197        }
198
199        let _client = (self.config.common.build_client(0.into(), 0))
200            .context("validate mqtt sink error")
201            .map_err(SinkError::Mqtt)?;
202
203        Ok(())
204    }
205
206    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
207        Ok(MqttSinkWriter::new(
208            self.config.clone(),
209            self.schema.clone(),
210            &self.format_desc,
211            &self.name,
212            writer_param.sink_id,
213            writer_param.actor_id,
214        )
215        .await?
216        .into_log_sinker(usize::MAX))
217    }
218}
219
220impl MqttSinkWriter {
221    pub async fn new(
222        config: MqttConfig,
223        schema: Schema,
224        format_desc: &SinkFormatDesc,
225        name: &str,
226        sink_id: SinkId,
227        actor_id: ActorId,
228    ) -> Result<Self> {
229        let mut topic_index_path = vec![];
230        if let Some(field) = &config.topic_field {
231            topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
232        }
233
234        let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
235        let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
236        let encoder = match format_desc.format {
237            SinkFormat::AppendOnly => match format_desc.encode {
238                SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
239                    schema.clone(),
240                    None,
241                    DateHandlingMode::FromCe,
242                    TimestampHandlingMode::Milli,
243                    timestamptz_mode,
244                    TimeHandlingMode::Milli,
245                    jsonb_handling_mode,
246                )),
247                SinkEncode::Protobuf => {
248                    let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
249                        &format_desc.options,
250                        config.topic.as_deref().unwrap_or(name),
251                        None,
252                    )
253                    .await
254                    .map_err(|e| SinkError::Config(anyhow!(e)))?;
255                    let header = match sid {
256                        None => ProtoHeader::None,
257                        Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
258                    };
259                    RowEncoderWrapper::Proto(ProtoEncoder::new(
260                        schema.clone(),
261                        None,
262                        descriptor,
263                        header,
264                    )?)
265                }
266                _ => {
267                    return Err(SinkError::Config(anyhow!(
268                        "mqtt sink encode unsupported: {:?}",
269                        format_desc.encode,
270                    )));
271                }
272            },
273            _ => {
274                return Err(SinkError::Config(anyhow!(
275                    "MQTT sink only supports append-only mode"
276                )));
277            }
278        };
279        let qos = config.common.qos();
280
281        let (client, mut eventloop) = config
282            .common
283            .build_client(actor_id, sink_id.as_raw_id())
284            .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
285
286        let stopped = Arc::new(AtomicBool::new(false));
287        let stopped_clone = stopped.clone();
288        tokio::spawn(async move {
289            while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
290                match eventloop.poll().await {
291                    Ok(_) => (),
292                    Err(err) => match err {
293                        ConnectionError::Timeout(_) => (),
294                        ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
295                        | ConnectionError::Io(err)
296                            if err.kind() == std::io::ErrorKind::ConnectionAborted
297                                || err.kind() == std::io::ErrorKind::ConnectionReset =>
298                        {
299                            continue;
300                        }
301                        err => {
302                            tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
303                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
304                        }
305                    },
306                }
307            }
308        });
309
310        let payload_writer = MqttSinkPayloadWriter {
311            topic: config.topic.clone(),
312            client,
313            qos,
314            retain: config.retain,
315            topic_index_path,
316        };
317
318        Ok::<_, SinkError>(Self {
319            config: config.clone(),
320            payload_writer,
321            schema: schema.clone(),
322            stopped,
323            encoder,
324        })
325    }
326}
327
328impl AsyncTruncateSinkWriter for MqttSinkWriter {
329    async fn write_chunk<'a>(
330        &'a mut self,
331        chunk: StreamChunk,
332        _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
333    ) -> Result<()> {
334        self.payload_writer.write_chunk(chunk, &self.encoder).await
335    }
336}
337
338impl Drop for MqttSinkWriter {
339    fn drop(&mut self) {
340        self.stopped
341            .store(true, std::sync::atomic::Ordering::Relaxed);
342    }
343}
344
345struct MqttSinkPayloadWriter {
346    // connection to mqtt, one per executor
347    client: rumqttc::v5::AsyncClient,
348    topic: Option<String>,
349    qos: QoS,
350    retain: bool,
351    topic_index_path: Vec<usize>,
352}
353
354impl MqttSinkPayloadWriter {
355    async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
356        for (op, row) in chunk.rows() {
357            if op != Op::Insert {
358                continue;
359            }
360
361            let topic = match get_topic_from_index_path(
362                &self.topic_index_path,
363                self.topic.as_deref(),
364                &row,
365            ) {
366                Some(s) => s,
367                None => {
368                    tracing::error!("topic field not found in row, skipping: {:?}", row);
369                    return Ok(());
370                }
371            };
372
373            let v = encoder.encode(row)?;
374
375            self.client
376                .publish(topic, self.qos, self.retain, v)
377                .await
378                .context("mqtt sink error")
379                .map_err(SinkError::Mqtt)?;
380        }
381
382        Ok(())
383    }
384}
385
386fn get_topic_from_index_path<'s>(
387    path: &[usize],
388    default_topic: Option<&'s str>,
389    row: &'s RowRef<'s>,
390) -> Option<&'s str> {
391    if let Some(topic) = default_topic
392        && path.is_empty()
393    {
394        Some(topic)
395    } else {
396        let mut iter = path.iter();
397        let scalar = iter
398            .next()
399            .and_then(|pos| row.datum_at(*pos))
400            .and_then(|d| {
401                iter.try_fold(d, |d, pos| match d {
402                    ScalarRefImpl::Struct(struct_ref) => {
403                        struct_ref.iter_fields_ref().nth(*pos).flatten()
404                    }
405                    _ => None,
406                })
407            });
408        match scalar {
409            Some(ScalarRefImpl::Utf8(s)) => Some(s),
410            _ => {
411                if let Some(topic) = default_topic {
412                    Some(topic)
413                } else {
414                    None
415                }
416            }
417        }
418    }
419}
420
421// This function returns the index path to the topic field in the schema, validating that the field exists and is of type string
422// the returnent path can be used to extract the topic field from a row. The path is a list of indexes to be used to navigate the row
423// to the topic field.
424fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
425    let mut iter = topic_field.split('.');
426    let mut path = vec![];
427    let dt =
428        iter.next()
429            .and_then(|field| {
430                // Extract the field from the schema
431                schema
432                    .fields()
433                    .iter()
434                    .enumerate()
435                    .find(|(_, f)| f.name == field)
436                    .map(|(pos, f)| {
437                        path.push(pos);
438                        &f.data_type
439                    })
440            })
441            .and_then(|dt| {
442                // Iterate over the next fields to extract the fields from the nested structs
443                iter.try_fold(dt, |dt, field| match dt {
444                    DataType::Struct(st) => {
445                        st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
446                            |(pos, (_, dt))| {
447                                path.push(pos);
448                                dt
449                            },
450                        )
451                    }
452                    _ => None,
453                })
454            });
455
456    match dt {
457        Some(DataType::Varchar) => Ok(path),
458        Some(dt) => Err(SinkError::Config(anyhow!(
459            "topic field `{}` must be of type string but got {:?}",
460            topic_field,
461            dt
462        ))),
463        None => Err(SinkError::Config(anyhow!(
464            "topic field `{}`  not found",
465            topic_field
466        ))),
467    }
468}
469
470#[cfg(test)]
471mod test {
472    use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
473    use risingwave_common::catalog::{Field, Schema};
474    use risingwave_common::types::{DataType, StructType};
475
476    use super::{get_topic_field_index_path, get_topic_from_index_path};
477
478    #[test]
479    fn test_single_field_extraction() {
480        let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
481        let path = get_topic_field_index_path(&schema, "topic").unwrap();
482        assert_eq!(path, vec![0]);
483
484        let chunk = DataChunk::from_pretty(
485            "T
486            test",
487        );
488
489        let row = RowRef::new(&chunk, 0);
490
491        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
492
493        let result = get_topic_field_index_path(&schema, "other_field");
494        assert!(result.is_err());
495    }
496
497    #[test]
498    fn test_nested_field_extraction() {
499        let schema = Schema::new(vec![Field::with_name(
500            DataType::Struct(StructType::new(vec![
501                ("field", DataType::Int32),
502                ("subtopic", DataType::Varchar),
503            ])),
504            "topic",
505        )]);
506        let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
507        assert_eq!(path, vec![0, 1]);
508
509        let chunk = DataChunk::from_pretty(
510            "<i,T>
511            (1,test)",
512        );
513
514        let row = RowRef::new(&chunk, 0);
515
516        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
517
518        let result = get_topic_field_index_path(&schema, "topic.other_field");
519        assert!(result.is_err());
520    }
521
522    #[test]
523    fn test_null_values_extraction() {
524        let path = vec![0];
525        let chunk = DataChunk::from_pretty(
526            "T
527            .",
528        );
529        let row = RowRef::new(&chunk, 0);
530        assert_eq!(
531            get_topic_from_index_path(&path, Some("default"), &row),
532            Some("default")
533        );
534        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
535
536        let path = vec![0, 1];
537        let chunk = DataChunk::from_pretty(
538            "<i,T>
539            (1,)",
540        );
541        let row = RowRef::new(&chunk, 0);
542        assert_eq!(
543            get_topic_from_index_path(&path, Some("default"), &row),
544            Some("default")
545        );
546        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
547    }
548
549    #[test]
550    fn test_multiple_levels() {
551        let schema = Schema::new(vec![
552            Field::with_name(
553                DataType::Struct(StructType::new(vec![
554                    ("field", DataType::Int32),
555                    (
556                        "subtopic",
557                        DataType::Struct(StructType::new(vec![
558                            ("int_field", DataType::Int32),
559                            ("boolean_field", DataType::Boolean),
560                            ("string_field", DataType::Varchar),
561                        ])),
562                    ),
563                ])),
564                "topic",
565            ),
566            Field::with_name(DataType::Varchar, "other_field"),
567        ]);
568
569        let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
570        assert_eq!(path, vec![0, 1, 2]);
571
572        assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
573
574        assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
575
576        assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
577
578        let path = get_topic_field_index_path(&schema, "other_field").unwrap();
579        assert_eq!(path, vec![1]);
580
581        let chunk = DataChunk::from_pretty(
582            "<i,<T>> T
583            (1,(test)) other",
584        );
585
586        let row = RowRef::new(&chunk, 0);
587
588        // topic.subtopic.string_field
589        assert_eq!(
590            get_topic_from_index_path(&[0, 1, 0], None, &row),
591            Some("test")
592        );
593
594        // topic.field
595        assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
596
597        // other_field
598        assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
599    }
600}