risingwave_connector/source/cdc/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::marker::PhantomData;
17use std::ops::Deref;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use anyhow::{Context, anyhow};
22use async_trait::async_trait;
23use itertools::Itertools;
24use prost::Message;
25use risingwave_common::global_jvm::Jvm;
26use risingwave_common::id::SourceId;
27use risingwave_common::util::addr::HostAddr;
28use risingwave_jni_core::call_static_method;
29use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
30use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
31use thiserror_ext::AsReport;
32use tokio_postgres::types::PgLsn;
33
34use crate::connector_common::{SslMode, create_pg_client};
35use crate::error::ConnectorResult;
36use crate::source::cdc::{
37    CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
38    SqlServer, table_schema_exclude_additional_columns,
39};
40use crate::source::monitor::metrics::EnumeratorMetrics;
41use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
42
43pub const DATABASE_SERVERS_KEY: &str = "database.servers";
44
45#[derive(Debug)]
46pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
47    /// The `source_id` in the catalog
48    source_id: SourceId,
49    worker_node_addrs: Vec<HostAddr>,
50    metrics: Arc<EnumeratorMetrics>,
51    /// Properties specified in the WITH clause by user for database connection
52    properties: Arc<BTreeMap<String, String>>,
53    _phantom: PhantomData<T>,
54}
55
56#[async_trait]
57impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
58where
59    Self: ListCdcSplits<CdcSourceType = T> + CdcMonitor,
60{
61    type Properties = CdcProperties<T>;
62    type Split = DebeziumCdcSplit<T>;
63
64    async fn new(
65        props: CdcProperties<T>,
66        context: SourceEnumeratorContextRef,
67    ) -> ConnectorResult<Self> {
68        let server_addrs = props
69            .properties
70            .get(DATABASE_SERVERS_KEY)
71            .map(|s| {
72                s.split(',')
73                    .map(HostAddr::from_str)
74                    .collect::<Result<Vec<_>, _>>()
75            })
76            .transpose()?
77            .unwrap_or_default();
78
79        assert_eq!(
80            props.get_source_type_pb(),
81            SourceType::from(T::source_type())
82        );
83
84        let jvm = Jvm::get_or_init()?;
85        let source_id = context.info.source_id;
86
87        // Extract fields before moving props
88        let source_type_pb = props.get_source_type_pb();
89
90        // Create Arc once and share it
91        let properties_arc = Arc::new(props.properties);
92        let properties_arc_for_validation = properties_arc.clone();
93        let table_schema_for_validation = props.table_schema;
94
95        tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
96            execute_with_jni_env(jvm, |env| {
97                let validate_source_request = ValidateSourceRequest {
98                    source_id: source_id.as_raw_id() as u64,
99                    source_type: source_type_pb as _,
100                    properties: (*properties_arc_for_validation).clone(),
101                    table_schema: Some(table_schema_exclude_additional_columns(
102                        &table_schema_for_validation,
103                    )),
104                    is_source_job: props.is_cdc_source_job,
105                    is_backfill_table: props.is_backfill_table,
106                };
107
108                let validate_source_request_bytes =
109                    env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
110
111                let validate_source_response_bytes = call_static_method!(
112                    env,
113                    {com.risingwave.connector.source.JniSourceValidateHandler},
114                    {byte[] validate(byte[] validateSourceRequestBytes)},
115                    &validate_source_request_bytes
116                )?;
117
118                let validate_source_response: ValidateSourceResponse = Message::decode(
119                    risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
120                        .deref(),
121                )?;
122
123                if let Some(error) = validate_source_response.error {
124                    return Err(
125                        anyhow!(error.error_message).context("source cannot pass validation")
126                    );
127                }
128
129                Ok(())
130            })
131        })
132        .await
133        .context("failed to validate source")??;
134
135        tracing::debug!("validate cdc source properties success");
136        Ok(Self {
137            source_id,
138            worker_node_addrs: server_addrs,
139            metrics: context.metrics.clone(),
140            properties: properties_arc,
141            _phantom: PhantomData,
142        })
143    }
144
145    async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
146        Ok(self.list_cdc_splits())
147    }
148
149    async fn on_tick(&mut self) -> ConnectorResult<()> {
150        self.monitor_cdc().await
151    }
152}
153
154impl<T: CdcSourceTypeTrait> DebeziumSplitEnumerator<T> {
155    async fn monitor_postgres_confirmed_flush_lsn(&mut self) -> ConnectorResult<()> {
156        // Query confirmed flush LSN and update metrics
157        match self.query_confirmed_flush_lsn().await {
158            Ok(Some((lsn, slot_name))) => {
159                // Update metrics
160                self.metrics
161                    .pg_cdc_confirmed_flush_lsn
162                    .with_guarded_label_values(&[
163                        &self.source_id.to_string(),
164                        &slot_name.to_owned(),
165                    ])
166                    .set(lsn as i64);
167                tracing::debug!(
168                    "Updated confirm_flush_lsn for source {} slot {}: {}",
169                    self.source_id,
170                    slot_name,
171                    lsn
172                );
173            }
174            Ok(None) => {
175                tracing::warn!("No confirmed_flush_lsn found for source {}", self.source_id);
176            }
177            Err(e) => {
178                tracing::error!(
179                    "Failed to query confirmed_flush_lsn for source {}: {}",
180                    self.source_id,
181                    e.as_report()
182                );
183            }
184        }
185        Ok(())
186    }
187
188    /// Query confirmed flush LSN from PostgreSQL, return the slot name and the confirmed flush LSN.
189    async fn query_confirmed_flush_lsn(&self) -> ConnectorResult<Option<(u64, &str)>> {
190        // Extract connection parameters from CDC properties
191        let hostname = self
192            .properties
193            .get("hostname")
194            .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
195        let port = self
196            .properties
197            .get("port")
198            .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?;
199        let user = self
200            .properties
201            .get("username")
202            .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
203        let password = self
204            .properties
205            .get("password")
206            .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
207        let database = self
208            .properties
209            .get("database.name")
210            .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
211
212        let ssl_mode = SslMode::Preferred;
213        let ssl_root_cert = self.properties.get("database.ssl.root.cert").cloned();
214
215        let slot_name = self
216            .properties
217            .get("slot.name")
218            .ok_or_else(|| anyhow::anyhow!("slot.name not found in CDC properties"))?;
219
220        // Create PostgreSQL client
221        let client = create_pg_client(
222            user,
223            password,
224            hostname,
225            port,
226            database,
227            &ssl_mode,
228            &ssl_root_cert,
229        )
230        .await
231        .context("Failed to create PostgreSQL client")?;
232
233        let query = "SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = $1";
234        let row = client
235            .query_opt(query, &[&slot_name])
236            .await
237            .context("PostgreSQL query confirmed flush lsn error")?;
238
239        match row {
240            Some(row) => {
241                let confirm_flush_lsn: Option<PgLsn> = row.get(0);
242                if let Some(lsn) = confirm_flush_lsn {
243                    Ok(Some((lsn.into(), slot_name.as_str())))
244                } else {
245                    Ok(None)
246                }
247            }
248            None => {
249                tracing::warn!("No replication slot found with name: {}", slot_name);
250                Ok(None)
251            }
252        }
253    }
254}
255
256pub trait ListCdcSplits {
257    type CdcSourceType: CdcSourceTypeTrait;
258    /// Generates a single split for shared source.
259    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
260}
261
262/// Trait for CDC-specific monitoring behavior
263#[async_trait]
264pub trait CdcMonitor {
265    async fn monitor_cdc(&mut self) -> ConnectorResult<()>;
266}
267
268#[async_trait]
269impl<T: CdcSourceTypeTrait> CdcMonitor for DebeziumSplitEnumerator<T> {
270    default async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
271        Ok(())
272    }
273}
274
275impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
276    type CdcSourceType = Mysql;
277
278    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
279        // CDC source only supports single split
280        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
281            self.source_id.as_raw_id(),
282            None,
283            None,
284        )]
285    }
286}
287
288impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
289    type CdcSourceType = Postgres;
290
291    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
292        // CDC source only supports single split
293        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
294            self.source_id.as_raw_id(),
295            None,
296            None,
297        )]
298    }
299}
300
301#[async_trait]
302impl CdcMonitor for DebeziumSplitEnumerator<Postgres> {
303    async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
304        // For PostgreSQL CDC, query the upstream Postgres confirmed flush lsn and monitor it.
305        self.monitor_postgres_confirmed_flush_lsn().await?;
306        Ok(())
307    }
308}
309
310impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
311    type CdcSourceType = Citus;
312
313    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
314        self.worker_node_addrs
315            .iter()
316            .enumerate()
317            .map(|(id, addr)| {
318                DebeziumCdcSplit::<Self::CdcSourceType>::new(
319                    id as u32,
320                    None,
321                    Some(addr.to_string()),
322                )
323            })
324            .collect_vec()
325    }
326}
327impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
328    type CdcSourceType = Mongodb;
329
330    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
331        // CDC source only supports single split
332        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
333            self.source_id.as_raw_id(),
334            None,
335            None,
336        )]
337    }
338}
339
340impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
341    type CdcSourceType = SqlServer;
342
343    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
344        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
345            self.source_id.as_raw_id(),
346            None,
347            None,
348        )]
349    }
350}