risingwave_connector/source/cdc/enumerator/
mod.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::marker::PhantomData;
17use std::ops::Deref;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use anyhow::{Context, anyhow};
22use async_trait::async_trait;
23use itertools::Itertools;
24use mysql_async::Row;
25use mysql_async::prelude::*;
26use prost::Message;
27use risingwave_common::global_jvm::Jvm;
28use risingwave_common::id::SourceId;
29use risingwave_common::util::addr::HostAddr;
30use risingwave_jni_core::call_static_method;
31use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
32use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
33use thiserror_ext::AsReport;
34use tiberius::Config;
35use tokio_postgres::types::PgLsn;
36
37use crate::connector_common::{SslMode, create_pg_client};
38use crate::error::ConnectorResult;
39use crate::sink::sqlserver::SqlServerClient;
40use crate::source::cdc::external::mysql::build_mysql_connection_pool;
41use crate::source::cdc::split::parse_sql_server_lsn_str;
42use crate::source::cdc::{
43    CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
44    SqlServer, table_schema_exclude_additional_columns,
45};
46use crate::source::monitor::metrics::EnumeratorMetrics;
47use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
48
49pub const DATABASE_SERVERS_KEY: &str = "database.servers";
50
51#[derive(Debug)]
52pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
53    /// The `source_id` in the catalog
54    source_id: SourceId,
55    worker_node_addrs: Vec<HostAddr>,
56    metrics: Arc<EnumeratorMetrics>,
57    /// Properties specified in the WITH clause by user for database connection
58    properties: Arc<BTreeMap<String, String>>,
59    _phantom: PhantomData<T>,
60}
61
62#[async_trait]
63impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
64where
65    Self: ListCdcSplits<CdcSourceType = T> + CdcMonitor,
66{
67    type Properties = CdcProperties<T>;
68    type Split = DebeziumCdcSplit<T>;
69
70    async fn new(
71        props: CdcProperties<T>,
72        context: SourceEnumeratorContextRef,
73    ) -> ConnectorResult<Self> {
74        let server_addrs = props
75            .properties
76            .get(DATABASE_SERVERS_KEY)
77            .map(|s| {
78                s.split(',')
79                    .map(HostAddr::from_str)
80                    .collect::<Result<Vec<_>, _>>()
81            })
82            .transpose()?
83            .unwrap_or_default();
84
85        assert_eq!(
86            props.get_source_type_pb(),
87            SourceType::from(T::source_type())
88        );
89
90        let jvm = Jvm::get_or_init()?;
91        let source_id = context.info.source_id;
92
93        // Extract fields before moving props
94        let source_type_pb = props.get_source_type_pb();
95
96        // Create Arc once and share it
97        let properties_arc = Arc::new(props.properties);
98        let properties_arc_for_validation = properties_arc.clone();
99        let table_schema_for_validation = props.table_schema;
100
101        tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
102            execute_with_jni_env(jvm, |env| {
103                let validate_source_request = ValidateSourceRequest {
104                    source_id: source_id.as_raw_id() as u64,
105                    source_type: source_type_pb as _,
106                    properties: (*properties_arc_for_validation).clone(),
107                    table_schema: Some(table_schema_exclude_additional_columns(
108                        &table_schema_for_validation,
109                    )),
110                    is_source_job: props.is_cdc_source_job,
111                    is_backfill_table: props.is_backfill_table,
112                };
113
114                let validate_source_request_bytes =
115                    env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
116
117                let validate_source_response_bytes = call_static_method!(
118                    env,
119                    {com.risingwave.connector.source.JniSourceValidateHandler},
120                    {byte[] validate(byte[] validateSourceRequestBytes)},
121                    &validate_source_request_bytes
122                )?;
123
124                let validate_source_response: ValidateSourceResponse = Message::decode(
125                    risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
126                        .deref(),
127                )?;
128
129                if let Some(error) = validate_source_response.error {
130                    return Err(
131                        anyhow!(error.error_message).context("source cannot pass validation")
132                    );
133                }
134
135                Ok(())
136            })
137        })
138        .await
139        .context("failed to validate source")??;
140
141        tracing::debug!("validate cdc source properties success");
142        Ok(Self {
143            source_id,
144            worker_node_addrs: server_addrs,
145            metrics: context.metrics.clone(),
146            properties: properties_arc,
147            _phantom: PhantomData,
148        })
149    }
150
151    async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
152        Ok(self.list_cdc_splits())
153    }
154
155    async fn on_tick(&mut self) -> ConnectorResult<()> {
156        self.monitor_cdc().await
157    }
158}
159
160impl<T: CdcSourceTypeTrait> DebeziumSplitEnumerator<T> {
161    fn sql_server_lsn_to_i64(lsn: &str) -> Option<i64> {
162        parse_sql_server_lsn_str(lsn).map(|v| v.min(i64::MAX as u128) as i64)
163    }
164
165    async fn monitor_postgres_confirmed_flush_lsn(&mut self) -> ConnectorResult<()> {
166        // Query upstream LSNs and update metrics.
167        match self.query_postgres_lsns().await {
168            Ok(Some((confirmed_flush_lsn, upstream_max_lsn, slot_name))) => {
169                let labels = [&self.source_id.to_string(), &slot_name.to_owned()];
170
171                self.metrics
172                    .pg_cdc_upstream_max_lsn
173                    .with_guarded_label_values(&labels)
174                    .set(upstream_max_lsn as i64);
175
176                if let Some(lsn) = confirmed_flush_lsn {
177                    self.metrics
178                        .pg_cdc_confirmed_flush_lsn
179                        .with_guarded_label_values(&labels)
180                        .set(lsn as i64);
181                    tracing::debug!(
182                        "Updated confirmed_flush_lsn for source {} slot {}: {}",
183                        self.source_id,
184                        slot_name,
185                        lsn
186                    );
187                } else {
188                    tracing::warn!(
189                        "confirmed_flush_lsn is NULL for source {} slot {}",
190                        self.source_id,
191                        slot_name
192                    );
193                }
194            }
195            Ok(None) => {
196                tracing::warn!(
197                    "No replication slot found when querying LSNs for source {}",
198                    self.source_id
199                );
200            }
201            Err(e) => {
202                tracing::error!(
203                    "Failed to query PostgreSQL LSNs for source {}: {}",
204                    self.source_id,
205                    e.as_report()
206                );
207            }
208        };
209        Ok(())
210    }
211
212    /// Query LSNs from PostgreSQL, return (`confirmed_flush_lsn`, `upstream_max_lsn`, `slot_name`).
213    async fn query_postgres_lsns(&self) -> ConnectorResult<Option<(Option<u64>, u64, &str)>> {
214        // Extract connection parameters from CDC properties
215        let hostname = self
216            .properties
217            .get("hostname")
218            .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
219        let port = self
220            .properties
221            .get("port")
222            .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?;
223        let user = self
224            .properties
225            .get("username")
226            .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
227        let password = self
228            .properties
229            .get("password")
230            .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
231        let database = self
232            .properties
233            .get("database.name")
234            .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
235
236        // Get SSL mode from properties, default to Preferred if not specified
237        let ssl_mode = self
238            .properties
239            .get("ssl.mode")
240            .and_then(|s| s.parse().ok())
241            .unwrap_or(SslMode::Preferred);
242        let ssl_root_cert = self.properties.get("database.ssl.root.cert").cloned();
243
244        let slot_name = self
245            .properties
246            .get("slot.name")
247            .ok_or_else(|| anyhow::anyhow!("slot.name not found in CDC properties"))?;
248
249        // Create PostgreSQL client
250        let client = create_pg_client(
251            user,
252            password,
253            hostname,
254            port,
255            database,
256            &ssl_mode,
257            &ssl_root_cert,
258            None, // No TCP keepalive for CDC enumerator
259        )
260        .await
261        .context("Failed to create PostgreSQL client")?;
262
263        let query = "SELECT confirmed_flush_lsn, pg_current_wal_lsn() \
264            FROM pg_replication_slots WHERE slot_name = $1";
265        let row = client
266            .query_opt(query, &[&slot_name])
267            .await
268            .context("PostgreSQL query LSNs error")?;
269        match row {
270            Some(row) => {
271                let confirmed_flush_lsn: Option<PgLsn> = row.get(0);
272                let upstream_max_lsn: PgLsn = row.get(1);
273                Ok(Some((
274                    confirmed_flush_lsn.map(Into::into),
275                    upstream_max_lsn.into(),
276                    slot_name.as_str(),
277                )))
278            }
279            None => {
280                tracing::warn!("No replication slot found with name: {}", slot_name);
281                Ok(None)
282            }
283        }
284    }
285
286    /// Query min/max LSNs from SQL Server CDC.
287    async fn query_sql_server_lsns(&self) -> ConnectorResult<Option<(String, String)>> {
288        let hostname = self
289            .properties
290            .get("hostname")
291            .ok_or_else(|| anyhow!("hostname not found in CDC properties"))?;
292        let port = self
293            .properties
294            .get("port")
295            .ok_or_else(|| anyhow!("port not found in CDC properties"))?
296            .parse::<u16>()
297            .context("failed to parse port as u16")?;
298        let username = self
299            .properties
300            .get("username")
301            .ok_or_else(|| anyhow!("username not found in CDC properties"))?;
302        let password = self
303            .properties
304            .get("password")
305            .ok_or_else(|| anyhow!("password not found in CDC properties"))?;
306        let database = self
307            .properties
308            .get("database.name")
309            .ok_or_else(|| anyhow!("database.name not found in CDC properties"))?;
310
311        let mut config = Config::new();
312        config.host(hostname);
313        config.port(port);
314        config.database(database);
315        config.authentication(tiberius::AuthMethod::sql_server(username, password));
316        config.trust_cert();
317
318        let mut client = SqlServerClient::new_with_config(config).await?;
319        let row = client
320            .inner_client
321            .simple_query(
322                "SELECT \
323                    sys.fn_cdc_get_max_lsn() AS max_lsn, \
324                    (SELECT MIN(sys.fn_cdc_get_min_lsn(capture_instance)) FROM cdc.change_tables) AS min_lsn"
325                    .to_owned(),
326            )
327            .await?
328            .into_row()
329            .await?
330            .ok_or_else(|| anyhow!("No result returned when querying SQL Server max/min LSN"))?;
331
332        let lsn_bytes_to_hex = |bytes: &[u8]| -> ConnectorResult<String> {
333            if bytes.len() != 10 {
334                return Err(anyhow!(
335                    "SQL Server LSN should be 10 bytes, got {} bytes",
336                    bytes.len()
337                )
338                .into());
339            }
340            let mut hex_string = String::with_capacity(22);
341            for byte in &bytes[0..4] {
342                hex_string.push_str(&format!("{:02x}", byte));
343            }
344            hex_string.push(':');
345            for byte in &bytes[4..8] {
346                hex_string.push_str(&format!("{:02x}", byte));
347            }
348            hex_string.push(':');
349            for byte in &bytes[8..10] {
350                hex_string.push_str(&format!("{:02x}", byte));
351            }
352            Ok(hex_string)
353        };
354
355        let max_lsn = row
356            .try_get::<&[u8], usize>(0)?
357            .map(lsn_bytes_to_hex)
358            .transpose()?
359            .ok_or_else(|| anyhow!("SQL Server max_lsn is NULL"))?;
360        let min_lsn = row
361            .try_get::<&[u8], usize>(1)?
362            .map(lsn_bytes_to_hex)
363            .transpose()?
364            .ok_or_else(|| anyhow!("SQL Server min_lsn is NULL"))?;
365
366        Ok(Some((min_lsn, max_lsn)))
367    }
368
369    async fn monitor_sql_server_lsns(&mut self) -> ConnectorResult<()> {
370        match self.query_sql_server_lsns().await {
371            Ok(Some((min_lsn, max_lsn))) => {
372                let source_id = self.source_id.to_string();
373
374                if let Some(value) = Self::sql_server_lsn_to_i64(&min_lsn) {
375                    self.metrics
376                        .sqlserver_cdc_upstream_min_lsn
377                        .with_guarded_label_values(&[&source_id])
378                        .set(value);
379                }
380
381                if let Some(value) = Self::sql_server_lsn_to_i64(&max_lsn) {
382                    self.metrics
383                        .sqlserver_cdc_upstream_max_lsn
384                        .with_guarded_label_values(&[&source_id])
385                        .set(value);
386                }
387            }
388            Ok(None) => {}
389            Err(e) => {
390                tracing::error!(
391                    "Failed to query SQL Server LSNs for source {}: {}",
392                    self.source_id,
393                    e.as_report()
394                );
395            }
396        }
397
398        Ok(())
399    }
400}
401
402pub trait ListCdcSplits {
403    type CdcSourceType: CdcSourceTypeTrait;
404    /// Generates a single split for shared source.
405    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
406}
407
408/// Trait for CDC-specific monitoring behavior
409#[async_trait]
410pub trait CdcMonitor {
411    async fn monitor_cdc(&mut self) -> ConnectorResult<()>;
412}
413
414#[async_trait]
415impl<T: CdcSourceTypeTrait> CdcMonitor for DebeziumSplitEnumerator<T> {
416    default async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
417        Ok(())
418    }
419}
420
421impl DebeziumSplitEnumerator<Mysql> {
422    async fn monitor_mysql_binlog_files(&mut self) -> ConnectorResult<()> {
423        // Get hostname and port for metrics labels
424        let hostname = self
425            .properties
426            .get("hostname")
427            .map(|s| s.as_str())
428            .ok_or_else(|| {
429                anyhow::anyhow!("missing required property 'hostname' for MySQL CDC source")
430            })?;
431        let port = self
432            .properties
433            .get("port")
434            .map(|s| s.as_str())
435            .ok_or_else(|| {
436                anyhow::anyhow!("missing required property 'port' for MySQL CDC source")
437            })?;
438
439        // Query binlog files and update metrics
440        match self.query_binlog_files().await {
441            Ok(binlog_files) => {
442                if let Some((oldest_file, oldest_size)) = binlog_files.first() {
443                    // Extract sequence number from filename (e.g., "binlog.000001" -> 1)
444                    if let Some(seq) = Self::extract_binlog_seq(oldest_file) {
445                        self.metrics
446                            .mysql_cdc_binlog_file_seq_min
447                            .with_guarded_label_values(&[hostname, port])
448                            .set(seq as i64);
449                        tracing::debug!(
450                            "MySQL CDC source {} ({}:{}): oldest binlog = {}, seq = {}, size = {}",
451                            self.source_id,
452                            hostname,
453                            port,
454                            oldest_file,
455                            seq,
456                            oldest_size
457                        );
458                    }
459                }
460                if let Some((newest_file, newest_size)) = binlog_files.last() {
461                    // Extract sequence number from filename
462                    if let Some(seq) = Self::extract_binlog_seq(newest_file) {
463                        self.metrics
464                            .mysql_cdc_binlog_file_seq_max
465                            .with_guarded_label_values(&[hostname, port])
466                            .set(seq as i64);
467                        tracing::debug!(
468                            "MySQL CDC source {} ({}:{}): newest binlog = {}, seq = {}, size = {}",
469                            self.source_id,
470                            hostname,
471                            port,
472                            newest_file,
473                            seq,
474                            newest_size
475                        );
476                    }
477                }
478                tracing::debug!(
479                    "MySQL CDC source {} ({}:{}): total {} binlog files",
480                    self.source_id,
481                    hostname,
482                    port,
483                    binlog_files.len()
484                );
485            }
486            Err(e) => {
487                tracing::error!(
488                    "Failed to query binlog files for MySQL CDC source {} ({}:{}): {}",
489                    self.source_id,
490                    hostname,
491                    port,
492                    e.as_report()
493                );
494            }
495        }
496        Ok(())
497    }
498
499    /// Extract sequence number from binlog filename
500    /// e.g., "binlog.000001" -> Some(1), "mysql-bin.000123" -> Some(123)
501    fn extract_binlog_seq(filename: &str) -> Option<u64> {
502        filename.rsplit('.').next()?.parse::<u64>().ok()
503    }
504
505    /// Query binlog files from MySQL, returns Vec<(filename, size)>
506    async fn query_binlog_files(&self) -> ConnectorResult<Vec<(String, u64)>> {
507        // Extract connection parameters from CDC properties
508        let hostname = self
509            .properties
510            .get("hostname")
511            .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
512        let port = self
513            .properties
514            .get("port")
515            .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?
516            .parse::<u16>()
517            .context("failed to parse port as u16")?;
518        let username = self
519            .properties
520            .get("username")
521            .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
522        let password = self
523            .properties
524            .get("password")
525            .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
526        let database = self
527            .properties
528            .get("database.name")
529            .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
530
531        // Get SSL mode configuration (default to Disabled if not specified)
532        let ssl_mode = self
533            .properties
534            .get("ssl.mode")
535            .and_then(|s| s.parse().ok())
536            .unwrap_or(SslMode::Preferred);
537
538        // Build MySQL connection pool with proper SSL configuration
539        let pool =
540            build_mysql_connection_pool(hostname, port, username, password, database, ssl_mode);
541        let mut conn = pool
542            .get_conn()
543            .await
544            .context("Failed to connect to MySQL")?;
545
546        // Query binlog files using SHOW BINARY LOGS.
547        // MySQL 8.0+ may return 3 columns (Log_name, File_size, Encrypted), while some variants
548        // only return the first 2. Decode the row manually so we don't panic on column-count
549        // differences.
550        let rows: Vec<Row> = conn
551            .query("SHOW BINARY LOGS")
552            .await
553            .context("Failed to execute SHOW BINARY LOGS")?;
554        let query_result = rows
555            .into_iter()
556            .map(|mut row| -> ConnectorResult<(String, u64)> {
557                let log_name = row
558                    .take_opt::<String, _>(0)
559                    .transpose()
560                    .context("SHOW BINARY LOGS: failed to decode Log_name")?
561                    .ok_or_else(|| anyhow!("SHOW BINARY LOGS: missing Log_name column"))?;
562                let file_size = row
563                    .take_opt::<u64, _>(1)
564                    .transpose()
565                    .context("SHOW BINARY LOGS: failed to decode File_size")?
566                    .ok_or_else(|| anyhow!("SHOW BINARY LOGS: missing File_size column"))?;
567                Ok((log_name, file_size))
568            })
569            .collect::<ConnectorResult<Vec<_>>>()?;
570
571        drop(conn);
572        pool.disconnect().await.ok();
573
574        Ok(query_result)
575    }
576}
577
578impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
579    type CdcSourceType = Mysql;
580
581    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
582        // CDC source only supports single split
583        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
584            self.source_id.as_raw_id(),
585            None,
586            None,
587        )]
588    }
589}
590
591#[async_trait]
592impl CdcMonitor for DebeziumSplitEnumerator<Mysql> {
593    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
594        // For MySQL CDC, query the upstream MySQL binlog files and monitor them.
595        self.monitor_mysql_binlog_files().await?;
596        Ok(())
597    }
598}
599
600impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
601    type CdcSourceType = Postgres;
602
603    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
604        // CDC source only supports single split
605        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
606            self.source_id.as_raw_id(),
607            None,
608            None,
609        )]
610    }
611}
612
613#[async_trait]
614impl CdcMonitor for DebeziumSplitEnumerator<Postgres> {
615    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
616        // For PostgreSQL CDC, query the upstream Postgres confirmed flush lsn and monitor it.
617        self.monitor_postgres_confirmed_flush_lsn().await?;
618        Ok(())
619    }
620}
621
622impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
623    type CdcSourceType = Citus;
624
625    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
626        self.worker_node_addrs
627            .iter()
628            .enumerate()
629            .map(|(id, addr)| {
630                DebeziumCdcSplit::<Self::CdcSourceType>::new(
631                    id as u32,
632                    None,
633                    Some(addr.to_string()),
634                )
635            })
636            .collect_vec()
637    }
638}
639impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
640    type CdcSourceType = Mongodb;
641
642    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
643        // CDC source only supports single split
644        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
645            self.source_id.as_raw_id(),
646            None,
647            None,
648        )]
649    }
650}
651
652impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
653    type CdcSourceType = SqlServer;
654
655    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
656        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
657            self.source_id.as_raw_id(),
658            None,
659            None,
660        )]
661    }
662}
663
664#[async_trait]
665impl CdcMonitor for DebeziumSplitEnumerator<SqlServer> {
666    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
667        self.monitor_sql_server_lsns().await
668    }
669}