risingwave_connector/source/
mod.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod prelude {
16    // import all split enumerators
17    pub use crate::source::datagen::DatagenSplitEnumerator;
18    pub use crate::source::filesystem::LegacyS3SplitEnumerator;
19    pub use crate::source::filesystem::opendal_source::OpendalEnumerator;
20    pub use crate::source::google_pubsub::PubsubSplitEnumerator as GooglePubsubSplitEnumerator;
21    pub use crate::source::iceberg::IcebergSplitEnumerator;
22    pub use crate::source::kafka::KafkaSplitEnumerator;
23    pub use crate::source::kinesis::KinesisSplitEnumerator;
24    pub use crate::source::mqtt::MqttSplitEnumerator;
25    pub use crate::source::nats::NatsSplitEnumerator;
26    pub use crate::source::nexmark::NexmarkSplitEnumerator;
27    pub use crate::source::pulsar::PulsarSplitEnumerator;
28    pub use crate::source::test_source::TestSourceSplitEnumerator as TestSplitEnumerator;
29    pub type AzblobSplitEnumerator =
30        OpendalEnumerator<crate::source::filesystem::opendal_source::OpendalAzblob>;
31    pub type GcsSplitEnumerator =
32        OpendalEnumerator<crate::source::filesystem::opendal_source::OpendalGcs>;
33    pub type OpendalS3SplitEnumerator =
34        OpendalEnumerator<crate::source::filesystem::opendal_source::OpendalS3>;
35    pub type PosixFsSplitEnumerator =
36        OpendalEnumerator<crate::source::filesystem::opendal_source::OpendalPosixFs>;
37    pub use crate::source::cdc::enumerator::DebeziumSplitEnumerator;
38    pub use crate::source::filesystem::opendal_source::BatchPosixFsEnumerator as BatchPosixFsSplitEnumerator;
39    pub type CitusCdcSplitEnumerator = DebeziumSplitEnumerator<crate::source::cdc::Citus>;
40    pub type MongodbCdcSplitEnumerator = DebeziumSplitEnumerator<crate::source::cdc::Mongodb>;
41    pub type PostgresCdcSplitEnumerator = DebeziumSplitEnumerator<crate::source::cdc::Postgres>;
42    pub type MysqlCdcSplitEnumerator = DebeziumSplitEnumerator<crate::source::cdc::Mysql>;
43    pub type SqlServerCdcSplitEnumerator = DebeziumSplitEnumerator<crate::source::cdc::SqlServer>;
44}
45
46pub mod base;
47pub mod batch;
48pub mod cdc;
49pub mod data_gen_util;
50pub mod datagen;
51pub mod filesystem;
52pub mod google_pubsub;
53pub mod kafka;
54pub mod kinesis;
55pub mod monitor;
56pub mod mqtt;
57pub mod nats;
58pub mod nexmark;
59pub mod pulsar;
60pub mod utils;
61
62mod util;
63use std::future::IntoFuture;
64use std::time::Duration;
65
66pub use base::{UPSTREAM_SOURCE_KEY, WEBHOOK_CONNECTOR, *};
67pub use batch::BatchSourceSplitImpl;
68pub(crate) use common::*;
69use google_cloud_pubsub::subscription::Subscription;
70pub use google_pubsub::GOOGLE_PUBSUB_CONNECTOR;
71pub use kafka::KAFKA_CONNECTOR;
72pub use kinesis::KINESIS_CONNECTOR;
73pub use mqtt::MQTT_CONNECTOR;
74pub use nats::NATS_CONNECTOR;
75use utils::feature_gated_source_mod;
76
77pub use self::adbc_snowflake::ADBC_SNOWFLAKE_CONNECTOR;
78mod common;
79pub mod iceberg;
80mod manager;
81pub mod reader;
82pub mod test_source;
83feature_gated_source_mod!(adbc_snowflake, "adbc_snowflake");
84
85use async_nats::jetstream::consumer::AckPolicy as JetStreamAckPolicy;
86use async_nats::jetstream::context::Context as JetStreamContext;
87pub use manager::{SourceColumnDesc, SourceColumnType};
88use risingwave_common::array::{Array, ArrayRef};
89use risingwave_common::row::OwnedRow;
90use risingwave_pb::id::SourceId;
91use thiserror_ext::AsReport;
92pub use util::fill_adaptive_split;
93
94pub use crate::source::filesystem::LEGACY_S3_CONNECTOR;
95pub use crate::source::filesystem::opendal_source::{
96    AZBLOB_CONNECTOR, BATCH_POSIX_FS_CONNECTOR, GCS_CONNECTOR, OPENDAL_S3_CONNECTOR,
97    POSIX_FS_CONNECTOR,
98};
99pub use crate::source::nexmark::NEXMARK_CONNECTOR;
100pub use crate::source::pulsar::PULSAR_CONNECTOR;
101use crate::source::pulsar::source::reader::PULSAR_ACK_CHANNEL;
102
103pub fn should_copy_to_format_encode_options(key: &str, connector: &str) -> bool {
104    const PREFIXES: &[&str] = &[
105        "schema.registry",
106        "schema.location",
107        "message",
108        "key.message",
109        "without_header",
110        "delimiter",
111        // AwsAuthProps
112        "region",
113        "endpoint_url",
114        "access_key",
115        "secret_key",
116        "session_token",
117        "arn",
118        "external_id",
119        "profile",
120    ];
121    PREFIXES.iter().any(|prefix| key.starts_with(prefix))
122        || (key == "endpoint" && !connector.eq_ignore_ascii_case(KINESIS_CONNECTOR))
123}
124
125/// Tasks executed by `WaitCheckpointWorker`
126pub enum WaitCheckpointTask {
127    CommitCdcOffset(Option<(SplitId, String)>),
128    AckPubsubMessage(Subscription, Vec<ArrayRef>),
129    AckNatsJetStream(JetStreamContext, Vec<ArrayRef>, JetStreamAckPolicy),
130    AckPulsarMessage(Vec<(String, ArrayRef)>),
131}
132
133impl WaitCheckpointTask {
134    /// Create a fresh task for the next epoch, reusing expensive-to-create clients
135    /// (e.g. `PubSub` `Subscription`, NATS `JetStreamContext`) from the current task.
136    /// This avoids re-establishing gRPC/network connections on every checkpoint.
137    pub fn reset_for_next_epoch(&self) -> Self {
138        match self {
139            WaitCheckpointTask::CommitCdcOffset(_) => WaitCheckpointTask::CommitCdcOffset(None),
140            WaitCheckpointTask::AckPubsubMessage(subscription, _) => {
141                WaitCheckpointTask::AckPubsubMessage(subscription.clone(), vec![])
142            }
143            WaitCheckpointTask::AckNatsJetStream(context, _, ack_policy) => {
144                WaitCheckpointTask::AckNatsJetStream(context.clone(), vec![], *ack_policy)
145            }
146            WaitCheckpointTask::AckPulsarMessage(_) => WaitCheckpointTask::AckPulsarMessage(vec![]),
147        }
148    }
149
150    pub async fn run(self, source_id: SourceId, source_name: &str) {
151        self.run_with_on_commit_success(source_id, source_name, |_source_id, _offset| {
152            // Default implementation: no action on commit success
153        })
154        .await;
155    }
156
157    pub async fn run_with_on_commit_success<F>(
158        self,
159        source_id: SourceId,
160        source_name: &str,
161        mut on_commit_success: F,
162    ) where
163        F: FnMut(u64, &str),
164    {
165        use std::str::FromStr;
166        let source_id_label = source_id.to_string();
167        match self {
168            WaitCheckpointTask::CommitCdcOffset(updated_offset) => {
169                if let Some((split_id, offset)) = updated_offset {
170                    let committed_source_id: u64 = u64::from_str(split_id.as_ref()).unwrap();
171                    // notify cdc connector to commit offset
172                    match cdc::jni_source::commit_cdc_offset(committed_source_id, offset.clone()) {
173                        Ok(()) => {
174                            // Execute callback after successful commit
175                            on_commit_success(committed_source_id, &offset);
176                        }
177                        Err(e) => {
178                            tracing::error!(
179                                source_id = committed_source_id,
180                                source_name,
181                                error = %e.as_report(),
182                                "source#{committed_source_id}: failed to commit cdc offset: {offset}.",
183                            )
184                        }
185                    }
186                }
187            }
188            WaitCheckpointTask::AckPulsarMessage(ack_array) => {
189                if let Some((ack_channel_id, to_cumulative_ack)) = ack_array.last() {
190                    let encode_message_id_data = to_cumulative_ack
191                        .as_bytea()
192                        .iter()
193                        .last()
194                        .flatten()
195                        .map(|x| x.to_owned())
196                        .unwrap();
197                    if let Some(ack_tx) = PULSAR_ACK_CHANNEL.get(ack_channel_id).await {
198                        let _ = ack_tx.send(encode_message_id_data);
199                    }
200                }
201            }
202            WaitCheckpointTask::AckPubsubMessage(subscription, ack_id_arrs) => {
203                const ACK_RPC_TIMEOUT: Duration = Duration::from_secs(30);
204                async fn ack(
205                    subscription: &Subscription,
206                    ack_ids: Vec<String>,
207                    source_id_label: &str,
208                    source_name: &str,
209                ) {
210                    tracing::trace!("acking pubsub messages {:?}", ack_ids);
211                    match tokio::time::timeout(ACK_RPC_TIMEOUT, subscription.ack(ack_ids)).await {
212                        Ok(Ok(())) => {}
213                        Ok(Err(e)) => {
214                            crate::source::monitor::GLOBAL_SOURCE_METRICS
215                                .connector_ack_failure_count
216                                .with_label_values(&[source_name, "pubsub", "error"])
217                                .inc();
218                            tracing::error!(
219                                source_id = source_id_label,
220                                source_name,
221                                error = %e.as_report(),
222                                "failed to ack pubsub messages",
223                            )
224                        }
225                        Err(_) => {
226                            crate::source::monitor::GLOBAL_SOURCE_METRICS
227                                .connector_ack_failure_count
228                                .with_label_values(&[source_name, "pubsub", "timeout"])
229                                .inc();
230                            tracing::error!(
231                                source_id = source_id_label,
232                                source_name,
233                                "pubsub ack timed out after {ACK_RPC_TIMEOUT:?}",
234                            )
235                        }
236                    }
237                }
238                const MAX_ACK_BATCH_SIZE: usize = 1000;
239                let mut ack_ids: Vec<String> = vec![];
240                for arr in ack_id_arrs {
241                    for ack_id in arr.as_utf8().iter().flatten() {
242                        ack_ids.push(ack_id.to_owned());
243                        if ack_ids.len() >= MAX_ACK_BATCH_SIZE {
244                            ack(
245                                &subscription,
246                                std::mem::take(&mut ack_ids),
247                                &source_id_label,
248                                source_name,
249                            )
250                            .await;
251                        }
252                    }
253                }
254                ack(&subscription, ack_ids, &source_id_label, source_name).await;
255            }
256            WaitCheckpointTask::AckNatsJetStream(
257                ref context,
258                reply_subjects_arrs,
259                ref ack_policy,
260            ) => {
261                const ACK_RPC_TIMEOUT: Duration = Duration::from_secs(30);
262                async fn ack(
263                    context: &JetStreamContext,
264                    reply_subject: String,
265                    source_id_label: &str,
266                    source_name: &str,
267                ) {
268                    let fut = async {
269                        let ack_future = context
270                            .publish(reply_subject.clone(), "+ACK".into())
271                            .await
272                            .map_err(|e| e.to_report_string())?;
273                        ack_future
274                            .into_future()
275                            .await
276                            .map_err(|e| e.to_report_string())?;
277                        Ok::<(), String>(())
278                    };
279                    match tokio::time::timeout(ACK_RPC_TIMEOUT, fut).await {
280                        Ok(Ok(())) => {}
281                        Ok(Err(e)) => {
282                            crate::source::monitor::GLOBAL_SOURCE_METRICS
283                                .connector_ack_failure_count
284                                .with_label_values(&[source_name, "nats_jetstream", "error"])
285                                .inc();
286                            tracing::error!(
287                                source_id = source_id_label,
288                                source_name,
289                                error = %e,
290                                subject = ?reply_subject,
291                                "failed to ack NATS JetStream message",
292                            );
293                        }
294                        Err(_) => {
295                            crate::source::monitor::GLOBAL_SOURCE_METRICS
296                                .connector_ack_failure_count
297                                .with_label_values(&[source_name, "nats_jetstream", "timeout"])
298                                .inc();
299                            tracing::error!(
300                                source_id = source_id_label,
301                                source_name,
302                                subject = ?reply_subject,
303                                "NATS JetStream ack timed out after {ACK_RPC_TIMEOUT:?}",
304                            );
305                        }
306                    }
307                }
308
309                let reply_subjects = reply_subjects_arrs
310                    .iter()
311                    .flat_map(|arr| {
312                        arr.as_utf8()
313                            .iter()
314                            .flatten()
315                            .map(|s| s.to_owned())
316                            .collect::<Vec<String>>()
317                    })
318                    .collect::<Vec<String>>();
319
320                match ack_policy {
321                    JetStreamAckPolicy::None => (),
322                    JetStreamAckPolicy::Explicit => {
323                        for reply_subject in reply_subjects {
324                            if reply_subject.is_empty() {
325                                continue;
326                            }
327                            ack(context, reply_subject, &source_id_label, source_name).await;
328                        }
329                    }
330                    JetStreamAckPolicy::All => {
331                        if let Some(reply_subject) = reply_subjects.last() {
332                            ack(
333                                context,
334                                reply_subject.clone(),
335                                &source_id_label,
336                                source_name,
337                            )
338                            .await;
339                        }
340                    }
341                }
342            }
343        }
344    }
345}
346
347#[derive(Clone, Debug, PartialEq)]
348pub struct CdcTableSnapshotSplitCommon<T: Clone> {
349    pub split_id: i64,
350    pub left_bound_inclusive: T,
351    pub right_bound_exclusive: T,
352}
353
354pub type CdcTableSnapshotSplit = CdcTableSnapshotSplitCommon<OwnedRow>;
355pub type CdcTableSnapshotSplitRaw = CdcTableSnapshotSplitCommon<Vec<u8>>;
356
357#[inline]
358pub fn build_pulsar_ack_channel_id(source_id: SourceId, split_id: &SplitId) -> String {
359    format!("{}-{}", source_id, split_id)
360}