risingwave_frontend/optimizer/rule/
table_function_to_file_scan_rule.rs1use itertools::Itertools;
16use risingwave_common::catalog::{Field, Schema};
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_connector::source::iceberg::{FileScanBackend, extract_bucket_and_file_name};
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{Expr, TableFunctionType};
22use crate::optimizer::plan_node::generic::GenericPlanRef;
23use crate::optimizer::plan_node::{LogicalFileScan, LogicalTableFunction};
24
25pub struct TableFunctionToFileScanRule {}
27impl Rule<Logical> for TableFunctionToFileScanRule {
28 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
29 let logical_table_function: &LogicalTableFunction = plan.as_logical_table_function()?;
30 if logical_table_function.table_function.function_type != TableFunctionType::FileScan {
31 return None;
32 }
33 assert!(!logical_table_function.with_ordinality);
34 let table_function_return_type = logical_table_function.table_function().return_type();
35
36 if let DataType::Struct(st) = table_function_return_type.clone() {
37 let fields = st
38 .iter()
39 .map(|(name, data_type)| Field::with_name(data_type.clone(), name.to_owned()))
40 .collect_vec();
41
42 let schema = Schema::new(fields);
43
44 let mut eval_args = vec![];
45 for arg in &logical_table_function.table_function().args {
46 assert_eq!(arg.return_type(), DataType::Varchar);
47 let value = arg.try_fold_const().unwrap().unwrap();
48 match value {
49 Some(ScalarImpl::Utf8(s)) => {
50 eval_args.push(s.to_string());
51 }
52 _ => {
53 unreachable!("must be a varchar")
54 }
55 }
56 }
57 assert!("parquet".eq_ignore_ascii_case(&eval_args[0]));
58 assert!(
59 ("s3".eq_ignore_ascii_case(&eval_args[1]))
60 || "gcs".eq_ignore_ascii_case(&eval_args[1])
61 || "azblob".eq_ignore_ascii_case(&eval_args[1])
62 );
63
64 if "s3".eq_ignore_ascii_case(&eval_args[1]) {
65 let s3_access_key = eval_args[3].clone();
66 let s3_secret_key = eval_args[4].clone();
67 let file_location = eval_args[5..].iter().cloned().collect_vec();
68
69 let (bucket, _) =
70 extract_bucket_and_file_name(&file_location[0], &FileScanBackend::S3).ok()?;
71 let (s3_region, s3_endpoint) = match eval_args[2].starts_with("http") {
72 true => ("us-east-1".to_owned(), eval_args[2].clone()), false => (
74 eval_args[2].clone(),
75 format!("https://{}.s3.{}.amazonaws.com", bucket, eval_args[2],),
76 ),
77 };
78 Some(
79 LogicalFileScan::new_s3_logical_file_scan(
80 logical_table_function.ctx(),
81 schema,
82 "parquet".to_owned(),
83 "s3".to_owned(),
84 s3_region,
85 s3_access_key,
86 s3_secret_key,
87 file_location,
88 s3_endpoint,
89 )
90 .into(),
91 )
92 } else if "gcs".eq_ignore_ascii_case(&eval_args[1]) {
93 let credential = eval_args[2].clone();
94 let file_location = eval_args[3..].iter().cloned().collect_vec();
96 Some(
97 LogicalFileScan::new_gcs_logical_file_scan(
98 logical_table_function.ctx(),
99 schema,
100 "parquet".to_owned(),
101 "gcs".to_owned(),
102 credential,
103 file_location,
104 )
105 .into(),
106 )
107 } else if "azblob".eq_ignore_ascii_case(&eval_args[1]) {
108 let endpoint = eval_args[2].clone();
109 let account_name = eval_args[3].clone();
110 let account_key = eval_args[4].clone();
111 let file_location = eval_args[5..].iter().cloned().collect_vec();
113 Some(
114 LogicalFileScan::new_azblob_logical_file_scan(
115 logical_table_function.ctx(),
116 schema,
117 "parquet".to_owned(),
118 "azblob".to_owned(),
119 account_name,
120 account_key,
121 endpoint,
122 file_location,
123 )
124 .into(),
125 )
126 } else {
127 unreachable!()
128 }
129 } else {
130 unreachable!("TableFunction return type should be struct")
131 }
132 }
133}
134
135impl TableFunctionToFileScanRule {
136 pub fn create() -> BoxedRule {
137 Box::new(TableFunctionToFileScanRule {})
138 }
139}