risingwave_connector/sink/
mqtt.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use core::fmt::Debug;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicBool;
18
19use anyhow::{Context as _, anyhow};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, ScalarRefImpl};
24use rumqttc::v5::ConnectionError;
25use rumqttc::v5::mqttbytes::QoS;
26use serde_derive::Deserialize;
27use serde_with::serde_as;
28use thiserror_ext::AsReport;
29use with_options::WithOptions;
30
31use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc};
32use super::encoder::{
33    DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
34    TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
35};
36use super::writer::AsyncTruncateSinkWriterExt;
37use super::{DummySinkCommitCoordinator, SinkWriterParam};
38use crate::connector_common::MqttCommon;
39use crate::deserialize_bool_from_string;
40use crate::enforce_secret::EnforceSecret;
41use crate::sink::log_store::DeliveryFutureManagerAddFuture;
42use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
43use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
44
45pub const MQTT_SINK: &str = "mqtt";
46
47#[serde_as]
48#[derive(Clone, Debug, Deserialize, WithOptions)]
49pub struct MqttConfig {
50    #[serde(flatten)]
51    pub common: MqttCommon,
52
53    /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/#
54    pub topic: Option<String>,
55
56    /// Whether the message should be retained by the broker
57    #[serde(default, deserialize_with = "deserialize_bool_from_string")]
58    pub retain: bool,
59
60    // accept "append-only"
61    pub r#type: String,
62
63    // if set, will use a field value as the topic name, if topic is also set it will be used as a fallback
64    #[serde(rename = "topic.field")]
65    pub topic_field: Option<String>,
66}
67
68impl EnforceSecret for MqttConfig {
69    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
70        MqttCommon::enforce_one(prop)
71    }
72}
73
74pub enum RowEncoderWrapper {
75    Json(JsonEncoder),
76    Proto(ProtoEncoder),
77}
78
79impl RowEncoder for RowEncoderWrapper {
80    type Output = Vec<u8>;
81
82    fn encode_cols(
83        &self,
84        row: impl Row,
85        col_indices: impl Iterator<Item = usize>,
86    ) -> Result<Self::Output> {
87        match self {
88            RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
89            RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
90        }
91    }
92
93    fn schema(&self) -> &Schema {
94        match self {
95            RowEncoderWrapper::Json(json) => json.schema(),
96            RowEncoderWrapper::Proto(proto) => proto.schema(),
97        }
98    }
99
100    fn col_indices(&self) -> Option<&[usize]> {
101        match self {
102            RowEncoderWrapper::Json(json) => json.col_indices(),
103            RowEncoderWrapper::Proto(proto) => proto.col_indices(),
104        }
105    }
106
107    fn encode(&self, row: impl Row) -> Result<Self::Output> {
108        match self {
109            RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
110            RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
111        }
112    }
113}
114
115#[derive(Clone, Debug)]
116pub struct MqttSink {
117    pub config: MqttConfig,
118    schema: Schema,
119    format_desc: SinkFormatDesc,
120    is_append_only: bool,
121    name: String,
122}
123
124impl EnforceSecret for MqttSink {
125    fn enforce_secret<'a>(
126        prop_iter: impl Iterator<Item = &'a str>,
127    ) -> crate::error::ConnectorResult<()> {
128        for prop in prop_iter {
129            MqttConfig::enforce_one(prop)?;
130        }
131        Ok(())
132    }
133}
134
135// sink write
136pub struct MqttSinkWriter {
137    pub config: MqttConfig,
138    payload_writer: MqttSinkPayloadWriter,
139    #[expect(dead_code)]
140    schema: Schema,
141    encoder: RowEncoderWrapper,
142    stopped: Arc<AtomicBool>,
143}
144
145/// Basic data types for use with the mqtt interface
146impl MqttConfig {
147    pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
148        let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
149            .map_err(|e| SinkError::Config(anyhow!(e)))?;
150        if config.r#type != SINK_TYPE_APPEND_ONLY {
151            Err(SinkError::Config(anyhow!(
152                "MQTT sink only supports append-only mode"
153            )))
154        } else {
155            Ok(config)
156        }
157    }
158}
159
160impl TryFrom<SinkParam> for MqttSink {
161    type Error = SinkError;
162
163    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
164        let schema = param.schema();
165        let config = MqttConfig::from_btreemap(param.properties)?;
166        Ok(Self {
167            config,
168            schema,
169            name: param.sink_name,
170            format_desc: param
171                .format_desc
172                .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
173            is_append_only: param.sink_type.is_append_only(),
174        })
175    }
176}
177
178impl Sink for MqttSink {
179    type Coordinator = DummySinkCommitCoordinator;
180    type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
181
182    const SINK_NAME: &'static str = MQTT_SINK;
183
184    async fn validate(&self) -> Result<()> {
185        if !self.is_append_only {
186            return Err(SinkError::Mqtt(anyhow!(
187                "MQTT sink only supports append-only mode"
188            )));
189        }
190
191        if let Some(field) = &self.config.topic_field {
192            let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
193        } else if self.config.topic.is_none() {
194            return Err(SinkError::Config(anyhow!(
195                "either topic or topic.field must be set"
196            )));
197        }
198
199        let _client = (self.config.common.build_client(0, 0))
200            .context("validate mqtt sink error")
201            .map_err(SinkError::Mqtt)?;
202
203        Ok(())
204    }
205
206    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
207        Ok(MqttSinkWriter::new(
208            self.config.clone(),
209            self.schema.clone(),
210            &self.format_desc,
211            &self.name,
212            writer_param.executor_id,
213        )
214        .await?
215        .into_log_sinker(usize::MAX))
216    }
217}
218
219impl MqttSinkWriter {
220    pub async fn new(
221        config: MqttConfig,
222        schema: Schema,
223        format_desc: &SinkFormatDesc,
224        name: &str,
225        id: u64,
226    ) -> Result<Self> {
227        let mut topic_index_path = vec![];
228        if let Some(field) = &config.topic_field {
229            topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
230        }
231
232        let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
233        let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
234        let encoder = match format_desc.format {
235            SinkFormat::AppendOnly => match format_desc.encode {
236                SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
237                    schema.clone(),
238                    None,
239                    DateHandlingMode::FromCe,
240                    TimestampHandlingMode::Milli,
241                    timestamptz_mode,
242                    TimeHandlingMode::Milli,
243                    jsonb_handling_mode,
244                )),
245                SinkEncode::Protobuf => {
246                    let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
247                        &format_desc.options,
248                        config.topic.as_deref().unwrap_or(name),
249                        None,
250                    )
251                    .await
252                    .map_err(|e| SinkError::Config(anyhow!(e)))?;
253                    let header = match sid {
254                        None => ProtoHeader::None,
255                        Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
256                    };
257                    RowEncoderWrapper::Proto(ProtoEncoder::new(
258                        schema.clone(),
259                        None,
260                        descriptor,
261                        header,
262                    )?)
263                }
264                _ => {
265                    return Err(SinkError::Config(anyhow!(
266                        "mqtt sink encode unsupported: {:?}",
267                        format_desc.encode,
268                    )));
269                }
270            },
271            _ => {
272                return Err(SinkError::Config(anyhow!(
273                    "MQTT sink only supports append-only mode"
274                )));
275            }
276        };
277        let qos = config.common.qos();
278
279        let (client, mut eventloop) = config
280            .common
281            .build_client(0, id)
282            .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
283
284        let stopped = Arc::new(AtomicBool::new(false));
285        let stopped_clone = stopped.clone();
286        tokio::spawn(async move {
287            while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
288                match eventloop.poll().await {
289                    Ok(_) => (),
290                    Err(err) => match err {
291                        ConnectionError::Timeout(_) => (),
292                        ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
293                        | ConnectionError::Io(err)
294                            if err.kind() == std::io::ErrorKind::ConnectionAborted
295                                || err.kind() == std::io::ErrorKind::ConnectionReset =>
296                        {
297                            continue;
298                        }
299                        err => {
300                            tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
301                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
302                        }
303                    },
304                }
305            }
306        });
307
308        let payload_writer = MqttSinkPayloadWriter {
309            topic: config.topic.clone(),
310            client,
311            qos,
312            retain: config.retain,
313            topic_index_path,
314        };
315
316        Ok::<_, SinkError>(Self {
317            config: config.clone(),
318            payload_writer,
319            schema: schema.clone(),
320            stopped,
321            encoder,
322        })
323    }
324}
325
326impl AsyncTruncateSinkWriter for MqttSinkWriter {
327    async fn write_chunk<'a>(
328        &'a mut self,
329        chunk: StreamChunk,
330        _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
331    ) -> Result<()> {
332        self.payload_writer.write_chunk(chunk, &self.encoder).await
333    }
334}
335
336impl Drop for MqttSinkWriter {
337    fn drop(&mut self) {
338        self.stopped
339            .store(true, std::sync::atomic::Ordering::Relaxed);
340    }
341}
342
343struct MqttSinkPayloadWriter {
344    // connection to mqtt, one per executor
345    client: rumqttc::v5::AsyncClient,
346    topic: Option<String>,
347    qos: QoS,
348    retain: bool,
349    topic_index_path: Vec<usize>,
350}
351
352impl MqttSinkPayloadWriter {
353    async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
354        for (op, row) in chunk.rows() {
355            if op != Op::Insert {
356                continue;
357            }
358
359            let topic = match get_topic_from_index_path(
360                &self.topic_index_path,
361                self.topic.as_deref(),
362                &row,
363            ) {
364                Some(s) => s,
365                None => {
366                    tracing::error!("topic field not found in row, skipping: {:?}", row);
367                    return Ok(());
368                }
369            };
370
371            let v = encoder.encode(row)?;
372
373            self.client
374                .publish(topic, self.qos, self.retain, v)
375                .await
376                .context("mqtt sink error")
377                .map_err(SinkError::Mqtt)?;
378        }
379
380        Ok(())
381    }
382}
383
384fn get_topic_from_index_path<'s>(
385    path: &[usize],
386    default_topic: Option<&'s str>,
387    row: &'s RowRef<'s>,
388) -> Option<&'s str> {
389    if let Some(topic) = default_topic
390        && path.is_empty()
391    {
392        Some(topic)
393    } else {
394        let mut iter = path.iter();
395        let scalar = iter
396            .next()
397            .and_then(|pos| row.datum_at(*pos))
398            .and_then(|d| {
399                iter.try_fold(d, |d, pos| match d {
400                    ScalarRefImpl::Struct(struct_ref) => {
401                        struct_ref.iter_fields_ref().nth(*pos).flatten()
402                    }
403                    _ => None,
404                })
405            });
406        match scalar {
407            Some(ScalarRefImpl::Utf8(s)) => Some(s),
408            _ => {
409                if let Some(topic) = default_topic {
410                    Some(topic)
411                } else {
412                    None
413                }
414            }
415        }
416    }
417}
418
419// This function returns the index path to the topic field in the schema, validating that the field exists and is of type string
420// the returnent path can be used to extract the topic field from a row. The path is a list of indexes to be used to navigate the row
421// to the topic field.
422fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
423    let mut iter = topic_field.split('.');
424    let mut path = vec![];
425    let dt =
426        iter.next()
427            .and_then(|field| {
428                // Extract the field from the schema
429                schema
430                    .fields()
431                    .iter()
432                    .enumerate()
433                    .find(|(_, f)| f.name == field)
434                    .map(|(pos, f)| {
435                        path.push(pos);
436                        &f.data_type
437                    })
438            })
439            .and_then(|dt| {
440                // Iterate over the next fields to extract the fields from the nested structs
441                iter.try_fold(dt, |dt, field| match dt {
442                    DataType::Struct(st) => {
443                        st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
444                            |(pos, (_, dt))| {
445                                path.push(pos);
446                                dt
447                            },
448                        )
449                    }
450                    _ => None,
451                })
452            });
453
454    match dt {
455        Some(DataType::Varchar) => Ok(path),
456        Some(dt) => Err(SinkError::Config(anyhow!(
457            "topic field `{}` must be of type string but got {:?}",
458            topic_field,
459            dt
460        ))),
461        None => Err(SinkError::Config(anyhow!(
462            "topic field `{}`  not found",
463            topic_field
464        ))),
465    }
466}
467
468#[cfg(test)]
469mod test {
470    use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
471    use risingwave_common::catalog::{Field, Schema};
472    use risingwave_common::types::{DataType, StructType};
473
474    use super::{get_topic_field_index_path, get_topic_from_index_path};
475
476    #[test]
477    fn test_single_field_extraction() {
478        let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
479        let path = get_topic_field_index_path(&schema, "topic").unwrap();
480        assert_eq!(path, vec![0]);
481
482        let chunk = DataChunk::from_pretty(
483            "T
484            test",
485        );
486
487        let row = RowRef::new(&chunk, 0);
488
489        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
490
491        let result = get_topic_field_index_path(&schema, "other_field");
492        assert!(result.is_err());
493    }
494
495    #[test]
496    fn test_nested_field_extraction() {
497        let schema = Schema::new(vec![Field::with_name(
498            DataType::Struct(StructType::new(vec![
499                ("field", DataType::Int32),
500                ("subtopic", DataType::Varchar),
501            ])),
502            "topic",
503        )]);
504        let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
505        assert_eq!(path, vec![0, 1]);
506
507        let chunk = DataChunk::from_pretty(
508            "<i,T>
509            (1,test)",
510        );
511
512        let row = RowRef::new(&chunk, 0);
513
514        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
515
516        let result = get_topic_field_index_path(&schema, "topic.other_field");
517        assert!(result.is_err());
518    }
519
520    #[test]
521    fn test_null_values_extraction() {
522        let path = vec![0];
523        let chunk = DataChunk::from_pretty(
524            "T
525            .",
526        );
527        let row = RowRef::new(&chunk, 0);
528        assert_eq!(
529            get_topic_from_index_path(&path, Some("default"), &row),
530            Some("default")
531        );
532        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
533
534        let path = vec![0, 1];
535        let chunk = DataChunk::from_pretty(
536            "<i,T>
537            (1,)",
538        );
539        let row = RowRef::new(&chunk, 0);
540        assert_eq!(
541            get_topic_from_index_path(&path, Some("default"), &row),
542            Some("default")
543        );
544        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
545    }
546
547    #[test]
548    fn test_multiple_levels() {
549        let schema = Schema::new(vec![
550            Field::with_name(
551                DataType::Struct(StructType::new(vec![
552                    ("field", DataType::Int32),
553                    (
554                        "subtopic",
555                        DataType::Struct(StructType::new(vec![
556                            ("int_field", DataType::Int32),
557                            ("boolean_field", DataType::Boolean),
558                            ("string_field", DataType::Varchar),
559                        ])),
560                    ),
561                ])),
562                "topic",
563            ),
564            Field::with_name(DataType::Varchar, "other_field"),
565        ]);
566
567        let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
568        assert_eq!(path, vec![0, 1, 2]);
569
570        assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
571
572        assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
573
574        assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
575
576        let path = get_topic_field_index_path(&schema, "other_field").unwrap();
577        assert_eq!(path, vec![1]);
578
579        let chunk = DataChunk::from_pretty(
580            "<i,<T>> T
581            (1,(test)) other",
582        );
583
584        let row = RowRef::new(&chunk, 0);
585
586        // topic.subtopic.string_field
587        assert_eq!(
588            get_topic_from_index_path(&[0, 1, 0], None, &row),
589            Some("test")
590        );
591
592        // topic.field
593        assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
594
595        // other_field
596        assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
597    }
598}