risingwave_connector/source/cdc/enumerator/
mod.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::marker::PhantomData;
17use std::ops::Deref;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use anyhow::{Context, anyhow};
22use async_trait::async_trait;
23use itertools::Itertools;
24use mysql_async::prelude::*;
25use prost::Message;
26use risingwave_common::global_jvm::Jvm;
27use risingwave_common::id::SourceId;
28use risingwave_common::util::addr::HostAddr;
29use risingwave_jni_core::call_static_method;
30use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
31use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
32use thiserror_ext::AsReport;
33use tokio_postgres::types::PgLsn;
34
35use crate::connector_common::{SslMode, create_pg_client};
36use crate::error::ConnectorResult;
37use crate::source::cdc::external::mysql::build_mysql_connection_pool;
38use crate::source::cdc::{
39    CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
40    SqlServer, table_schema_exclude_additional_columns,
41};
42use crate::source::monitor::metrics::EnumeratorMetrics;
43use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
44
45pub const DATABASE_SERVERS_KEY: &str = "database.servers";
46
47#[derive(Debug)]
48pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
49    /// The `source_id` in the catalog
50    source_id: SourceId,
51    worker_node_addrs: Vec<HostAddr>,
52    metrics: Arc<EnumeratorMetrics>,
53    /// Properties specified in the WITH clause by user for database connection
54    properties: Arc<BTreeMap<String, String>>,
55    _phantom: PhantomData<T>,
56}
57
58#[async_trait]
59impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
60where
61    Self: ListCdcSplits<CdcSourceType = T> + CdcMonitor,
62{
63    type Properties = CdcProperties<T>;
64    type Split = DebeziumCdcSplit<T>;
65
66    async fn new(
67        props: CdcProperties<T>,
68        context: SourceEnumeratorContextRef,
69    ) -> ConnectorResult<Self> {
70        let server_addrs = props
71            .properties
72            .get(DATABASE_SERVERS_KEY)
73            .map(|s| {
74                s.split(',')
75                    .map(HostAddr::from_str)
76                    .collect::<Result<Vec<_>, _>>()
77            })
78            .transpose()?
79            .unwrap_or_default();
80
81        assert_eq!(
82            props.get_source_type_pb(),
83            SourceType::from(T::source_type())
84        );
85
86        let jvm = Jvm::get_or_init()?;
87        let source_id = context.info.source_id;
88
89        // Extract fields before moving props
90        let source_type_pb = props.get_source_type_pb();
91
92        // Create Arc once and share it
93        let properties_arc = Arc::new(props.properties);
94        let properties_arc_for_validation = properties_arc.clone();
95        let table_schema_for_validation = props.table_schema;
96
97        tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
98            execute_with_jni_env(jvm, |env| {
99                let validate_source_request = ValidateSourceRequest {
100                    source_id: source_id.as_raw_id() as u64,
101                    source_type: source_type_pb as _,
102                    properties: (*properties_arc_for_validation).clone(),
103                    table_schema: Some(table_schema_exclude_additional_columns(
104                        &table_schema_for_validation,
105                    )),
106                    is_source_job: props.is_cdc_source_job,
107                    is_backfill_table: props.is_backfill_table,
108                };
109
110                let validate_source_request_bytes =
111                    env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
112
113                let validate_source_response_bytes = call_static_method!(
114                    env,
115                    {com.risingwave.connector.source.JniSourceValidateHandler},
116                    {byte[] validate(byte[] validateSourceRequestBytes)},
117                    &validate_source_request_bytes
118                )?;
119
120                let validate_source_response: ValidateSourceResponse = Message::decode(
121                    risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
122                        .deref(),
123                )?;
124
125                if let Some(error) = validate_source_response.error {
126                    return Err(
127                        anyhow!(error.error_message).context("source cannot pass validation")
128                    );
129                }
130
131                Ok(())
132            })
133        })
134        .await
135        .context("failed to validate source")??;
136
137        tracing::debug!("validate cdc source properties success");
138        Ok(Self {
139            source_id,
140            worker_node_addrs: server_addrs,
141            metrics: context.metrics.clone(),
142            properties: properties_arc,
143            _phantom: PhantomData,
144        })
145    }
146
147    async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
148        Ok(self.list_cdc_splits())
149    }
150
151    async fn on_tick(&mut self) -> ConnectorResult<()> {
152        self.monitor_cdc().await
153    }
154}
155
156impl<T: CdcSourceTypeTrait> DebeziumSplitEnumerator<T> {
157    async fn monitor_postgres_confirmed_flush_lsn(&mut self) -> ConnectorResult<()> {
158        // Query upstream LSNs and update metrics.
159        match self.query_postgres_lsns().await {
160            Ok(Some((confirmed_flush_lsn, upstream_max_lsn, slot_name))) => {
161                let labels = [&self.source_id.to_string(), &slot_name.to_owned()];
162
163                self.metrics
164                    .pg_cdc_upstream_max_lsn
165                    .with_guarded_label_values(&labels)
166                    .set(upstream_max_lsn as i64);
167
168                if let Some(lsn) = confirmed_flush_lsn {
169                    self.metrics
170                        .pg_cdc_confirmed_flush_lsn
171                        .with_guarded_label_values(&labels)
172                        .set(lsn as i64);
173                    tracing::debug!(
174                        "Updated confirmed_flush_lsn for source {} slot {}: {}",
175                        self.source_id,
176                        slot_name,
177                        lsn
178                    );
179                } else {
180                    tracing::warn!(
181                        "confirmed_flush_lsn is NULL for source {} slot {}",
182                        self.source_id,
183                        slot_name
184                    );
185                }
186            }
187            Ok(None) => {
188                tracing::warn!(
189                    "No replication slot found when querying LSNs for source {}",
190                    self.source_id
191                );
192            }
193            Err(e) => {
194                tracing::error!(
195                    "Failed to query PostgreSQL LSNs for source {}: {}",
196                    self.source_id,
197                    e.as_report()
198                );
199            }
200        };
201        Ok(())
202    }
203
204    /// Query LSNs from PostgreSQL, return (`confirmed_flush_lsn`, `upstream_max_lsn`, `slot_name`).
205    async fn query_postgres_lsns(&self) -> ConnectorResult<Option<(Option<u64>, u64, &str)>> {
206        // Extract connection parameters from CDC properties
207        let hostname = self
208            .properties
209            .get("hostname")
210            .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
211        let port = self
212            .properties
213            .get("port")
214            .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?;
215        let user = self
216            .properties
217            .get("username")
218            .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
219        let password = self
220            .properties
221            .get("password")
222            .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
223        let database = self
224            .properties
225            .get("database.name")
226            .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
227
228        // Get SSL mode from properties, default to Preferred if not specified
229        let ssl_mode = self
230            .properties
231            .get("ssl.mode")
232            .and_then(|s| s.parse().ok())
233            .unwrap_or(SslMode::Preferred);
234        let ssl_root_cert = self.properties.get("database.ssl.root.cert").cloned();
235
236        let slot_name = self
237            .properties
238            .get("slot.name")
239            .ok_or_else(|| anyhow::anyhow!("slot.name not found in CDC properties"))?;
240
241        // Create PostgreSQL client
242        let client = create_pg_client(
243            user,
244            password,
245            hostname,
246            port,
247            database,
248            &ssl_mode,
249            &ssl_root_cert,
250            None, // No TCP keepalive for CDC enumerator
251        )
252        .await
253        .context("Failed to create PostgreSQL client")?;
254
255        let query = "SELECT confirmed_flush_lsn, pg_current_wal_lsn() \
256            FROM pg_replication_slots WHERE slot_name = $1";
257        let row = client
258            .query_opt(query, &[&slot_name])
259            .await
260            .context("PostgreSQL query LSNs error")?;
261        match row {
262            Some(row) => {
263                let confirmed_flush_lsn: Option<PgLsn> = row.get(0);
264                let upstream_max_lsn: PgLsn = row.get(1);
265                Ok(Some((
266                    confirmed_flush_lsn.map(Into::into),
267                    upstream_max_lsn.into(),
268                    slot_name.as_str(),
269                )))
270            }
271            None => {
272                tracing::warn!("No replication slot found with name: {}", slot_name);
273                Ok(None)
274            }
275        }
276    }
277}
278
279pub trait ListCdcSplits {
280    type CdcSourceType: CdcSourceTypeTrait;
281    /// Generates a single split for shared source.
282    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
283}
284
285/// Trait for CDC-specific monitoring behavior
286#[async_trait]
287pub trait CdcMonitor {
288    async fn monitor_cdc(&mut self) -> ConnectorResult<()>;
289}
290
291#[async_trait]
292impl<T: CdcSourceTypeTrait> CdcMonitor for DebeziumSplitEnumerator<T> {
293    default async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
294        Ok(())
295    }
296}
297
298impl DebeziumSplitEnumerator<Mysql> {
299    async fn monitor_mysql_binlog_files(&mut self) -> ConnectorResult<()> {
300        // Get hostname and port for metrics labels
301        let hostname = self
302            .properties
303            .get("hostname")
304            .map(|s| s.as_str())
305            .ok_or_else(|| {
306                anyhow::anyhow!("missing required property 'hostname' for MySQL CDC source")
307            })?;
308        let port = self
309            .properties
310            .get("port")
311            .map(|s| s.as_str())
312            .ok_or_else(|| {
313                anyhow::anyhow!("missing required property 'port' for MySQL CDC source")
314            })?;
315
316        // Query binlog files and update metrics
317        match self.query_binlog_files().await {
318            Ok(binlog_files) => {
319                if let Some((oldest_file, oldest_size)) = binlog_files.first() {
320                    // Extract sequence number from filename (e.g., "binlog.000001" -> 1)
321                    if let Some(seq) = Self::extract_binlog_seq(oldest_file) {
322                        self.metrics
323                            .mysql_cdc_binlog_file_seq_min
324                            .with_guarded_label_values(&[hostname, port])
325                            .set(seq as i64);
326                        tracing::debug!(
327                            "MySQL CDC source {} ({}:{}): oldest binlog = {}, seq = {}, size = {}",
328                            self.source_id,
329                            hostname,
330                            port,
331                            oldest_file,
332                            seq,
333                            oldest_size
334                        );
335                    }
336                }
337                if let Some((newest_file, newest_size)) = binlog_files.last() {
338                    // Extract sequence number from filename
339                    if let Some(seq) = Self::extract_binlog_seq(newest_file) {
340                        self.metrics
341                            .mysql_cdc_binlog_file_seq_max
342                            .with_guarded_label_values(&[hostname, port])
343                            .set(seq as i64);
344                        tracing::debug!(
345                            "MySQL CDC source {} ({}:{}): newest binlog = {}, seq = {}, size = {}",
346                            self.source_id,
347                            hostname,
348                            port,
349                            newest_file,
350                            seq,
351                            newest_size
352                        );
353                    }
354                }
355                tracing::debug!(
356                    "MySQL CDC source {} ({}:{}): total {} binlog files",
357                    self.source_id,
358                    hostname,
359                    port,
360                    binlog_files.len()
361                );
362            }
363            Err(e) => {
364                tracing::error!(
365                    "Failed to query binlog files for MySQL CDC source {} ({}:{}): {}",
366                    self.source_id,
367                    hostname,
368                    port,
369                    e.as_report()
370                );
371            }
372        }
373        Ok(())
374    }
375
376    /// Extract sequence number from binlog filename
377    /// e.g., "binlog.000001" -> Some(1), "mysql-bin.000123" -> Some(123)
378    fn extract_binlog_seq(filename: &str) -> Option<u64> {
379        filename.rsplit('.').next()?.parse::<u64>().ok()
380    }
381
382    /// Query binlog files from MySQL, returns Vec<(filename, size)>
383    async fn query_binlog_files(&self) -> ConnectorResult<Vec<(String, u64)>> {
384        // Extract connection parameters from CDC properties
385        let hostname = self
386            .properties
387            .get("hostname")
388            .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
389        let port = self
390            .properties
391            .get("port")
392            .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?
393            .parse::<u16>()
394            .context("failed to parse port as u16")?;
395        let username = self
396            .properties
397            .get("username")
398            .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
399        let password = self
400            .properties
401            .get("password")
402            .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
403        let database = self
404            .properties
405            .get("database.name")
406            .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
407
408        // Get SSL mode configuration (default to Disabled if not specified)
409        let ssl_mode = self
410            .properties
411            .get("ssl.mode")
412            .and_then(|s| s.parse().ok())
413            .unwrap_or(SslMode::Preferred);
414
415        // Build MySQL connection pool with proper SSL configuration
416        let pool =
417            build_mysql_connection_pool(hostname, port, username, password, database, ssl_mode);
418        let mut conn = pool
419            .get_conn()
420            .await
421            .context("Failed to connect to MySQL")?;
422
423        // Query binlog files using SHOW BINARY LOGS
424        // Note: MySQL 8.0+ returns 3 columns: Log_name, File_size, Encrypted
425        let query_result: Vec<(String, u64)> = conn
426            .query_map(
427                "SHOW BINARY LOGS",
428                |(log_name, file_size, _encrypted): (String, u64, String)| (log_name, file_size),
429            )
430            .await
431            .context("Failed to execute SHOW BINARY LOGS")?;
432
433        drop(conn);
434        pool.disconnect().await.ok();
435
436        Ok(query_result)
437    }
438}
439
440impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
441    type CdcSourceType = Mysql;
442
443    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
444        // CDC source only supports single split
445        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
446            self.source_id.as_raw_id(),
447            None,
448            None,
449        )]
450    }
451}
452
453#[async_trait]
454impl CdcMonitor for DebeziumSplitEnumerator<Mysql> {
455    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
456        // For MySQL CDC, query the upstream MySQL binlog files and monitor them.
457        self.monitor_mysql_binlog_files().await?;
458        Ok(())
459    }
460}
461
462impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
463    type CdcSourceType = Postgres;
464
465    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
466        // CDC source only supports single split
467        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
468            self.source_id.as_raw_id(),
469            None,
470            None,
471        )]
472    }
473}
474
475#[async_trait]
476impl CdcMonitor for DebeziumSplitEnumerator<Postgres> {
477    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
478        // For PostgreSQL CDC, query the upstream Postgres confirmed flush lsn and monitor it.
479        self.monitor_postgres_confirmed_flush_lsn().await?;
480        Ok(())
481    }
482}
483
484impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
485    type CdcSourceType = Citus;
486
487    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
488        self.worker_node_addrs
489            .iter()
490            .enumerate()
491            .map(|(id, addr)| {
492                DebeziumCdcSplit::<Self::CdcSourceType>::new(
493                    id as u32,
494                    None,
495                    Some(addr.to_string()),
496                )
497            })
498            .collect_vec()
499    }
500}
501impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
502    type CdcSourceType = Mongodb;
503
504    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
505        // CDC source only supports single split
506        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
507            self.source_id.as_raw_id(),
508            None,
509            None,
510        )]
511    }
512}
513
514impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
515    type CdcSourceType = SqlServer;
516
517    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
518        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
519            self.source_id.as_raw_id(),
520            None,
521            None,
522        )]
523    }
524}