risingwave_connector/source/cdc/enumerator/
mod.rs1use std::collections::BTreeMap;
16use std::marker::PhantomData;
17use std::ops::Deref;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use anyhow::{Context, anyhow};
22use async_trait::async_trait;
23use itertools::Itertools;
24use mysql_async::Row;
25use mysql_async::prelude::*;
26use prost::Message;
27use risingwave_common::global_jvm::Jvm;
28use risingwave_common::id::SourceId;
29use risingwave_common::util::addr::HostAddr;
30use risingwave_jni_core::call_static_method;
31use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
32use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
33use thiserror_ext::AsReport;
34use tiberius::Config;
35use tokio_postgres::types::PgLsn;
36
37use crate::connector_common::{SslMode, create_pg_client};
38use crate::error::ConnectorResult;
39use crate::sink::sqlserver::SqlServerClient;
40use crate::source::cdc::external::mysql::build_mysql_connection_pool;
41use crate::source::cdc::split::parse_sql_server_lsn_str;
42use crate::source::cdc::{
43 CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
44 SqlServer, table_schema_exclude_additional_columns,
45};
46use crate::source::monitor::metrics::EnumeratorMetrics;
47use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
48
49pub const DATABASE_SERVERS_KEY: &str = "database.servers";
50
51#[derive(Debug)]
52pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
53 source_id: SourceId,
55 worker_node_addrs: Vec<HostAddr>,
56 metrics: Arc<EnumeratorMetrics>,
57 properties: Arc<BTreeMap<String, String>>,
59 _phantom: PhantomData<T>,
60}
61
62#[async_trait]
63impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
64where
65 Self: ListCdcSplits<CdcSourceType = T> + CdcMonitor,
66{
67 type Properties = CdcProperties<T>;
68 type Split = DebeziumCdcSplit<T>;
69
70 async fn new(
71 props: CdcProperties<T>,
72 context: SourceEnumeratorContextRef,
73 ) -> ConnectorResult<Self> {
74 let server_addrs = props
75 .properties
76 .get(DATABASE_SERVERS_KEY)
77 .map(|s| {
78 s.split(',')
79 .map(HostAddr::from_str)
80 .collect::<Result<Vec<_>, _>>()
81 })
82 .transpose()?
83 .unwrap_or_default();
84
85 assert_eq!(
86 props.get_source_type_pb(),
87 SourceType::from(T::source_type())
88 );
89
90 let jvm = Jvm::get_or_init()?;
91 let source_id = context.info.source_id;
92
93 let source_type_pb = props.get_source_type_pb();
95
96 let properties_arc = Arc::new(props.properties);
98 let properties_arc_for_validation = properties_arc.clone();
99 let table_schema_for_validation = props.table_schema;
100
101 tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
102 execute_with_jni_env(jvm, |env| {
103 let validate_source_request = ValidateSourceRequest {
104 source_id: source_id.as_raw_id() as u64,
105 source_type: source_type_pb as _,
106 properties: (*properties_arc_for_validation).clone(),
107 table_schema: Some(table_schema_exclude_additional_columns(
108 &table_schema_for_validation,
109 )),
110 is_source_job: props.is_cdc_source_job,
111 is_backfill_table: props.is_backfill_table,
112 };
113
114 let validate_source_request_bytes =
115 env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
116
117 let validate_source_response_bytes = call_static_method!(
118 env,
119 {com.risingwave.connector.source.JniSourceValidateHandler},
120 {byte[] validate(byte[] validateSourceRequestBytes)},
121 &validate_source_request_bytes
122 )?;
123
124 let validate_source_response: ValidateSourceResponse = Message::decode(
125 risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
126 .deref(),
127 )?;
128
129 if let Some(error) = validate_source_response.error {
130 return Err(
131 anyhow!(error.error_message).context("source cannot pass validation")
132 );
133 }
134
135 Ok(())
136 })
137 })
138 .await
139 .context("failed to validate source")??;
140
141 tracing::debug!("validate cdc source properties success");
142 Ok(Self {
143 source_id,
144 worker_node_addrs: server_addrs,
145 metrics: context.metrics.clone(),
146 properties: properties_arc,
147 _phantom: PhantomData,
148 })
149 }
150
151 async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
152 Ok(self.list_cdc_splits())
153 }
154
155 async fn on_tick(&mut self) -> ConnectorResult<()> {
156 self.monitor_cdc().await
157 }
158}
159
160impl<T: CdcSourceTypeTrait> DebeziumSplitEnumerator<T> {
161 fn sql_server_lsn_to_i64(lsn: &str) -> Option<i64> {
162 parse_sql_server_lsn_str(lsn).map(|v| v.min(i64::MAX as u128) as i64)
163 }
164
165 async fn monitor_postgres_confirmed_flush_lsn(&mut self) -> ConnectorResult<()> {
166 match self.query_postgres_lsns().await {
168 Ok(Some((confirmed_flush_lsn, upstream_max_lsn, slot_name))) => {
169 let labels = [&self.source_id.to_string(), &slot_name.to_owned()];
170
171 self.metrics
172 .pg_cdc_upstream_max_lsn
173 .with_guarded_label_values(&labels)
174 .set(upstream_max_lsn as i64);
175
176 if let Some(lsn) = confirmed_flush_lsn {
177 self.metrics
178 .pg_cdc_confirmed_flush_lsn
179 .with_guarded_label_values(&labels)
180 .set(lsn as i64);
181 tracing::debug!(
182 "Updated confirmed_flush_lsn for source {} slot {}: {}",
183 self.source_id,
184 slot_name,
185 lsn
186 );
187 } else {
188 tracing::warn!(
189 "confirmed_flush_lsn is NULL for source {} slot {}",
190 self.source_id,
191 slot_name
192 );
193 }
194 }
195 Ok(None) => {
196 tracing::warn!(
197 "No replication slot found when querying LSNs for source {}",
198 self.source_id
199 );
200 }
201 Err(e) => {
202 tracing::error!(
203 "Failed to query PostgreSQL LSNs for source {}: {}",
204 self.source_id,
205 e.as_report()
206 );
207 }
208 };
209 Ok(())
210 }
211
212 async fn query_postgres_lsns(&self) -> ConnectorResult<Option<(Option<u64>, u64, &str)>> {
214 let hostname = self
216 .properties
217 .get("hostname")
218 .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
219 let port = self
220 .properties
221 .get("port")
222 .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?;
223 let user = self
224 .properties
225 .get("username")
226 .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
227 let password = self
228 .properties
229 .get("password")
230 .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
231 let database = self
232 .properties
233 .get("database.name")
234 .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
235
236 let ssl_mode = self
238 .properties
239 .get("ssl.mode")
240 .and_then(|s| s.parse().ok())
241 .unwrap_or(SslMode::Preferred);
242 let ssl_root_cert = self.properties.get("database.ssl.root.cert").cloned();
243
244 let slot_name = self
245 .properties
246 .get("slot.name")
247 .ok_or_else(|| anyhow::anyhow!("slot.name not found in CDC properties"))?;
248
249 let client = create_pg_client(
251 user,
252 password,
253 hostname,
254 port,
255 database,
256 &ssl_mode,
257 &ssl_root_cert,
258 None, )
260 .await
261 .context("Failed to create PostgreSQL client")?;
262
263 let query = "SELECT confirmed_flush_lsn, pg_current_wal_lsn() \
264 FROM pg_replication_slots WHERE slot_name = $1";
265 let row = client
266 .query_opt(query, &[&slot_name])
267 .await
268 .context("PostgreSQL query LSNs error")?;
269 match row {
270 Some(row) => {
271 let confirmed_flush_lsn: Option<PgLsn> = row.get(0);
272 let upstream_max_lsn: PgLsn = row.get(1);
273 Ok(Some((
274 confirmed_flush_lsn.map(Into::into),
275 upstream_max_lsn.into(),
276 slot_name.as_str(),
277 )))
278 }
279 None => {
280 tracing::warn!("No replication slot found with name: {}", slot_name);
281 Ok(None)
282 }
283 }
284 }
285
286 async fn query_sql_server_lsns(&self) -> ConnectorResult<Option<(String, String)>> {
288 let hostname = self
289 .properties
290 .get("hostname")
291 .ok_or_else(|| anyhow!("hostname not found in CDC properties"))?;
292 let port = self
293 .properties
294 .get("port")
295 .ok_or_else(|| anyhow!("port not found in CDC properties"))?
296 .parse::<u16>()
297 .context("failed to parse port as u16")?;
298 let username = self
299 .properties
300 .get("username")
301 .ok_or_else(|| anyhow!("username not found in CDC properties"))?;
302 let password = self
303 .properties
304 .get("password")
305 .ok_or_else(|| anyhow!("password not found in CDC properties"))?;
306 let database = self
307 .properties
308 .get("database.name")
309 .ok_or_else(|| anyhow!("database.name not found in CDC properties"))?;
310
311 let mut config = Config::new();
312 config.host(hostname);
313 config.port(port);
314 config.database(database);
315 config.authentication(tiberius::AuthMethod::sql_server(username, password));
316 config.trust_cert();
317
318 let mut client = SqlServerClient::new_with_config(config).await?;
319 let row = client
320 .inner_client
321 .simple_query(
322 "SELECT \
323 sys.fn_cdc_get_max_lsn() AS max_lsn, \
324 (SELECT MIN(sys.fn_cdc_get_min_lsn(capture_instance)) FROM cdc.change_tables) AS min_lsn"
325 .to_owned(),
326 )
327 .await?
328 .into_row()
329 .await?
330 .ok_or_else(|| anyhow!("No result returned when querying SQL Server max/min LSN"))?;
331
332 let lsn_bytes_to_hex = |bytes: &[u8]| -> ConnectorResult<String> {
333 if bytes.len() != 10 {
334 return Err(anyhow!(
335 "SQL Server LSN should be 10 bytes, got {} bytes",
336 bytes.len()
337 )
338 .into());
339 }
340 let mut hex_string = String::with_capacity(22);
341 for byte in &bytes[0..4] {
342 hex_string.push_str(&format!("{:02x}", byte));
343 }
344 hex_string.push(':');
345 for byte in &bytes[4..8] {
346 hex_string.push_str(&format!("{:02x}", byte));
347 }
348 hex_string.push(':');
349 for byte in &bytes[8..10] {
350 hex_string.push_str(&format!("{:02x}", byte));
351 }
352 Ok(hex_string)
353 };
354
355 let max_lsn = row
356 .try_get::<&[u8], usize>(0)?
357 .map(lsn_bytes_to_hex)
358 .transpose()?
359 .ok_or_else(|| anyhow!("SQL Server max_lsn is NULL"))?;
360 let min_lsn = row
361 .try_get::<&[u8], usize>(1)?
362 .map(lsn_bytes_to_hex)
363 .transpose()?
364 .ok_or_else(|| anyhow!("SQL Server min_lsn is NULL"))?;
365
366 Ok(Some((min_lsn, max_lsn)))
367 }
368
369 async fn monitor_sql_server_lsns(&mut self) -> ConnectorResult<()> {
370 match self.query_sql_server_lsns().await {
371 Ok(Some((min_lsn, max_lsn))) => {
372 let source_id = self.source_id.to_string();
373
374 if let Some(value) = Self::sql_server_lsn_to_i64(&min_lsn) {
375 self.metrics
376 .sqlserver_cdc_upstream_min_lsn
377 .with_guarded_label_values(&[&source_id])
378 .set(value);
379 }
380
381 if let Some(value) = Self::sql_server_lsn_to_i64(&max_lsn) {
382 self.metrics
383 .sqlserver_cdc_upstream_max_lsn
384 .with_guarded_label_values(&[&source_id])
385 .set(value);
386 }
387 }
388 Ok(None) => {}
389 Err(e) => {
390 tracing::error!(
391 "Failed to query SQL Server LSNs for source {}: {}",
392 self.source_id,
393 e.as_report()
394 );
395 }
396 }
397
398 Ok(())
399 }
400}
401
402pub trait ListCdcSplits {
403 type CdcSourceType: CdcSourceTypeTrait;
404 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
406}
407
408#[async_trait]
410pub trait CdcMonitor {
411 async fn monitor_cdc(&mut self) -> ConnectorResult<()>;
412}
413
414#[async_trait]
415impl<T: CdcSourceTypeTrait> CdcMonitor for DebeziumSplitEnumerator<T> {
416 default async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
417 Ok(())
418 }
419}
420
421impl DebeziumSplitEnumerator<Mysql> {
422 async fn monitor_mysql_binlog_files(&mut self) -> ConnectorResult<()> {
423 let hostname = self
425 .properties
426 .get("hostname")
427 .map(|s| s.as_str())
428 .ok_or_else(|| {
429 anyhow::anyhow!("missing required property 'hostname' for MySQL CDC source")
430 })?;
431 let port = self
432 .properties
433 .get("port")
434 .map(|s| s.as_str())
435 .ok_or_else(|| {
436 anyhow::anyhow!("missing required property 'port' for MySQL CDC source")
437 })?;
438
439 match self.query_binlog_files().await {
441 Ok(binlog_files) => {
442 if let Some((oldest_file, oldest_size)) = binlog_files.first() {
443 if let Some(seq) = Self::extract_binlog_seq(oldest_file) {
445 self.metrics
446 .mysql_cdc_binlog_file_seq_min
447 .with_guarded_label_values(&[hostname, port])
448 .set(seq as i64);
449 tracing::debug!(
450 "MySQL CDC source {} ({}:{}): oldest binlog = {}, seq = {}, size = {}",
451 self.source_id,
452 hostname,
453 port,
454 oldest_file,
455 seq,
456 oldest_size
457 );
458 }
459 }
460 if let Some((newest_file, newest_size)) = binlog_files.last() {
461 if let Some(seq) = Self::extract_binlog_seq(newest_file) {
463 self.metrics
464 .mysql_cdc_binlog_file_seq_max
465 .with_guarded_label_values(&[hostname, port])
466 .set(seq as i64);
467 tracing::debug!(
468 "MySQL CDC source {} ({}:{}): newest binlog = {}, seq = {}, size = {}",
469 self.source_id,
470 hostname,
471 port,
472 newest_file,
473 seq,
474 newest_size
475 );
476 }
477 }
478 tracing::debug!(
479 "MySQL CDC source {} ({}:{}): total {} binlog files",
480 self.source_id,
481 hostname,
482 port,
483 binlog_files.len()
484 );
485 }
486 Err(e) => {
487 tracing::error!(
488 "Failed to query binlog files for MySQL CDC source {} ({}:{}): {}",
489 self.source_id,
490 hostname,
491 port,
492 e.as_report()
493 );
494 }
495 }
496 Ok(())
497 }
498
499 fn extract_binlog_seq(filename: &str) -> Option<u64> {
502 filename.rsplit('.').next()?.parse::<u64>().ok()
503 }
504
505 async fn query_binlog_files(&self) -> ConnectorResult<Vec<(String, u64)>> {
507 let hostname = self
509 .properties
510 .get("hostname")
511 .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
512 let port = self
513 .properties
514 .get("port")
515 .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?
516 .parse::<u16>()
517 .context("failed to parse port as u16")?;
518 let username = self
519 .properties
520 .get("username")
521 .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
522 let password = self
523 .properties
524 .get("password")
525 .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
526 let database = self
527 .properties
528 .get("database.name")
529 .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
530
531 let ssl_mode = self
533 .properties
534 .get("ssl.mode")
535 .and_then(|s| s.parse().ok())
536 .unwrap_or(SslMode::Preferred);
537
538 let pool =
540 build_mysql_connection_pool(hostname, port, username, password, database, ssl_mode);
541 let mut conn = pool
542 .get_conn()
543 .await
544 .context("Failed to connect to MySQL")?;
545
546 let rows: Vec<Row> = conn
551 .query("SHOW BINARY LOGS")
552 .await
553 .context("Failed to execute SHOW BINARY LOGS")?;
554 let query_result = rows
555 .into_iter()
556 .map(|mut row| -> ConnectorResult<(String, u64)> {
557 let log_name = row
558 .take_opt::<String, _>(0)
559 .transpose()
560 .context("SHOW BINARY LOGS: failed to decode Log_name")?
561 .ok_or_else(|| anyhow!("SHOW BINARY LOGS: missing Log_name column"))?;
562 let file_size = row
563 .take_opt::<u64, _>(1)
564 .transpose()
565 .context("SHOW BINARY LOGS: failed to decode File_size")?
566 .ok_or_else(|| anyhow!("SHOW BINARY LOGS: missing File_size column"))?;
567 Ok((log_name, file_size))
568 })
569 .collect::<ConnectorResult<Vec<_>>>()?;
570
571 drop(conn);
572 pool.disconnect().await.ok();
573
574 Ok(query_result)
575 }
576}
577
578impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
579 type CdcSourceType = Mysql;
580
581 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
582 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
584 self.source_id.as_raw_id(),
585 None,
586 None,
587 )]
588 }
589}
590
591#[async_trait]
592impl CdcMonitor for DebeziumSplitEnumerator<Mysql> {
593 async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
594 self.monitor_mysql_binlog_files().await?;
596 Ok(())
597 }
598}
599
600impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
601 type CdcSourceType = Postgres;
602
603 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
604 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
606 self.source_id.as_raw_id(),
607 None,
608 None,
609 )]
610 }
611}
612
613#[async_trait]
614impl CdcMonitor for DebeziumSplitEnumerator<Postgres> {
615 async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
616 self.monitor_postgres_confirmed_flush_lsn().await?;
618 Ok(())
619 }
620}
621
622impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
623 type CdcSourceType = Citus;
624
625 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
626 self.worker_node_addrs
627 .iter()
628 .enumerate()
629 .map(|(id, addr)| {
630 DebeziumCdcSplit::<Self::CdcSourceType>::new(
631 id as u32,
632 None,
633 Some(addr.to_string()),
634 )
635 })
636 .collect_vec()
637 }
638}
639impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
640 type CdcSourceType = Mongodb;
641
642 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
643 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
645 self.source_id.as_raw_id(),
646 None,
647 None,
648 )]
649 }
650}
651
652impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
653 type CdcSourceType = SqlServer;
654
655 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
656 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
657 self.source_id.as_raw_id(),
658 None,
659 None,
660 )]
661 }
662}
663
664#[async_trait]
665impl CdcMonitor for DebeziumSplitEnumerator<SqlServer> {
666 async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
667 self.monitor_sql_server_lsns().await
668 }
669}