1use std::sync::Arc;
16
17use anyhow::Context;
18use itertools::Itertools;
19use mysql_async::consts::ColumnType as MySqlColumnType;
20use mysql_async::prelude::*;
21use risingwave_common::array::arrow::IcebergArrowConvert;
22use risingwave_common::secret::LocalSecretManager;
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use risingwave_connector::source::iceberg::{
25 FileScanBackend, extract_bucket_and_file_name, get_parquet_fields, list_data_directory,
26 new_azblob_operator, new_gcs_operator, new_s3_operator,
27};
28use risingwave_pb::expr::PbTableFunction;
29pub use risingwave_pb::expr::table_function::PbType as TableFunctionType;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as TokioPgType;
32
33use super::{ErrorCode, Expr, ExprImpl, ExprRewriter, Literal, RwResult, infer_type};
34use crate::catalog::catalog_service::CatalogReadGuard;
35use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};
36use crate::catalog::root_catalog::SchemaPath;
37use crate::error::ErrorCode::BindError;
38use crate::utils::FRONTEND_RUNTIME;
39
40const INLINE_ARG_LEN: usize = 6;
41const CDC_SOURCE_ARG_LEN: usize = 2;
42
43#[derive(Clone, Eq, PartialEq, Hash)]
49pub struct TableFunction {
50 pub args: Vec<ExprImpl>,
51 pub return_type: DataType,
52 pub function_type: TableFunctionType,
53 pub user_defined: Option<Arc<FunctionCatalog>>,
55}
56
57impl TableFunction {
58 pub fn new(func_type: TableFunctionType, mut args: Vec<ExprImpl>) -> RwResult<Self> {
61 let return_type = infer_type(func_type.into(), &mut args)?;
62 Ok(TableFunction {
63 args,
64 return_type,
65 function_type: func_type,
66 user_defined: None,
67 })
68 }
69
70 pub fn new_user_defined(catalog: Arc<FunctionCatalog>, args: Vec<ExprImpl>) -> Self {
72 let FunctionKind::Table = &catalog.kind else {
73 panic!("not a table function");
74 };
75 TableFunction {
76 args,
77 return_type: catalog.return_type.clone(),
78 function_type: TableFunctionType::UserDefined,
79 user_defined: Some(catalog),
80 }
81 }
82
83 pub fn new_file_scan(mut args: Vec<ExprImpl>) -> RwResult<Self> {
86 let return_type = {
87 let mut eval_args: Vec<String> = vec![];
94 for arg in &args {
95 if arg.return_type() != DataType::Varchar {
96 return Err(BindError(
97 "file_scan function only accepts string arguments".to_owned(),
98 )
99 .into());
100 }
101 match arg.try_fold_const() {
102 Some(Ok(value)) => {
103 if value.is_none() {
104 return Err(BindError(
105 "file_scan function does not accept null arguments".to_owned(),
106 )
107 .into());
108 }
109 match value {
110 Some(ScalarImpl::Utf8(s)) => {
111 eval_args.push(s.to_string());
112 }
113 _ => {
114 return Err(BindError(
115 "file_scan function only accepts string arguments".to_owned(),
116 )
117 .into());
118 }
119 }
120 }
121 Some(Err(err)) => {
122 return Err(err);
123 }
124 None => {
125 return Err(BindError(
126 "file_scan function only accepts constant arguments".to_owned(),
127 )
128 .into());
129 }
130 }
131 }
132
133 if (eval_args.len() != 4 && eval_args.len() != 6)
134 || (eval_args.len() == 4 && !"gcs".eq_ignore_ascii_case(&eval_args[1]))
135 || (eval_args.len() == 6
136 && !"s3".eq_ignore_ascii_case(&eval_args[1])
137 && !"azblob".eq_ignore_ascii_case(&eval_args[1]))
138 {
139 return Err(BindError(
140 "file_scan function supports three backends: s3, gcs, and azblob. Their formats are as follows: \n
141 file_scan('parquet', 's3', s3_region, s3_access_key, s3_secret_key, file_location) \n
142 file_scan('parquet', 'gcs', credential, service_account, file_location) \n
143 file_scan('parquet', 'azblob', endpoint, account_name, account_key, file_location)"
144 .to_owned(),
145 )
146 .into());
147 }
148 if !"parquet".eq_ignore_ascii_case(&eval_args[0]) {
149 return Err(BindError(
150 "file_scan function only accepts 'parquet' as file format".to_owned(),
151 )
152 .into());
153 }
154
155 if !"s3".eq_ignore_ascii_case(&eval_args[1])
156 && !"gcs".eq_ignore_ascii_case(&eval_args[1])
157 && !"azblob".eq_ignore_ascii_case(&eval_args[1])
158 {
159 return Err(BindError(
160 "file_scan function only accepts 's3', 'gcs' or 'azblob' as storage type"
161 .to_owned(),
162 )
163 .into());
164 }
165
166 #[cfg(madsim)]
167 return Err(crate::error::ErrorCode::BindError(
168 "file_scan can't be used in the madsim mode".to_string(),
169 )
170 .into());
171
172 #[cfg(not(madsim))]
173 {
174 let (file_scan_backend, input_file_location) =
175 if "s3".eq_ignore_ascii_case(&eval_args[1]) {
176 (FileScanBackend::S3, eval_args[5].clone())
177 } else if "gcs".eq_ignore_ascii_case(&eval_args[1]) {
178 (FileScanBackend::Gcs, eval_args[3].clone())
179 } else if "azblob".eq_ignore_ascii_case(&eval_args[1]) {
180 (FileScanBackend::Azblob, eval_args[5].clone())
181 } else {
182 unreachable!();
183 };
184 let op = match file_scan_backend {
185 FileScanBackend::S3 => {
186 let (bucket, _) = extract_bucket_and_file_name(
187 &eval_args[5].clone(),
188 &file_scan_backend,
189 )?;
190
191 let (s3_region, s3_endpoint) = match eval_args[2].starts_with("http") {
192 true => ("us-east-1".to_owned(), eval_args[2].clone()), false => (
194 eval_args[2].clone(),
195 format!("https://{}.s3.{}.amazonaws.com", bucket, eval_args[2],),
196 ),
197 };
198 new_s3_operator(
199 s3_region.clone(),
200 eval_args[3].clone(),
201 eval_args[4].clone(),
202 bucket.clone(),
203 s3_endpoint.clone(),
204 )?
205 }
206 FileScanBackend::Gcs => {
207 let (bucket, _) =
208 extract_bucket_and_file_name(&input_file_location, &file_scan_backend)?;
209
210 new_gcs_operator(eval_args[2].clone(), bucket.clone())?
211 }
212 FileScanBackend::Azblob => {
213 let (bucket, _) =
214 extract_bucket_and_file_name(&input_file_location, &file_scan_backend)?;
215
216 new_azblob_operator(
217 eval_args[2].clone(),
218 eval_args[3].clone(),
219 eval_args[4].clone(),
220 bucket.clone(),
221 )?
222 }
223 };
224 let files = if input_file_location.ends_with('/') {
225 let files = tokio::task::block_in_place(|| {
226 FRONTEND_RUNTIME.block_on(async {
227 let files = list_data_directory(
228 op.clone(),
229 input_file_location.clone(),
230 &file_scan_backend,
231 )
232 .await?;
233
234 Ok::<Vec<String>, anyhow::Error>(files)
235 })
236 })?;
237 if files.is_empty() {
238 return Err(BindError(
239 "file_scan function only accepts non-empty directory".to_owned(),
240 )
241 .into());
242 }
243
244 Some(files)
245 } else {
246 None
247 };
248 let schema = tokio::task::block_in_place(|| {
249 FRONTEND_RUNTIME.block_on(async {
250 let location = match files.as_ref() {
251 Some(files) => files[0].clone(),
252 None => input_file_location.clone(),
253 };
254 let (_, file_name) =
255 extract_bucket_and_file_name(&location, &file_scan_backend)?;
256
257 let fields = get_parquet_fields(op, file_name).await?;
258
259 let mut rw_types = vec![];
260 for field in &fields {
261 rw_types.push((
262 field.name().to_string(),
263 IcebergArrowConvert.type_from_field(field)?,
264 ));
265 }
266
267 Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
268 StructType::new(rw_types),
269 ))
270 })
271 })?;
272
273 if let Some(files) = files {
274 match file_scan_backend {
276 FileScanBackend::S3 => args.remove(5),
277 FileScanBackend::Gcs => args.remove(3),
278 FileScanBackend::Azblob => args.remove(5),
279 };
280 for file in files {
281 args.push(ExprImpl::Literal(Box::new(Literal::new(
282 Some(ScalarImpl::Utf8(file.into())),
283 DataType::Varchar,
284 ))));
285 }
286 }
287
288 schema
289 }
290 };
291
292 Ok(TableFunction {
293 args,
294 return_type,
295 function_type: TableFunctionType::FileScan,
296 user_defined: None,
297 })
298 }
299
300 fn handle_postgres_or_mysql_query_args(
301 catalog_reader: &CatalogReadGuard,
302 db_name: &str,
303 schema_path: SchemaPath<'_>,
304 args: Vec<ExprImpl>,
305 expect_connector_name: &str,
306 ) -> RwResult<Vec<ExprImpl>> {
307 let cast_args = match args.len() {
308 INLINE_ARG_LEN => {
309 let mut cast_args = Vec::with_capacity(INLINE_ARG_LEN);
310 for arg in args {
311 let arg = arg.cast_implicit(DataType::Varchar)?;
312 cast_args.push(arg);
313 }
314 cast_args
315 }
316 CDC_SOURCE_ARG_LEN => {
317 let source_name = expr_impl_to_string_fn(&args[0])?;
318 let source_catalog = catalog_reader
319 .get_source_by_name(db_name, schema_path, &source_name)?
320 .0;
321 if !source_catalog
322 .connector_name()
323 .eq_ignore_ascii_case(expect_connector_name)
324 {
325 return Err(BindError(format!("TVF function only accepts `mysql-cdc` and `postgres-cdc` source. Expected: {}, but got: {}", expect_connector_name, source_catalog.connector_name())).into());
326 }
327
328 let (props, secret_refs) = source_catalog.with_properties.clone().into_parts();
329 let secret_resolved =
330 LocalSecretManager::global().fill_secrets(props, secret_refs)?;
331
332 vec![
333 ExprImpl::literal_varchar(secret_resolved["hostname"].clone()),
334 ExprImpl::literal_varchar(secret_resolved["port"].clone()),
335 ExprImpl::literal_varchar(secret_resolved["username"].clone()),
336 ExprImpl::literal_varchar(secret_resolved["password"].clone()),
337 ExprImpl::literal_varchar(secret_resolved["database.name"].clone()),
338 args.get(1)
339 .unwrap()
340 .clone()
341 .cast_implicit(DataType::Varchar)?,
342 ]
343 }
344 _ => {
345 return Err(BindError("postgres_query function and mysql_query function accept either 2 arguments: (cdc_source_name varchar, query varchar) or 6 arguments: (hostname varchar, port varchar, username varchar, password varchar, database_name varchar, query varchar)".to_owned()).into());
346 }
347 };
348
349 Ok(cast_args)
350 }
351
352 pub fn new_postgres_query(
353 catalog_reader: &CatalogReadGuard,
354 db_name: &str,
355 schema_path: SchemaPath<'_>,
356 args: Vec<ExprImpl>,
357 ) -> RwResult<Self> {
358 let args = Self::handle_postgres_or_mysql_query_args(
359 catalog_reader,
360 db_name,
361 schema_path,
362 args,
363 "postgres-cdc",
364 )?;
365 let evaled_args = args
366 .iter()
367 .map(expr_impl_to_string_fn)
368 .collect::<RwResult<Vec<_>>>()?;
369
370 #[cfg(madsim)]
371 {
372 return Err(crate::error::ErrorCode::BindError(
373 "postgres_query can't be used in the madsim mode".to_string(),
374 )
375 .into());
376 }
377
378 #[cfg(not(madsim))]
379 {
380 let schema = tokio::task::block_in_place(|| {
381 FRONTEND_RUNTIME.block_on(async {
382 let mut conf = tokio_postgres::Config::new();
383 let (client, connection) = conf
384 .host(&evaled_args[0])
385 .port(evaled_args[1].parse().map_err(|_| {
386 ErrorCode::InvalidParameterValue(format!(
387 "port number: {}",
388 evaled_args[1]
389 ))
390 })?)
391 .user(&evaled_args[2])
392 .password(evaled_args[3].clone())
393 .dbname(&evaled_args[4])
394 .connect(tokio_postgres::NoTls)
395 .await?;
396
397 tokio::spawn(async move {
398 if let Err(e) = connection.await {
399 tracing::error!(
400 "mysql_query_executor: connection error: {:?}",
401 e.as_report()
402 );
403 }
404 });
405
406 let statement = client.prepare(evaled_args[5].as_str()).await?;
407
408 let mut rw_types = vec![];
409 for column in statement.columns() {
410 let name = column.name().to_owned();
411 let data_type = match *column.type_() {
412 TokioPgType::BOOL => DataType::Boolean,
413 TokioPgType::INT2 => DataType::Int16,
414 TokioPgType::INT4 => DataType::Int32,
415 TokioPgType::INT8 => DataType::Int64,
416 TokioPgType::FLOAT4 => DataType::Float32,
417 TokioPgType::FLOAT8 => DataType::Float64,
418 TokioPgType::NUMERIC => DataType::Decimal,
419 TokioPgType::DATE => DataType::Date,
420 TokioPgType::TIME => DataType::Time,
421 TokioPgType::TIMESTAMP => DataType::Timestamp,
422 TokioPgType::TIMESTAMPTZ => DataType::Timestamptz,
423 TokioPgType::TEXT | TokioPgType::VARCHAR => DataType::Varchar,
424 TokioPgType::INTERVAL => DataType::Interval,
425 TokioPgType::JSONB => DataType::Jsonb,
426 TokioPgType::BYTEA => DataType::Bytea,
427 _ => {
428 return Err(crate::error::ErrorCode::BindError(format!(
429 "unsupported column type: {}",
430 column.type_()
431 ))
432 .into());
433 }
434 };
435 rw_types.push((name, data_type));
436 }
437 Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
438 StructType::new(rw_types),
439 ))
440 })
441 })?;
442
443 Ok(TableFunction {
444 args,
445 return_type: schema,
446 function_type: TableFunctionType::PostgresQuery,
447 user_defined: None,
448 })
449 }
450 }
451
452 pub fn new_mysql_query(
453 catalog_reader: &CatalogReadGuard,
454 db_name: &str,
455 schema_path: SchemaPath<'_>,
456 args: Vec<ExprImpl>,
457 ) -> RwResult<Self> {
458 let args = Self::handle_postgres_or_mysql_query_args(
459 catalog_reader,
460 db_name,
461 schema_path,
462 args,
463 "mysql-cdc",
464 )?;
465 let evaled_args = args
466 .iter()
467 .map(expr_impl_to_string_fn)
468 .collect::<RwResult<Vec<_>>>()?;
469
470 #[cfg(madsim)]
471 {
472 return Err(crate::error::ErrorCode::BindError(
473 "postgres_query can't be used in the madsim mode".to_string(),
474 )
475 .into());
476 }
477
478 #[cfg(not(madsim))]
479 {
480 let schema = tokio::task::block_in_place(|| {
481 FRONTEND_RUNTIME.block_on(async {
482 let database_opts: mysql_async::Opts = {
483 let port = evaled_args[1]
484 .parse::<u16>()
485 .context("failed to parse port")?;
486 mysql_async::OptsBuilder::default()
487 .ip_or_hostname(evaled_args[0].clone())
488 .tcp_port(port)
489 .user(Some(evaled_args[2].clone()))
490 .pass(Some(evaled_args[3].clone()))
491 .db_name(Some(evaled_args[4].clone()))
492 .into()
493 };
494
495 let pool = mysql_async::Pool::new(database_opts);
496 let mut conn = pool
497 .get_conn()
498 .await
499 .context("failed to connect to mysql in binder")?;
500
501 let query = evaled_args[5].clone();
502 let statement = conn
503 .prep(query)
504 .await
505 .context("failed to prepare mysql_query in binder")?;
506
507 let mut rw_types = vec![];
508 #[allow(clippy::never_loop)]
509 for column in statement.columns() {
510 let name = column.name_str().to_string();
511 let data_type = match column.column_type() {
512 MySqlColumnType::MYSQL_TYPE_BIT if column.column_length() == 1 => {
514 DataType::Boolean
515 }
516
517 MySqlColumnType::MYSQL_TYPE_TINY => DataType::Int16,
521 MySqlColumnType::MYSQL_TYPE_SHORT => DataType::Int16,
522 MySqlColumnType::MYSQL_TYPE_INT24 => DataType::Int32,
523 MySqlColumnType::MYSQL_TYPE_LONG => DataType::Int32,
524 MySqlColumnType::MYSQL_TYPE_LONGLONG => DataType::Int64,
525 MySqlColumnType::MYSQL_TYPE_FLOAT => DataType::Float32,
526 MySqlColumnType::MYSQL_TYPE_DOUBLE => DataType::Float64,
527 MySqlColumnType::MYSQL_TYPE_NEWDECIMAL => DataType::Decimal,
528 MySqlColumnType::MYSQL_TYPE_DECIMAL => DataType::Decimal,
529
530 MySqlColumnType::MYSQL_TYPE_YEAR => DataType::Int32,
532 MySqlColumnType::MYSQL_TYPE_DATE => DataType::Date,
533 MySqlColumnType::MYSQL_TYPE_NEWDATE => DataType::Date,
534 MySqlColumnType::MYSQL_TYPE_TIME => DataType::Time,
535 MySqlColumnType::MYSQL_TYPE_TIME2 => DataType::Time,
536 MySqlColumnType::MYSQL_TYPE_DATETIME => DataType::Timestamp,
537 MySqlColumnType::MYSQL_TYPE_DATETIME2 => DataType::Timestamp,
538 MySqlColumnType::MYSQL_TYPE_TIMESTAMP => DataType::Timestamptz,
539 MySqlColumnType::MYSQL_TYPE_TIMESTAMP2 => DataType::Timestamptz,
540
541 MySqlColumnType::MYSQL_TYPE_VARCHAR
543 | MySqlColumnType::MYSQL_TYPE_STRING
544 | MySqlColumnType::MYSQL_TYPE_VAR_STRING => DataType::Varchar,
545
546 MySqlColumnType::MYSQL_TYPE_JSON => DataType::Jsonb,
548
549 MySqlColumnType::MYSQL_TYPE_BIT
551 | MySqlColumnType::MYSQL_TYPE_BLOB
552 | MySqlColumnType::MYSQL_TYPE_TINY_BLOB
553 | MySqlColumnType::MYSQL_TYPE_MEDIUM_BLOB
554 | MySqlColumnType::MYSQL_TYPE_LONG_BLOB => DataType::Bytea,
555
556 MySqlColumnType::MYSQL_TYPE_UNKNOWN
557 | MySqlColumnType::MYSQL_TYPE_TYPED_ARRAY
558 | MySqlColumnType::MYSQL_TYPE_ENUM
559 | MySqlColumnType::MYSQL_TYPE_SET
560 | MySqlColumnType::MYSQL_TYPE_GEOMETRY
561 | MySqlColumnType::MYSQL_TYPE_NULL => {
562 return Err(crate::error::ErrorCode::BindError(format!(
563 "unsupported column type: {:?}",
564 column.column_type()
565 ))
566 .into());
567 }
568 };
569 rw_types.push((name, data_type));
570 }
571 Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
572 StructType::new(rw_types),
573 ))
574 })
575 })?;
576
577 Ok(TableFunction {
578 args,
579 return_type: schema,
580 function_type: TableFunctionType::MysqlQuery,
581 user_defined: None,
582 })
583 }
584 }
585
586 pub fn new_internal_backfill_progress() -> Self {
589 TableFunction {
590 args: vec![],
591 return_type: DataType::Struct(StructType::new(vec![
592 ("job_id".to_owned(), DataType::Int32),
593 ("fragment_id".to_owned(), DataType::Int32),
594 ("backfill_state_table_id".to_owned(), DataType::Int32),
595 ("current_row_count".to_owned(), DataType::Int64),
596 ("min_epoch".to_owned(), DataType::Int64),
597 ])),
598 function_type: TableFunctionType::InternalBackfillProgress,
599 user_defined: None,
600 }
601 }
602
603 pub fn to_protobuf(&self) -> PbTableFunction {
604 PbTableFunction {
605 function_type: self.function_type as i32,
606 args: self.args.iter().map(|c| c.to_expr_proto()).collect_vec(),
607 return_type: Some(self.return_type.to_protobuf()),
608 udf: self.user_defined.as_ref().map(|c| c.as_ref().into()),
609 }
610 }
611
612 pub fn name(&self) -> String {
614 match self.function_type {
615 TableFunctionType::UserDefined => self.user_defined.as_ref().unwrap().name.clone(),
616 t => t.as_str_name().to_lowercase(),
617 }
618 }
619
620 pub fn rewrite(self, rewriter: &mut impl ExprRewriter) -> Self {
621 Self {
622 args: self
623 .args
624 .into_iter()
625 .map(|e| rewriter.rewrite_expr(e))
626 .collect(),
627 ..self
628 }
629 }
630}
631
632impl std::fmt::Debug for TableFunction {
633 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
634 if f.alternate() {
635 f.debug_struct("FunctionCall")
636 .field("function_type", &self.function_type)
637 .field("return_type", &self.return_type)
638 .field("args", &self.args)
639 .finish()
640 } else {
641 let func_name = format!("{:?}", self.function_type);
642 let mut builder = f.debug_tuple(&func_name);
643 self.args.iter().for_each(|child| {
644 builder.field(child);
645 });
646 builder.finish()
647 }
648 }
649}
650
651impl Expr for TableFunction {
652 fn return_type(&self) -> DataType {
653 self.return_type.clone()
654 }
655
656 fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
657 unreachable!("Table function should not be converted to ExprNode")
658 }
659}
660
661fn expr_impl_to_string_fn(arg: &ExprImpl) -> RwResult<String> {
662 match arg.try_fold_const() {
663 Some(Ok(value)) => {
664 let Some(scalar) = value else {
665 return Err(BindError(
666 "postgres_query function and mysql_query function do not accept null arguments"
667 .to_owned(),
668 )
669 .into());
670 };
671 Ok(scalar.into_utf8().to_string())
672 }
673 Some(Err(err)) => Err(err),
674 None => Err(BindError(
675 "postgres_query function and mysql_query function only accept constant arguments"
676 .to_owned(),
677 )
678 .into()),
679 }
680}