risingwave_connector/source/cdc/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::marker::PhantomData;
16use std::ops::Deref;
17use std::str::FromStr;
18
19use anyhow::{Context, anyhow};
20use async_trait::async_trait;
21use itertools::Itertools;
22use prost::Message;
23use risingwave_common::util::addr::HostAddr;
24use risingwave_jni_core::call_static_method;
25use risingwave_jni_core::jvm_runtime::{JVM, execute_with_jni_env};
26use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
27
28use crate::error::ConnectorResult;
29use crate::source::cdc::{
30    CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
31    SqlServer, table_schema_exclude_additional_columns,
32};
33use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
34
35pub const DATABASE_SERVERS_KEY: &str = "database.servers";
36
37#[derive(Debug)]
38pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
39    /// The `source_id` in the catalog
40    source_id: u32,
41    worker_node_addrs: Vec<HostAddr>,
42    _phantom: PhantomData<T>,
43}
44
45#[async_trait]
46impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
47where
48    Self: ListCdcSplits<CdcSourceType = T>,
49{
50    type Properties = CdcProperties<T>;
51    type Split = DebeziumCdcSplit<T>;
52
53    async fn new(
54        props: CdcProperties<T>,
55        context: SourceEnumeratorContextRef,
56    ) -> ConnectorResult<Self> {
57        let server_addrs = props
58            .properties
59            .get(DATABASE_SERVERS_KEY)
60            .map(|s| {
61                s.split(',')
62                    .map(HostAddr::from_str)
63                    .collect::<Result<Vec<_>, _>>()
64            })
65            .transpose()?
66            .unwrap_or_default();
67
68        assert_eq!(
69            props.get_source_type_pb(),
70            SourceType::from(T::source_type())
71        );
72
73        let jvm = JVM.get_or_init()?;
74        let source_id = context.info.source_id;
75        tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
76            execute_with_jni_env(jvm, |env| {
77                let validate_source_request = ValidateSourceRequest {
78                    source_id: source_id as u64,
79                    source_type: props.get_source_type_pb() as _,
80                    properties: props.properties,
81                    table_schema: Some(table_schema_exclude_additional_columns(
82                        &props.table_schema,
83                    )),
84                    is_source_job: props.is_cdc_source_job,
85                    is_backfill_table: props.is_backfill_table,
86                };
87
88                let validate_source_request_bytes =
89                    env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
90
91                let validate_source_response_bytes = call_static_method!(
92                    env,
93                    {com.risingwave.connector.source.JniSourceValidateHandler},
94                    {byte[] validate(byte[] validateSourceRequestBytes)},
95                    &validate_source_request_bytes
96                )?;
97
98                let validate_source_response: ValidateSourceResponse = Message::decode(
99                    risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
100                        .deref(),
101                )?;
102
103                if let Some(error) = validate_source_response.error {
104                    return Err(
105                        anyhow!(error.error_message).context("source cannot pass validation")
106                    );
107                }
108
109                Ok(())
110            })
111        })
112        .await
113        .context("failed to validate source")??;
114
115        tracing::debug!("validate cdc source properties success");
116        Ok(Self {
117            source_id,
118            worker_node_addrs: server_addrs,
119            _phantom: PhantomData,
120        })
121    }
122
123    async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
124        Ok(self.list_cdc_splits())
125    }
126}
127
128pub trait ListCdcSplits {
129    type CdcSourceType: CdcSourceTypeTrait;
130    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
131}
132
133impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
134    type CdcSourceType = Mysql;
135
136    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
137        // CDC source only supports single split
138        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
139            self.source_id,
140            None,
141            None,
142        )]
143    }
144}
145
146impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
147    type CdcSourceType = Postgres;
148
149    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
150        // CDC source only supports single split
151        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
152            self.source_id,
153            None,
154            None,
155        )]
156    }
157}
158
159impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
160    type CdcSourceType = Citus;
161
162    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
163        self.worker_node_addrs
164            .iter()
165            .enumerate()
166            .map(|(id, addr)| {
167                DebeziumCdcSplit::<Self::CdcSourceType>::new(
168                    id as u32,
169                    None,
170                    Some(addr.to_string()),
171                )
172            })
173            .collect_vec()
174    }
175}
176impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
177    type CdcSourceType = Mongodb;
178
179    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
180        // CDC source only supports single split
181        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
182            self.source_id,
183            None,
184            None,
185        )]
186    }
187}
188
189impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
190    type CdcSourceType = SqlServer;
191
192    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
193        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
194            self.source_id,
195            None,
196            None,
197        )]
198    }
199}