risingwave_connector/source/nats/
mod.rs1pub mod enumerator;
16pub use enumerator::NatsSplitEnumerator;
17pub mod source;
18pub mod split;
19
20use std::collections::HashMap;
21use std::fmt::Display;
22use std::time::Duration;
23
24use anyhow::Context;
25use async_nats::jetstream::consumer::pull::Config;
26use async_nats::jetstream::consumer::{AckPolicy, ReplayPolicy};
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use thiserror::Error;
30use with_options::WithOptions;
31
32use crate::connector_common::NatsCommon;
33use crate::enforce_secret::EnforceSecret;
34use crate::error::{ConnectorError, ConnectorResult};
35use crate::source::SourceProperties;
36use crate::source::nats::source::{NatsSplit, NatsSplitReader};
37use crate::{
38 deserialize_optional_string_seq_from_string, deserialize_optional_u64_seq_from_string,
39};
40
41#[derive(Debug, Clone, Error)]
42pub struct NatsJetStreamError(String);
43
44impl Display for NatsJetStreamError {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(f, "{}", self.0)
47 }
48}
49
50pub const NATS_CONNECTOR: &str = "nats";
51
52pub struct AckPolicyWrapper;
53
54impl AckPolicyWrapper {
55 pub fn parse_str(s: &str) -> Result<AckPolicy, NatsJetStreamError> {
56 match s {
57 "none" => Ok(AckPolicy::None),
58 "all" => Ok(AckPolicy::All),
59 "explicit" => Ok(AckPolicy::Explicit),
60 _ => Err(NatsJetStreamError(format!(
61 "Invalid AckPolicy '{}', expect `none`, `all`, and `explicit`",
62 s
63 ))),
64 }
65 }
66}
67
68pub struct ReplayPolicyWrapper;
69
70impl ReplayPolicyWrapper {
71 pub fn parse_str(s: &str) -> Result<ReplayPolicy, NatsJetStreamError> {
72 match s {
73 "instant" => Ok(ReplayPolicy::Instant),
74 "original" => Ok(ReplayPolicy::Original),
75 _ => Err(NatsJetStreamError(format!(
76 "Invalid ReplayPolicy '{}', expect `instant` and `original`",
77 s
78 ))),
79 }
80 }
81}
82
83#[serde_as]
84#[derive(Clone, Debug, Deserialize, WithOptions)]
85pub struct NatsProperties {
86 #[serde(flatten)]
87 pub common: NatsCommon,
88
89 #[serde(flatten)]
90 pub nats_properties_consumer: NatsPropertiesConsumer,
91
92 #[serde(rename = "scan.startup.mode")]
93 pub scan_startup_mode: Option<String>,
94
95 #[serde(
96 rename = "scan.startup.timestamp.millis",
97 alias = "scan.startup.timestamp_millis"
98 )]
99 #[serde_as(as = "Option<DisplayFromStr>")]
100 pub start_timestamp_millis: Option<i64>,
101
102 #[serde(rename = "stream")]
103 pub stream: String,
104
105 #[serde(rename = "consumer.durable_name")]
106 pub durable_consumer_name: String,
107
108 #[serde(flatten)]
109 pub unknown_fields: HashMap<String, String>,
110}
111
112impl EnforceSecret for NatsProperties {
113 fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
114 for prop in prop_iter {
115 NatsCommon::enforce_one(prop)?;
116 }
117 Ok(())
118 }
119}
120
121impl NatsProperties {
122 pub fn set_config(&self, c: &mut Config) -> ConnectorResult<()> {
123 self.nats_properties_consumer.set_config(c)
124 }
125}
126
127#[serde_as]
130#[derive(Clone, Debug, Deserialize, WithOptions)]
131pub struct NatsPropertiesConsumer {
132 #[serde(rename = "consumer.deliver_subject")]
133 pub deliver_subject: Option<String>,
134
135 #[serde(rename = "consumer.name")]
136 pub name: Option<String>,
137
138 #[serde(rename = "consumer.description")]
139 pub description: Option<String>,
140
141 #[serde(rename = "consumer.deliver_policy")]
142 #[serde_as(as = "Option<DisplayFromStr>")]
143 pub deliver_policy: Option<String>,
144
145 #[serde(rename = "consumer.ack_policy")]
146 #[serde_as(as = "Option<DisplayFromStr>")]
147 pub ack_policy: Option<String>,
148
149 #[serde(rename = "consumer.ack_wait.sec")]
150 #[serde_as(as = "Option<DisplayFromStr>")]
151 pub ack_wait: Option<u64>,
152
153 #[serde(rename = "consumer.max_deliver")]
154 #[serde_as(as = "Option<DisplayFromStr>")]
155 pub max_deliver: Option<i64>,
156
157 #[serde(rename = "consumer.filter_subject")]
158 pub filter_subject: Option<String>,
159
160 #[serde(
161 rename = "consumer.filter_subjects",
162 default,
163 deserialize_with = "deserialize_optional_string_seq_from_string"
164 )]
165 pub filter_subjects: Option<Vec<String>>,
166
167 #[serde(rename = "consumer.replay_policy")]
168 #[serde_as(as = "Option<DisplayFromStr>")]
169 pub replay_policy: Option<String>,
170
171 #[serde(rename = "consumer.rate_limit")]
172 #[serde_as(as = "Option<DisplayFromStr>")]
173 pub rate_limit: Option<u64>,
174
175 #[serde(rename = "consumer.sample_frequency")]
176 #[serde_as(as = "Option<DisplayFromStr>")]
177 pub sample_frequency: Option<u8>,
178
179 #[serde(rename = "consumer.max_waiting")]
180 #[serde_as(as = "Option<DisplayFromStr>")]
181 pub max_waiting: Option<i64>,
182
183 #[serde(rename = "consumer.max_ack_pending")]
184 #[serde_as(as = "Option<DisplayFromStr>")]
185 pub max_ack_pending: Option<i64>,
186
187 #[serde(rename = "consumer.headers_only")]
188 #[serde_as(as = "Option<DisplayFromStr>")]
189 pub headers_only: Option<bool>,
190
191 #[serde(rename = "consumer.max_batch")]
192 #[serde_as(as = "Option<DisplayFromStr>")]
193 pub max_batch: Option<i64>,
194
195 #[serde(rename = "consumer.max_bytes")]
196 #[serde_as(as = "Option<DisplayFromStr>")]
197 pub max_bytes: Option<i64>,
198
199 #[serde(rename = "consumer.max_expires.sec")]
200 #[serde_as(as = "Option<DisplayFromStr>")]
201 pub max_expires: Option<u64>,
202
203 #[serde(rename = "consumer.inactive_threshold.sec")]
204 #[serde_as(as = "Option<DisplayFromStr>")]
205 pub inactive_threshold: Option<u64>,
206
207 #[serde(rename = "consumer.num.replicas", alias = "consumer.num_replicas")]
208 #[serde_as(as = "Option<DisplayFromStr>")]
209 pub num_replicas: Option<usize>,
210
211 #[serde(rename = "consumer.memory_storage")]
212 #[serde_as(as = "Option<DisplayFromStr>")]
213 pub memory_storage: Option<bool>,
214
215 #[serde(
216 rename = "consumer.backoff.sec",
217 default,
218 deserialize_with = "deserialize_optional_u64_seq_from_string"
219 )]
220 pub backoff: Option<Vec<u64>>,
221}
222
223impl NatsPropertiesConsumer {
224 pub fn set_config(&self, c: &mut Config) -> ConnectorResult<()> {
225 if let Some(v) = &self.name {
226 c.name = Some(v.clone())
227 }
228 if let Some(v) = &self.description {
229 c.description = Some(v.clone())
230 }
231 c.ack_policy = self.effective_ack_policy()?;
232 if let Some(v) = &self.ack_wait {
233 c.ack_wait = Duration::from_secs(*v)
234 }
235 if let Some(v) = &self.max_deliver {
236 c.max_deliver = *v
237 }
238 if let Some(v) = &self.filter_subject {
239 c.filter_subject = v.clone()
240 }
241 if let Some(v) = &self.filter_subjects {
242 c.filter_subjects = v.clone()
243 }
244 if let Some(v) = &self.replay_policy {
245 c.replay_policy = ReplayPolicyWrapper::parse_str(v)
246 .with_context(|| format!("invalid value for `consumer.replay_policy`: {v}"))
247 .map_err(ConnectorError::from)?
248 }
249 if let Some(v) = &self.rate_limit {
250 c.rate_limit = *v
251 }
252 if let Some(v) = &self.sample_frequency {
253 c.sample_frequency = *v
254 }
255 if let Some(v) = &self.max_waiting {
256 c.max_waiting = *v
257 }
258 if let Some(v) = &self.max_ack_pending {
259 c.max_ack_pending = *v
260 }
261 if let Some(v) = &self.headers_only {
262 c.headers_only = *v
263 }
264 if let Some(v) = &self.max_batch {
265 c.max_batch = *v
266 }
267 if let Some(v) = &self.max_bytes {
268 c.max_bytes = *v
269 }
270 if let Some(v) = &self.max_expires {
271 c.max_expires = Duration::from_secs(*v)
272 }
273 if let Some(v) = &self.inactive_threshold {
274 c.inactive_threshold = Duration::from_secs(*v)
275 }
276 if let Some(v) = &self.num_replicas {
277 c.num_replicas = *v
278 }
279 if let Some(v) = &self.memory_storage {
280 c.memory_storage = *v
281 }
282 if let Some(v) = &self.backoff {
283 c.backoff = v.iter().map(|&x| Duration::from_secs(x)).collect()
284 }
285 Ok(())
286 }
287
288 fn effective_ack_policy(&self) -> ConnectorResult<AckPolicy> {
289 match &self.ack_policy {
290 Some(policy) => Ok(AckPolicyWrapper::parse_str(policy)
291 .with_context(|| format!("invalid value for `consumer.ack_policy`: {policy}"))
292 .map_err(ConnectorError::from)?),
293 None => Ok(AckPolicy::Explicit),
294 }
295 }
296
297 pub fn get_ack_policy(&self) -> ConnectorResult<AckPolicy> {
298 self.effective_ack_policy()
299 }
300}
301
302impl SourceProperties for NatsProperties {
303 type Split = NatsSplit;
304 type SplitEnumerator = NatsSplitEnumerator;
305 type SplitReader = NatsSplitReader;
306
307 const SOURCE_NAME: &'static str = NATS_CONNECTOR;
308}
309
310impl crate::source::UnknownFields for NatsProperties {
311 fn unknown_fields(&self) -> HashMap<String, String> {
312 self.unknown_fields.clone()
313 }
314}
315
316#[cfg(test)]
317mod test {
318 use std::collections::BTreeMap;
319
320 use async_nats::jetstream::consumer::pull::Config;
321 use maplit::btreemap;
322
323 use super::*;
324
325 fn parse_nats_properties(config: BTreeMap<String, String>) -> NatsProperties {
326 serde_json::from_value(
327 serde_json::to_value(config).expect("failed to serialize NATS test config"),
328 )
329 .expect("failed to deserialize NATS test config into NatsProperties")
330 }
331
332 fn base_config() -> BTreeMap<String, String> {
333 btreemap! {
334 "stream".to_owned() => "risingwave".to_owned(),
335
336 "subject".to_owned() => "subject1".to_owned(),
338 "server_url".to_owned() => "nats-server:4222".to_owned(),
339 "connect_mode".to_owned() => "plain".to_owned(),
340 "type".to_owned() => "append-only".to_owned(),
341
342 "consumer.name".to_owned() => "foobar".to_owned(),
344 "consumer.durable_name".to_owned() => "durable_foobar".to_owned(),
345 "consumer.description".to_owned() => "A description".to_owned(),
346 "consumer.ack_wait.sec".to_owned() => "10".to_owned(),
347 "consumer.max_deliver".to_owned() => "10".to_owned(),
348 "consumer.filter_subject".to_owned() => "subject".to_owned(),
349 "consumer.filter_subjects".to_owned() => "subject1,subject2".to_owned(),
350 "consumer.replay_policy".to_owned() => "instant".to_owned(),
351 "consumer.rate_limit".to_owned() => "100".to_owned(),
352 "consumer.sample_frequency".to_owned() => "1".to_owned(),
353 "consumer.max_waiting".to_owned() => "5".to_owned(),
354 "consumer.max_ack_pending".to_owned() => "100".to_owned(),
355 "consumer.headers_only".to_owned() => "true".to_owned(),
356 "consumer.max_batch".to_owned() => "10".to_owned(),
357 "consumer.max_bytes".to_owned() => "1024".to_owned(),
358 "consumer.max_expires.sec".to_owned() => "24".to_owned(),
359 "consumer.inactive_threshold.sec".to_owned() => "10".to_owned(),
360 "consumer.num_replicas".to_owned() => "3".to_owned(),
361 "consumer.memory_storage".to_owned() => "true".to_owned(),
362 "consumer.backoff.sec".to_owned() => "2,10,15".to_owned(),
363 "durable_consumer_name".to_owned() => "test_durable_consumer".to_owned(),
364 }
365 }
366
367 #[test]
368 fn test_parse_config_consumer() {
369 let mut config = base_config();
370 config.insert("consumer.ack_policy".to_owned(), "all".to_owned());
371
372 let props = parse_nats_properties(config);
373
374 assert_eq!(
375 props.nats_properties_consumer.name,
376 Some("foobar".to_owned())
377 );
378 assert_eq!(props.durable_consumer_name, "durable_foobar".to_owned());
379 assert_eq!(
380 props.nats_properties_consumer.description,
381 Some("A description".to_owned())
382 );
383 assert_eq!(
384 props.nats_properties_consumer.ack_policy,
385 Some("all".to_owned())
386 );
387 assert_eq!(props.nats_properties_consumer.ack_wait, Some(10));
388 assert_eq!(
389 props.nats_properties_consumer.filter_subjects,
390 Some(vec!["subject1".to_owned(), "subject2".to_owned()])
391 );
392 assert_eq!(
393 props.nats_properties_consumer.replay_policy,
394 Some("instant".to_owned())
395 );
396 assert_eq!(props.nats_properties_consumer.rate_limit, Some(100));
397 assert_eq!(props.nats_properties_consumer.sample_frequency, Some(1));
398 assert_eq!(props.nats_properties_consumer.max_waiting, Some(5));
399 assert_eq!(props.nats_properties_consumer.max_ack_pending, Some(100));
400 assert_eq!(props.nats_properties_consumer.headers_only, Some(true));
401 assert_eq!(props.nats_properties_consumer.max_batch, Some(10));
402 assert_eq!(props.nats_properties_consumer.max_bytes, Some(1024));
403 assert_eq!(props.nats_properties_consumer.max_expires, Some(24));
404 assert_eq!(props.nats_properties_consumer.inactive_threshold, Some(10));
405 assert_eq!(props.nats_properties_consumer.num_replicas, Some(3));
406 assert_eq!(props.nats_properties_consumer.memory_storage, Some(true));
407 assert_eq!(
408 props.nats_properties_consumer.backoff,
409 Some(vec![2, 10, 15])
410 );
411 }
412
413 #[test]
414 fn test_default_ack_policy_is_explicit() {
415 let props = parse_nats_properties(base_config());
416
417 let mut consumer_config = Config::default();
418 props.set_config(&mut consumer_config).unwrap();
419
420 assert_eq!(
421 props.nats_properties_consumer.get_ack_policy().unwrap(),
422 AckPolicy::Explicit
423 );
424 assert_eq!(consumer_config.ack_policy, AckPolicy::Explicit);
425 }
426
427 #[test]
428 fn test_ack_policy_none_is_preserved() {
429 let mut config = base_config();
430 config.insert("consumer.ack_policy".to_owned(), "none".to_owned());
431 let props = parse_nats_properties(config);
432
433 let mut consumer_config = Config::default();
434 props.set_config(&mut consumer_config).unwrap();
435
436 assert_eq!(
437 props.nats_properties_consumer.get_ack_policy().unwrap(),
438 AckPolicy::None
439 );
440 assert_eq!(consumer_config.ack_policy, AckPolicy::None);
441 }
442
443 #[test]
444 fn test_invalid_ack_policy_returns_error() {
445 let mut config = base_config();
446 config.insert("consumer.ack_policy".to_owned(), "invalid".to_owned());
447 let props = parse_nats_properties(config);
448
449 let err = props.nats_properties_consumer.get_ack_policy().unwrap_err();
450 let err_message = format!("{err:#}");
451 assert!(
452 err_message.contains("Invalid AckPolicy 'invalid'"),
453 "unexpected error: {err_message}"
454 );
455
456 let mut consumer_config = Config::default();
457 let err = props.set_config(&mut consumer_config).unwrap_err();
458 let err_message = format!("{err:#}");
459 assert!(
460 err_message.contains("Invalid AckPolicy 'invalid'"),
461 "unexpected error: {err_message}"
462 );
463 }
464}