risingwave_connector/source/cdc/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::marker::PhantomData;
17use std::ops::Deref;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use anyhow::{Context, anyhow};
22use async_trait::async_trait;
23use itertools::Itertools;
24use mysql_async::prelude::*;
25use prost::Message;
26use risingwave_common::global_jvm::Jvm;
27use risingwave_common::id::SourceId;
28use risingwave_common::util::addr::HostAddr;
29use risingwave_jni_core::call_static_method;
30use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
31use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
32use thiserror_ext::AsReport;
33use tokio_postgres::types::PgLsn;
34
35use crate::connector_common::{SslMode, create_pg_client};
36use crate::error::ConnectorResult;
37use crate::source::cdc::external::mysql::build_mysql_connection_pool;
38use crate::source::cdc::{
39    CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
40    SqlServer, table_schema_exclude_additional_columns,
41};
42use crate::source::monitor::metrics::EnumeratorMetrics;
43use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
44
45pub const DATABASE_SERVERS_KEY: &str = "database.servers";
46
47#[derive(Debug)]
48pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
49    /// The `source_id` in the catalog
50    source_id: SourceId,
51    worker_node_addrs: Vec<HostAddr>,
52    metrics: Arc<EnumeratorMetrics>,
53    /// Properties specified in the WITH clause by user for database connection
54    properties: Arc<BTreeMap<String, String>>,
55    _phantom: PhantomData<T>,
56}
57
58#[async_trait]
59impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
60where
61    Self: ListCdcSplits<CdcSourceType = T> + CdcMonitor,
62{
63    type Properties = CdcProperties<T>;
64    type Split = DebeziumCdcSplit<T>;
65
66    async fn new(
67        props: CdcProperties<T>,
68        context: SourceEnumeratorContextRef,
69    ) -> ConnectorResult<Self> {
70        let server_addrs = props
71            .properties
72            .get(DATABASE_SERVERS_KEY)
73            .map(|s| {
74                s.split(',')
75                    .map(HostAddr::from_str)
76                    .collect::<Result<Vec<_>, _>>()
77            })
78            .transpose()?
79            .unwrap_or_default();
80
81        assert_eq!(
82            props.get_source_type_pb(),
83            SourceType::from(T::source_type())
84        );
85
86        let jvm = Jvm::get_or_init()?;
87        let source_id = context.info.source_id;
88
89        // Extract fields before moving props
90        let source_type_pb = props.get_source_type_pb();
91
92        // Create Arc once and share it
93        let properties_arc = Arc::new(props.properties);
94        let properties_arc_for_validation = properties_arc.clone();
95        let table_schema_for_validation = props.table_schema;
96
97        tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
98            execute_with_jni_env(jvm, |env| {
99                let validate_source_request = ValidateSourceRequest {
100                    source_id: source_id.as_raw_id() as u64,
101                    source_type: source_type_pb as _,
102                    properties: (*properties_arc_for_validation).clone(),
103                    table_schema: Some(table_schema_exclude_additional_columns(
104                        &table_schema_for_validation,
105                    )),
106                    is_source_job: props.is_cdc_source_job,
107                    is_backfill_table: props.is_backfill_table,
108                };
109
110                let validate_source_request_bytes =
111                    env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
112
113                let validate_source_response_bytes = call_static_method!(
114                    env,
115                    {com.risingwave.connector.source.JniSourceValidateHandler},
116                    {byte[] validate(byte[] validateSourceRequestBytes)},
117                    &validate_source_request_bytes
118                )?;
119
120                let validate_source_response: ValidateSourceResponse = Message::decode(
121                    risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
122                        .deref(),
123                )?;
124
125                if let Some(error) = validate_source_response.error {
126                    return Err(
127                        anyhow!(error.error_message).context("source cannot pass validation")
128                    );
129                }
130
131                Ok(())
132            })
133        })
134        .await
135        .context("failed to validate source")??;
136
137        tracing::debug!("validate cdc source properties success");
138        Ok(Self {
139            source_id,
140            worker_node_addrs: server_addrs,
141            metrics: context.metrics.clone(),
142            properties: properties_arc,
143            _phantom: PhantomData,
144        })
145    }
146
147    async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
148        Ok(self.list_cdc_splits())
149    }
150
151    async fn on_tick(&mut self) -> ConnectorResult<()> {
152        self.monitor_cdc().await
153    }
154}
155
156impl<T: CdcSourceTypeTrait> DebeziumSplitEnumerator<T> {
157    async fn monitor_postgres_confirmed_flush_lsn(&mut self) -> ConnectorResult<()> {
158        // Query confirmed flush LSN and update metrics
159        match self.query_confirmed_flush_lsn().await {
160            Ok(Some((lsn, slot_name))) => {
161                // Update metrics
162                self.metrics
163                    .pg_cdc_confirmed_flush_lsn
164                    .with_guarded_label_values(&[
165                        &self.source_id.to_string(),
166                        &slot_name.to_owned(),
167                    ])
168                    .set(lsn as i64);
169                tracing::debug!(
170                    "Updated confirm_flush_lsn for source {} slot {}: {}",
171                    self.source_id,
172                    slot_name,
173                    lsn
174                );
175            }
176            Ok(None) => {
177                tracing::warn!("No confirmed_flush_lsn found for source {}", self.source_id);
178            }
179            Err(e) => {
180                tracing::error!(
181                    "Failed to query confirmed_flush_lsn for source {}: {}",
182                    self.source_id,
183                    e.as_report()
184                );
185            }
186        }
187        Ok(())
188    }
189
190    /// Query confirmed flush LSN from PostgreSQL, return the slot name and the confirmed flush LSN.
191    async fn query_confirmed_flush_lsn(&self) -> ConnectorResult<Option<(u64, &str)>> {
192        // Extract connection parameters from CDC properties
193        let hostname = self
194            .properties
195            .get("hostname")
196            .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
197        let port = self
198            .properties
199            .get("port")
200            .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?;
201        let user = self
202            .properties
203            .get("username")
204            .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
205        let password = self
206            .properties
207            .get("password")
208            .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
209        let database = self
210            .properties
211            .get("database.name")
212            .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
213
214        // Get SSL mode from properties, default to Preferred if not specified
215        let ssl_mode = self
216            .properties
217            .get("ssl.mode")
218            .and_then(|s| serde_json::from_value(serde_json::Value::String(s.clone())).ok())
219            .unwrap_or(SslMode::Preferred);
220        let ssl_root_cert = self.properties.get("database.ssl.root.cert").cloned();
221
222        let slot_name = self
223            .properties
224            .get("slot.name")
225            .ok_or_else(|| anyhow::anyhow!("slot.name not found in CDC properties"))?;
226
227        // Create PostgreSQL client
228        let client = create_pg_client(
229            user,
230            password,
231            hostname,
232            port,
233            database,
234            &ssl_mode,
235            &ssl_root_cert,
236        )
237        .await
238        .context("Failed to create PostgreSQL client")?;
239
240        let query = "SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = $1";
241        let row = client
242            .query_opt(query, &[&slot_name])
243            .await
244            .context("PostgreSQL query confirmed flush lsn error")?;
245
246        match row {
247            Some(row) => {
248                let confirm_flush_lsn: Option<PgLsn> = row.get(0);
249                if let Some(lsn) = confirm_flush_lsn {
250                    Ok(Some((lsn.into(), slot_name.as_str())))
251                } else {
252                    Ok(None)
253                }
254            }
255            None => {
256                tracing::warn!("No replication slot found with name: {}", slot_name);
257                Ok(None)
258            }
259        }
260    }
261}
262
263pub trait ListCdcSplits {
264    type CdcSourceType: CdcSourceTypeTrait;
265    /// Generates a single split for shared source.
266    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
267}
268
269/// Trait for CDC-specific monitoring behavior
270#[async_trait]
271pub trait CdcMonitor {
272    async fn monitor_cdc(&mut self) -> ConnectorResult<()>;
273}
274
275#[async_trait]
276impl<T: CdcSourceTypeTrait> CdcMonitor for DebeziumSplitEnumerator<T> {
277    default async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
278        Ok(())
279    }
280}
281
282impl DebeziumSplitEnumerator<Mysql> {
283    async fn monitor_mysql_binlog_files(&mut self) -> ConnectorResult<()> {
284        // Get hostname and port for metrics labels
285        let hostname = self
286            .properties
287            .get("hostname")
288            .map(|s| s.as_str())
289            .ok_or_else(|| {
290                anyhow::anyhow!("missing required property 'hostname' for MySQL CDC source")
291            })?;
292        let port = self
293            .properties
294            .get("port")
295            .map(|s| s.as_str())
296            .ok_or_else(|| {
297                anyhow::anyhow!("missing required property 'port' for MySQL CDC source")
298            })?;
299
300        // Query binlog files and update metrics
301        match self.query_binlog_files().await {
302            Ok(binlog_files) => {
303                if let Some((oldest_file, oldest_size)) = binlog_files.first() {
304                    // Extract sequence number from filename (e.g., "binlog.000001" -> 1)
305                    if let Some(seq) = Self::extract_binlog_seq(oldest_file) {
306                        self.metrics
307                            .mysql_cdc_binlog_file_seq_min
308                            .with_guarded_label_values(&[hostname, port])
309                            .set(seq as i64);
310                        tracing::debug!(
311                            "MySQL CDC source {} ({}:{}): oldest binlog = {}, seq = {}, size = {}",
312                            self.source_id,
313                            hostname,
314                            port,
315                            oldest_file,
316                            seq,
317                            oldest_size
318                        );
319                    }
320                }
321                if let Some((newest_file, newest_size)) = binlog_files.last() {
322                    // Extract sequence number from filename
323                    if let Some(seq) = Self::extract_binlog_seq(newest_file) {
324                        self.metrics
325                            .mysql_cdc_binlog_file_seq_max
326                            .with_guarded_label_values(&[hostname, port])
327                            .set(seq as i64);
328                        tracing::debug!(
329                            "MySQL CDC source {} ({}:{}): newest binlog = {}, seq = {}, size = {}",
330                            self.source_id,
331                            hostname,
332                            port,
333                            newest_file,
334                            seq,
335                            newest_size
336                        );
337                    }
338                }
339                tracing::debug!(
340                    "MySQL CDC source {} ({}:{}): total {} binlog files",
341                    self.source_id,
342                    hostname,
343                    port,
344                    binlog_files.len()
345                );
346            }
347            Err(e) => {
348                tracing::error!(
349                    "Failed to query binlog files for MySQL CDC source {} ({}:{}): {}",
350                    self.source_id,
351                    hostname,
352                    port,
353                    e.as_report()
354                );
355            }
356        }
357        Ok(())
358    }
359
360    /// Extract sequence number from binlog filename
361    /// e.g., "binlog.000001" -> Some(1), "mysql-bin.000123" -> Some(123)
362    fn extract_binlog_seq(filename: &str) -> Option<u64> {
363        filename.rsplit('.').next()?.parse::<u64>().ok()
364    }
365
366    /// Query binlog files from MySQL, returns Vec<(filename, size)>
367    async fn query_binlog_files(&self) -> ConnectorResult<Vec<(String, u64)>> {
368        // Extract connection parameters from CDC properties
369        let hostname = self
370            .properties
371            .get("hostname")
372            .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
373        let port = self
374            .properties
375            .get("port")
376            .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?
377            .parse::<u16>()
378            .context("failed to parse port as u16")?;
379        let username = self
380            .properties
381            .get("username")
382            .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
383        let password = self
384            .properties
385            .get("password")
386            .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
387        let database = self
388            .properties
389            .get("database.name")
390            .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
391
392        // Get SSL mode configuration (default to Disabled if not specified)
393        let ssl_mode = self
394            .properties
395            .get("ssl.mode")
396            .and_then(|s| serde_json::from_value(serde_json::Value::String(s.clone())).ok())
397            .unwrap_or(SslMode::Preferred);
398
399        // Build MySQL connection pool with proper SSL configuration
400        let pool =
401            build_mysql_connection_pool(hostname, port, username, password, database, ssl_mode);
402        let mut conn = pool
403            .get_conn()
404            .await
405            .context("Failed to connect to MySQL")?;
406
407        // Query binlog files using SHOW BINARY LOGS
408        // Note: MySQL 8.0+ returns 3 columns: Log_name, File_size, Encrypted
409        let query_result: Vec<(String, u64)> = conn
410            .query_map(
411                "SHOW BINARY LOGS",
412                |(log_name, file_size, _encrypted): (String, u64, String)| (log_name, file_size),
413            )
414            .await
415            .context("Failed to execute SHOW BINARY LOGS")?;
416
417        drop(conn);
418        pool.disconnect().await.ok();
419
420        Ok(query_result)
421    }
422}
423
424impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
425    type CdcSourceType = Mysql;
426
427    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
428        // CDC source only supports single split
429        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
430            self.source_id.as_raw_id(),
431            None,
432            None,
433        )]
434    }
435}
436
437#[async_trait]
438impl CdcMonitor for DebeziumSplitEnumerator<Mysql> {
439    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
440        // For MySQL CDC, query the upstream MySQL binlog files and monitor them.
441        self.monitor_mysql_binlog_files().await?;
442        Ok(())
443    }
444}
445
446impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
447    type CdcSourceType = Postgres;
448
449    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
450        // CDC source only supports single split
451        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
452            self.source_id.as_raw_id(),
453            None,
454            None,
455        )]
456    }
457}
458
459#[async_trait]
460impl CdcMonitor for DebeziumSplitEnumerator<Postgres> {
461    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
462        // For PostgreSQL CDC, query the upstream Postgres confirmed flush lsn and monitor it.
463        self.monitor_postgres_confirmed_flush_lsn().await?;
464        Ok(())
465    }
466}
467
468impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
469    type CdcSourceType = Citus;
470
471    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
472        self.worker_node_addrs
473            .iter()
474            .enumerate()
475            .map(|(id, addr)| {
476                DebeziumCdcSplit::<Self::CdcSourceType>::new(
477                    id as u32,
478                    None,
479                    Some(addr.to_string()),
480                )
481            })
482            .collect_vec()
483    }
484}
485impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
486    type CdcSourceType = Mongodb;
487
488    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
489        // CDC source only supports single split
490        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
491            self.source_id.as_raw_id(),
492            None,
493            None,
494        )]
495    }
496}
497
498impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
499    type CdcSourceType = SqlServer;
500
501    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
502        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
503            self.source_id.as_raw_id(),
504            None,
505            None,
506        )]
507    }
508}