risingwave_connector/source/kafka/enumerator/
client.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, LazyLock, Weak};
use std::time::Duration;

use anyhow::{anyhow, Context};
use async_trait::async_trait;
use moka::future::Cache as MokaCache;
use moka::ops::compute::Op;
use prometheus::core::{AtomicI64, GenericGauge};
use rdkafka::consumer::{BaseConsumer, Consumer};
use rdkafka::error::KafkaResult;
use rdkafka::{ClientConfig, Offset, TopicPartitionList};
use risingwave_common::bail;
use risingwave_common::metrics::LabelGuardedMetric;

use crate::error::{ConnectorError, ConnectorResult};
use crate::source::base::SplitEnumerator;
use crate::source::kafka::split::KafkaSplit;
use crate::source::kafka::{
    KafkaConnection, KafkaContextCommon, KafkaProperties, RwConsumerContext, KAFKA_ISOLATION_LEVEL,
};
use crate::source::SourceEnumeratorContextRef;

type KafkaClientType = BaseConsumer<RwConsumerContext>;

pub static SHARED_KAFKA_CLIENT: LazyLock<MokaCache<KafkaConnection, Weak<KafkaClientType>>> =
    LazyLock::new(|| moka::future::Cache::builder().build());

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum KafkaEnumeratorOffset {
    Earliest,
    Latest,
    Timestamp(i64),
    None,
}

pub struct KafkaSplitEnumerator {
    context: SourceEnumeratorContextRef,
    broker_address: String,
    topic: String,
    client: Arc<KafkaClientType>,
    start_offset: KafkaEnumeratorOffset,

    // maybe used in the future for batch processing
    stop_offset: KafkaEnumeratorOffset,

    sync_call_timeout: Duration,
    high_watermark_metrics: HashMap<i32, LabelGuardedMetric<GenericGauge<AtomicI64>, 2>>,
}

impl KafkaSplitEnumerator {}

#[async_trait]
impl SplitEnumerator for KafkaSplitEnumerator {
    type Properties = KafkaProperties;
    type Split = KafkaSplit;

    async fn new(
        properties: KafkaProperties,
        context: SourceEnumeratorContextRef,
    ) -> ConnectorResult<KafkaSplitEnumerator> {
        let mut config = rdkafka::ClientConfig::new();
        let common_props = &properties.common;

        let broker_address = properties.connection.brokers.clone();
        let broker_rewrite_map = properties.privatelink_common.broker_rewrite_map.clone();
        let topic = common_props.topic.clone();
        config.set("bootstrap.servers", &broker_address);
        config.set("isolation.level", KAFKA_ISOLATION_LEVEL);
        properties.connection.set_security_properties(&mut config);
        properties.set_client(&mut config);
        let mut scan_start_offset = match properties
            .scan_startup_mode
            .as_ref()
            .map(|s| s.to_lowercase())
            .as_deref()
        {
            Some("earliest") => KafkaEnumeratorOffset::Earliest,
            Some("latest") => KafkaEnumeratorOffset::Latest,
            None => KafkaEnumeratorOffset::Earliest,
            _ => bail!(
                "properties `scan_startup_mode` only supports earliest and latest or leaving it empty"
            ),
        };

        if let Some(s) = &properties.time_offset {
            let time_offset = s.parse::<i64>().map_err(|e| anyhow!(e))?;
            scan_start_offset = KafkaEnumeratorOffset::Timestamp(time_offset)
        }

        async fn build_kafka_client(
            config: &ClientConfig,
            properties: &KafkaProperties,
            rewrite_map: Option<BTreeMap<String, String>>,
        ) -> ConnectorResult<KafkaClientType> {
            let ctx_common = KafkaContextCommon::new(
                rewrite_map,
                None,
                None,
                properties.aws_auth_props.clone(),
                properties.connection.is_aws_msk_iam(),
            )
            .await?;
            let client_ctx = RwConsumerContext::new(ctx_common);
            let client: BaseConsumer<RwConsumerContext> =
                config.create_with_context(client_ctx).await?;

            // Note that before any SASL/OAUTHBEARER broker connection can succeed the application must call
            // rd_kafka_oauthbearer_set_token() once – either directly or, more typically, by invoking either
            // rd_kafka_poll(), rd_kafka_consumer_poll(), rd_kafka_queue_poll(), etc, in order to cause retrieval
            // of an initial token to occur.
            // https://docs.confluent.io/platform/current/clients/librdkafka/html/rdkafka_8h.html#a988395722598f63396d7a1bedb22adaf
            if properties.connection.is_aws_msk_iam() {
                #[cfg(not(madsim))]
                client.poll(Duration::from_secs(10)); // note: this is a blocking call
                #[cfg(madsim)]
                client.poll(Duration::from_secs(10)).await;
            }
            Ok(client)
        }

        let mut client_arc: Option<Arc<KafkaClientType>> = None;
        SHARED_KAFKA_CLIENT
            .entry_by_ref(&properties.connection)
            .and_try_compute_with::<_, _, ConnectorError>(|maybe_entry| async {
                if let Some(entry) = maybe_entry {
                    let entry_value = entry.into_value();
                    if let Some(client) = entry_value.upgrade() {
                        // return if the client is already built
                        tracing::info!("reuse existing kafka client for {}", broker_address);
                        client_arc = Some(client);
                        return Ok(Op::Nop);
                    }
                }
                let new_client_arc = Arc::new(
                    build_kafka_client(&config, &properties, broker_rewrite_map.clone()).await?,
                );
                tracing::info!("build new kafka client for {}", broker_address);
                client_arc = Some(new_client_arc.clone());
                Ok(Op::Put(Arc::downgrade(&new_client_arc)))
            })
            .await?;

        Ok(Self {
            context,
            broker_address,
            topic,
            client: client_arc.unwrap(),
            start_offset: scan_start_offset,
            stop_offset: KafkaEnumeratorOffset::None,
            sync_call_timeout: properties.common.sync_call_timeout,
            high_watermark_metrics: HashMap::new(),
        })
    }

