risingwave_connector/source/cdc/enumerator/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::marker::PhantomData;
16use std::ops::Deref;
17use std::str::FromStr;
18
19use anyhow::{Context, anyhow};
20use async_trait::async_trait;
21use itertools::Itertools;
22use prost::Message;
23use risingwave_common::global_jvm::JVM;
24use risingwave_common::util::addr::HostAddr;
25use risingwave_jni_core::call_static_method;
26use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
27use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};
28
29use crate::error::ConnectorResult;
30use crate::source::cdc::{
31    CdcProperties, CdcSourceTypeTrait, Citus, DebeziumCdcSplit, Mongodb, Mysql, Postgres,
32    SqlServer, table_schema_exclude_additional_columns,
33};
34use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
35
36pub const DATABASE_SERVERS_KEY: &str = "database.servers";
37
38#[derive(Debug)]
39pub struct DebeziumSplitEnumerator<T: CdcSourceTypeTrait> {
40    /// The `source_id` in the catalog
41    source_id: u32,
42    worker_node_addrs: Vec<HostAddr>,
43    _phantom: PhantomData<T>,
44}
45
46#[async_trait]
47impl<T: CdcSourceTypeTrait> SplitEnumerator for DebeziumSplitEnumerator<T>
48where
49    Self: ListCdcSplits<CdcSourceType = T>,
50{
51    type Properties = CdcProperties<T>;
52    type Split = DebeziumCdcSplit<T>;
53
54    async fn new(
55        props: CdcProperties<T>,
56        context: SourceEnumeratorContextRef,
57    ) -> ConnectorResult<Self> {
58        let server_addrs = props
59            .properties
60            .get(DATABASE_SERVERS_KEY)
61            .map(|s| {
62                s.split(',')
63                    .map(HostAddr::from_str)
64                    .collect::<Result<Vec<_>, _>>()
65            })
66            .transpose()?
67            .unwrap_or_default();
68
69        assert_eq!(
70            props.get_source_type_pb(),
71            SourceType::from(T::source_type())
72        );
73
74        let jvm = JVM.get_or_init()?;
75        let source_id = context.info.source_id;
76        tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
77            execute_with_jni_env(jvm, |env| {
78                let validate_source_request = ValidateSourceRequest {
79                    source_id: source_id as u64,
80                    source_type: props.get_source_type_pb() as _,
81                    properties: props.properties,
82                    table_schema: Some(table_schema_exclude_additional_columns(
83                        &props.table_schema,
84                    )),
85                    is_source_job: props.is_cdc_source_job,
86                    is_backfill_table: props.is_backfill_table,
87                };
88
89                let validate_source_request_bytes =
90                    env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;
91
92                let validate_source_response_bytes = call_static_method!(
93                    env,
94                    {com.risingwave.connector.source.JniSourceValidateHandler},
95                    {byte[] validate(byte[] validateSourceRequestBytes)},
96                    &validate_source_request_bytes
97                )?;
98
99                let validate_source_response: ValidateSourceResponse = Message::decode(
100                    risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
101                        .deref(),
102                )?;
103
104                if let Some(error) = validate_source_response.error {
105                    return Err(
106                        anyhow!(error.error_message).context("source cannot pass validation")
107                    );
108                }
109
110                Ok(())
111            })
112        })
113        .await
114        .context("failed to validate source")??;
115
116        tracing::debug!("validate cdc source properties success");
117        Ok(Self {
118            source_id,
119            worker_node_addrs: server_addrs,
120            _phantom: PhantomData,
121        })
122    }
123
124    async fn list_splits(&mut self) -> ConnectorResult<Vec<DebeziumCdcSplit<T>>> {
125        Ok(self.list_cdc_splits())
126    }
127}
128
129pub trait ListCdcSplits {
130    type CdcSourceType: CdcSourceTypeTrait;
131    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>>;
132}
133
134impl ListCdcSplits for DebeziumSplitEnumerator<Mysql> {
135    type CdcSourceType = Mysql;
136
137    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
138        // CDC source only supports single split
139        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
140            self.source_id,
141            None,
142            None,
143        )]
144    }
145}
146
147impl ListCdcSplits for DebeziumSplitEnumerator<Postgres> {
148    type CdcSourceType = Postgres;
149
150    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
151        // CDC source only supports single split
152        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
153            self.source_id,
154            None,
155            None,
156        )]
157    }
158}
159
160impl ListCdcSplits for DebeziumSplitEnumerator<Citus> {
161    type CdcSourceType = Citus;
162
163    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
164        self.worker_node_addrs
165            .iter()
166            .enumerate()
167            .map(|(id, addr)| {
168                DebeziumCdcSplit::<Self::CdcSourceType>::new(
169                    id as u32,
170                    None,
171                    Some(addr.to_string()),
172                )
173            })
174            .collect_vec()
175    }
176}
177impl ListCdcSplits for DebeziumSplitEnumerator<Mongodb> {
178    type CdcSourceType = Mongodb;
179
180    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
181        // CDC source only supports single split
182        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
183            self.source_id,
184            None,
185            None,
186        )]
187    }
188}
189
190impl ListCdcSplits for DebeziumSplitEnumerator<SqlServer> {
191    type CdcSourceType = SqlServer;
192
193    fn list_cdc_splits(&mut self) -> Vec<DebeziumCdcSplit<Self::CdcSourceType>> {
194        vec![DebeziumCdcSplit::<Self::CdcSourceType>::new(
195            self.source_id,
196            None,
197            None,
198        )]
199    }
200}