risingwave_frontend/expr/
table_function.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use itertools::Itertools;
19use mysql_async::consts::ColumnType as MySqlColumnType;
20use mysql_async::prelude::*;
21use risingwave_common::array::arrow::IcebergArrowConvert;
22use risingwave_common::secret::LocalSecretManager;
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use risingwave_connector::connector_common::create_pg_client;
25use risingwave_connector::source::iceberg::{
26    FileScanBackend, extract_bucket_and_file_name, get_parquet_fields, list_data_directory,
27    new_azblob_operator, new_gcs_operator, new_s3_operator,
28};
29use risingwave_pb::expr::PbTableFunction;
30pub use risingwave_pb::expr::table_function::PbType as TableFunctionType;
31use tokio_postgres::types::Type as TokioPgType;
32
33use super::{Expr, ExprImpl, ExprRewriter, Literal, RwResult, infer_type};
34use crate::catalog::catalog_service::CatalogReadGuard;
35use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};
36use crate::catalog::root_catalog::SchemaPath;
37use crate::error::ErrorCode::BindError;
38use crate::expr::reject_impure;
39use crate::utils::FRONTEND_RUNTIME;
40
41const INLINE_ARG_LEN: usize = 6;
42const CDC_SOURCE_ARG_LEN: usize = 2;
43
44/// A table function takes a row as input and returns a table. It is also known as Set-Returning
45/// Function.
46///
47/// See also [`TableFunction`](risingwave_expr::table_function::TableFunction) trait in expr crate
48/// and [`ProjectSetSelectItem`](risingwave_pb::expr::ProjectSetSelectItem).
49#[derive(Clone, Eq, PartialEq, Hash)]
50pub struct TableFunction {
51    pub args: Vec<ExprImpl>,
52    pub return_type: DataType,
53    pub function_type: TableFunctionType,
54    /// Catalog of user defined table function.
55    pub user_defined: Option<Arc<FunctionCatalog>>,
56}
57
58impl TableFunction {
59    /// Create a `TableFunction` expr with the return type inferred from `func_type` and types of
60    /// `inputs`.
61    pub fn new(func_type: TableFunctionType, mut args: Vec<ExprImpl>) -> RwResult<Self> {
62        let return_type = infer_type(func_type.into(), &mut args)?;
63        Ok(TableFunction {
64            args,
65            return_type,
66            function_type: func_type,
67            user_defined: None,
68        })
69    }
70
71    /// Create a user-defined `TableFunction`.
72    pub fn new_user_defined(catalog: Arc<FunctionCatalog>, args: Vec<ExprImpl>) -> Self {
73        let FunctionKind::Table = &catalog.kind else {
74            panic!("not a table function");
75        };
76        TableFunction {
77            args,
78            return_type: catalog.return_type.clone(),
79            function_type: TableFunctionType::UserDefined,
80            user_defined: Some(catalog),
81        }
82    }
83
84    /// A special table function which would be transformed into `LogicalFileScan` by `TableFunctionToFileScanRule` in the optimizer.
85    /// select * from `file_scan`('parquet', 's3', region, ak, sk, location)
86    pub fn new_file_scan(mut args: Vec<ExprImpl>) -> RwResult<Self> {
87        let return_type = {
88            // arguments:
89            // file format e.g. parquet
90            // storage type e.g. s3, gcs, azblob
91            // For s3: file_scan('parquet', 's3', s3_region, s3_access_key, s3_secret_key, file_location_or_directory)
92            // For gcs: file_scan('parquet', 'gcs', credential, file_location_or_directory)
93            // For azblob: file_scan('parquet', 'azblob', endpoint, account_name, account_key, file_location)
94            let mut eval_args: Vec<String> = vec![];
95            for arg in &args {
96                if arg.return_type() != DataType::Varchar {
97                    return Err(BindError(
98                        "file_scan function only accepts string arguments".to_owned(),
99                    )
100                    .into());
101                }
102                match arg.try_fold_const() {
103                    Some(Ok(value)) => {
104                        if value.is_none() {
105                            return Err(BindError(
106                                "file_scan function does not accept null arguments".to_owned(),
107                            )
108                            .into());
109                        }
110                        match value {
111                            Some(ScalarImpl::Utf8(s)) => {
112                                eval_args.push(s.to_string());
113                            }
114                            _ => {
115                                return Err(BindError(
116                                    "file_scan function only accepts string arguments".to_owned(),
117                                )
118                                .into());
119                            }
120                        }
121                    }
122                    Some(Err(err)) => {
123                        return Err(err);
124                    }
125                    None => {
126                        return Err(BindError(
127                            "file_scan function only accepts constant arguments".to_owned(),
128                        )
129                        .into());
130                    }
131                }
132            }
133
134            if (eval_args.len() != 4 && eval_args.len() != 6)
135                || (eval_args.len() == 4 && !"gcs".eq_ignore_ascii_case(&eval_args[1]))
136                || (eval_args.len() == 6
137                    && !"s3".eq_ignore_ascii_case(&eval_args[1])
138                    && !"azblob".eq_ignore_ascii_case(&eval_args[1]))
139            {
140                return Err(BindError(
141                "file_scan function supports three backends: s3, gcs, and azblob. Their formats are as follows: \n
142                    file_scan('parquet', 's3', s3_region, s3_access_key, s3_secret_key, file_location) \n
143                    file_scan('parquet', 'gcs', credential, service_account, file_location) \n
144                    file_scan('parquet', 'azblob', endpoint, account_name, account_key, file_location)"
145                        .to_owned(),
146                )
147                .into());
148            }
149            if !"parquet".eq_ignore_ascii_case(&eval_args[0]) {
150                return Err(BindError(
151                    "file_scan function only accepts 'parquet' as file format".to_owned(),
152                )
153                .into());
154            }
155
156            if !"s3".eq_ignore_ascii_case(&eval_args[1])
157                && !"gcs".eq_ignore_ascii_case(&eval_args[1])
158                && !"azblob".eq_ignore_ascii_case(&eval_args[1])
159            {
160                return Err(BindError(
161                    "file_scan function only accepts 's3', 'gcs' or 'azblob' as storage type"
162                        .to_owned(),
163                )
164                .into());
165            }
166
167            #[cfg(madsim)]
168            return Err(crate::error::ErrorCode::BindError(
169                "file_scan can't be used in the madsim mode".to_string(),
170            )
171            .into());
172
173            #[cfg(not(madsim))]
174            {
175                let (file_scan_backend, input_file_location) =
176                    if "s3".eq_ignore_ascii_case(&eval_args[1]) {
177                        (FileScanBackend::S3, eval_args[5].clone())
178                    } else if "gcs".eq_ignore_ascii_case(&eval_args[1]) {
179                        (FileScanBackend::Gcs, eval_args[3].clone())
180                    } else if "azblob".eq_ignore_ascii_case(&eval_args[1]) {
181                        (FileScanBackend::Azblob, eval_args[5].clone())
182                    } else {
183                        unreachable!();
184                    };
185                let op = match file_scan_backend {
186                    FileScanBackend::S3 => {
187                        let (bucket, _) = extract_bucket_and_file_name(
188                            &eval_args[5].clone(),
189                            &file_scan_backend,
190                        )?;
191
192                        let (s3_region, s3_endpoint) = match eval_args[2].starts_with("http") {
193                            true => ("us-east-1".to_owned(), eval_args[2].clone()), /* for minio, hard code region as not used but needed. */
194                            false => (
195                                eval_args[2].clone(),
196                                format!("https://{}.s3.{}.amazonaws.com", bucket, eval_args[2],),
197                            ),
198                        };
199                        new_s3_operator(
200                            s3_region,
201                            eval_args[3].clone(),
202                            eval_args[4].clone(),
203                            bucket,
204                            s3_endpoint,
205                        )?
206                    }
207                    FileScanBackend::Gcs => {
208                        let (bucket, _) =
209                            extract_bucket_and_file_name(&input_file_location, &file_scan_backend)?;
210
211                        new_gcs_operator(eval_args[2].clone(), bucket)?
212                    }
213                    FileScanBackend::Azblob => {
214                        let (bucket, _) =
215                            extract_bucket_and_file_name(&input_file_location, &file_scan_backend)?;
216
217                        new_azblob_operator(
218                            eval_args[2].clone(),
219                            eval_args[3].clone(),
220                            eval_args[4].clone(),
221                            bucket,
222                        )?
223                    }
224                };
225                let files = if input_file_location.ends_with('/') {
226                    let files = tokio::task::block_in_place(|| {
227                        FRONTEND_RUNTIME.block_on(async {
228                            let files = list_data_directory(
229                                op.clone(),
230                                input_file_location.clone(),
231                                &file_scan_backend,
232                            )
233                            .await?;
234
235                            Ok::<Vec<String>, anyhow::Error>(files)
236                        })
237                    })?;
238                    if files.is_empty() {
239                        return Err(BindError(
240                            "file_scan function only accepts non-empty directory".to_owned(),
241                        )
242                        .into());
243                    }
244
245                    Some(files)
246                } else {
247                    None
248                };
249                let schema = tokio::task::block_in_place(|| {
250                    FRONTEND_RUNTIME.block_on(async {
251                        let location = match files.as_ref() {
252                            Some(files) => files[0].clone(),
253                            None => input_file_location.clone(),
254                        };
255                        let (_, file_name) =
256                            extract_bucket_and_file_name(&location, &file_scan_backend)?;
257
258                        let fields = get_parquet_fields(op, file_name).await?;
259
260                        let mut rw_types = vec![];
261                        for field in &fields {
262                            rw_types.push((
263                                field.name().clone(),
264                                IcebergArrowConvert.type_from_field(field)?,
265                            ));
266                        }
267
268                        Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
269                            StructType::new(rw_types),
270                        ))
271                    })
272                })?;
273
274                if let Some(files) = files {
275                    // if the file location is a directory, we need to remove the last argument and add all files in the directory as arguments
276                    match file_scan_backend {
277                        FileScanBackend::S3 => args.remove(5),
278                        FileScanBackend::Gcs => args.remove(3),
279                        FileScanBackend::Azblob => args.remove(5),
280                    };
281                    for file in files {
282                        args.push(ExprImpl::Literal(Box::new(Literal::new(
283                            Some(ScalarImpl::Utf8(file.into())),
284                            DataType::Varchar,
285                        ))));
286                    }
287                }
288
289                schema
290            }
291        };
292
293        Ok(TableFunction {
294            args,
295            return_type,
296            function_type: TableFunctionType::FileScan,
297            user_defined: None,
298        })
299    }
300
301    fn handle_postgres_or_mysql_query_args(
302        catalog_reader: &CatalogReadGuard,
303        db_name: &str,
304        schema_path: SchemaPath<'_>,
305        args: Vec<ExprImpl>,
306        expect_connector_name: &str,
307    ) -> RwResult<Vec<ExprImpl>> {
308        let cast_args = match args.len() {
309            INLINE_ARG_LEN => {
310                let mut cast_args = Vec::with_capacity(INLINE_ARG_LEN);
311                for arg in args {
312                    let arg = arg.cast_implicit(&DataType::Varchar)?;
313                    cast_args.push(arg);
314                }
315                cast_args
316            }
317            CDC_SOURCE_ARG_LEN => {
318                let source_name = expr_impl_to_string_fn(&args[0])?;
319                let source_catalog = catalog_reader
320                    .get_source_by_name(db_name, schema_path, &source_name)?
321                    .0;
322                if !source_catalog
323                    .connector_name()
324                    .eq_ignore_ascii_case(expect_connector_name)
325                {
326                    return Err(BindError(format!("TVF function only accepts `mysql-cdc` and `postgres-cdc` source. Expected: {}, but got: {}", expect_connector_name, source_catalog.connector_name())).into());
327                }
328
329                let (props, secret_refs) = source_catalog.with_properties.clone().into_parts();
330                let secret_resolved =
331                    LocalSecretManager::global().fill_secrets(props, secret_refs)?;
332
333                let mut args_vec = vec![
334                    ExprImpl::literal_varchar(secret_resolved["hostname"].clone()),
335                    ExprImpl::literal_varchar(secret_resolved["port"].clone()),
336                    ExprImpl::literal_varchar(secret_resolved["username"].clone()),
337                    ExprImpl::literal_varchar(secret_resolved["password"].clone()),
338                    ExprImpl::literal_varchar(secret_resolved["database.name"].clone()),
339                    args.get(1)
340                        .unwrap()
341                        .clone()
342                        .cast_implicit(&DataType::Varchar)?,
343                ];
344
345                if expect_connector_name.eq_ignore_ascii_case("postgres-cdc") {
346                    args_vec.push(ExprImpl::literal_varchar(
347                        secret_resolved.get("ssl.mode").cloned().unwrap_or_default(),
348                    ));
349                    args_vec.push(ExprImpl::literal_varchar(
350                        secret_resolved
351                            .get("ssl.root.cert")
352                            .cloned()
353                            .unwrap_or_default(),
354                    ));
355                }
356
357                args_vec
358            }
359            _ => {
360                return Err(BindError("postgres_query function and mysql_query function accept either 2 arguments: (cdc_source_name varchar, query varchar) or 6 arguments: (hostname varchar, port varchar, username varchar, password varchar, database_name varchar, query varchar)".to_owned()).into());
361            }
362        };
363
364        Ok(cast_args)
365    }
366
367    pub fn new_postgres_query(
368        catalog_reader: &CatalogReadGuard,
369        db_name: &str,
370        schema_path: SchemaPath<'_>,
371        args: Vec<ExprImpl>,
372    ) -> RwResult<Self> {
373        let args = Self::handle_postgres_or_mysql_query_args(
374            catalog_reader,
375            db_name,
376            schema_path,
377            args,
378            "postgres-cdc",
379        )?;
380        let evaled_args = args
381            .iter()
382            .map(expr_impl_to_string_fn)
383            .collect::<RwResult<Vec<_>>>()?;
384
385        #[cfg(madsim)]
386        {
387            return Err(crate::error::ErrorCode::BindError(
388                "postgres_query can't be used in the madsim mode".to_string(),
389            )
390            .into());
391        }
392
393        #[cfg(not(madsim))]
394        {
395            let schema = tokio::task::block_in_place(|| {
396                FRONTEND_RUNTIME.block_on(async {
397                    let ssl_mode = evaled_args
398                        .get(6)
399                        .and_then(|s| s.parse().ok())
400                        .unwrap_or_default();
401
402                    let ssl_root_cert = evaled_args
403                        .get(7)
404                        .and_then(|s| if s.is_empty() { None } else { Some(s.clone()) });
405
406                    let client = create_pg_client(
407                        &evaled_args[2],
408                        &evaled_args[3],
409                        &evaled_args[0],
410                        &evaled_args[1],
411                        &evaled_args[4],
412                        &ssl_mode,
413                        &ssl_root_cert,
414                    )
415                    .await?;
416
417                    let statement = client.prepare(evaled_args[5].as_str()).await?;
418
419                    let mut rw_types = vec![];
420                    for column in statement.columns() {
421                        let name = column.name().to_owned();
422                        let data_type = match *column.type_() {
423                            TokioPgType::BOOL => DataType::Boolean,
424                            TokioPgType::INT2 => DataType::Int16,
425                            TokioPgType::INT4 => DataType::Int32,
426                            TokioPgType::INT8 => DataType::Int64,
427                            TokioPgType::FLOAT4 => DataType::Float32,
428                            TokioPgType::FLOAT8 => DataType::Float64,
429                            TokioPgType::NUMERIC => DataType::Decimal,
430                            TokioPgType::DATE => DataType::Date,
431                            TokioPgType::TIME => DataType::Time,
432                            TokioPgType::TIMESTAMP => DataType::Timestamp,
433                            TokioPgType::TIMESTAMPTZ => DataType::Timestamptz,
434                            TokioPgType::TEXT | TokioPgType::VARCHAR => DataType::Varchar,
435                            TokioPgType::INTERVAL => DataType::Interval,
436                            TokioPgType::JSONB => DataType::Jsonb,
437                            TokioPgType::BYTEA => DataType::Bytea,
438                            _ => {
439                                return Err(crate::error::ErrorCode::BindError(format!(
440                                    "unsupported column type: {}",
441                                    column.type_()
442                                ))
443                                .into());
444                            }
445                        };
446                        rw_types.push((name, data_type));
447                    }
448                    Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
449                        StructType::new(rw_types),
450                    ))
451                })
452            })?;
453
454            Ok(TableFunction {
455                args,
456                return_type: schema,
457                function_type: TableFunctionType::PostgresQuery,
458                user_defined: None,
459            })
460        }
461    }
462
463    pub fn new_mysql_query(
464        catalog_reader: &CatalogReadGuard,
465        db_name: &str,
466        schema_path: SchemaPath<'_>,
467        args: Vec<ExprImpl>,
468    ) -> RwResult<Self> {
469        let args = Self::handle_postgres_or_mysql_query_args(
470            catalog_reader,
471            db_name,
472            schema_path,
473            args,
474            "mysql-cdc",
475        )?;
476        let evaled_args = args
477            .iter()
478            .map(expr_impl_to_string_fn)
479            .collect::<RwResult<Vec<_>>>()?;
480
481        #[cfg(madsim)]
482        {
483            return Err(crate::error::ErrorCode::BindError(
484                "postgres_query can't be used in the madsim mode".to_string(),
485            )
486            .into());
487        }
488
489        #[cfg(not(madsim))]
490        {
491            let schema = tokio::task::block_in_place(|| {
492                FRONTEND_RUNTIME.block_on(async {
493                    let database_opts: mysql_async::Opts = {
494                        let port = evaled_args[1]
495                            .parse::<u16>()
496                            .context("failed to parse port")?;
497                        mysql_async::OptsBuilder::default()
498                            .ip_or_hostname(evaled_args[0].clone())
499                            .tcp_port(port)
500                            .user(Some(evaled_args[2].clone()))
501                            .pass(Some(evaled_args[3].clone()))
502                            .db_name(Some(evaled_args[4].clone()))
503                            .into()
504                    };
505
506                    let pool = mysql_async::Pool::new(database_opts);
507                    let mut conn = pool
508                        .get_conn()
509                        .await
510                        .context("failed to connect to mysql in binder")?;
511
512                    let query = evaled_args[5].clone();
513                    let statement = conn
514                        .prep(query)
515                        .await
516                        .context("failed to prepare mysql_query in binder")?;
517
518                    let mut rw_types = vec![];
519                    #[allow(clippy::never_loop)]
520                    for column in statement.columns() {
521                        let name = column.name_str().to_string();
522                        let data_type = match column.column_type() {
523                            // Boolean types
524                            MySqlColumnType::MYSQL_TYPE_BIT if column.column_length() == 1 => {
525                                DataType::Boolean
526                            }
527
528                            // Numeric types
529                            // NOTE(kwannoel): Although `bool/boolean` is a synonym of TINY(1) in MySQL,
530                            // we treat it as Int16 here. It is better to be straightforward in our conversion.
531                            MySqlColumnType::MYSQL_TYPE_TINY => DataType::Int16,
532                            MySqlColumnType::MYSQL_TYPE_SHORT => DataType::Int16,
533                            MySqlColumnType::MYSQL_TYPE_INT24 => DataType::Int32,
534                            MySqlColumnType::MYSQL_TYPE_LONG => DataType::Int32,
535                            MySqlColumnType::MYSQL_TYPE_LONGLONG => DataType::Int64,
536                            MySqlColumnType::MYSQL_TYPE_FLOAT => DataType::Float32,
537                            MySqlColumnType::MYSQL_TYPE_DOUBLE => DataType::Float64,
538                            MySqlColumnType::MYSQL_TYPE_NEWDECIMAL => DataType::Decimal,
539                            MySqlColumnType::MYSQL_TYPE_DECIMAL => DataType::Decimal,
540
541                            // Date time types
542                            MySqlColumnType::MYSQL_TYPE_YEAR => DataType::Int32,
543                            MySqlColumnType::MYSQL_TYPE_DATE => DataType::Date,
544                            MySqlColumnType::MYSQL_TYPE_NEWDATE => DataType::Date,
545                            MySqlColumnType::MYSQL_TYPE_TIME => DataType::Time,
546                            MySqlColumnType::MYSQL_TYPE_TIME2 => DataType::Time,
547                            MySqlColumnType::MYSQL_TYPE_DATETIME => DataType::Timestamp,
548                            MySqlColumnType::MYSQL_TYPE_DATETIME2 => DataType::Timestamp,
549                            MySqlColumnType::MYSQL_TYPE_TIMESTAMP => DataType::Timestamptz,
550                            MySqlColumnType::MYSQL_TYPE_TIMESTAMP2 => DataType::Timestamptz,
551
552                            // String types
553                            MySqlColumnType::MYSQL_TYPE_VARCHAR => DataType::Varchar,
554                            // mysql_async does not have explicit `varbinary` and `binary` types,
555                            // we need to check the `ColumnFlags` to distinguish them.
556                            MySqlColumnType::MYSQL_TYPE_STRING
557                            | MySqlColumnType::MYSQL_TYPE_VAR_STRING => {
558                                if column
559                                    .flags()
560                                    .contains(mysql_common::constants::ColumnFlags::BINARY_FLAG)
561                                {
562                                    DataType::Bytea
563                                } else {
564                                    DataType::Varchar
565                                }
566                            }
567
568                            // JSON types
569                            MySqlColumnType::MYSQL_TYPE_JSON => DataType::Jsonb,
570
571                            // Binary types
572                            MySqlColumnType::MYSQL_TYPE_BIT
573                            | MySqlColumnType::MYSQL_TYPE_BLOB
574                            | MySqlColumnType::MYSQL_TYPE_TINY_BLOB
575                            | MySqlColumnType::MYSQL_TYPE_MEDIUM_BLOB
576                            | MySqlColumnType::MYSQL_TYPE_LONG_BLOB => DataType::Bytea,
577
578                            MySqlColumnType::MYSQL_TYPE_UNKNOWN
579                            | MySqlColumnType::MYSQL_TYPE_TYPED_ARRAY
580                            | MySqlColumnType::MYSQL_TYPE_ENUM
581                            | MySqlColumnType::MYSQL_TYPE_SET
582                            | MySqlColumnType::MYSQL_TYPE_GEOMETRY
583                            | MySqlColumnType::MYSQL_TYPE_VECTOR
584                            | MySqlColumnType::MYSQL_TYPE_NULL => {
585                                return Err(crate::error::ErrorCode::BindError(format!(
586                                    "unsupported column type: {:?}",
587                                    column.column_type()
588                                ))
589                                .into());
590                            }
591                        };
592                        rw_types.push((name, data_type));
593                    }
594                    Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
595                        StructType::new(rw_types),
596                    ))
597                })
598            })?;
599
600            Ok(TableFunction {
601                args,
602                return_type: schema,
603                function_type: TableFunctionType::MysqlQuery,
604                user_defined: None,
605            })
606        }
607    }
608
609    /// This is a highly specific _internal_ table function meant to scan and aggregate
610    /// `backfill_table_id`, `row_count` for all MVs which are still being created.
611    pub fn new_internal_backfill_progress() -> Self {
612        TableFunction {
613            args: vec![],
614            return_type: DataType::Struct(StructType::new(vec![
615                ("job_id".to_owned(), DataType::Int32),
616                ("fragment_id".to_owned(), DataType::Int32),
617                ("backfill_state_table_id".to_owned(), DataType::Int32),
618                ("current_row_count".to_owned(), DataType::Int64),
619                ("min_epoch".to_owned(), DataType::Int64),
620            ])),
621            function_type: TableFunctionType::InternalBackfillProgress,
622            user_defined: None,
623        }
624    }
625
626    pub fn new_internal_source_backfill_progress() -> Self {
627        TableFunction {
628            args: vec![],
629            return_type: DataType::Struct(StructType::new(vec![
630                ("job_id".to_owned(), DataType::Int32),
631                ("fragment_id".to_owned(), DataType::Int32),
632                ("backfill_state_table_id".to_owned(), DataType::Int32),
633                ("backfill_progress".to_owned(), DataType::Jsonb),
634            ])),
635            function_type: TableFunctionType::InternalSourceBackfillProgress,
636            user_defined: None,
637        }
638    }
639
640    pub fn new_internal_get_channel_delta_stats(args: Vec<ExprImpl>) -> Self {
641        Self {
642            args,
643            return_type: DataType::Struct(StructType::new(vec![
644                ("upstream_fragment_id".to_owned(), DataType::Int32),
645                ("downstream_fragment_id".to_owned(), DataType::Int32),
646                ("backpressure_rate".to_owned(), DataType::Float64),
647                ("recv_throughput".to_owned(), DataType::Float64),
648                ("send_throughput".to_owned(), DataType::Float64),
649            ])),
650            function_type: TableFunctionType::InternalGetChannelDeltaStats,
651            user_defined: None,
652        }
653    }
654
655    pub fn to_protobuf(&self) -> PbTableFunction {
656        PbTableFunction {
657            function_type: self.function_type as i32,
658            args: self.args.iter().map(|c| c.to_expr_proto()).collect_vec(),
659            return_type: Some(self.return_type.to_protobuf()),
660            udf: self.user_defined.as_ref().map(|c| c.as_ref().into()),
661        }
662    }
663
664    /// Serialize the table function. Returns an error if this will result in an impure table
665    /// function on a retract stream, which may lead to inconsistent results.
666    pub fn to_protobuf_checked_pure(&self, retract: bool) -> crate::error::Result<PbTableFunction> {
667        if retract {
668            reject_impure(self.clone(), "table function")?;
669        }
670
671        let args = self
672            .args
673            .iter()
674            .map(|arg| arg.to_expr_proto_checked_pure(retract, "table function argument"))
675            .collect::<crate::error::Result<Vec<_>>>()?;
676
677        Ok(PbTableFunction {
678            function_type: self.function_type as i32,
679            args,
680            return_type: Some(self.return_type.to_protobuf()),
681            udf: self.user_defined.as_ref().map(|c| c.as_ref().into()),
682        })
683    }
684
685    /// Get the name of the table function.
686    pub fn name(&self) -> String {
687        match self.function_type {
688            TableFunctionType::UserDefined => self.user_defined.as_ref().unwrap().name.clone(),
689            t => t.as_str_name().to_lowercase(),
690        }
691    }
692
693    pub fn rewrite(self, rewriter: &mut impl ExprRewriter) -> Self {
694        Self {
695            args: self
696                .args
697                .into_iter()
698                .map(|e| rewriter.rewrite_expr(e))
699                .collect(),
700            ..self
701        }
702    }
703}
704
705impl std::fmt::Debug for TableFunction {
706    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
707        if f.alternate() {
708            f.debug_struct("FunctionCall")
709                .field("function_type", &self.function_type)
710                .field("return_type", &self.return_type)
711                .field("args", &self.args)
712                .finish()
713        } else {
714            let func_name = format!("{:?}", self.function_type);
715            let mut builder = f.debug_tuple(&func_name);
716            self.args.iter().for_each(|child| {
717                builder.field(child);
718            });
719            builder.finish()
720        }
721    }
722}
723
724impl Expr for TableFunction {
725    fn return_type(&self) -> DataType {
726        self.return_type.clone()
727    }
728
729    fn try_to_expr_proto(&self) -> Result<risingwave_pb::expr::ExprNode, String> {
730        Err("Table function should not be converted to ExprNode".to_owned())
731    }
732}
733
734fn expr_impl_to_string_fn(arg: &ExprImpl) -> RwResult<String> {
735    match arg.try_fold_const() {
736        Some(Ok(value)) => {
737            let Some(scalar) = value else {
738                return Err(BindError(
739                    "postgres_query function and mysql_query function do not accept null arguments"
740                        .to_owned(),
741                )
742                .into());
743            };
744            Ok(scalar.into_utf8().to_string())
745        }
746        Some(Err(err)) => Err(err),
747        None => Err(BindError(
748            "postgres_query function and mysql_query function only accept constant arguments"
749                .to_owned(),
750        )
751        .into()),
752    }
753}