risingwave_connector/sink/
mqtt.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::fmt::Debug;
16use std::collections::BTreeMap;
17use std::sync::Arc;
18use std::sync::atomic::AtomicBool;
19
20use anyhow::{Context as _, anyhow};
21use risingwave_common::array::{Op, RowRef, StreamChunk};
22use risingwave_common::catalog::Schema;
23use risingwave_common::id::ActorId;
24use risingwave_common::row::Row;
25use risingwave_common::types::{DataType, ScalarRefImpl};
26use rumqttc::v5::ConnectionError;
27use rumqttc::v5::mqttbytes::QoS;
28use serde::Deserialize;
29use serde_with::serde_as;
30use thiserror_ext::AsReport;
31use with_options::WithOptions;
32
33use super::SinkWriterParam;
34use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc, SinkId};
35use super::encoder::{
36    DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo,
37    TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode,
38};
39use super::writer::AsyncTruncateSinkWriterExt;
40use crate::connector_common::MqttCommon;
41use crate::deserialize_bool_from_string;
42use crate::enforce_secret::EnforceSecret;
43use crate::sink::log_store::DeliveryFutureManagerAddFuture;
44use crate::sink::writer::{AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter};
45use crate::sink::{Result, SINK_TYPE_APPEND_ONLY, Sink, SinkError, SinkParam};
46
47pub const MQTT_SINK: &str = "mqtt";
48
49#[serde_as]
50#[derive(Clone, Debug, Deserialize, WithOptions)]
51pub struct MqttConfig {
52    #[serde(flatten)]
53    pub common: MqttCommon,
54
55    /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/#
56    pub topic: Option<String>,
57
58    /// Whether the message should be retained by the broker
59    #[serde(default, deserialize_with = "deserialize_bool_from_string")]
60    pub retain: bool,
61
62    // accept "append-only"
63    pub r#type: String,
64
65    // if set, will use a field value as the topic name, if topic is also set it will be used as a fallback
66    #[serde(rename = "topic.field")]
67    pub topic_field: Option<String>,
68}
69
70impl EnforceSecret for MqttConfig {
71    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
72        MqttCommon::enforce_one(prop)
73    }
74}
75
76pub enum RowEncoderWrapper {
77    Json(JsonEncoder),
78    Proto(ProtoEncoder),
79}
80
81impl RowEncoder for RowEncoderWrapper {
82    type Output = Vec<u8>;
83
84    fn encode_cols(
85        &self,
86        row: impl Row,
87        col_indices: impl Iterator<Item = usize>,
88    ) -> Result<Self::Output> {
89        match self {
90            RowEncoderWrapper::Json(json) => json.encode_cols(row, col_indices)?.ser_to(),
91            RowEncoderWrapper::Proto(proto) => proto.encode_cols(row, col_indices)?.ser_to(),
92        }
93    }
94
95    fn schema(&self) -> &Schema {
96        match self {
97            RowEncoderWrapper::Json(json) => json.schema(),
98            RowEncoderWrapper::Proto(proto) => proto.schema(),
99        }
100    }
101
102    fn col_indices(&self) -> Option<&[usize]> {
103        match self {
104            RowEncoderWrapper::Json(json) => json.col_indices(),
105            RowEncoderWrapper::Proto(proto) => proto.col_indices(),
106        }
107    }
108
109    fn encode(&self, row: impl Row) -> Result<Self::Output> {
110        match self {
111            RowEncoderWrapper::Json(json) => json.encode(row)?.ser_to(),
112            RowEncoderWrapper::Proto(proto) => proto.encode(row)?.ser_to(),
113        }
114    }
115}
116
117#[derive(Clone, Debug)]
118pub struct MqttSink {
119    pub config: MqttConfig,
120    schema: Schema,
121    format_desc: SinkFormatDesc,
122    is_append_only: bool,
123    name: String,
124}
125
126impl EnforceSecret for MqttSink {
127    fn enforce_secret<'a>(
128        prop_iter: impl Iterator<Item = &'a str>,
129    ) -> crate::error::ConnectorResult<()> {
130        for prop in prop_iter {
131            MqttConfig::enforce_one(prop)?;
132        }
133        Ok(())
134    }
135}
136
137// sink write
138pub struct MqttSinkWriter {
139    pub config: MqttConfig,
140    payload_writer: MqttSinkPayloadWriter,
141    #[expect(dead_code)]
142    schema: Schema,
143    encoder: RowEncoderWrapper,
144    stopped: Arc<AtomicBool>,
145}
146
147/// Basic data types for use with the mqtt interface
148impl MqttConfig {
149    pub fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
150        let config = serde_json::from_value::<MqttConfig>(serde_json::to_value(values).unwrap())
151            .map_err(|e| SinkError::Config(anyhow!(e)))?;
152        if config.r#type != SINK_TYPE_APPEND_ONLY {
153            Err(SinkError::Config(anyhow!(
154                "MQTT sink only supports append-only mode"
155            )))
156        } else {
157            Ok(config)
158        }
159    }
160}
161
162impl TryFrom<SinkParam> for MqttSink {
163    type Error = SinkError;
164
165    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
166        let schema = param.schema();
167        let config = MqttConfig::from_btreemap(param.properties)?;
168        Ok(Self {
169            config,
170            schema,
171            name: param.sink_name,
172            format_desc: param
173                .format_desc
174                .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?,
175            is_append_only: param.sink_type.is_append_only(),
176        })
177    }
178}
179
180impl Sink for MqttSink {
181    type LogSinker = AsyncTruncateLogSinkerOf<MqttSinkWriter>;
182
183    const SINK_NAME: &'static str = MQTT_SINK;
184
185    async fn validate(&self) -> Result<()> {
186        if !self.is_append_only {
187            return Err(SinkError::Mqtt(anyhow!(
188                "MQTT sink only supports append-only mode"
189            )));
190        }
191
192        if let Some(field) = &self.config.topic_field {
193            let _ = get_topic_field_index_path(&self.schema, field.as_str())?;
194        } else if self.config.topic.is_none() {
195            return Err(SinkError::Config(anyhow!(
196                "either topic or topic.field must be set"
197            )));
198        }
199
200        let _client = (self.config.common.build_client(0.into(), 0))
201            .context("validate mqtt sink error")
202            .map_err(SinkError::Mqtt)?;
203
204        Ok(())
205    }
206
207    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
208        Ok(MqttSinkWriter::new(
209            self.config.clone(),
210            self.schema.clone(),
211            &self.format_desc,
212            &self.name,
213            writer_param.sink_id,
214            writer_param.actor_id,
215        )
216        .await?
217        .into_log_sinker(usize::MAX))
218    }
219}
220
221impl MqttSinkWriter {
222    pub async fn new(
223        config: MqttConfig,
224        schema: Schema,
225        format_desc: &SinkFormatDesc,
226        name: &str,
227        sink_id: SinkId,
228        actor_id: ActorId,
229    ) -> Result<Self> {
230        let mut topic_index_path = vec![];
231        if let Some(field) = &config.topic_field {
232            topic_index_path = get_topic_field_index_path(&schema, field.as_str())?;
233        }
234
235        let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?;
236        let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?;
237        let encoder = match format_desc.format {
238            SinkFormat::AppendOnly => match format_desc.encode {
239                SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new(
240                    schema.clone(),
241                    None,
242                    DateHandlingMode::FromCe,
243                    TimestampHandlingMode::Milli,
244                    timestamptz_mode,
245                    TimeHandlingMode::Milli,
246                    jsonb_handling_mode,
247                )),
248                SinkEncode::Protobuf => {
249                    let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor(
250                        &format_desc.options,
251                        config.topic.as_deref().unwrap_or(name),
252                        None,
253                    )
254                    .await
255                    .map_err(|e| SinkError::Config(anyhow!(e)))?;
256                    let header = match sid {
257                        None => ProtoHeader::None,
258                        Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid),
259                    };
260                    RowEncoderWrapper::Proto(ProtoEncoder::new(
261                        schema.clone(),
262                        None,
263                        descriptor,
264                        header,
265                    )?)
266                }
267                _ => {
268                    return Err(SinkError::Config(anyhow!(
269                        "mqtt sink encode unsupported: {:?}",
270                        format_desc.encode,
271                    )));
272                }
273            },
274            _ => {
275                return Err(SinkError::Config(anyhow!(
276                    "MQTT sink only supports append-only mode"
277                )));
278            }
279        };
280        let qos = config.common.qos();
281
282        let (client, mut eventloop) = config
283            .common
284            .build_client(actor_id, sink_id.as_raw_id())
285            .map_err(|e| SinkError::Mqtt(anyhow!(e)))?;
286
287        let stopped = Arc::new(AtomicBool::new(false));
288        let stopped_clone = stopped.clone();
289        tokio::spawn(async move {
290            while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
291                match eventloop.poll().await {
292                    Ok(_) => (),
293                    Err(err) => match err {
294                        ConnectionError::Timeout(_) => (),
295                        ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
296                        | ConnectionError::Io(err)
297                            if err.kind() == std::io::ErrorKind::ConnectionAborted
298                                || err.kind() == std::io::ErrorKind::ConnectionReset =>
299                        {
300                            continue;
301                        }
302                        err => {
303                            tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
304                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
305                        }
306                    },
307                }
308            }
309        });
310
311        let payload_writer = MqttSinkPayloadWriter {
312            topic: config.topic.clone(),
313            client,
314            qos,
315            retain: config.retain,
316            topic_index_path,
317        };
318
319        Ok::<_, SinkError>(Self {
320            config: config.clone(),
321            payload_writer,
322            schema: schema.clone(),
323            stopped,
324            encoder,
325        })
326    }
327}
328
329impl AsyncTruncateSinkWriter for MqttSinkWriter {
330    async fn write_chunk<'a>(
331        &'a mut self,
332        chunk: StreamChunk,
333        _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
334    ) -> Result<()> {
335        self.payload_writer.write_chunk(chunk, &self.encoder).await
336    }
337}
338
339impl Drop for MqttSinkWriter {
340    fn drop(&mut self) {
341        self.stopped
342            .store(true, std::sync::atomic::Ordering::Relaxed);
343    }
344}
345
346struct MqttSinkPayloadWriter {
347    // connection to mqtt, one per executor
348    client: rumqttc::v5::AsyncClient,
349    topic: Option<String>,
350    qos: QoS,
351    retain: bool,
352    topic_index_path: Vec<usize>,
353}
354
355impl MqttSinkPayloadWriter {
356    async fn write_chunk(&mut self, chunk: StreamChunk, encoder: &RowEncoderWrapper) -> Result<()> {
357        for (op, row) in chunk.rows() {
358            if op != Op::Insert {
359                continue;
360            }
361
362            let topic = match get_topic_from_index_path(
363                &self.topic_index_path,
364                self.topic.as_deref(),
365                &row,
366            ) {
367                Some(s) => s,
368                None => {
369                    tracing::error!("topic field not found in row, skipping: {:?}", row);
370                    return Ok(());
371                }
372            };
373
374            let v = encoder.encode(row)?;
375
376            self.client
377                .publish(topic, self.qos, self.retain, v)
378                .await
379                .context("mqtt sink error")
380                .map_err(SinkError::Mqtt)?;
381        }
382
383        Ok(())
384    }
385}
386
387fn get_topic_from_index_path<'s>(
388    path: &[usize],
389    default_topic: Option<&'s str>,
390    row: &'s RowRef<'s>,
391) -> Option<&'s str> {
392    if let Some(topic) = default_topic
393        && path.is_empty()
394    {
395        Some(topic)
396    } else {
397        let mut iter = path.iter();
398        let scalar = iter
399            .next()
400            .and_then(|pos| row.datum_at(*pos))
401            .and_then(|d| {
402                iter.try_fold(d, |d, pos| match d {
403                    ScalarRefImpl::Struct(struct_ref) => {
404                        struct_ref.iter_fields_ref().nth(*pos).flatten()
405                    }
406                    _ => None,
407                })
408            });
409        match scalar {
410            Some(ScalarRefImpl::Utf8(s)) => Some(s),
411            _ => {
412                if let Some(topic) = default_topic {
413                    Some(topic)
414                } else {
415                    None
416                }
417            }
418        }
419    }
420}
421
422// This function returns the index path to the topic field in the schema, validating that the field exists and is of type string
423// the returnent path can be used to extract the topic field from a row. The path is a list of indexes to be used to navigate the row
424// to the topic field.
425fn get_topic_field_index_path(schema: &Schema, topic_field: &str) -> Result<Vec<usize>> {
426    let mut iter = topic_field.split('.');
427    let mut path = vec![];
428    let dt =
429        iter.next()
430            .and_then(|field| {
431                // Extract the field from the schema
432                schema
433                    .fields()
434                    .iter()
435                    .enumerate()
436                    .find(|(_, f)| f.name == field)
437                    .map(|(pos, f)| {
438                        path.push(pos);
439                        &f.data_type
440                    })
441            })
442            .and_then(|dt| {
443                // Iterate over the next fields to extract the fields from the nested structs
444                iter.try_fold(dt, |dt, field| match dt {
445                    DataType::Struct(st) => {
446                        st.iter().enumerate().find(|(_, (s, _))| *s == field).map(
447                            |(pos, (_, dt))| {
448                                path.push(pos);
449                                dt
450                            },
451                        )
452                    }
453                    _ => None,
454                })
455            });
456
457    match dt {
458        Some(DataType::Varchar) => Ok(path),
459        Some(dt) => Err(SinkError::Config(anyhow!(
460            "topic field `{}` must be of type string but got {:?}",
461            topic_field,
462            dt
463        ))),
464        None => Err(SinkError::Config(anyhow!(
465            "topic field `{}`  not found",
466            topic_field
467        ))),
468    }
469}
470
471#[cfg(test)]
472mod test {
473    use risingwave_common::array::{DataChunk, DataChunkTestExt, RowRef};
474    use risingwave_common::catalog::{Field, Schema};
475    use risingwave_common::types::{DataType, StructType};
476
477    use super::{get_topic_field_index_path, get_topic_from_index_path};
478
479    #[test]
480    fn test_single_field_extraction() {
481        let schema = Schema::new(vec![Field::with_name(DataType::Varchar, "topic")]);
482        let path = get_topic_field_index_path(&schema, "topic").unwrap();
483        assert_eq!(path, vec![0]);
484
485        let chunk = DataChunk::from_pretty(
486            "T
487            test",
488        );
489
490        let row = RowRef::new(&chunk, 0);
491
492        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
493
494        let result = get_topic_field_index_path(&schema, "other_field");
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_nested_field_extraction() {
500        let schema = Schema::new(vec![Field::with_name(
501            DataType::Struct(StructType::new(vec![
502                ("field", DataType::Int32),
503                ("subtopic", DataType::Varchar),
504            ])),
505            "topic",
506        )]);
507        let path = get_topic_field_index_path(&schema, "topic.subtopic").unwrap();
508        assert_eq!(path, vec![0, 1]);
509
510        let chunk = DataChunk::from_pretty(
511            "<i,T>
512            (1,test)",
513        );
514
515        let row = RowRef::new(&chunk, 0);
516
517        assert_eq!(get_topic_from_index_path(&path, None, &row), Some("test"));
518
519        let result = get_topic_field_index_path(&schema, "topic.other_field");
520        assert!(result.is_err());
521    }
522
523    #[test]
524    fn test_null_values_extraction() {
525        let path = vec![0];
526        let chunk = DataChunk::from_pretty(
527            "T
528            .",
529        );
530        let row = RowRef::new(&chunk, 0);
531        assert_eq!(
532            get_topic_from_index_path(&path, Some("default"), &row),
533            Some("default")
534        );
535        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
536
537        let path = vec![0, 1];
538        let chunk = DataChunk::from_pretty(
539            "<i,T>
540            (1,)",
541        );
542        let row = RowRef::new(&chunk, 0);
543        assert_eq!(
544            get_topic_from_index_path(&path, Some("default"), &row),
545            Some("default")
546        );
547        assert_eq!(get_topic_from_index_path(&path, None, &row), None);
548    }
549
550    #[test]
551    fn test_multiple_levels() {
552        let schema = Schema::new(vec![
553            Field::with_name(
554                DataType::Struct(StructType::new(vec![
555                    ("field", DataType::Int32),
556                    (
557                        "subtopic",
558                        DataType::Struct(StructType::new(vec![
559                            ("int_field", DataType::Int32),
560                            ("boolean_field", DataType::Boolean),
561                            ("string_field", DataType::Varchar),
562                        ])),
563                    ),
564                ])),
565                "topic",
566            ),
567            Field::with_name(DataType::Varchar, "other_field"),
568        ]);
569
570        let path = get_topic_field_index_path(&schema, "topic.subtopic.string_field").unwrap();
571        assert_eq!(path, vec![0, 1, 2]);
572
573        assert!(get_topic_field_index_path(&schema, "topic.subtopic.boolean_field").is_err());
574
575        assert!(get_topic_field_index_path(&schema, "topic.subtopic.int_field").is_err());
576
577        assert!(get_topic_field_index_path(&schema, "topic.field").is_err());
578
579        let path = get_topic_field_index_path(&schema, "other_field").unwrap();
580        assert_eq!(path, vec![1]);
581
582        let chunk = DataChunk::from_pretty(
583            "<i,<T>> T
584            (1,(test)) other",
585        );
586
587        let row = RowRef::new(&chunk, 0);
588
589        // topic.subtopic.string_field
590        assert_eq!(
591            get_topic_from_index_path(&[0, 1, 0], None, &row),
592            Some("test")
593        );
594
595        // topic.field
596        assert_eq!(get_topic_from_index_path(&[0, 0], None, &row), None);
597
598        // other_field
599        assert_eq!(get_topic_from_index_path(&[1], None, &row), Some("other"));
600    }
601}