    async fn list_splits(&mut self) -> ConnectorResult<Vec<KafkaSplit>> {
        let topic_partitions = self.fetch_topic_partition().await.with_context(|| {
            format!(
                "failed to fetch metadata from kafka ({})",
                self.broker_address
            )
        })?;

        let watermarks = self.get_watermarks(topic_partitions.as_ref()).await?;
        let mut start_offsets = self
            .fetch_start_offset(topic_partitions.as_ref(), &watermarks)
            .await?;

        let mut stop_offsets = self
            .fetch_stop_offset(topic_partitions.as_ref(), &watermarks)
            .await?;

        let ret: Vec<_> = topic_partitions
            .into_iter()
            .map(|partition| KafkaSplit {
                topic: self.topic.clone(),
                partition,
                start_offset: start_offsets.remove(&partition).unwrap(),
                stop_offset: stop_offsets.remove(&partition).unwrap(),
            })
            .collect();

        Ok(ret)
    }
}

impl KafkaSplitEnumerator {
    async fn get_watermarks(
        &mut self,
        partitions: &[i32],
    ) -> KafkaResult<HashMap<i32, (i64, i64)>> {
        let mut map = HashMap::new();
        for partition in partitions {
            let (low, high) = self
                .client
                .fetch_watermarks(self.topic.as_str(), *partition, self.sync_call_timeout)
                .await?;
            self.report_high_watermark(*partition, high);
            map.insert(*partition, (low, high));
        }
        tracing::debug!("fetch kafka watermarks: {map:?}");
        Ok(map)
    }

    pub async fn list_splits_batch(
        &mut self,
        expect_start_timestamp_millis: Option<i64>,
        expect_stop_timestamp_millis: Option<i64>,
    ) -> ConnectorResult<Vec<KafkaSplit>> {
        let topic_partitions = self.fetch_topic_partition().await.with_context(|| {
            format!(
                "failed to fetch metadata from kafka ({})",
                self.broker_address
            )
        })?;

        // here we are getting the start offset and end offset for each partition with the given
        // timestamp if the timestamp is None, we will use the low watermark and high
        // watermark as the start and end offset if the timestamp is provided, we will use
        // the watermark to narrow down the range
        let mut expect_start_offset = if let Some(ts) = expect_start_timestamp_millis {
            Some(
                self.fetch_offset_for_time(topic_partitions.as_ref(), ts)
                    .await?,
            )
        } else {
            None
        };

        let mut expect_stop_offset = if let Some(ts) = expect_stop_timestamp_millis {
            Some(
                self.fetch_offset_for_time(topic_partitions.as_ref(), ts)
                    .await?,
            )
        } else {
            None
        };

        // Watermark here has nothing to do with watermark in streaming processing. Watermark
        // here means smallest/largest offset available for reading.
        let mut watermarks = {
            let mut ret = HashMap::new();
            for partition in &topic_partitions {
                let (low, high) = self
                    .client
                    .fetch_watermarks(self.topic.as_str(), *partition, self.sync_call_timeout)
                    .await?;
                ret.insert(partition, (low - 1, high));
            }
            ret
        };

        Ok(topic_partitions
            .iter()
            .map(|partition| {
                let (low, high) = watermarks.remove(&partition).unwrap();
                let start_offset = {
                    let start = expect_start_offset
                        .as_mut()
                        .map(|m| m.remove(partition).flatten().map(|t| t-1).unwrap_or(low))
                        .unwrap_or(low);
                    i64::max(start, low)
                };
                let stop_offset = {
                    let stop = expect_stop_offset
                        .as_mut()
                        .map(|m| m.remove(partition).unwrap_or(Some(high)))
                        .unwrap_or(Some(high))
                        .unwrap_or(high);
                    i64::min(stop, high)
                };

                if start_offset > stop_offset {
                    tracing::warn!(
                        "Skipping topic {} partition {}: requested start offset {} is greater than stop offset {}",
                        self.topic,
                        partition,
                        start_offset,
                        stop_offset
                    );
                }
                KafkaSplit {
                    topic: self.topic.clone(),
                    partition: *partition,
                    start_offset: Some(start_offset),
                    stop_offset: Some(stop_offset),
                }
            })
            .collect::<Vec<KafkaSplit>>())
    }

