risingwave_connector/source/nats/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod enumerator;
16pub use enumerator::NatsSplitEnumerator;
17pub mod source;
18pub mod split;
19
20use std::collections::HashMap;
21use std::fmt::Display;
22use std::time::Duration;
23
24use async_nats::jetstream::consumer::pull::Config;
25use async_nats::jetstream::consumer::{AckPolicy, ReplayPolicy};
26use serde::Deserialize;
27use serde_with::{DisplayFromStr, serde_as};
28use thiserror::Error;
29use with_options::WithOptions;
30
31use crate::connector_common::NatsCommon;
32use crate::error::{ConnectorError, ConnectorResult};
33use crate::source::SourceProperties;
34use crate::source::nats::source::{NatsSplit, NatsSplitReader};
35use crate::{
36    deserialize_optional_string_seq_from_string, deserialize_optional_u64_seq_from_string,
37};
38
39#[derive(Debug, Clone, Error)]
40pub struct NatsJetStreamError(String);
41
42impl Display for NatsJetStreamError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{}", self.0)
45    }
46}
47
48pub const NATS_CONNECTOR: &str = "nats";
49
50pub struct AckPolicyWrapper;
51
52impl AckPolicyWrapper {
53    pub fn parse_str(s: &str) -> Result<AckPolicy, NatsJetStreamError> {
54        match s {
55            "none" => Ok(AckPolicy::None),
56            "all" => Ok(AckPolicy::All),
57            "explicit" => Ok(AckPolicy::Explicit),
58            _ => Err(NatsJetStreamError(format!(
59                "Invalid AckPolicy '{}', expect `none`, `all`, and `explicit`",
60                s
61            ))),
62        }
63    }
64}
65
66pub struct ReplayPolicyWrapper;
67
68impl ReplayPolicyWrapper {
69    pub fn parse_str(s: &str) -> Result<ReplayPolicy, NatsJetStreamError> {
70        match s {
71            "instant" => Ok(ReplayPolicy::Instant),
72            "original" => Ok(ReplayPolicy::Original),
73            _ => Err(NatsJetStreamError(format!(
74                "Invalid ReplayPolicy '{}', expect `instant` and `original`",
75                s
76            ))),
77        }
78    }
79}
80
81#[serde_as]
82#[derive(Clone, Debug, Deserialize, WithOptions)]
83pub struct NatsProperties {
84    #[serde(flatten)]
85    pub common: NatsCommon,
86
87    #[serde(flatten)]
88    pub nats_properties_consumer: NatsPropertiesConsumer,
89
90    #[serde(rename = "scan.startup.mode")]
91    pub scan_startup_mode: Option<String>,
92
93    #[serde(
94        rename = "scan.startup.timestamp.millis",
95        alias = "scan.startup.timestamp_millis"
96    )]
97    #[serde_as(as = "Option<DisplayFromStr>")]
98    pub start_timestamp_millis: Option<i64>,
99
100    #[serde(rename = "stream")]
101    pub stream: String,
102
103    #[serde(rename = "consumer.durable_name")]
104    pub durable_consumer_name: String,
105
106    #[serde(flatten)]
107    pub unknown_fields: HashMap<String, String>,
108}
109
110impl NatsProperties {
111    pub fn set_config(&self, c: &mut Config) {
112        self.nats_properties_consumer.set_config(c);
113    }
114}
115
116/// Properties for the async-nats library.
117/// See <https://docs.rs/async-nats/latest/async_nats/jetstream/consumer/struct.Config.html>
118#[serde_as]
119#[derive(Clone, Debug, Deserialize, WithOptions)]
120pub struct NatsPropertiesConsumer {
121    #[serde(rename = "consumer.deliver_subject")]
122    pub deliver_subject: Option<String>,
123
124    #[serde(rename = "consumer.name")]
125    pub name: Option<String>,
126
127    #[serde(rename = "consumer.description")]
128    pub description: Option<String>,
129
130    #[serde(rename = "consumer.deliver_policy")]
131    #[serde_as(as = "Option<DisplayFromStr>")]
132    pub deliver_policy: Option<String>,
133
134    #[serde(rename = "consumer.ack_policy")]
135    #[serde_as(as = "Option<DisplayFromStr>")]
136    pub ack_policy: Option<String>,
137
138    #[serde(rename = "consumer.ack_wait.sec")]
139    #[serde_as(as = "Option<DisplayFromStr>")]
140    pub ack_wait: Option<u64>,
141
142    #[serde(rename = "consumer.max_deliver")]
143    #[serde_as(as = "Option<DisplayFromStr>")]
144    pub max_deliver: Option<i64>,
145
146    #[serde(rename = "consumer.filter_subject")]
147    pub filter_subject: Option<String>,
148
149    #[serde(
150        rename = "consumer.filter_subjects",
151        default,
152        deserialize_with = "deserialize_optional_string_seq_from_string"
153    )]
154    pub filter_subjects: Option<Vec<String>>,
155
156    #[serde(rename = "consumer.replay_policy")]
157    #[serde_as(as = "Option<DisplayFromStr>")]
158    pub replay_policy: Option<String>,
159
160    #[serde(rename = "consumer.rate_limit")]
161    #[serde_as(as = "Option<DisplayFromStr>")]
162    pub rate_limit: Option<u64>,
163
164    #[serde(rename = "consumer.sample_frequency")]
165    #[serde_as(as = "Option<DisplayFromStr>")]
166    pub sample_frequency: Option<u8>,
167
168    #[serde(rename = "consumer.max_waiting")]
169    #[serde_as(as = "Option<DisplayFromStr>")]
170    pub max_waiting: Option<i64>,
171
172    #[serde(rename = "consumer.max_ack_pending")]
173    #[serde_as(as = "Option<DisplayFromStr>")]
174    pub max_ack_pending: Option<i64>,
175
176    #[serde(rename = "consumer.headers_only")]
177    #[serde_as(as = "Option<DisplayFromStr>")]
178    pub headers_only: Option<bool>,
179
180    #[serde(rename = "consumer.max_batch")]
181    #[serde_as(as = "Option<DisplayFromStr>")]
182    pub max_batch: Option<i64>,
183
184    #[serde(rename = "consumer.max_bytes")]
185    #[serde_as(as = "Option<DisplayFromStr>")]
186    pub max_bytes: Option<i64>,
187
188    #[serde(rename = "consumer.max_expires.sec")]
189    #[serde_as(as = "Option<DisplayFromStr>")]
190    pub max_expires: Option<u64>,
191
192    #[serde(rename = "consumer.inactive_threshold.sec")]
193    #[serde_as(as = "Option<DisplayFromStr>")]
194    pub inactive_threshold: Option<u64>,
195
196    #[serde(rename = "consumer.num.replicas", alias = "consumer.num_replicas")]
197    #[serde_as(as = "Option<DisplayFromStr>")]
198    pub num_replicas: Option<usize>,
199
200    #[serde(rename = "consumer.memory_storage")]
201    #[serde_as(as = "Option<DisplayFromStr>")]
202    pub memory_storage: Option<bool>,
203
204    #[serde(
205        rename = "consumer.backoff.sec",
206        default,
207        deserialize_with = "deserialize_optional_u64_seq_from_string"
208    )]
209    pub backoff: Option<Vec<u64>>,
210}
211
212impl NatsPropertiesConsumer {
213    pub fn set_config(&self, c: &mut Config) {
214        if let Some(v) = &self.name {
215            c.name = Some(v.clone())
216        }
217        if let Some(v) = &self.description {
218            c.description = Some(v.clone())
219        }
220        if let Some(v) = &self.ack_policy {
221            c.ack_policy = AckPolicyWrapper::parse_str(v).unwrap()
222        }
223        if let Some(v) = &self.ack_wait {
224            c.ack_wait = Duration::from_secs(*v)
225        }
226        if let Some(v) = &self.max_deliver {
227            c.max_deliver = *v
228        }
229        if let Some(v) = &self.filter_subject {
230            c.filter_subject = v.clone()
231        }
232        if let Some(v) = &self.filter_subjects {
233            c.filter_subjects = v.clone()
234        }
235        if let Some(v) = &self.replay_policy {
236            c.replay_policy = ReplayPolicyWrapper::parse_str(v).unwrap()
237        }
238        if let Some(v) = &self.rate_limit {
239            c.rate_limit = *v
240        }
241        if let Some(v) = &self.sample_frequency {
242            c.sample_frequency = *v
243        }
244        if let Some(v) = &self.max_waiting {
245            c.max_waiting = *v
246        }
247        if let Some(v) = &self.max_ack_pending {
248            c.max_ack_pending = *v
249        }
250        if let Some(v) = &self.headers_only {
251            c.headers_only = *v
252        }
253        if let Some(v) = &self.max_batch {
254            c.max_batch = *v
255        }
256        if let Some(v) = &self.max_bytes {
257            c.max_bytes = *v
258        }
259        if let Some(v) = &self.max_expires {
260            c.max_expires = Duration::from_secs(*v)
261        }
262        if let Some(v) = &self.inactive_threshold {
263            c.inactive_threshold = Duration::from_secs(*v)
264        }
265        if let Some(v) = &self.num_replicas {
266            c.num_replicas = *v
267        }
268        if let Some(v) = &self.memory_storage {
269            c.memory_storage = *v
270        }
271        if let Some(v) = &self.backoff {
272            c.backoff = v.iter().map(|&x| Duration::from_secs(x)).collect()
273        }
274    }
275
276    pub fn get_ack_policy(&self) -> ConnectorResult<AckPolicy> {
277        match &self.ack_policy {
278            Some(policy) => Ok(AckPolicyWrapper::parse_str(policy).map_err(ConnectorError::from)?),
279            None => Ok(AckPolicy::None),
280        }
281    }
282}
283
284impl SourceProperties for NatsProperties {
285    type Split = NatsSplit;
286    type SplitEnumerator = NatsSplitEnumerator;
287    type SplitReader = NatsSplitReader;
288
289    const SOURCE_NAME: &'static str = NATS_CONNECTOR;
290}
291
292impl crate::source::UnknownFields for NatsProperties {
293    fn unknown_fields(&self) -> HashMap<String, String> {
294        self.unknown_fields.clone()
295    }
296}
297
298#[cfg(test)]
299mod test {
300    use std::collections::BTreeMap;
301
302    use maplit::btreemap;
303
304    use super::*;
305
306    #[test]
307    fn test_parse_config_consumer() {
308        let config: BTreeMap<String, String> = btreemap! {
309            "stream".to_owned() => "risingwave".to_owned(),
310
311            // NATS common
312            "subject".to_owned() => "subject1".to_owned(),
313            "server_url".to_owned() => "nats-server:4222".to_owned(),
314            "connect_mode".to_owned() => "plain".to_owned(),
315            "type".to_owned() => "append-only".to_owned(),
316
317            // NATS properties consumer
318            "consumer.name".to_owned() => "foobar".to_owned(),
319            "consumer.durable_name".to_owned() => "durable_foobar".to_owned(),
320            "consumer.description".to_owned() => "A description".to_owned(),
321            "consumer.ack_policy".to_owned() => "all".to_owned(),
322            "consumer.ack_wait.sec".to_owned() => "10".to_owned(),
323            "consumer.max_deliver".to_owned() => "10".to_owned(),
324            "consumer.filter_subject".to_owned() => "subject".to_owned(),
325            "consumer.filter_subjects".to_owned() => "subject1,subject2".to_owned(),
326            "consumer.replay_policy".to_owned() => "instant".to_owned(),
327            "consumer.rate_limit".to_owned() => "100".to_owned(),
328            "consumer.sample_frequency".to_owned() => "1".to_owned(),
329            "consumer.max_waiting".to_owned() => "5".to_owned(),
330            "consumer.max_ack_pending".to_owned() => "100".to_owned(),
331            "consumer.headers_only".to_owned() => "true".to_owned(),
332            "consumer.max_batch".to_owned() => "10".to_owned(),
333            "consumer.max_bytes".to_owned() => "1024".to_owned(),
334            "consumer.max_expires.sec".to_owned() => "24".to_owned(),
335            "consumer.inactive_threshold.sec".to_owned() => "10".to_owned(),
336            "consumer.num_replicas".to_owned() => "3".to_owned(),
337            "consumer.memory_storage".to_owned() => "true".to_owned(),
338            "consumer.backoff.sec".to_owned() => "2,10,15".to_owned(),
339            "durable_consumer_name".to_owned() => "test_durable_consumer".to_owned(),
340
341        };
342
343        let props: NatsProperties =
344            serde_json::from_value(serde_json::to_value(config).unwrap()).unwrap();
345
346        assert_eq!(
347            props.nats_properties_consumer.name,
348            Some("foobar".to_owned())
349        );
350        assert_eq!(props.durable_consumer_name, "durable_foobar".to_owned());
351        assert_eq!(
352            props.nats_properties_consumer.description,
353            Some("A description".to_owned())
354        );
355        assert_eq!(
356            props.nats_properties_consumer.ack_policy,
357            Some("all".to_owned())
358        );
359        assert_eq!(props.nats_properties_consumer.ack_wait, Some(10));
360        assert_eq!(
361            props.nats_properties_consumer.filter_subjects,
362            Some(vec!["subject1".to_owned(), "subject2".to_owned()])
363        );
364        assert_eq!(
365            props.nats_properties_consumer.replay_policy,
366            Some("instant".to_owned())
367        );
368        assert_eq!(props.nats_properties_consumer.rate_limit, Some(100));
369        assert_eq!(props.nats_properties_consumer.sample_frequency, Some(1));
370        assert_eq!(props.nats_properties_consumer.max_waiting, Some(5));
371        assert_eq!(props.nats_properties_consumer.max_ack_pending, Some(100));
372        assert_eq!(props.nats_properties_consumer.headers_only, Some(true));
373        assert_eq!(props.nats_properties_consumer.max_batch, Some(10));
374        assert_eq!(props.nats_properties_consumer.max_bytes, Some(1024));
375        assert_eq!(props.nats_properties_consumer.max_expires, Some(24));
376        assert_eq!(props.nats_properties_consumer.inactive_threshold, Some(10));
377        assert_eq!(props.nats_properties_consumer.num_replicas, Some(3));
378        assert_eq!(props.nats_properties_consumer.memory_storage, Some(true));
379        assert_eq!(
380            props.nats_properties_consumer.backoff,
381            Some(vec![2, 10, 15])
382        );
383    }
384}