risingwave_connector/source/cdc/enumerator/
mod.rs1use std::collections::BTreeMap;
16use std::marker::PhantomData;
17use std::ops::Deref;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use anyhow::{Context, anyhow};
22use async_trait::async_trait;
23use itertools::Itertools;
24use prost::Message;
25use risingwave_common::global_jvm::Jvm;
26use risingwave_common::util::addr::HostAddr;
27use risingwave_jni_core::call_static_method;
28use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
29use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::PgLsn;
32
33use crate::connector_common::{SslMode, create_pg_client};
34use crate::error::ConnectorResult;
35use crate::source::cdc::{
36 CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
37 SqlServer, table_schema_exclude_additional_columns,
38};
39use crate::source::monitor::metrics::EnumeratorMetrics;
40use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
41
42pub const DATABASE_SERVERS_KEY: &str = "database.servers";
43
44#[derive(Debug)]
45pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
46 source_id: u32,
48 worker_node_addrs: Vec<HostAddr>,
49 metrics: Arc<EnumeratorMetrics>,
50 properties: Arc<BTreeMap<String, String>>,
52 _phantom: PhantomData<T>,
53}
54
55#[async_trait]
56impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
57where
58 Self: ListCdcSplits<CdcSourceType = T> + CdcMonitor,
59{
60 type Properties = CdcProperties<T>;
61 type Split = DebeziumCdcSplit<T>;
62
63 async fn new(
64 props: CdcProperties<T>,
65 context: SourceEnumeratorContextRef,
66 ) -> ConnectorResult<Self> {
67 let server_addrs = props
68 .properties
69 .get(DATABASE_SERVERS_KEY)
70 .map(|s| {
71 s.split(',')
72 .map(HostAddr::from_str)
73 .collect::<Result<Vec<_>, _>>()
74 })
75 .transpose()?
76 .unwrap_or_default();
77
78 assert_eq!(
79 props.get_source_type_pb(),
80 SourceType::from(T::source_type())
81 );
82
83 let jvm = Jvm::get_or_init()?;
84 let source_id = context.info.source_id;
85
86 let source_type_pb = props.get_source_type_pb();
88
89 let properties_arc = Arc::new(props.properties);
91 let properties_arc_for_validation = properties_arc.clone();
92 let table_schema_for_validation = props.table_schema;
93
94 tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
95 execute_with_jni_env(jvm, |env| {
96 let validate_source_request = ValidateSourceRequest {
97 source_id: source_id as u64,
98 source_type: source_type_pb as _,
99 properties: (*properties_arc_for_validation).clone(),
100 table_schema: Some(table_schema_exclude_additional_columns(
101 &table_schema_for_validation,
102 )),
103 is_source_job: props.is_cdc_source_job,
104 is_backfill_table: props.is_backfill_table,
105 };
106
107 let validate_source_request_bytes =
108 env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
109
110 let validate_source_response_bytes = call_static_method!(
111 env,
112 {com.risingwave.connector.source.JniSourceValidateHandler},
113 {byte[] validate(byte[] validateSourceRequestBytes)},
114 &validate_source_request_bytes
115 )?;
116
117 let validate_source_response: ValidateSourceResponse = Message::decode(
118 risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
119 .deref(),
120 )?;
121
122 if let Some(error) = validate_source_response.error {
123 return Err(
124 anyhow!(error.error_message).context("source cannot pass validation")
125 );
126 }
127
128 Ok(())
129 })
130 })
131 .await
132 .context("failed to validate source")??;
133
134 tracing::debug!("validate cdc source properties success");
135 Ok(Self {
136 source_id,
137 worker_node_addrs: server_addrs,
138 metrics: context.metrics.clone(),
139 properties: properties_arc,
140 _phantom: PhantomData,
141 })
142 }
143
144 async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
145 Ok(self.list_cdc_splits())
146 }
147
148 async fn on_tick(&mut self) -> ConnectorResult<()> {
149 self.monitor_cdc().await
150 }
151}
152
153impl<T: CdcSourceTypeTrait> DebeziumSplitEnumerator<T> {
154 async fn monitor_postgres_confirmed_flush_lsn(&mut self) -> ConnectorResult<()> {
155 match self.query_confirmed_flush_lsn().await {
157 Ok(Some((lsn, slot_name))) => {
158 self.metrics
160 .pg_cdc_confirmed_flush_lsn
161 .with_guarded_label_values(&[
162 &self.source_id.to_string(),
163 &slot_name.to_owned(),
164 ])
165 .set(lsn as i64);
166 tracing::debug!(
167 "Updated confirm_flush_lsn for source {} slot {}: {}",
168 self.source_id,
169 slot_name,
170 lsn
171 );
172 }
173 Ok(None) => {
174 tracing::warn!("No confirmed_flush_lsn found for source {}", self.source_id);
175 }
176 Err(e) => {
177 tracing::error!(
178 "Failed to query confirmed_flush_lsn for source {}: {}",
179 self.source_id,
180 e.as_report()
181 );
182 }
183 }
184 Ok(())
185 }
186
187 async fn query_confirmed_flush_lsn(&self) -> ConnectorResult<Option<(u64, &str)>> {
189 let hostname = self
191 .properties
192 .get("hostname")
193 .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
194 let port = self
195 .properties
196 .get("port")
197 .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?;
198 let user = self
199 .properties
200 .get("username")
201 .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
202 let password = self
203 .properties
204 .get("password")
205 .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
206 let database = self
207 .properties
208 .get("database.name")
209 .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
210
211 let ssl_mode = SslMode::Preferred;
212 let ssl_root_cert = self.properties.get("database.ssl.root.cert").cloned();
213
214 let slot_name = self
215 .properties
216 .get("slot.name")
217 .ok_or_else(|| anyhow::anyhow!("slot.name not found in CDC properties"))?;
218
219 let client = create_pg_client(
221 user,
222 password,
223 hostname,
224 port,
225 database,
226 &ssl_mode,
227 &ssl_root_cert,
228 )
229 .await
230 .context("Failed to create PostgreSQL client")?;
231
232 let query = "SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = $1";
233 let row = client
234 .query_opt(query, &[&slot_name])
235 .await
236 .context("PostgreSQL query confirmed flush lsn error")?;
237
238 match row {
239 Some(row) => {
240 let confirm_flush_lsn: Option<PgLsn> = row.get(0);
241 if let Some(lsn) = confirm_flush_lsn {
242 Ok(Some((lsn.into(), slot_name.as_str())))
243 } else {
244 Ok(None)
245 }
246 }
247 None => {
248 tracing::warn!("No replication slot found with name: {}", slot_name);
249 Ok(None)
250 }
251 }
252 }
253}
254
255pub trait ListCdcSplits {
256 type CdcSourceType: CdcSourceTypeTrait;
257 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
259}
260
261#[async_trait]
263pub trait CdcMonitor {
264 async fn monitor_cdc(&mut self) -> ConnectorResult<()>;
265}
266
267#[async_trait]
268impl<T: CdcSourceTypeTrait> CdcMonitor for DebeziumSplitEnumerator<T> {
269 default async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
270 Ok(())
271 }
272}
273
274impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
275 type CdcSourceType = Mysql;
276
277 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
278 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
280 self.source_id,
281 None,
282 None,
283 )]
284 }
285}
286
287impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
288 type CdcSourceType = Postgres;
289
290 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
291 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
293 self.source_id,
294 None,
295 None,
296 )]
297 }
298}
299
300#[async_trait]
301impl CdcMonitor for DebeziumSplitEnumerator<Postgres> {
302 async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
303 self.monitor_postgres_confirmed_flush_lsn().await?;
305 Ok(())
306 }
307}
308
309impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
310 type CdcSourceType = Citus;
311
312 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
313 self.worker_node_addrs
314 .iter()
315 .enumerate()
316 .map(|(id, addr)| {
317 DebeziumCdcSplit::<Self::CdcSourceType>::new(
318 id as u32,
319 None,
320 Some(addr.to_string()),
321 )
322 })
323 .collect_vec()
324 }
325}
326impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
327 type CdcSourceType = Mongodb;
328
329 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
330 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
332 self.source_id,
333 None,
334 None,
335 )]
336 }
337}
338
339impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
340 type CdcSourceType = SqlServer;
341
342 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
343 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
344 self.source_id,
345 None,
346 None,
347 )]
348 }
349}