risingwave_connector/source/cdc/enumerator/
mod.rs1use std::collections::BTreeMap;
16use std::marker::PhantomData;
17use std::ops::Deref;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use anyhow::{Context, anyhow};
22use async_trait::async_trait;
23use itertools::Itertools;
24use prost::Message;
25use risingwave_common::global_jvm::Jvm;
26use risingwave_common::id::SourceId;
27use risingwave_common::util::addr::HostAddr;
28use risingwave_jni_core::call_static_method;
29use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
30use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
31use thiserror_ext::AsReport;
32use tokio_postgres::types::PgLsn;
33
34use crate::connector_common::{SslMode, create_pg_client};
35use crate::error::ConnectorResult;
36use crate::source::cdc::{
37 CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
38 SqlServer, table_schema_exclude_additional_columns,
39};
40use crate::source::monitor::metrics::EnumeratorMetrics;
41use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
42
43pub const DATABASE_SERVERS_KEY: &str = "database.servers";
44
45#[derive(Debug)]
46pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
47 source_id: SourceId,
49 worker_node_addrs: Vec<HostAddr>,
50 metrics: Arc<EnumeratorMetrics>,
51 properties: Arc<BTreeMap<String, String>>,
53 _phantom: PhantomData<T>,
54}
55
56#[async_trait]
57impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
58where
59 Self: ListCdcSplits<CdcSourceType = T> + CdcMonitor,
60{
61 type Properties = CdcProperties<T>;
62 type Split = DebeziumCdcSplit<T>;
63
64 async fn new(
65 props: CdcProperties<T>,
66 context: SourceEnumeratorContextRef,
67 ) -> ConnectorResult<Self> {
68 let server_addrs = props
69 .properties
70 .get(DATABASE_SERVERS_KEY)
71 .map(|s| {
72 s.split(',')
73 .map(HostAddr::from_str)
74 .collect::<Result<Vec<_>, _>>()
75 })
76 .transpose()?
77 .unwrap_or_default();
78
79 assert_eq!(
80 props.get_source_type_pb(),
81 SourceType::from(T::source_type())
82 );
83
84 let jvm = Jvm::get_or_init()?;
85 let source_id = context.info.source_id;
86
87 let source_type_pb = props.get_source_type_pb();
89
90 let properties_arc = Arc::new(props.properties);
92 let properties_arc_for_validation = properties_arc.clone();
93 let table_schema_for_validation = props.table_schema;
94
95 tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
96 execute_with_jni_env(jvm, |env| {
97 let validate_source_request = ValidateSourceRequest {
98 source_id: source_id.as_raw_id() as u64,
99 source_type: source_type_pb as _,
100 properties: (*properties_arc_for_validation).clone(),
101 table_schema: Some(table_schema_exclude_additional_columns(
102 &table_schema_for_validation,
103 )),
104 is_source_job: props.is_cdc_source_job,
105 is_backfill_table: props.is_backfill_table,
106 };
107
108 let validate_source_request_bytes =
109 env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
110
111 let validate_source_response_bytes = call_static_method!(
112 env,
113 {com.risingwave.connector.source.JniSourceValidateHandler},
114 {byte[] validate(byte[] validateSourceRequestBytes)},
115 &validate_source_request_bytes
116 )?;
117
118 let validate_source_response: ValidateSourceResponse = Message::decode(
119 risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
120 .deref(),
121 )?;
122
123 if let Some(error) = validate_source_response.error {
124 return Err(
125 anyhow!(error.error_message).context("source cannot pass validation")
126 );
127 }
128
129 Ok(())
130 })
131 })
132 .await
133 .context("failed to validate source")??;
134
135 tracing::debug!("validate cdc source properties success");
136 Ok(Self {
137 source_id,
138 worker_node_addrs: server_addrs,
139 metrics: context.metrics.clone(),
140 properties: properties_arc,
141 _phantom: PhantomData,
142 })
143 }
144
145 async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
146 Ok(self.list_cdc_splits())
147 }
148
149 async fn on_tick(&mut self) -> ConnectorResult<()> {
150 self.monitor_cdc().await
151 }
152}
153
154impl<T: CdcSourceTypeTrait> DebeziumSplitEnumerator<T> {
155 async fn monitor_postgres_confirmed_flush_lsn(&mut self) -> ConnectorResult<()> {
156 match self.query_confirmed_flush_lsn().await {
158 Ok(Some((lsn, slot_name))) => {
159 self.metrics
161 .pg_cdc_confirmed_flush_lsn
162 .with_guarded_label_values(&[
163 &self.source_id.to_string(),
164 &slot_name.to_owned(),
165 ])
166 .set(lsn as i64);
167 tracing::debug!(
168 "Updated confirm_flush_lsn for source {} slot {}: {}",
169 self.source_id,
170 slot_name,
171 lsn
172 );
173 }
174 Ok(None) => {
175 tracing::warn!("No confirmed_flush_lsn found for source {}", self.source_id);
176 }
177 Err(e) => {
178 tracing::error!(
179 "Failed to query confirmed_flush_lsn for source {}: {}",
180 self.source_id,
181 e.as_report()
182 );
183 }
184 }
185 Ok(())
186 }
187
188 async fn query_confirmed_flush_lsn(&self) -> ConnectorResult<Option<(u64, &str)>> {
190 let hostname = self
192 .properties
193 .get("hostname")
194 .ok_or_else(|| anyhow::anyhow!("hostname not found in CDC properties"))?;
195 let port = self
196 .properties
197 .get("port")
198 .ok_or_else(|| anyhow::anyhow!("port not found in CDC properties"))?;
199 let user = self
200 .properties
201 .get("username")
202 .ok_or_else(|| anyhow::anyhow!("username not found in CDC properties"))?;
203 let password = self
204 .properties
205 .get("password")
206 .ok_or_else(|| anyhow::anyhow!("password not found in CDC properties"))?;
207 let database = self
208 .properties
209 .get("database.name")
210 .ok_or_else(|| anyhow::anyhow!("database.name not found in CDC properties"))?;
211
212 let ssl_mode = SslMode::Preferred;
213 let ssl_root_cert = self.properties.get("database.ssl.root.cert").cloned();
214
215 let slot_name = self
216 .properties
217 .get("slot.name")
218 .ok_or_else(|| anyhow::anyhow!("slot.name not found in CDC properties"))?;
219
220 let client = create_pg_client(
222 user,
223 password,
224 hostname,
225 port,
226 database,
227 &ssl_mode,
228 &ssl_root_cert,
229 )
230 .await
231 .context("Failed to create PostgreSQL client")?;
232
233 let query = "SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = $1";
234 let row = client
235 .query_opt(query, &[&slot_name])
236 .await
237 .context("PostgreSQL query confirmed flush lsn error")?;
238
239 match row {
240 Some(row) => {
241 let confirm_flush_lsn: Option<PgLsn> = row.get(0);
242 if let Some(lsn) = confirm_flush_lsn {
243 Ok(Some((lsn.into(), slot_name.as_str())))
244 } else {
245 Ok(None)
246 }
247 }
248 None => {
249 tracing::warn!("No replication slot found with name: {}", slot_name);
250 Ok(None)
251 }
252 }
253 }
254}
255
256pub trait ListCdcSplits {
257 type CdcSourceType: CdcSourceTypeTrait;
258 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
260}
261
262#[async_trait]
264pub trait CdcMonitor {
265 async fn monitor_cdc(&mut self) -> ConnectorResult<()>;
266}
267
268#[async_trait]
269impl<T: CdcSourceTypeTrait> CdcMonitor for DebeziumSplitEnumerator<T> {
270 default async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
271 Ok(())
272 }
273}
274
275impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
276 type CdcSourceType = Mysql;
277
278 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
279 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
281 self.source_id.as_raw_id(),
282 None,
283 None,
284 )]
285 }
286}
287
288impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
289 type CdcSourceType = Postgres;
290
291 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
292 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
294 self.source_id.as_raw_id(),
295 None,
296 None,
297 )]
298 }
299}
300
301#[async_trait]
302impl CdcMonitor for DebeziumSplitEnumerator<Postgres> {
303 async fn monitor_cdc(&mut self) -> ConnectorResult<()> {
304 self.monitor_postgres_confirmed_flush_lsn().await?;
306 Ok(())
307 }
308}
309
310impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
311 type CdcSourceType = Citus;
312
313 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
314 self.worker_node_addrs
315 .iter()
316 .enumerate()
317 .map(|(id, addr)| {
318 DebeziumCdcSplit::<Self::CdcSourceType>::new(
319 id as u32,
320 None,
321 Some(addr.to_string()),
322 )
323 })
324 .collect_vec()
325 }
326}
327impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
328 type CdcSourceType = Mongodb;
329
330 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
331 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
333 self.source_id.as_raw_id(),
334 None,
335 None,
336 )]
337 }
338}
339
340impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
341 type CdcSourceType = SqlServer;
342
343 fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
344 vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
345 self.source_id.as_raw_id(),
346 None,
347 None,
348 )]
349 }
350}