risingwave_frontend/optimizer/rule/
table_function_to_file_scan_rule.rs1use itertools::Itertools;
16use risingwave_common::catalog::{Field, Schema};
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_connector::source::iceberg::{FileScanBackend, extract_bucket_and_file_name};
19
20use super::{BoxedRule, Rule};
21use crate::expr::{Expr, TableFunctionType};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::generic::GenericPlanRef;
24use crate::optimizer::plan_node::{LogicalFileScan, LogicalTableFunction};
25
26pub struct TableFunctionToFileScanRule {}
28impl Rule for TableFunctionToFileScanRule {
29 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
30 let logical_table_function: &LogicalTableFunction = plan.as_logical_table_function()?;
31 if logical_table_function.table_function.function_type != TableFunctionType::FileScan {
32 return None;
33 }
34 assert!(!logical_table_function.with_ordinality);
35 let table_function_return_type = logical_table_function.table_function().return_type();
36
37 if let DataType::Struct(st) = table_function_return_type.clone() {
38 let fields = st
39 .iter()
40 .map(|(name, data_type)| Field::with_name(data_type.clone(), name.to_owned()))
41 .collect_vec();
42
43 let schema = Schema::new(fields);
44
45 let mut eval_args = vec![];
46 for arg in &logical_table_function.table_function().args {
47 assert_eq!(arg.return_type(), DataType::Varchar);
48 let value = arg.try_fold_const().unwrap().unwrap();
49 match value {
50 Some(ScalarImpl::Utf8(s)) => {
51 eval_args.push(s.to_string());
52 }
53 _ => {
54 unreachable!("must be a varchar")
55 }
56 }
57 }
58 assert!("parquet".eq_ignore_ascii_case(&eval_args[0]));
59 assert!(
60 ("s3".eq_ignore_ascii_case(&eval_args[1]))
61 || "gcs".eq_ignore_ascii_case(&eval_args[1])
62 || "azblob".eq_ignore_ascii_case(&eval_args[1])
63 );
64
65 if "s3".eq_ignore_ascii_case(&eval_args[1]) {
66 let s3_access_key = eval_args[3].clone();
67 let s3_secret_key = eval_args[4].clone();
68 let file_location = eval_args[5..].iter().cloned().collect_vec();
69
70 let (bucket, _) =
71 extract_bucket_and_file_name(&file_location[0], &FileScanBackend::S3).ok()?;
72 let (s3_region, s3_endpoint) = match eval_args[2].starts_with("http") {
73 true => ("us-east-1".to_owned(), eval_args[2].clone()), false => (
75 eval_args[2].clone(),
76 format!("https://{}.s3.{}.amazonaws.com", bucket, eval_args[2],),
77 ),
78 };
79 Some(
80 LogicalFileScan::new_s3_logical_file_scan(
81 logical_table_function.ctx(),
82 schema,
83 "parquet".to_owned(),
84 "s3".to_owned(),
85 s3_region,
86 s3_access_key,
87 s3_secret_key,
88 file_location,
89 s3_endpoint,
90 )
91 .into(),
92 )
93 } else if "gcs".eq_ignore_ascii_case(&eval_args[1]) {
94 let credential = eval_args[2].clone();
95 let file_location = eval_args[3..].iter().cloned().collect_vec();
97 Some(
98 LogicalFileScan::new_gcs_logical_file_scan(
99 logical_table_function.ctx(),
100 schema,
101 "parquet".to_owned(),
102 "gcs".to_owned(),
103 credential,
104 file_location,
105 )
106 .into(),
107 )
108 } else if "azblob".eq_ignore_ascii_case(&eval_args[1]) {
109 let endpoint = eval_args[2].clone();
110 let account_name = eval_args[3].clone();
111 let account_key = eval_args[4].clone();
112 let file_location = eval_args[5..].iter().cloned().collect_vec();
114 Some(
115 LogicalFileScan::new_azblob_logical_file_scan(
116 logical_table_function.ctx(),
117 schema,
118 "parquet".to_owned(),
119 "azblob".to_owned(),
120 account_name,
121 account_key,
122 endpoint,
123 file_location,
124 )
125 .into(),
126 )
127 } else {
128 unreachable!()
129 }
130 } else {
131 unreachable!("TableFunction return type should be struct")
132 }
133 }
134}
135
136impl TableFunctionToFileScanRule {
137 pub fn create() -> BoxedRule {
138 Box::new(TableFunctionToFileScanRule {})
139 }
140}