risingwave_connector/sink/
mqtt.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use core::fmt::Debug;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use anyhow::{Context as _, anyhow};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, ScalarRefImpl};
24use rumqttc::v5::ConnectionError;
25use rumqttc::v5::mqttbytes::QoS;
26use serde_derive::Deserialize;
27use serde_with::serde_as;
28use thiserror_ext::AsReport;
29use with_options::WithOptions;
30
31use super::SinkWriterParam;
32use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc};
33use super::encoder::{
34    DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
35    TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
36};
37use super::writer::AsyncTruncateSinkWriterExt;
38use crate::connector_common::MqttCommon;
39use crate::deserialize_bool_from_string;
40use crate::enforce_secret::EnforceSecret;
41use crate::sink::log_store::DeliveryFutureManagerAddFuture;
42use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
43use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
44
45pub const MQTT_SINK: &str = "mqtt";
46
47#[serde_as]
48#[derive(Clone, Debug, Deserialize, WithOptions)]
49pub struct MqttConfig {
50    #[serde(flatten)]
51    pub common: MqttCommon,
52
53    /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/#
54    pub topic: Option<String>,
55
56    /// Whether the message should be retained by the broker
57    #[serde(default, deserialize_with = "deserialize_bool_from_string")]
58    pub retain: bool,
59
60    // accept "append-only"
61    pub r#type: String,
62
63    // if set, will use a field value as the topic name, if topic is also set it will be used as a fallback
64    #[serde(rename = "topic.field")]
65    pub topic_field: Option<String>,
66}
67
68impl EnforceSecret for MqttConfig {
69    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
70        MqttCommon::enforce_one(prop)
71    }
72}
73
74pub enum RowEncoderWrapper {
75    Json(JsonEncoder),
76    Proto(ProtoEncoder),
77}
78
79impl RowEncoder for RowEncoderWrapper {
80    type Output = Vec<u8>;
81
82    fn encode_cols(
83        &self,
84        row: impl Row,
85        col_indices: impl Iterator<Item = usize>,
86    ) -> Result<Self::Output> {
87        match self {
88            RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
89            RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
90        }
91    }
92
93    fn schema(&self) -> &Schema {
94        match self {
95            RowEncoderWrapper::Json(json) => json.schema(),
96            RowEncoderWrapper::Proto(proto) => proto.schema(),
97        }
98    }
99
100    fn col_indices(&self) -> Option<&[usize]> {
101        match self {
102            RowEncoderWrapper::Json(json) => json.col_indices(),
103            RowEncoderWrapper::Proto(proto) => proto.col_indices(),
104        }
105    }
106
107    fn encode(&self, row: impl Row) -> Result<Self::Output> {
108        match self {
109            RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
110            RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
111        }
112    }
113}
114
115#[derive(Clone, Debug)]
116pub struct MqttSink {
117    pub config: MqttConfig,
118    schema: Schema,
119    format_desc: SinkFormatDesc,
120    is_append_only: bool,
121    name: String,
122}
123
124impl EnforceSecret for MqttSink {
125    fn enforce_secret<'a>(
126        prop_iter: impl Iterator<Item = &'a str>,
127    ) -> crate::error::ConnectorResult<()> {
128        for prop in prop_iter {
129            MqttConfig::enforce_one(prop)?;
130        }
131        Ok(())
132    }
133}
134
135// sink write
136pub struct MqttSinkWriter {
137    pub config: MqttConfig,
138    payload_writer: MqttSinkPayloadWriter,
139    #[expect(dead_code)]
140    schema: Schema,
141    encoder: RowEncoderWrapper,
142    stopped: Arc<AtomicBool>,
143}
144
145/// Basic data types for use with the mqtt interface
146impl MqttConfig {
147    pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
148        let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
149            .map_err(|e| SinkError::Config(anyhow!(e)))?;
150        if config.r#type != SINK_TYPE_APPEND_ONLY {
151            Err(SinkError::Config(anyhow!(
152                "MQTT sink only supports append-only mode"
153            )))
154        } else {
155            Ok(config)
156        }
157    }
158}
159
160impl TryFrom<SinkParam> for MqttSink {
161    type Error = SinkError;
162
163    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
164        let schema = param.schema();
165        let config = MqttConfig::from_btreemap(param.properties)?;
166        Ok(Self {
167            config,
168            schema,
169            name: param.sink_name,
170            format_desc: param
171                .format_desc
172                .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
173            is_append_only: param.sink_type.is_append_only(),
174        })
175    }
176}
177
178impl Sink for MqttSink {
179    type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
180
181    const SINK_NAME: &'static str = MQTT_SINK;
182
183    async fn validate(&self) -> Result<()> {
184        if !self.is_append_only {
185            return Err(SinkError::Mqtt(anyhow!(
186                "MQTT sink only supports append-only mode"
187            )));
188        }
189
190        if let Some(field) = &self.config.topic_field {
191            let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
192        } else if self.config.topic.is_none() {
193            return Err(SinkError::Config(anyhow!(
194                "either topic or topic.field must be set"
195            )));
196        }
197
198        let _client = (self.config.common.build_client(0, 0))
199            .context("validate mqtt sink error")
200            .map_err(SinkError::Mqtt)?;
201
202        Ok(())
203    }
204
205    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
206        Ok(MqttSinkWriter::new(
207            self.config.clone(),
208            self.schema.clone(),
209            &self.format_desc,
210            &self.name,
211            writer_param.executor_id,
212        )
213        .await?
214        .into_log_sinker(usize::MAX))
215    }
216}
217
218impl MqttSinkWriter {
219    pub async fn new(
220        config: MqttConfig,
221        schema: Schema,
222        format_desc: &SinkFormatDesc,
223        name: &str,
224        id: u64,
225    ) -> Result<Self> {
226        let mut topic_index_path = vec![];
227        if let Some(field) = &config.topic_field {
228            topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
229        }
230
231        let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
232        let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
233        let encoder = match format_desc.format {
234            SinkFormat::AppendOnly => match format_desc.encode {
235                SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
236                    schema.clone(),
237                    None,
238                    DateHandlingMode::FromCe,
239                    TimestampHandlingMode::Milli,
240                    timestamptz_mode,
241                    TimeHandlingMode::Milli,
242                    jsonb_handling_mode,
243                )),
244                SinkEncode::Protobuf => {
245                    let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
246                        &format_desc.options,
247                        config.topic.as_deref().unwrap_or(name),
248                        None,
249                    )
250                    .await
251                    .map_err(|e| SinkError::Config(anyhow!(e)))?;
252                    let header = match sid {
253                        None => ProtoHeader::None,
254                        Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
255                    };
256                    RowEncoderWrapper::Proto(ProtoEncoder::new(
257                        schema.clone(),
258                        None,
259                        descriptor,
260                        header,
261                    )?)
262                }
263                _ => {
264                    return Err(SinkError::Config(anyhow!(
265                        "mqtt sink encode unsupported: {:?}",
266                        format_desc.encode,
267                    )));
268                }
269            },
270            _ => {
271                return Err(SinkError::Config(anyhow!(
272                    "MQTT sink only supports append-only mode"
273                )));
274            }
275        };
276        let qos = config.common.qos();
277
278        let (client, mut eventloop) = config
279            .common
280            .build_client(0, id)
281            .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
282
283        let stopped = Arc::new(AtomicBool::new(false));
284        let stopped_clone = stopped.clone();
285        tokio::spawn(async move {
286            while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
287                match eventloop.poll().await {
288                    Ok(_) => (),
289                    Err(err) => match err {
290                        ConnectionError::Timeout(_) => (),
291                        ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
292                        | ConnectionError::Io(err)
293                            if err.kind() == std::io::ErrorKind::ConnectionAborted
294                                || err.kind() == std::io::ErrorKind::ConnectionReset =>
295                        {
296                            continue;
297                        }
298                        err => {
299                            tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
300                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
301                        }
302                    },
303                }
304            }
305        });
306
307        let payload_writer = MqttSinkPayloadWriter {
308            topic: config.topic.clone(),
309            client,
310            qos,
311            retain: config.retain,
312            topic_index_path,
313        };
314
315        Ok::<_, SinkError>(Self {
316            config: config.clone(),
317            payload_writer,
318            schema: schema.clone(),
319            stopped,
320            encoder,
321        })
322    }
323}
324
325impl AsyncTruncateSinkWriter for MqttSinkWriter {
326    async fn write_chunk<'a>(
327        &'a mut self,
328        chunk: StreamChunk,
329        _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
330    ) -> Result<()> {
331        self.payload_writer.write_chunk(chunk, &self.encoder).await
332    }
333}
334
335impl Drop for MqttSinkWriter {
336    fn drop(&mut self) {
337        self.stopped
338            .store(true, std::sync::atomic::Ordering::Relaxed);
339    }
340}
341
342struct MqttSinkPayloadWriter {
343    // connection to mqtt, one per executor
344    client: rumqttc::v5::AsyncClient,
345    topic: Option<String>,
346    qos: QoS,
347    retain: bool,
348    topic_index_path: Vec<usize>,
349}
350
351impl MqttSinkPayloadWriter {
352    async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
353        for (op, row) in chunk.rows() {
354            if op != Op::Insert {
355                continue;
356            }
357
358            let topic = match get_topic_from_index_path(
359                &self.topic_index_path,
360                self.topic.as_deref(),
361                &row,
362            ) {
363                Some(s) => s,
364                None => {
365                    tracing::error!("topic field not found in row, skipping: {:?}", row);
366                    return Ok(());
367                }
368            };
369
370            let v = encoder.encode(row)?;
371
372            self.client
373                .publish(topic, self.qos, self.retain, v)
374                .await
375                .context("mqtt sink error")
376                .map_err(SinkError::Mqtt)?;
377        }
378
379        Ok(())
380    }
381}
382
383fn get_topic_from_index_path<'s>(
384    path: &[usize],
385    default_topic: Option<&'s str>,
386    row: &'s RowRef<'s>,
387) -> Option<&'s str> {
388    if let Some(topic) = default_topic
389        && path.is_empty()
390    {
391        Some(topic)
392    } else {
393        let mut iter = path.iter();
394        let scalar = iter
395            .next()
396            .and_then(|pos| row.datum_at(*pos))
397            .and_then(|d| {
398                iter.try_fold(d, |d, pos| match d {
399                    ScalarRefImpl::Struct(struct_ref) => {
400                        struct_ref.iter_fields_ref().nth(*pos).flatten()
401                    }
402                    _ => None,
403                })
404            });
405        match scalar {
406            Some(ScalarRefImpl::Utf8(s)) => Some(s),
407            _ => {
408                if let Some(topic) = default_topic {
409                    Some(topic)
410                } else {
411                    None
412                }
413            }
414        }
415    }
416}
417
418// This function returns the index path to the topic field in the schema, validating that the field exists and is of type string
419// the returnent path can be used to extract the topic field from a row. The path is a list of indexes to be used to navigate the row
420// to the topic field.
421fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
422    let mut iter = topic_field.split('.');
423    let mut path = vec![];
424    let dt =
425        iter.next()
426            .and_then(|field| {
427                // Extract the field from the schema
428                schema
429                    .fields()
430                    .iter()
431                    .enumerate()
432                    .find(|(_, f)| f.name == field)
433                    .map(|(pos, f)| {
434                        path.push(pos);
435                        &f.data_type
436                    })
437            })
438            .and_then(|dt| {
439                // Iterate over the next fields to extract the fields from the nested structs
440                iter.try_fold(dt, |dt, field| match dt {
441                    DataType::Struct(st) => {
442                        st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
443                            |(pos, (_, dt))| {
444                                path.push(pos);
445                                dt
446                            },
447                        )
448                    }
449                    _ => None,
450                })
451            });
452
453    match dt {
454        Some(DataType::Varchar) => Ok(path),
455        Some(dt) => Err(SinkError::Config(anyhow!(
456            "topic field `{}` must be of type string but got {:?}",
457            topic_field,
458            dt
459        ))),
460        None => Err(SinkError::Config(anyhow!(
461            "topic field `{}`  not found",
462            topic_field
463        ))),
464    }
465}
466
467#[cfg(test)]
468mod test {
469    use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
470    use risingwave_common::catalog::{Field, Schema};
471    use risingwave_common::types::{DataType, StructType};
472
473    use super::{get_topic_field_index_path, get_topic_from_index_path};
474
475    #[test]
476    fn test_single_field_extraction() {
477        let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
478        let path = get_topic_field_index_path(&schema, "topic").unwrap();
479        assert_eq!(path, vec![0]);
480
481        let chunk = DataChunk::from_pretty(
482            "T
483            test",
484        );
485
486        let row = RowRef::new(&chunk, 0);
487
488        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
489
490        let result = get_topic_field_index_path(&schema, "other_field");
491        assert!(result.is_err());
492    }
493
494    #[test]
495    fn test_nested_field_extraction() {
496        let schema = Schema::new(vec![Field::with_name(
497            DataType::Struct(StructType::new(vec![
498                ("field", DataType::Int32),
499                ("subtopic", DataType::Varchar),
500            ])),
501            "topic",
502        )]);
503        let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
504        assert_eq!(path, vec![0, 1]);
505
506        let chunk = DataChunk::from_pretty(
507            "<i,T>
508            (1,test)",
509        );
510
511        let row = RowRef::new(&chunk, 0);
512
513        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
514
515        let result = get_topic_field_index_path(&schema, "topic.other_field");
516        assert!(result.is_err());
517    }
518
519    #[test]
520    fn test_null_values_extraction() {
521        let path = vec![0];
522        let chunk = DataChunk::from_pretty(
523            "T
524            .",
525        );
526        let row = RowRef::new(&chunk, 0);
527        assert_eq!(
528            get_topic_from_index_path(&path, Some("default"), &row),
529            Some("default")
530        );
531        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
532
533        let path = vec![0, 1];
534        let chunk = DataChunk::from_pretty(
535            "<i,T>
536            (1,)",
537        );
538        let row = RowRef::new(&chunk, 0);
539        assert_eq!(
540            get_topic_from_index_path(&path, Some("default"), &row),
541            Some("default")
542        );
543        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
544    }
545
546    #[test]
547    fn test_multiple_levels() {
548        let schema = Schema::new(vec![
549            Field::with_name(
550                DataType::Struct(StructType::new(vec![
551                    ("field", DataType::Int32),
552                    (
553                        "subtopic",
554                        DataType::Struct(StructType::new(vec![
555                            ("int_field", DataType::Int32),
556                            ("boolean_field", DataType::Boolean),
557                            ("string_field", DataType::Varchar),
558                        ])),
559                    ),
560                ])),
561                "topic",
562            ),
563            Field::with_name(DataType::Varchar, "other_field"),
564        ]);
565
566        let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
567        assert_eq!(path, vec![0, 1, 2]);
568
569        assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
570
571        assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
572
573        assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
574
575        let path = get_topic_field_index_path(&schema, "other_field").unwrap();
576        assert_eq!(path, vec![1]);
577
578        let chunk = DataChunk::from_pretty(
579            "<i,<T>> T
580            (1,(test)) other",
581        );
582
583        let row = RowRef::new(&chunk, 0);
584
585        // topic.subtopic.string_field
586        assert_eq!(
587            get_topic_from_index_path(&[0, 1, 0], None, &row),
588            Some("test")
589        );
590
591        // topic.field
592        assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
593
594        // other_field
595        assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
596    }
597}