risingwave_connector/source/nats/
mod.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod enumerator;
16pub use enumerator::NatsSplitEnumerator;
17pub mod source;
18pub mod split;
19
20use std::collections::HashMap;
21use std::fmt::Display;
22use std::time::Duration;
23
24use anyhow::Context;
25use async_nats::jetstream::consumer::pull::Config;
26use async_nats::jetstream::consumer::{AckPolicy, ReplayPolicy};
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use thiserror::Error;
30use with_options::WithOptions;
31
32use crate::connector_common::NatsCommon;
33use crate::enforce_secret::EnforceSecret;
34use crate::error::{ConnectorError, ConnectorResult};
35use crate::source::SourceProperties;
36use crate::source::nats::source::{NatsSplit, NatsSplitReader};
37use crate::{
38    deserialize_optional_string_seq_from_string, deserialize_optional_u64_seq_from_string,
39};
40
41#[derive(Debug, Clone, Error)]
42pub struct NatsJetStreamError(String);
43
44impl Display for NatsJetStreamError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "{}", self.0)
47    }
48}
49
50pub const NATS_CONNECTOR: &str = "nats";
51
52pub struct AckPolicyWrapper;
53
54impl AckPolicyWrapper {
55    pub fn parse_str(s: &str) -> Result<AckPolicy, NatsJetStreamError> {
56        match s {
57            "none" => Ok(AckPolicy::None),
58            "all" => Ok(AckPolicy::All),
59            "explicit" => Ok(AckPolicy::Explicit),
60            _ => Err(NatsJetStreamError(format!(
61                "Invalid AckPolicy '{}', expect `none`, `all`, and `explicit`",
62                s
63            ))),
64        }
65    }
66}
67
68pub struct ReplayPolicyWrapper;
69
70impl ReplayPolicyWrapper {
71    pub fn parse_str(s: &str) -> Result<ReplayPolicy, NatsJetStreamError> {
72        match s {
73            "instant" => Ok(ReplayPolicy::Instant),
74            "original" => Ok(ReplayPolicy::Original),
75            _ => Err(NatsJetStreamError(format!(
76                "Invalid ReplayPolicy '{}', expect `instant` and `original`",
77                s
78            ))),
79        }
80    }
81}
82
83#[serde_as]
84#[derive(Clone, Debug, Deserialize, WithOptions)]
85pub struct NatsProperties {
86    #[serde(flatten)]
87    pub common: NatsCommon,
88
89    #[serde(flatten)]
90    pub nats_properties_consumer: NatsPropertiesConsumer,
91
92    #[serde(rename = "scan.startup.mode")]
93    pub scan_startup_mode: Option<String>,
94
95    #[serde(
96        rename = "scan.startup.timestamp.millis",
97        alias = "scan.startup.timestamp_millis"
98    )]
99    #[serde_as(as = "Option<DisplayFromStr>")]
100    pub start_timestamp_millis: Option<i64>,
101
102    #[serde(rename = "stream")]
103    pub stream: String,
104
105    #[serde(rename = "consumer.durable_name")]
106    pub durable_consumer_name: String,
107
108    #[serde(flatten)]
109    pub unknown_fields: HashMap<String, String>,
110}
111
112impl EnforceSecret for NatsProperties {
113    fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
114        for prop in prop_iter {
115            NatsCommon::enforce_one(prop)?;
116        }
117        Ok(())
118    }
119}
120
121impl NatsProperties {
122    pub fn set_config(&self, c: &mut Config) -> ConnectorResult<()> {
123        self.nats_properties_consumer.set_config(c)
124    }
125}
126
127/// Properties for the async-nats library.
128/// See <https://docs.rs/async-nats/latest/async_nats/jetstream/consumer/struct.Config.html>
129#[serde_as]
130#[derive(Clone, Debug, Deserialize, WithOptions)]
131pub struct NatsPropertiesConsumer {
132    #[serde(rename = "consumer.deliver_subject")]
133    pub deliver_subject: Option<String>,
134
135    #[serde(rename = "consumer.name")]
136    pub name: Option<String>,
137
138    #[serde(rename = "consumer.description")]
139    pub description: Option<String>,
140
141    #[serde(rename = "consumer.deliver_policy")]
142    #[serde_as(as = "Option<DisplayFromStr>")]
143    pub deliver_policy: Option<String>,
144
145    #[serde(rename = "consumer.ack_policy")]
146    #[serde_as(as = "Option<DisplayFromStr>")]
147    pub ack_policy: Option<String>,
148
149    #[serde(rename = "consumer.ack_wait.sec")]
150    #[serde_as(as = "Option<DisplayFromStr>")]
151    pub ack_wait: Option<u64>,
152
153    #[serde(rename = "consumer.max_deliver")]
154    #[serde_as(as = "Option<DisplayFromStr>")]
155    pub max_deliver: Option<i64>,
156
157    #[serde(rename = "consumer.filter_subject")]
158    pub filter_subject: Option<String>,
159
160    #[serde(
161        rename = "consumer.filter_subjects",
162        default,
163        deserialize_with = "deserialize_optional_string_seq_from_string"
164    )]
165    pub filter_subjects: Option<Vec<String>>,
166
167    #[serde(rename = "consumer.replay_policy")]
168    #[serde_as(as = "Option<DisplayFromStr>")]
169    pub replay_policy: Option<String>,
170
171    #[serde(rename = "consumer.rate_limit")]
172    #[serde_as(as = "Option<DisplayFromStr>")]
173    pub rate_limit: Option<u64>,
174
175    #[serde(rename = "consumer.sample_frequency")]
176    #[serde_as(as = "Option<DisplayFromStr>")]
177    pub sample_frequency: Option<u8>,
178
179    #[serde(rename = "consumer.max_waiting")]
180    #[serde_as(as = "Option<DisplayFromStr>")]
181    pub max_waiting: Option<i64>,
182
183    #[serde(rename = "consumer.max_ack_pending")]
184    #[serde_as(as = "Option<DisplayFromStr>")]
185    pub max_ack_pending: Option<i64>,
186
187    #[serde(rename = "consumer.headers_only")]
188    #[serde_as(as = "Option<DisplayFromStr>")]
189    pub headers_only: Option<bool>,
190
191    #[serde(rename = "consumer.max_batch")]
192    #[serde_as(as = "Option<DisplayFromStr>")]
193    pub max_batch: Option<i64>,
194
195    #[serde(rename = "consumer.max_bytes")]
196    #[serde_as(as = "Option<DisplayFromStr>")]
197    pub max_bytes: Option<i64>,
198
199    #[serde(rename = "consumer.max_expires.sec")]
200    #[serde_as(as = "Option<DisplayFromStr>")]
201    pub max_expires: Option<u64>,
202
203    #[serde(rename = "consumer.inactive_threshold.sec")]
204    #[serde_as(as = "Option<DisplayFromStr>")]
205    pub inactive_threshold: Option<u64>,
206
207    #[serde(rename = "consumer.num.replicas", alias = "consumer.num_replicas")]
208    #[serde_as(as = "Option<DisplayFromStr>")]
209    pub num_replicas: Option<usize>,
210
211    #[serde(rename = "consumer.memory_storage")]
212    #[serde_as(as = "Option<DisplayFromStr>")]
213    pub memory_storage: Option<bool>,
214
215    #[serde(
216        rename = "consumer.backoff.sec",
217        default,
218        deserialize_with = "deserialize_optional_u64_seq_from_string"
219    )]
220    pub backoff: Option<Vec<u64>>,
221}
222
223impl NatsPropertiesConsumer {
224    pub fn set_config(&self, c: &mut Config) -> ConnectorResult<()> {
225        if let Some(v) = &self.name {
226            c.name = Some(v.clone())
227        }
228        if let Some(v) = &self.description {
229            c.description = Some(v.clone())
230        }
231        c.ack_policy = self.effective_ack_policy()?;
232        if let Some(v) = &self.ack_wait {
233            c.ack_wait = Duration::from_secs(*v)
234        }
235        if let Some(v) = &self.max_deliver {
236            c.max_deliver = *v
237        }
238        if let Some(v) = &self.filter_subject {
239            c.filter_subject = v.clone()
240        }
241        if let Some(v) = &self.filter_subjects {
242            c.filter_subjects = v.clone()
243        }
244        if let Some(v) = &self.replay_policy {
245            c.replay_policy = ReplayPolicyWrapper::parse_str(v)
246                .with_context(|| format!("invalid value for `consumer.replay_policy`: {v}"))
247                .map_err(ConnectorError::from)?
248        }
249        if let Some(v) = &self.rate_limit {
250            c.rate_limit = *v
251        }
252        if let Some(v) = &self.sample_frequency {
253            c.sample_frequency = *v
254        }
255        if let Some(v) = &self.max_waiting {
256            c.max_waiting = *v
257        }
258        if let Some(v) = &self.max_ack_pending {
259            c.max_ack_pending = *v
260        }
261        if let Some(v) = &self.headers_only {
262            c.headers_only = *v
263        }
264        if let Some(v) = &self.max_batch {
265            c.max_batch = *v
266        }
267        if let Some(v) = &self.max_bytes {
268            c.max_bytes = *v
269        }
270        if let Some(v) = &self.max_expires {
271            c.max_expires = Duration::from_secs(*v)
272        }
273        if let Some(v) = &self.inactive_threshold {
274            c.inactive_threshold = Duration::from_secs(*v)
275        }
276        if let Some(v) = &self.num_replicas {
277            c.num_replicas = *v
278        }
279        if let Some(v) = &self.memory_storage {
280            c.memory_storage = *v
281        }
282        if let Some(v) = &self.backoff {
283            c.backoff = v.iter().map(|&x| Duration::from_secs(x)).collect()
284        }
285        Ok(())
286    }
287
288    fn effective_ack_policy(&self) -> ConnectorResult<AckPolicy> {
289        match &self.ack_policy {
290            Some(policy) => Ok(AckPolicyWrapper::parse_str(policy)
291                .with_context(|| format!("invalid value for `consumer.ack_policy`: {policy}"))
292                .map_err(ConnectorError::from)?),
293            None => Ok(AckPolicy::Explicit),
294        }
295    }
296
297    pub fn get_ack_policy(&self) -> ConnectorResult<AckPolicy> {
298        self.effective_ack_policy()
299    }
300}
301
302impl SourceProperties for NatsProperties {
303    type Split = NatsSplit;
304    type SplitEnumerator = NatsSplitEnumerator;
305    type SplitReader = NatsSplitReader;
306
307    const SOURCE_NAME: &'static str = NATS_CONNECTOR;
308}
309
310impl crate::source::UnknownFields for NatsProperties {
311    fn unknown_fields(&self) -> HashMap<String, String> {
312        self.unknown_fields.clone()
313    }
314}
315
316#[cfg(test)]
317mod test {
318    use std::collections::BTreeMap;
319
320    use async_nats::jetstream::consumer::pull::Config;
321    use maplit::btreemap;
322
323    use super::*;
324
325    fn parse_nats_properties(config: BTreeMap<String, String>) -> NatsProperties {
326        serde_json::from_value(
327            serde_json::to_value(config).expect("failed to serialize NATS test config"),
328        )
329        .expect("failed to deserialize NATS test config into NatsProperties")
330    }
331
332    fn base_config() -> BTreeMap<String, String> {
333        btreemap! {
334            "stream".to_owned() => "risingwave".to_owned(),
335
336            // NATS common
337            "subject".to_owned() => "subject1".to_owned(),
338            "server_url".to_owned() => "nats-server:4222".to_owned(),
339            "connect_mode".to_owned() => "plain".to_owned(),
340            "type".to_owned() => "append-only".to_owned(),
341
342            // NATS properties consumer
343            "consumer.name".to_owned() => "foobar".to_owned(),
344            "consumer.durable_name".to_owned() => "durable_foobar".to_owned(),
345            "consumer.description".to_owned() => "A description".to_owned(),
346            "consumer.ack_wait.sec".to_owned() => "10".to_owned(),
347            "consumer.max_deliver".to_owned() => "10".to_owned(),
348            "consumer.filter_subject".to_owned() => "subject".to_owned(),
349            "consumer.filter_subjects".to_owned() => "subject1,subject2".to_owned(),
350            "consumer.replay_policy".to_owned() => "instant".to_owned(),
351            "consumer.rate_limit".to_owned() => "100".to_owned(),
352            "consumer.sample_frequency".to_owned() => "1".to_owned(),
353            "consumer.max_waiting".to_owned() => "5".to_owned(),
354            "consumer.max_ack_pending".to_owned() => "100".to_owned(),
355            "consumer.headers_only".to_owned() => "true".to_owned(),
356            "consumer.max_batch".to_owned() => "10".to_owned(),
357            "consumer.max_bytes".to_owned() => "1024".to_owned(),
358            "consumer.max_expires.sec".to_owned() => "24".to_owned(),
359            "consumer.inactive_threshold.sec".to_owned() => "10".to_owned(),
360            "consumer.num_replicas".to_owned() => "3".to_owned(),
361            "consumer.memory_storage".to_owned() => "true".to_owned(),
362            "consumer.backoff.sec".to_owned() => "2,10,15".to_owned(),
363            "durable_consumer_name".to_owned() => "test_durable_consumer".to_owned(),
364        }
365    }
366
367    #[test]
368    fn test_parse_config_consumer() {
369        let mut config = base_config();
370        config.insert("consumer.ack_policy".to_owned(), "all".to_owned());
371
372        let props = parse_nats_properties(config);
373
374        assert_eq!(
375            props.nats_properties_consumer.name,
376            Some("foobar".to_owned())
377        );
378        assert_eq!(props.durable_consumer_name, "durable_foobar".to_owned());
379        assert_eq!(
380            props.nats_properties_consumer.description,
381            Some("A description".to_owned())
382        );
383        assert_eq!(
384            props.nats_properties_consumer.ack_policy,
385            Some("all".to_owned())
386        );
387        assert_eq!(props.nats_properties_consumer.ack_wait, Some(10));
388        assert_eq!(
389            props.nats_properties_consumer.filter_subjects,
390            Some(vec!["subject1".to_owned(), "subject2".to_owned()])
391        );
392        assert_eq!(
393            props.nats_properties_consumer.replay_policy,
394            Some("instant".to_owned())
395        );
396        assert_eq!(props.nats_properties_consumer.rate_limit, Some(100));
397        assert_eq!(props.nats_properties_consumer.sample_frequency, Some(1));
398        assert_eq!(props.nats_properties_consumer.max_waiting, Some(5));
399        assert_eq!(props.nats_properties_consumer.max_ack_pending, Some(100));
400        assert_eq!(props.nats_properties_consumer.headers_only, Some(true));
401        assert_eq!(props.nats_properties_consumer.max_batch, Some(10));
402        assert_eq!(props.nats_properties_consumer.max_bytes, Some(1024));
403        assert_eq!(props.nats_properties_consumer.max_expires, Some(24));
404        assert_eq!(props.nats_properties_consumer.inactive_threshold, Some(10));
405        assert_eq!(props.nats_properties_consumer.num_replicas, Some(3));
406        assert_eq!(props.nats_properties_consumer.memory_storage, Some(true));
407        assert_eq!(
408            props.nats_properties_consumer.backoff,
409            Some(vec![2, 10, 15])
410        );
411    }
412
413    #[test]
414    fn test_default_ack_policy_is_explicit() {
415        let props = parse_nats_properties(base_config());
416
417        let mut consumer_config = Config::default();
418        props.set_config(&mut consumer_config).unwrap();
419
420        assert_eq!(
421            props.nats_properties_consumer.get_ack_policy().unwrap(),
422            AckPolicy::Explicit
423        );
424        assert_eq!(consumer_config.ack_policy, AckPolicy::Explicit);
425    }
426
427    #[test]
428    fn test_ack_policy_none_is_preserved() {
429        let mut config = base_config();
430        config.insert("consumer.ack_policy".to_owned(), "none".to_owned());
431        let props = parse_nats_properties(config);
432
433        let mut consumer_config = Config::default();
434        props.set_config(&mut consumer_config).unwrap();
435
436        assert_eq!(
437            props.nats_properties_consumer.get_ack_policy().unwrap(),
438            AckPolicy::None
439        );
440        assert_eq!(consumer_config.ack_policy, AckPolicy::None);
441    }
442
443    #[test]
444    fn test_invalid_ack_policy_returns_error() {
445        let mut config = base_config();
446        config.insert("consumer.ack_policy".to_owned(), "invalid".to_owned());
447        let props = parse_nats_properties(config);
448
449        let err = props.nats_properties_consumer.get_ack_policy().unwrap_err();
450        let err_message = format!("{err:#}");
451        assert!(
452            err_message.contains("Invalid AckPolicy 'invalid'"),
453            "unexpected error: {err_message}"
454        );
455
456        let mut consumer_config = Config::default();
457        let err = props.set_config(&mut consumer_config).unwrap_err();
458        let err_message = format!("{err:#}");
459        assert!(
460            err_message.contains("Invalid AckPolicy 'invalid'"),
461            "unexpected error: {err_message}"
462        );
463    }
464}