    async fn fetch_stop_offset(
        &self,
        partitions: &[i32],
        watermarks: &HashMap<i32, (i64, i64)>,
    ) -> KafkaResult<HashMap<i32, Option<i64>>> {
        match self.stop_offset {
            KafkaEnumeratorOffset::Earliest => unreachable!(),
            KafkaEnumeratorOffset::Latest => {
                let mut map = HashMap::new();
                for partition in partitions {
                    let (_, high_watermark) = watermarks.get(partition).unwrap();
                    map.insert(*partition, Some(*high_watermark));
                }
                Ok(map)
            }
            KafkaEnumeratorOffset::Timestamp(time) => {
                self.fetch_offset_for_time(partitions, time).await
            }
            KafkaEnumeratorOffset::None => partitions
                .iter()
                .map(|partition| Ok((*partition, None)))
                .collect(),
        }
    }

    async fn fetch_start_offset(
        &self,
        partitions: &[i32],
        watermarks: &HashMap<i32, (i64, i64)>,
    ) -> KafkaResult<HashMap<i32, Option<i64>>> {
        match self.start_offset {
            KafkaEnumeratorOffset::Earliest | KafkaEnumeratorOffset::Latest => {
                let mut map = HashMap::new();
                for partition in partitions {
                    let (low_watermark, high_watermark) = watermarks.get(partition).unwrap();
                    let offset = match self.start_offset {
                        KafkaEnumeratorOffset::Earliest => low_watermark - 1,
                        KafkaEnumeratorOffset::Latest => high_watermark - 1,
                        _ => unreachable!(),
                    };
                    map.insert(*partition, Some(offset));
                }
                Ok(map)
            }
            KafkaEnumeratorOffset::Timestamp(time) => {
                self.fetch_offset_for_time(partitions, time).await
            }
            KafkaEnumeratorOffset::None => partitions
                .iter()
                .map(|partition| Ok((*partition, None)))
                .collect(),
        }
    }

    async fn fetch_offset_for_time(
        &self,
        partitions: &[i32],
        time: i64,
    ) -> KafkaResult<HashMap<i32, Option<i64>>> {
        let mut tpl = TopicPartitionList::new();

        for partition in partitions {
            tpl.add_partition_offset(self.topic.as_str(), *partition, Offset::Offset(time))?;
        }

        let offsets = self
            .client
            .offsets_for_times(tpl, self.sync_call_timeout)
            .await?;

        let mut result = HashMap::with_capacity(partitions.len());

        for elem in offsets.elements_for_topic(self.topic.as_str()) {
            match elem.offset() {
                Offset::Offset(offset) => {
                    // XXX(rc): currently in RW source, `offset` means the last consumed offset, so we need to subtract 1
                    result.insert(elem.partition(), Some(offset - 1));
                }
                _ => {
                    let (_, high_watermark) = self
                        .client
                        .fetch_watermarks(
                            self.topic.as_str(),
                            elem.partition(),
                            self.sync_call_timeout,
                        )
                        .await?;
                    result.insert(elem.partition(), Some(high_watermark));
                }
            }
        }

        Ok(result)
    }

    #[inline]
    fn report_high_watermark(&mut self, partition: i32, offset: i64) {
        let high_watermark_metrics =
            self.high_watermark_metrics
                .entry(partition)
                .or_insert_with(|| {
                    self.context
                        .metrics
                        .high_watermark
                        .with_guarded_label_values(&[
                            &self.context.info.source_id.to_string(),
                            &partition.to_string(),
                        ])
                });
        high_watermark_metrics.set(offset);
    }

    pub async fn check_reachability(&self) -> bool {
        self.client
            .fetch_metadata(Some(self.topic.as_str()), self.sync_call_timeout)
            .await
            .is_ok()
    }

    async fn fetch_topic_partition(&self) -> ConnectorResult<Vec<i32>> {
        // for now, we only support one topic
        let metadata = self
            .client
            .fetch_metadata(Some(self.topic.as_str()), self.sync_call_timeout)
            .await?;

        let topic_meta = match metadata.topics() {
            [meta] => meta,
            _ => bail!("topic {} not found", self.topic),
        };

        if topic_meta.partitions().is_empty() {
            bail!("topic {} not found", self.topic);
        }

        Ok(topic_meta
            .partitions()
            .iter()
            .map(|partition| partition.id())
            .collect())
    }
}