risingwave_connector/source/kinesis/enumerator/
client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::anyhow;
16use async_trait::async_trait;
17use aws_sdk_kinesis::Client as kinesis_client;
18use aws_sdk_kinesis::types::Shard;
19use risingwave_common::bail;
20
21use crate::error::ConnectorResult as Result;
22use crate::source::kinesis::split::KinesisOffset;
23use crate::source::kinesis::*;
24use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
25
26pub struct KinesisSplitEnumerator {
27    stream_name: String,
28    client: kinesis_client,
29}
30
31impl KinesisSplitEnumerator {}
32
33#[async_trait]
34impl SplitEnumerator for KinesisSplitEnumerator {
35    type Properties = KinesisProperties;
36    type Split = KinesisSplit;
37
38    async fn new(
39        properties: KinesisProperties,
40        _context: SourceEnumeratorContextRef,
41    ) -> Result<Self> {
42        let client = properties.common.build_client().await?;
43        let stream_name = properties.common.stream_name.clone();
44        Ok(Self {
45            stream_name,
46            client,
47        })
48    }
49
50    async fn list_splits(&mut self) -> Result<Vec<KinesisSplit>> {
51        let mut next_token: Option<String> = None;
52        let mut shard_collect: Vec<Shard> = Vec::new();
53
54        loop {
55            let mut req = self.client.list_shards();
56            if let Some(token) = next_token.take() {
57                req = req.next_token(token);
58            } else {
59                req = req.stream_name(&self.stream_name);
60            }
61
62            let list_shard_output = match req.send().await {
63                Ok(output) => output,
64                Err(e) => {
65                    if let Some(e_inner) = e.as_service_error()
66                        && e_inner.is_expired_next_token_exception()
67                    {
68                        tracing::info!("Kinesis ListShard token expired, retrying...");
69                        next_token = None;
70                        continue;
71                    }
72                    return Err(anyhow!(e).context("failed to list kinesis shards").into());
73                }
74            };
75            match list_shard_output.shards {
76                Some(shard) => shard_collect.extend(shard),
77                None => bail!("no shards in stream {}", &self.stream_name),
78            }
79
80            match list_shard_output.next_token {
81                Some(token) => next_token = Some(token),
82                None => break,
83            }
84        }
85        Ok(shard_collect
86            .into_iter()
87            .map(|x| KinesisSplit {
88                shard_id: x.shard_id().to_owned().into(),
89                // handle start with position in reader part
90                next_offset: KinesisOffset::None,
91                end_offset: KinesisOffset::None,
92            })
93            .collect())
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use aws_sdk_kinesis::config::Region;
100
101    use super::*;
102
103    #[tokio::test]
104    #[ignore]
105    async fn test_kinesis_split_enumerator() -> Result<()> {
106        let stream_name = "kinesis_debug".to_owned();
107        let config = aws_config::from_env()
108            .region(Region::new("cn-northwest-1"))
109            .load()
110            .await;
111        let client = aws_sdk_kinesis::Client::new(&config);
112        let mut enumerator = KinesisSplitEnumerator {
113            stream_name,
114            client,
115        };
116        let list_splits_resp = enumerator.list_splits().await?;
117        println!("{:#?}", list_splits_resp);
118        assert_eq!(list_splits_resp.len(), 4);
119        Ok(())
120    }
121}