risingwave_connector/
aws_utils.rs1use std::collections::HashMap;
16use std::time::Duration;
17
18use anyhow::{Context, anyhow};
19use aws_config::timeout::TimeoutConfig;
20use aws_sdk_s3::{client as s3_client, config as s3_config};
21use url::Url;
22
23use crate::connector_common::AwsAuthProps;
24use crate::error::ConnectorResult;
25
26const AWS_CUSTOM_CONFIG_KEY: [&str; 3] = ["retry_times", "conn_timeout", "read_timeout"];
27
28pub fn default_conn_config() -> HashMap<String, u64> {
29 let mut default_conn_config = HashMap::new();
30 default_conn_config.insert("retry_times".to_owned(), 3_u64);
31 default_conn_config.insert("conn_timeout".to_owned(), 3_u64);
32 default_conn_config.insert("read_timeout".to_owned(), 5_u64);
33 default_conn_config
34}
35
36#[derive(Clone, Debug, Eq, PartialEq)]
37pub struct AwsCustomConfig {
38 pub read_timeout: Duration,
39 pub conn_timeout: Duration,
40 pub retry_times: u32,
41}
42
43impl Default for AwsCustomConfig {
44 fn default() -> Self {
45 let map = default_conn_config();
46 AwsCustomConfig::from(map)
47 }
48}
49
50impl From<HashMap<String, u64>> for AwsCustomConfig {
51 fn from(input_config: HashMap<String, u64>) -> Self {
52 let mut config = AwsCustomConfig {
53 read_timeout: Duration::from_secs(3),
54 conn_timeout: Duration::from_secs(3),
55 retry_times: 0,
56 };
57 for key in AWS_CUSTOM_CONFIG_KEY {
58 let value = input_config.get(key);
59 if let Some(config_value) = value {
60 match key {
61 "retry_times" => {
62 config.retry_times = *config_value as u32;
63 }
64 "conn_timeout" => {
65 config.conn_timeout = Duration::from_secs(*config_value);
66 }
67 "read_timeout" => {
68 config.read_timeout = Duration::from_secs(*config_value);
69 }
70 _ => {
71 unreachable!()
72 }
73 }
74 } else {
75 continue;
76 }
77 }
78 config
79 }
80}
81
82pub fn s3_client(
83 sdk_config: &aws_types::SdkConfig,
84 config_pairs: Option<HashMap<String, u64>>,
85) -> aws_sdk_s3::Client {
86 let s3_config_obj = if let Some(config) = config_pairs {
87 let s3_config = AwsCustomConfig::from(config);
88 let retry_conf =
89 aws_config::retry::RetryConfig::standard().with_max_attempts(s3_config.retry_times);
90 let timeout_conf = TimeoutConfig::builder()
91 .connect_timeout(s3_config.conn_timeout)
92 .read_timeout(s3_config.read_timeout)
93 .build();
94
95 s3_config::Builder::from(&sdk_config.clone())
96 .retry_config(retry_conf)
97 .timeout_config(timeout_conf)
98 .force_path_style(true)
99 .build()
100 } else {
101 s3_config::Config::new(sdk_config)
102 };
103 s3_client::Client::from_conf(s3_config_obj)
104}
105
106pub async fn load_file_descriptor_from_s3(
108 location: &Url,
109 config: &AwsAuthProps,
110) -> ConnectorResult<Vec<u8>> {
111 let bucket = location
112 .domain()
113 .with_context(|| format!("illegal file path {}", location))?;
114 let key = location
115 .path()
116 .strip_prefix('/')
117 .ok_or_else(|| anyhow!("s3 url {location} should have a '/' at the start of path."))?;
118 let sdk_config = config.build_config().await?;
119 let s3_client = s3_client(&sdk_config, Some(default_conn_config()));
120 let response = s3_client
121 .get_object()
122 .bucket(bucket.to_owned())
123 .key(key)
124 .send()
125 .await
126 .with_context(|| format!("failed to get file from s3 at `{}`", location))?;
127
128 let body = response
129 .body
130 .collect()
131 .await
132 .with_context(|| format!("failed to read file from s3 at `{}`", location))?;
133 Ok(body.into_bytes().to_vec())
134}