risingwave_connector/source/filesystem/s3/
enumerator.rsuse anyhow::Context;
use async_trait::async_trait;
use aws_sdk_s3::client::Client;
use crate::aws_utils::{default_conn_config, s3_client};
use crate::connector_common::AwsAuthProps;
use crate::source::filesystem::file_common::FsSplit;
use crate::source::filesystem::s3::S3Properties;
use crate::source::{FsListInner, SourceEnumeratorContextRef, SplitEnumerator};
pub fn get_prefix(glob: &str) -> String {
let mut escaped = false;
let mut escaped_filter = false;
glob.chars()
.take_while(|c| match (c, &escaped) {
('*', false) => false,
('[', false) => false,
('{', false) => false,
('\\', false) => {
escaped = true;
true
}
(_, false) => true,
(_, true) => {
escaped = false;
true
}
})
.filter(|c| match (c, &escaped_filter) {
(_, true) => {
escaped_filter = false;
true
}
('\\', false) => {
escaped_filter = true;
false
}
(_, _) => true,
})
.collect()
}
#[derive(Debug, Clone)]
pub struct S3SplitEnumerator {
pub(crate) bucket_name: String,
pub(crate) prefix: Option<String>,
pub(crate) matcher: Option<glob::Pattern>,
pub(crate) client: Client,
pub(crate) next_continuation_token: Option<String>,
}
#[async_trait]
impl SplitEnumerator for S3SplitEnumerator {
type Properties = S3Properties;
type Split = FsSplit;
async fn new(
properties: Self::Properties,
_context: SourceEnumeratorContextRef,
) -> crate::error::ConnectorResult<Self> {
let config = AwsAuthProps::from(&properties);
let sdk_config = config.build_config().await?;
let s3_client = s3_client(&sdk_config, Some(default_conn_config()));
let properties = properties.common;
let (prefix, matcher) = if let Some(pattern) = properties.match_pattern.as_ref() {
let prefix = get_prefix(pattern);
let matcher = glob::Pattern::new(pattern)
.with_context(|| format!("Invalid match_pattern: {}", pattern))?;
(Some(prefix), Some(matcher))
} else {
(None, None)
};
Ok(S3SplitEnumerator {
bucket_name: properties.bucket_name,
matcher,
prefix,
client: s3_client,
next_continuation_token: None,
})
}
async fn list_splits(&mut self) -> crate::error::ConnectorResult<Vec<Self::Split>> {
let mut objects = Vec::new();
loop {
let (files, has_finished) = self.get_next_page::<FsSplit>().await?;
objects.extend(files);
if has_finished {
break;
}
}
Ok(objects)
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
#[test]
fn test_get_prefix() {
assert_eq!(&get_prefix("a/"), "a/");
assert_eq!(&get_prefix("a/**"), "a/");
assert_eq!(&get_prefix("[ab]*"), "");
assert_eq!(&get_prefix("a/{a,b}*"), "a/");
assert_eq!(&get_prefix(r"a/\{a,b}"), "a/{a,b}");
assert_eq!(&get_prefix(r"a/\[ab]"), "a/[ab]");
}
use super::*;
use crate::source::filesystem::file_common::CompressionFormat;
use crate::source::filesystem::s3::S3PropertiesCommon;
use crate::source::SourceEnumeratorContext;
#[tokio::test]
#[ignore]
async fn test_s3_split_enumerator() {
let props = S3PropertiesCommon {
region_name: "ap-southeast-1".to_owned(),
bucket_name: "mingchao-s3-source".to_owned(),
match_pattern: Some("happy[0-9].csv".to_owned()),
access: None,
secret: None,
endpoint_url: None,
compression_format: CompressionFormat::None,
};
let mut enumerator =
S3SplitEnumerator::new(props.into(), SourceEnumeratorContext::dummy().into())
.await
.unwrap();
let splits = enumerator.list_splits().await.unwrap();
let names = splits.into_iter().map(|split| split.name).collect_vec();
assert_eq!(names.len(), 2);
assert!(names.contains(&"happy1.csv".to_owned()));
assert!(names.contains(&"happy2.csv".to_owned()));
}
}