risingwave_frontend/expr/
table_function.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use itertools::Itertools;
19use mysql_async::consts::ColumnType as MySqlColumnType;
20use mysql_async::prelude::*;
21use risingwave_common::array::arrow::IcebergArrowConvert;
22use risingwave_common::secret::LocalSecretManager;
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use risingwave_connector::connector_common::create_pg_client;
25use risingwave_connector::source::iceberg::{
26    FileScanBackend, extract_bucket_and_file_name, get_parquet_fields, list_data_directory,
27    new_azblob_operator, new_gcs_operator, new_s3_operator,
28};
29use risingwave_pb::expr::PbTableFunction;
30pub use risingwave_pb::expr::table_function::PbType as TableFunctionType;
31use tokio_postgres::types::Type as TokioPgType;
32
33use super::{Expr, ExprImpl, ExprRewriter, Literal, RwResult, infer_type};
34use crate::catalog::catalog_service::CatalogReadGuard;
35use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};
36use crate::catalog::root_catalog::SchemaPath;
37use crate::error::ErrorCode::BindError;
38use crate::expr::reject_impure;
39use crate::utils::FRONTEND_RUNTIME;
40
41const INLINE_ARG_LEN: usize = 6;
42const CDC_SOURCE_ARG_LEN: usize = 2;
43
44/// A table function takes a row as input and returns a table. It is also known as Set-Returning
45/// Function.
46///
47/// See also [`TableFunction`](risingwave_expr::table_function::TableFunction) trait in expr crate
48/// and [`ProjectSetSelectItem`](risingwave_pb::expr::ProjectSetSelectItem).
49#[derive(Clone, Eq, PartialEq, Hash)]
50pub struct TableFunction {
51    pub args: Vec<ExprImpl>,
52    pub return_type: DataType,
53    pub function_type: TableFunctionType,
54    /// Catalog of user defined table function.
55    pub user_defined: Option<Arc<FunctionCatalog>>,
56}
57
58impl TableFunction {
59    /// Create a `TableFunction` expr with the return type inferred from `func_type` and types of
60    /// `inputs`.
61    pub fn new(func_type: TableFunctionType, mut args: Vec<ExprImpl>) -> RwResult<Self> {
62        let return_type = infer_type(func_type.into(), &mut args)?;
63        Ok(TableFunction {
64            args,
65            return_type,
66            function_type: func_type,
67            user_defined: None,
68        })
69    }
70
71    /// Create a user-defined `TableFunction`.
72    pub fn new_user_defined(catalog: Arc<FunctionCatalog>, args: Vec<ExprImpl>) -> Self {
73        let FunctionKind::Table = &catalog.kind else {
74            panic!("not a table function");
75        };
76        TableFunction {
77            args,
78            return_type: catalog.return_type.clone(),
79            function_type: TableFunctionType::UserDefined,
80            user_defined: Some(catalog),
81        }
82    }
83
84    /// A special table function which would be transformed into `LogicalFileScan` by `TableFunctionToFileScanRule` in the optimizer.
85    /// select * from `file_scan`('parquet', 's3', region, ak, sk, location)
86    pub fn new_file_scan(mut args: Vec<ExprImpl>) -> RwResult<Self> {
87        let return_type = {
88            // arguments:
89            // file format e.g. parquet
90            // storage type e.g. s3, gcs, azblob
91            // For s3: file_scan('parquet', 's3', s3_region, s3_access_key, s3_secret_key, file_location_or_directory)
92            // For gcs: file_scan('parquet', 'gcs', credential, file_location_or_directory)
93            // For azblob: file_scan('parquet', 'azblob', endpoint, account_name, account_key, file_location)
94            let mut eval_args: Vec<String> = vec![];
95            for arg in &args {
96                if arg.return_type() != DataType::Varchar {
97                    return Err(BindError(
98                        "file_scan function only accepts string arguments".to_owned(),
99                    )
100                    .into());
101                }
102                match arg.try_fold_const() {
103                    Some(Ok(value)) => {
104                        if value.is_none() {
105                            return Err(BindError(
106                                "file_scan function does not accept null arguments".to_owned(),
107                            )
108                            .into());
109                        }
110                        match value {
111                            Some(ScalarImpl::Utf8(s)) => {
112                                eval_args.push(s.to_string());
113                            }
114                            _ => {
115                                return Err(BindError(
116                                    "file_scan function only accepts string arguments".to_owned(),
117                                )
118                                .into());
119                            }
120                        }
121                    }
122                    Some(Err(err)) => {
123                        return Err(err);
124                    }
125                    None => {
126                        return Err(BindError(
127                            "file_scan function only accepts constant arguments".to_owned(),
128                        )
129                        .into());
130                    }
131                }
132            }
133
134            if (eval_args.len() != 4 && eval_args.len() != 6)
135                || (eval_args.len() == 4 && !"gcs".eq_ignore_ascii_case(&eval_args[1]))
136                || (eval_args.len() == 6
137                    && !"s3".eq_ignore_ascii_case(&eval_args[1])
138                    && !"azblob".eq_ignore_ascii_case(&eval_args[1]))
139            {
140                return Err(BindError(
141                "file_scan function supports three backends: s3, gcs, and azblob. Their formats are as follows: \n
142                    file_scan('parquet', 's3', s3_region, s3_access_key, s3_secret_key, file_location) \n
143                    file_scan('parquet', 'gcs', credential, service_account, file_location) \n
144                    file_scan('parquet', 'azblob', endpoint, account_name, account_key, file_location)"
145                        .to_owned(),
146                )
147                .into());
148            }
149            if !"parquet".eq_ignore_ascii_case(&eval_args[0]) {
150                return Err(BindError(
151                    "file_scan function only accepts 'parquet' as file format".to_owned(),
152                )
153                .into());
154            }
155
156            if !"s3".eq_ignore_ascii_case(&eval_args[1])
157                && !"gcs".eq_ignore_ascii_case(&eval_args[1])
158                && !"azblob".eq_ignore_ascii_case(&eval_args[1])
159            {
160                return Err(BindError(
161                    "file_scan function only accepts 's3', 'gcs' or 'azblob' as storage type"
162                        .to_owned(),
163                )
164                .into());
165            }
166
167            #[cfg(madsim)]
168            return Err(crate::error::ErrorCode::BindError(
169                "file_scan can't be used in the madsim mode".to_string(),
170            )
171            .into());
172
173            #[cfg(not(madsim))]
174            {
175                let (file_scan_backend, input_file_location) =
176                    if "s3".eq_ignore_ascii_case(&eval_args[1]) {
177                        (FileScanBackend::S3, eval_args[5].clone())
178                    } else if "gcs".eq_ignore_ascii_case(&eval_args[1]) {
179                        (FileScanBackend::Gcs, eval_args[3].clone())
180                    } else if "azblob".eq_ignore_ascii_case(&eval_args[1]) {
181                        (FileScanBackend::Azblob, eval_args[5].clone())
182                    } else {
183                        unreachable!();
184                    };
185                let op = match file_scan_backend {
186                    FileScanBackend::S3 => {
187                        let (bucket, _) = extract_bucket_and_file_name(
188                            &eval_args[5].clone(),
189                            &file_scan_backend,
190                        )?;
191
192                        let (s3_region, s3_endpoint) = match eval_args[2].starts_with("http") {
193                            true => ("us-east-1".to_owned(), eval_args[2].clone()), /* for minio, hard code region as not used but needed. */
194                            false => (
195                                eval_args[2].clone(),
196                                format!("https://{}.s3.{}.amazonaws.com", bucket, eval_args[2],),
197                            ),
198                        };
199                        new_s3_operator(
200                            s3_region,
201                            eval_args[3].clone(),
202                            eval_args[4].clone(),
203                            bucket,
204                            s3_endpoint,
205                        )?
206                    }
207                    FileScanBackend::Gcs => {
208                        let (bucket, _) =
209                            extract_bucket_and_file_name(&input_file_location, &file_scan_backend)?;
210
211                        new_gcs_operator(eval_args[2].clone(), bucket)?
212                    }
213                    FileScanBackend::Azblob => {
214                        let (bucket, _) =
215                            extract_bucket_and_file_name(&input_file_location, &file_scan_backend)?;
216
217                        new_azblob_operator(
218                            eval_args[2].clone(),
219                            eval_args[3].clone(),
220                            eval_args[4].clone(),
221                            bucket,
222                        )?
223                    }
224                };
225                let files = if input_file_location.ends_with('/') {
226                    let files = tokio::task::block_in_place(|| {
227                        FRONTEND_RUNTIME.block_on(async {
228                            let files = list_data_directory(
229                                op.clone(),
230                                input_file_location.clone(),
231                                &file_scan_backend,
232                            )
233                            .await?;
234
235                            Ok::<Vec<String>, anyhow::Error>(files)
236                        })
237                    })?;
238                    if files.is_empty() {
239                        return Err(BindError(
240                            "file_scan function only accepts non-empty directory".to_owned(),
241                        )
242                        .into());
243                    }
244
245                    Some(files)
246                } else {
247                    None
248                };
249                let schema = tokio::task::block_in_place(|| {
250                    FRONTEND_RUNTIME.block_on(async {
251                        let location = match files.as_ref() {
252                            Some(files) => files[0].clone(),
253                            None => input_file_location.clone(),
254                        };
255                        let (_, file_name) =
256                            extract_bucket_and_file_name(&location, &file_scan_backend)?;
257
258                        let fields = get_parquet_fields(op, file_name).await?;
259
260                        let mut rw_types = vec![];
261                        for field in &fields {
262                            rw_types.push((
263                                field.name().clone(),
264                                IcebergArrowConvert.type_from_field(field)?,
265                            ));
266                        }
267
268                        Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
269                            StructType::new(rw_types),
270                        ))
271                    })
272                })?;
273
274                if let Some(files) = files {
275                    // if the file location is a directory, we need to remove the last argument and add all files in the directory as arguments
276                    match file_scan_backend {
277                        FileScanBackend::S3 => args.remove(5),
278                        FileScanBackend::Gcs => args.remove(3),
279                        FileScanBackend::Azblob => args.remove(5),
280                    };
281                    for file in files {
282                        args.push(ExprImpl::Literal(Box::new(Literal::new(
283                            Some(ScalarImpl::Utf8(file.into())),
284                            DataType::Varchar,
285                        ))));
286                    }
287                }
288
289                schema
290            }
291        };
292
293        Ok(TableFunction {
294            args,
295            return_type,
296            function_type: TableFunctionType::FileScan,
297            user_defined: None,
298        })
299    }
300
301    fn handle_postgres_or_mysql_query_args(
302        catalog_reader: &CatalogReadGuard,
303        db_name: &str,
304        schema_path: SchemaPath<'_>,
305        args: Vec<ExprImpl>,
306        expect_connector_name: &str,
307    ) -> RwResult<Vec<ExprImpl>> {
308        let cast_args = match args.len() {
309            INLINE_ARG_LEN => {
310                let mut cast_args = Vec::with_capacity(INLINE_ARG_LEN);
311                for arg in args {
312                    let arg = arg.cast_implicit(&DataType::Varchar)?;
313                    cast_args.push(arg);
314                }
315                cast_args
316            }
317            CDC_SOURCE_ARG_LEN => {
318                let source_name = expr_impl_to_string_fn(&args[0])?;
319                let source_catalog = catalog_reader
320                    .get_source_by_name(db_name, schema_path, &source_name)?
321                    .0;
322                if !source_catalog
323                    .connector_name()
324                    .eq_ignore_ascii_case(expect_connector_name)
325                {
326                    return Err(BindError(format!("TVF function only accepts `mysql-cdc` and `postgres-cdc` source. Expected: {}, but got: {}", expect_connector_name, source_catalog.connector_name())).into());
327                }
328
329                let (props, secret_refs) = source_catalog.with_properties.clone().into_parts();
330                let secret_resolved =
331                    LocalSecretManager::global().fill_secrets(props, secret_refs)?;
332
333                let mut args_vec = vec![
334                    ExprImpl::literal_varchar(secret_resolved["hostname"].clone()),
335                    ExprImpl::literal_varchar(secret_resolved["port"].clone()),
336                    ExprImpl::literal_varchar(secret_resolved["username"].clone()),
337                    ExprImpl::literal_varchar(secret_resolved["password"].clone()),
338                    ExprImpl::literal_varchar(secret_resolved["database.name"].clone()),
339                    args.get(1)
340                        .unwrap()
341                        .clone()
342                        .cast_implicit(&DataType::Varchar)?,
343                ];
344
345                if expect_connector_name.eq_ignore_ascii_case("postgres-cdc") {
346                    args_vec.push(ExprImpl::literal_varchar(
347                        secret_resolved.get("ssl.mode").cloned().unwrap_or_default(),
348                    ));
349                    args_vec.push(ExprImpl::literal_varchar(
350                        secret_resolved
351                            .get("ssl.root.cert")
352                            .cloned()
353                            .unwrap_or_default(),
354                    ));
355                }
356
357                args_vec
358            }
359            _ => {
360                return Err(BindError("postgres_query function and mysql_query function accept either 2 arguments: (cdc_source_name varchar, query varchar) or 6 arguments: (hostname varchar, port varchar, username varchar, password varchar, database_name varchar, query varchar)".to_owned()).into());
361            }
362        };
363
364        Ok(cast_args)
365    }
366
367    pub fn new_postgres_query(
368        catalog_reader: &CatalogReadGuard,
369        db_name: &str,
370        schema_path: SchemaPath<'_>,
371        args: Vec<ExprImpl>,
372    ) -> RwResult<Self> {
373        let args = Self::handle_postgres_or_mysql_query_args(
374            catalog_reader,
375            db_name,
376            schema_path,
377            args,
378            "postgres-cdc",
379        )?;
380        let evaled_args = args
381            .iter()
382            .map(expr_impl_to_string_fn)
383            .collect::<RwResult<Vec<_>>>()?;
384
385        #[cfg(madsim)]
386        {
387            return Err(crate::error::ErrorCode::BindError(
388                "postgres_query can't be used in the madsim mode".to_string(),
389            )
390            .into());
391        }
392
393        #[cfg(not(madsim))]
394        {
395            let schema = tokio::task::block_in_place(|| {
396                FRONTEND_RUNTIME.block_on(async {
397                    let ssl_mode = evaled_args
398                        .get(6)
399                        .and_then(|s| s.parse().ok())
400                        .unwrap_or_default();
401
402                    let ssl_root_cert = evaled_args
403                        .get(7)
404                        .and_then(|s| if s.is_empty() { None } else { Some(s.clone()) });
405
406                    let client = create_pg_client(
407                        &evaled_args[2],
408                        &evaled_args[3],
409                        &evaled_args[0],
410                        &evaled_args[1],
411                        &evaled_args[4],
412                        &ssl_mode,
413                        &ssl_root_cert,
414                        None,
415                    )
416                    .await?;
417
418                    let statement = client.prepare(evaled_args[5].as_str()).await?;
419
420                    let mut rw_types = vec![];
421                    for column in statement.columns() {
422                        let name = column.name().to_owned();
423                        let data_type = match *column.type_() {
424                            TokioPgType::BOOL => DataType::Boolean,
425                            TokioPgType::INT2 => DataType::Int16,
426                            TokioPgType::INT4 => DataType::Int32,
427                            TokioPgType::INT8 => DataType::Int64,
428                            TokioPgType::FLOAT4 => DataType::Float32,
429                            TokioPgType::FLOAT8 => DataType::Float64,
430                            TokioPgType::NUMERIC => DataType::Decimal,
431                            TokioPgType::DATE => DataType::Date,
432                            TokioPgType::TIME => DataType::Time,
433                            TokioPgType::TIMESTAMP => DataType::Timestamp,
434                            TokioPgType::TIMESTAMPTZ => DataType::Timestamptz,
435                            TokioPgType::TEXT | TokioPgType::VARCHAR => DataType::Varchar,
436                            TokioPgType::INTERVAL => DataType::Interval,
437                            TokioPgType::JSONB => DataType::Jsonb,
438                            TokioPgType::BYTEA => DataType::Bytea,
439                            _ => {
440                                return Err(crate::error::ErrorCode::BindError(format!(
441                                    "unsupported column type: {}",
442                                    column.type_()
443                                ))
444                                .into());
445                            }
446                        };
447                        rw_types.push((name, data_type));
448                    }
449                    Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
450                        StructType::new(rw_types),
451                    ))
452                })
453            })?;
454
455            Ok(TableFunction {
456                args,
457                return_type: schema,
458                function_type: TableFunctionType::PostgresQuery,
459                user_defined: None,
460            })
461        }
462    }
463
464    pub fn new_mysql_query(
465        catalog_reader: &CatalogReadGuard,
466        db_name: &str,
467        schema_path: SchemaPath<'_>,
468        args: Vec<ExprImpl>,
469    ) -> RwResult<Self> {
470        let args = Self::handle_postgres_or_mysql_query_args(
471            catalog_reader,
472            db_name,
473            schema_path,
474            args,
475            "mysql-cdc",
476        )?;
477        let evaled_args = args
478            .iter()
479            .map(expr_impl_to_string_fn)
480            .collect::<RwResult<Vec<_>>>()?;
481
482        #[cfg(madsim)]
483        {
484            return Err(crate::error::ErrorCode::BindError(
485                "postgres_query can't be used in the madsim mode".to_string(),
486            )
487            .into());
488        }
489
490        #[cfg(not(madsim))]
491        {
492            let schema = tokio::task::block_in_place(|| {
493                FRONTEND_RUNTIME.block_on(async {
494                    let database_opts: mysql_async::Opts = {
495                        let port = evaled_args[1]
496                            .parse::<u16>()
497                            .context("failed to parse port")?;
498                        mysql_async::OptsBuilder::default()
499                            .ip_or_hostname(evaled_args[0].clone())
500                            .tcp_port(port)
501                            .user(Some(evaled_args[2].clone()))
502                            .pass(Some(evaled_args[3].clone()))
503                            .db_name(Some(evaled_args[4].clone()))
504                            .into()
505                    };
506
507                    let pool = mysql_async::Pool::new(database_opts);
508                    let mut conn = pool
509                        .get_conn()
510                        .await
511                        .context("failed to connect to mysql in binder")?;
512
513                    let query = evaled_args[5].clone();
514                    let statement = conn
515                        .prep(query)
516                        .await
517                        .context("failed to prepare mysql_query in binder")?;
518
519                    let mut rw_types = vec![];
520                    #[allow(clippy::never_loop)]
521                    for column in statement.columns() {
522                        let name = column.name_str().to_string();
523                        let data_type = match column.column_type() {
524                            // Boolean types
525                            MySqlColumnType::MYSQL_TYPE_BIT if column.column_length() == 1 => {
526                                DataType::Boolean
527                            }
528
529                            // Numeric types
530                            // NOTE(kwannoel): Although `bool/boolean` is a synonym of TINY(1) in MySQL,
531                            // we treat it as Int16 here. It is better to be straightforward in our conversion.
532                            MySqlColumnType::MYSQL_TYPE_TINY => DataType::Int16,
533                            MySqlColumnType::MYSQL_TYPE_SHORT => DataType::Int16,
534                            MySqlColumnType::MYSQL_TYPE_INT24 => DataType::Int32,
535                            MySqlColumnType::MYSQL_TYPE_LONG => DataType::Int32,
536                            MySqlColumnType::MYSQL_TYPE_LONGLONG => DataType::Int64,
537                            MySqlColumnType::MYSQL_TYPE_FLOAT => DataType::Float32,
538                            MySqlColumnType::MYSQL_TYPE_DOUBLE => DataType::Float64,
539                            MySqlColumnType::MYSQL_TYPE_NEWDECIMAL => DataType::Decimal,
540                            MySqlColumnType::MYSQL_TYPE_DECIMAL => DataType::Decimal,
541
542                            // Date time types
543                            MySqlColumnType::MYSQL_TYPE_YEAR => DataType::Int32,
544                            MySqlColumnType::MYSQL_TYPE_DATE => DataType::Date,
545                            MySqlColumnType::MYSQL_TYPE_NEWDATE => DataType::Date,
546                            MySqlColumnType::MYSQL_TYPE_TIME => DataType::Time,
547                            MySqlColumnType::MYSQL_TYPE_TIME2 => DataType::Time,
548                            MySqlColumnType::MYSQL_TYPE_DATETIME => DataType::Timestamp,
549                            MySqlColumnType::MYSQL_TYPE_DATETIME2 => DataType::Timestamp,
550                            MySqlColumnType::MYSQL_TYPE_TIMESTAMP => DataType::Timestamptz,
551                            MySqlColumnType::MYSQL_TYPE_TIMESTAMP2 => DataType::Timestamptz,
552
553                            // String types
554                            MySqlColumnType::MYSQL_TYPE_VARCHAR => DataType::Varchar,
555                            // mysql_async does not have explicit `varbinary` and `binary` types,
556                            // we need to check the `ColumnFlags` to distinguish them.
557                            MySqlColumnType::MYSQL_TYPE_STRING
558                            | MySqlColumnType::MYSQL_TYPE_VAR_STRING => {
559                                if column
560                                    .flags()
561                                    .contains(mysql_common::constants::ColumnFlags::BINARY_FLAG)
562                                {
563                                    DataType::Bytea
564                                } else {
565                                    DataType::Varchar
566                                }
567                            }
568
569                            // JSON types
570                            MySqlColumnType::MYSQL_TYPE_JSON => DataType::Jsonb,
571
572                            // Binary types
573                            MySqlColumnType::MYSQL_TYPE_BIT
574                            | MySqlColumnType::MYSQL_TYPE_BLOB
575                            | MySqlColumnType::MYSQL_TYPE_TINY_BLOB
576                            | MySqlColumnType::MYSQL_TYPE_MEDIUM_BLOB
577                            | MySqlColumnType::MYSQL_TYPE_LONG_BLOB => DataType::Bytea,
578
579                            MySqlColumnType::MYSQL_TYPE_UNKNOWN
580                            | MySqlColumnType::MYSQL_TYPE_TYPED_ARRAY
581                            | MySqlColumnType::MYSQL_TYPE_ENUM
582                            | MySqlColumnType::MYSQL_TYPE_SET
583                            | MySqlColumnType::MYSQL_TYPE_GEOMETRY
584                            | MySqlColumnType::MYSQL_TYPE_VECTOR
585                            | MySqlColumnType::MYSQL_TYPE_NULL => {
586                                return Err(crate::error::ErrorCode::BindError(format!(
587                                    "unsupported column type: {:?}",
588                                    column.column_type()
589                                ))
590                                .into());
591                            }
592                        };
593                        rw_types.push((name, data_type));
594                    }
595                    Ok::<risingwave_common::types::DataType, anyhow::Error>(DataType::Struct(
596                        StructType::new(rw_types),
597                    ))
598                })
599            })?;
600
601            Ok(TableFunction {
602                args,
603                return_type: schema,
604                function_type: TableFunctionType::MysqlQuery,
605                user_defined: None,
606            })
607        }
608    }
609
610    /// This is a highly specific _internal_ table function meant to scan and aggregate
611    /// `backfill_table_id`, `row_count` for all MVs which are still being created.
612    pub fn new_internal_backfill_progress() -> Self {
613        TableFunction {
614            args: vec![],
615            return_type: DataType::Struct(StructType::new(vec![
616                ("job_id".to_owned(), DataType::Int32),
617                ("fragment_id".to_owned(), DataType::Int32),
618                ("backfill_state_table_id".to_owned(), DataType::Int32),
619                ("current_row_count".to_owned(), DataType::Int64),
620                ("min_epoch".to_owned(), DataType::Int64),
621            ])),
622            function_type: TableFunctionType::InternalBackfillProgress,
623            user_defined: None,
624        }
625    }
626
627    pub fn new_internal_source_backfill_progress() -> Self {
628        TableFunction {
629            args: vec![],
630            return_type: DataType::Struct(StructType::new(vec![
631                ("job_id".to_owned(), DataType::Int32),
632                ("fragment_id".to_owned(), DataType::Int32),
633                ("backfill_state_table_id".to_owned(), DataType::Int32),
634                ("partition_id".to_owned(), DataType::Varchar),
635                ("backfill_progress".to_owned(), DataType::Jsonb),
636            ])),
637            function_type: TableFunctionType::InternalSourceBackfillProgress,
638            user_defined: None,
639        }
640    }
641
642    pub fn new_internal_get_channel_delta_stats(args: Vec<ExprImpl>) -> Self {
643        Self {
644            args,
645            return_type: DataType::Struct(StructType::new(vec![
646                ("upstream_fragment_id".to_owned(), DataType::Int32),
647                ("downstream_fragment_id".to_owned(), DataType::Int32),
648                ("backpressure_rate".to_owned(), DataType::Float64),
649                ("recv_throughput".to_owned(), DataType::Float64),
650                ("send_throughput".to_owned(), DataType::Float64),
651            ])),
652            function_type: TableFunctionType::InternalGetChannelDeltaStats,
653            user_defined: None,
654        }
655    }
656
657    pub fn to_protobuf(&self) -> PbTableFunction {
658        PbTableFunction {
659            function_type: self.function_type as i32,
660            args: self.args.iter().map(|c| c.to_expr_proto()).collect_vec(),
661            return_type: Some(self.return_type.to_protobuf()),
662            udf: self.user_defined.as_ref().map(|c| c.as_ref().into()),
663        }
664    }
665
666    /// Serialize the table function. Returns an error if this will result in an impure table
667    /// function on a retract stream, which may lead to inconsistent results.
668    pub fn to_protobuf_checked_pure(&self, retract: bool) -> crate::error::Result<PbTableFunction> {
669        if retract {
670            reject_impure(self.clone(), "table function")?;
671        }
672
673        let args = self
674            .args
675            .iter()
676            .map(|arg| arg.to_expr_proto_checked_pure(retract, "table function argument"))
677            .collect::<crate::error::Result<Vec<_>>>()?;
678
679        Ok(PbTableFunction {
680            function_type: self.function_type as i32,
681            args,
682            return_type: Some(self.return_type.to_protobuf()),
683            udf: self.user_defined.as_ref().map(|c| c.as_ref().into()),
684        })
685    }
686
687    /// Get the name of the table function.
688    pub fn name(&self) -> String {
689        match self.function_type {
690            TableFunctionType::UserDefined => self.user_defined.as_ref().unwrap().name.clone(),
691            t => t.as_str_name().to_lowercase(),
692        }
693    }
694
695    pub fn rewrite(self, rewriter: &mut impl ExprRewriter) -> Self {
696        Self {
697            args: self
698                .args
699                .into_iter()
700                .map(|e| rewriter.rewrite_expr(e))
701                .collect(),
702            ..self
703        }
704    }
705}
706
707impl std::fmt::Debug for TableFunction {
708    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
709        if f.alternate() {
710            f.debug_struct("FunctionCall")
711                .field("function_type", &self.function_type)
712                .field("return_type", &self.return_type)
713                .field("args", &self.args)
714                .finish()
715        } else {
716            let func_name = format!("{:?}", self.function_type);
717            let mut builder = f.debug_tuple(&func_name);
718            self.args.iter().for_each(|child| {
719                builder.field(child);
720            });
721            builder.finish()
722        }
723    }
724}
725
726impl Expr for TableFunction {
727    fn return_type(&self) -> DataType {
728        self.return_type.clone()
729    }
730
731    fn try_to_expr_proto(&self) -> Result<risingwave_pb::expr::ExprNode, String> {
732        Err("Table function should not be converted to ExprNode".to_owned())
733    }
734}
735
736fn expr_impl_to_string_fn(arg: &ExprImpl) -> RwResult<String> {
737    match arg.try_fold_const() {
738        Some(Ok(value)) => {
739            let Some(scalar) = value else {
740                return Err(BindError(
741                    "postgres_query function and mysql_query function do not accept null arguments"
742                        .to_owned(),
743                )
744                .into());
745            };
746            Ok(scalar.into_utf8().to_string())
747        }
748        Some(Err(err)) => Err(err),
749        None => Err(BindError(
750            "postgres_query function and mysql_query function only accept constant arguments"
751                .to_owned(),
752        )
753        .into()),
754    }
755}