risingwave_frontend/optimizer/rule/
table_function_to_file_scan_rule.rsuse itertools::Itertools;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_connector::source::iceberg::{extract_bucket_and_file_name, FileScanBackend};
use super::{BoxedRule, Rule};
use crate::expr::{Expr, TableFunctionType};
use crate::optimizer::plan_node::generic::GenericPlanRef;
use crate::optimizer::plan_node::{LogicalFileScan, LogicalTableFunction};
use crate::optimizer::PlanRef;
pub struct TableFunctionToFileScanRule {}
impl Rule for TableFunctionToFileScanRule {
fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
let logical_table_function: &LogicalTableFunction = plan.as_logical_table_function()?;
if logical_table_function.table_function.function_type != TableFunctionType::FileScan {
return None;
}
assert!(!logical_table_function.with_ordinality);
let table_function_return_type = logical_table_function.table_function().return_type();
if let DataType::Struct(st) = table_function_return_type.clone() {
let fields = st
.iter()
.map(|(name, data_type)| Field::with_name(data_type.clone(), name.to_owned()))
.collect_vec();
let schema = Schema::new(fields);
let mut eval_args = vec![];
for arg in &logical_table_function.table_function().args {
assert_eq!(arg.return_type(), DataType::Varchar);
let value = arg.try_fold_const().unwrap().unwrap();
match value {
Some(ScalarImpl::Utf8(s)) => {
eval_args.push(s.to_string());
}
_ => {
unreachable!("must be a varchar")
}
}
}
assert!("parquet".eq_ignore_ascii_case(&eval_args[0]));
assert!(
("s3".eq_ignore_ascii_case(&eval_args[1]))
|| "gcs".eq_ignore_ascii_case(&eval_args[1])
|| "azblob".eq_ignore_ascii_case(&eval_args[1])
);
if "s3".eq_ignore_ascii_case(&eval_args[1]) {
let s3_access_key = eval_args[3].clone();
let s3_secret_key = eval_args[4].clone();
let file_location = eval_args[5..].iter().cloned().collect_vec();
let (bucket, _) =
extract_bucket_and_file_name(&file_location[0], &FileScanBackend::S3).ok()?;
let (s3_region, s3_endpoint) = match eval_args[2].starts_with("http") {
true => ("us-east-1".to_owned(), eval_args[2].clone()), false => (
eval_args[2].clone(),
format!("https://{}.s3.{}.amazonaws.com", bucket, eval_args[2],),
),
};
Some(
LogicalFileScan::new_s3_logical_file_scan(
logical_table_function.ctx(),
schema,
"parquet".to_owned(),
"s3".to_owned(),
s3_region,
s3_access_key,
s3_secret_key,
file_location,
s3_endpoint,
)
.into(),
)
} else if "gcs".eq_ignore_ascii_case(&eval_args[1]) {
let credential = eval_args[2].clone();
let file_location = eval_args[3..].iter().cloned().collect_vec();
Some(
LogicalFileScan::new_gcs_logical_file_scan(
logical_table_function.ctx(),
schema,
"parquet".to_owned(),
"gcs".to_owned(),
credential,
file_location,
)
.into(),
)
} else if "azblob".eq_ignore_ascii_case(&eval_args[1]) {
let endpoint = eval_args[2].clone();
let account_name = eval_args[3].clone();
let account_key = eval_args[4].clone();
let file_location = eval_args[5..].iter().cloned().collect_vec();
Some(
LogicalFileScan::new_azblob_logical_file_scan(
logical_table_function.ctx(),
schema,
"parquet".to_owned(),
"azblob".to_owned(),
account_name,
account_key,
endpoint,
file_location,
)
.into(),
)
} else {
unreachable!()
}
} else {
unreachable!("TableFunction return type should be struct")
}
}
}
impl TableFunctionToFileScanRule {
pub fn create() -> BoxedRule {
Box::new(TableFunctionToFileScanRule {})
}
}