risingwave_connector/
aws_utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::time::Duration;
17
18use anyhow::{Context, anyhow};
19use aws_config::timeout::TimeoutConfig;
20use aws_sdk_s3::{client as s3_client, config as s3_config};
21use url::Url;
22
23use crate::connector_common::AwsAuthProps;
24use crate::error::ConnectorResult;
25
26const AWS_CUSTOM_CONFIG_KEY: [&str; 3] = ["retry_times", "conn_timeout", "read_timeout"];
27
28pub fn default_conn_config() -> HashMap<String, u64> {
29    let mut default_conn_config = HashMap::new();
30    default_conn_config.insert("retry_times".to_owned(), 3_u64);
31    default_conn_config.insert("conn_timeout".to_owned(), 3_u64);
32    default_conn_config.insert("read_timeout".to_owned(), 5_u64);
33    default_conn_config
34}
35
36#[derive(Clone, Debug, Eq, PartialEq)]
37pub struct AwsCustomConfig {
38    pub read_timeout: Duration,
39    pub conn_timeout: Duration,
40    pub retry_times: u32,
41}
42
43impl Default for AwsCustomConfig {
44    fn default() -> Self {
45        let map = default_conn_config();
46        AwsCustomConfig::from(map)
47    }
48}
49
50impl From<HashMap<String, u64>> for AwsCustomConfig {
51    fn from(input_config: HashMap<String, u64>) -> Self {
52        let mut config = AwsCustomConfig {
53            read_timeout: Duration::from_secs(3),
54            conn_timeout: Duration::from_secs(3),
55            retry_times: 0,
56        };
57        for key in AWS_CUSTOM_CONFIG_KEY {
58            let value = input_config.get(key);
59            if let Some(config_value) = value {
60                match key {
61                    "retry_times" => {
62                        config.retry_times = *config_value as u32;
63                    }
64                    "conn_timeout" => {
65                        config.conn_timeout = Duration::from_secs(*config_value);
66                    }
67                    "read_timeout" => {
68                        config.read_timeout = Duration::from_secs(*config_value);
69                    }
70                    _ => {
71                        unreachable!()
72                    }
73                }
74            } else {
75                continue;
76            }
77        }
78        config
79    }
80}
81
82pub fn s3_client(
83    sdk_config: &aws_types::SdkConfig,
84    config_pairs: Option<HashMap<String, u64>>,
85) -> aws_sdk_s3::Client {
86    let s3_config_obj = if let Some(config) = config_pairs {
87        let s3_config = AwsCustomConfig::from(config);
88        let retry_conf =
89            aws_config::retry::RetryConfig::standard().with_max_attempts(s3_config.retry_times);
90        let timeout_conf = TimeoutConfig::builder()
91            .connect_timeout(s3_config.conn_timeout)
92            .read_timeout(s3_config.read_timeout)
93            .build();
94
95        s3_config::Builder::from(&sdk_config.clone())
96            .retry_config(retry_conf)
97            .timeout_config(timeout_conf)
98            .force_path_style(true)
99            .build()
100    } else {
101        s3_config::Config::new(sdk_config)
102    };
103    s3_client::Client::from_conf(s3_config_obj)
104}
105
106// TODO(Tao): Probably we should never allow to use S3 URI.
107pub async fn load_file_descriptor_from_s3(
108    location: &Url,
109    config: &AwsAuthProps,
110) -> ConnectorResult<Vec<u8>> {
111    let bucket = location
112        .domain()
113        .with_context(|| format!("illegal file path {}", location))?;
114    let key = location
115        .path()
116        .strip_prefix('/')
117        .ok_or_else(|| anyhow!("s3 url {location} should have a '/' at the start of path."))?;
118    let sdk_config = config.build_config().await?;
119    let s3_client = s3_client(&sdk_config, Some(default_conn_config()));
120    let response = s3_client
121        .get_object()
122        .bucket(bucket.to_owned())
123        .key(key)
124        .send()
125        .await
126        .with_context(|| format!("failed to get file from s3 at `{}`", location))?;
127
128    let body = response
129        .body
130        .collect()
131        .await
132        .with_context(|| format!("failed to read file from s3 at `{}`", location))?;
133    Ok(body.into_bytes().to_vec())
134}