risingwave_connector/source/cdc/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::marker::PhantomData;
17use std::ops::Deref;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use anyhow::{Context, anyhow};
22use async_trait::async_trait;
23use itertools::Itertools;
24use prost::Message;
25use risingwave_common::global_jvm::Jvm;
26use risingwave_common::util::addr::HostAddr;
27use risingwave_jni_core::call_static_method;
28use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
29use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::PgLsn;
32
33use crate::connector_common::{SslMode, create_pg_client};
34use crate::error::ConnectorResult;
35use crate::source::cdc::{
36    CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
37    SqlServer, table_schema_exclude_additional_columns,
38};
39use crate::source::monitor::metrics::EnumeratorMetrics;
40use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
41
42pub const DATABASE_SERVERS_KEY: &str = "database.servers";
43
44#[derive(Debug)]
45pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
46    /// The `source_id` in the catalog
47    source_id: u32,
48    worker_node_addrs: Vec<HostAddr>,
49    metrics: Arc<EnumeratorMetrics>,
50    /// Properties specified in the WITH clause by user for database connection
51    properties: Arc<BTreeMap<String, String>>,
52    _phantom: PhantomData<T>,
53}
54
55#[async_trait]
56impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
57where
58    Self: ListCdcSplits<CdcSourceType = T> + CdcMonitor,
59{
60    type Properties = CdcProperties<T>;
61    type Split = DebeziumCdcSplit<T>;
62
63    async fn new(
64        props: CdcProperties<T>,
65        context: SourceEnumeratorContextRef,
66    ) -> ConnectorResult<Self> {
67        let server_addrs = props
68            .properties
69            .get(DATABASE_SERVERS_KEY)
70            .map(|s| {
71                s.split(',')
72                    .map(HostAddr::from_str)
73                    .collect::<Result<Vec<_>, _>>()
74            })
75            .transpose()?
76            .unwrap_or_default();
77
78        assert_eq!(
79            props.get_source_type_pb(),
80            SourceType::from(T::source_type())
81        );
82
83        let jvm = Jvm::get_or_init()?;
84        let source_id = context.info.source_id;
85
86        // Extract fields before moving props
87        let source_type_pb = props.get_source_type_pb();
88
89        // Create Arc once and share it
90        let properties_arc = Arc::new(props.properties);
91        let properties_arc_for_validation = properties_arc.clone();
92        let table_schema_for_validation = props.table_schema;
93
94        tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
95            execute_with_jni_env(jvm, |env| {
96                let validate_source_request = ValidateSourceRequest {
97                    source_id: source_id as u64,
98                    source_type: source_type_pb as _,
99                    properties: (*properties_arc_for_validation).clone(),
100                    table_schema: Some(table_schema_exclude_additional_columns(
101                        &table_schema_for_validation,
102                    )),
103                    is_source_job: props.is_cdc_source_job,
104                    is_backfill_table: props.is_backfill_table,
105                };
106
107                let validate_source_request_bytes =
108                    env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
109
110                let validate_source_response_bytes = call_static_method!(
111                    env,
112                    {com.risingwave.connector.source.JniSourceValidateHandler},
113                    {byte[] validate(byte[] validateSourceRequestBytes)},
114                    &validate_source_request_bytes
115                )?;
116
117                let validate_source_response: ValidateSourceResponse = Message::decode(
118                    risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
119                        .deref(),
120                )?;
121
122                if let Some(error) = validate_source_response.error {
123                    return Err(
124                        anyhow!(error.error_message).context("source cannot pass validation")
125                    );
126                }
127
128                Ok(())
129            })
130        })
131        .await
132        .context("failed to validate source")??;
133
134        tracing::debug!("validate cdc source properties success");
135        Ok(Self {
136            source_id,
137            worker_node_addrs: server_addrs,
138            metrics: context.metrics.clone(),
139            properties: properties_arc,
140            _phantom: PhantomData,
141        })
142    }
143
144    async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
145        Ok(self.list_cdc_splits())
146    }
147
148    async fn on_tick(&mut self) -> ConnectorResult<()> {
149        self.monitor_cdc().await
150    }
151}
152
153impl<T: CdcSourceTypeTrait> DebeziumSplitEnumerator<T> {
154    async fn monitor_postgres_confirmed_flush_lsn(&mut self) -> ConnectorResult<()> {
155        // Query confirmed flush LSN and update metrics
156        match self.query_confirmed_flush_lsn().await {
157            Ok(Some((lsn, slot_name))) => {
158                // Update metrics
159                self.metrics
160                    .pg_cdc_confirmed_flush_lsn
161                    .with_guarded_label_values(&[
162                        &self.source_id.to_string(),
163                        &slot_name.to_owned(),
164                    ])
165                    .set(lsn as i64);
166                tracing::debug!(
167                    "Updated confirm_flush_lsn for source {} slot {}: {}",
168                    self.source_id,
169                    slot_name,
170                    lsn
171                );
172            }
173            Ok(None) => {
174                tracing::warn!("No confirmed_flush_lsn found for source {}", self.source_id);
175            }
176            Err(e) => {
177                tracing::error!(
178                    "Failed to query confirmed_flush_lsn for source {}: {}",
179                    self.source_id,
180                    e.as_report()
181                );
182            }
183        }
184        Ok(())
185    }
186
187    /// Query confirmed flush LSN from PostgreSQL, return the slot name and the confirmed flush LSN.
188    async fn query_confirmed_flush_lsn(&self) -> ConnectorResult<Option<(u64, &str)>> {
189        // Extract connection parameters from CDC properties
190        let hostname = self
191            .properties
192            .get("hostname")
193            .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
194        let port = self
195            .properties
196            .get("port")
197            .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?;
198        let user = self
199            .properties
200            .get("username")
201            .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
202        let password = self
203            .properties
204            .get("password")
205            .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
206        let database = self
207            .properties
208            .get("database.name")
209            .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
210
211        let ssl_mode = SslMode::Preferred;
212        let ssl_root_cert = self.properties.get("database.ssl.root.cert").cloned();
213
214        let slot_name = self
215            .properties
216            .get("slot.name")
217            .ok_or_else(|| anyhow::anyhow!("slot.name not found in CDC properties"))?;
218
219        // Create PostgreSQL client
220        let client = create_pg_client(
221            user,
222            password,
223            hostname,
224            port,
225            database,
226            &ssl_mode,
227            &ssl_root_cert,
228        )
229        .await
230        .context("Failed to create PostgreSQL client")?;
231
232        let query = "SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = $1";
233        let row = client
234            .query_opt(query, &[&slot_name])
235            .await
236            .context("PostgreSQL query confirmed flush lsn error")?;
237
238        match row {
239            Some(row) => {
240                let confirm_flush_lsn: Option<PgLsn> = row.get(0);
241                if let Some(lsn) = confirm_flush_lsn {
242                    Ok(Some((lsn.into(), slot_name.as_str())))
243                } else {
244                    Ok(None)
245                }
246            }
247            None => {
248                tracing::warn!("No replication slot found with name: {}", slot_name);
249                Ok(None)
250            }
251        }
252    }
253}
254
255pub trait ListCdcSplits {
256    type CdcSourceType: CdcSourceTypeTrait;
257    /// Generates a single split for shared source.
258    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
259}
260
261/// Trait for CDC-specific monitoring behavior
262#[async_trait]
263pub trait CdcMonitor {
264    async fn monitor_cdc(&mut self) -> ConnectorResult<()>;
265}
266
267#[async_trait]
268impl<T: CdcSourceTypeTrait> CdcMonitor for DebeziumSplitEnumerator<T> {
269    default async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
270        Ok(())
271    }
272}
273
274impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
275    type CdcSourceType = Mysql;
276
277    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
278        // CDC source only supports single split
279        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
280            self.source_id,
281            None,
282            None,
283        )]
284    }
285}
286
287impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
288    type CdcSourceType = Postgres;
289
290    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
291        // CDC source only supports single split
292        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
293            self.source_id,
294            None,
295            None,
296        )]
297    }
298}
299
300#[async_trait]
301impl CdcMonitor for DebeziumSplitEnumerator<Postgres> {
302    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
303        // For PostgreSQL CDC, query the upstream Postgres confirmed flush lsn and monitor it.
304        self.monitor_postgres_confirmed_flush_lsn().await?;
305        Ok(())
306    }
307}
308
309impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
310    type CdcSourceType = Citus;
311
312    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
313        self.worker_node_addrs
314            .iter()
315            .enumerate()
316            .map(|(id, addr)| {
317                DebeziumCdcSplit::<Self::CdcSourceType>::new(
318                    id as u32,
319                    None,
320                    Some(addr.to_string()),
321                )
322            })
323            .collect_vec()
324    }
325}
326impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
327    type CdcSourceType = Mongodb;
328
329    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
330        // CDC source only supports single split
331        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
332            self.source_id,
333            None,
334            None,
335        )]
336    }
337}
338
339impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
340    type CdcSourceType = SqlServer;
341
342    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
343        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
344            self.source_id,
345            None,
346            None,
347        )]
348    }
349}