Skip to main content

risingwave_frontend/handler/
util.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::str::FromStr;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use bytes::{Bytes, BytesMut};
22use futures::Stream;
23use itertools::Itertools;
24use pgwire::pg_field_descriptor::PgFieldDescriptor;
25use pgwire::pg_response::RowSetResult;
26use pgwire::pg_server::BoxedError;
27use pgwire::types::{Format, FormatIterator, Row};
28use pin_project_lite::pin_project;
29use risingwave_common::array::DataChunk;
30use risingwave_common::catalog::Field;
31use risingwave_common::config::FrontendConfig;
32use risingwave_common::id::ObjectId;
33use risingwave_common::row::Row as _;
34use risingwave_common::types::{
35    DataType, Interval, ScalarRefImpl, Timestamptz, write_date_time_tz,
36};
37use risingwave_common::util::epoch::Epoch;
38use risingwave_common::util::iter_util::ZipEqFast;
39use risingwave_connector::sink::elasticsearch_opensearch::elasticsearch::ES_SINK;
40use risingwave_connector::sink::file_sink::fs::FS_SINK;
41use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR;
42use risingwave_connector::source::{BATCH_POSIX_FS_CONNECTOR, KAFKA_CONNECTOR, POSIX_FS_CONNECTOR};
43use risingwave_pb::catalog::connection_params::PbConnectionType;
44use risingwave_sqlparser::ast::{
45    CompatibleFormatEncode, FormatEncodeOptions, ObjectName, Query, Select, SelectItem, SetExpr,
46    TableFactor, TableWithJoins,
47};
48use thiserror_ext::AsReport;
49use tokio::select;
50use tokio::time::{Duration, sleep};
51
52use crate::catalog::root_catalog::SchemaPath;
53use crate::error::ErrorCode::ProtocolError;
54use crate::error::{ErrorCode, Result as RwResult, RwError};
55use crate::session::{SessionImpl, current};
56use crate::{Binder, HashSet, TableCatalog};
57
58pub fn ensure_local_fs_connector_allowed(session: &SessionImpl, connector: &str) -> RwResult<()> {
59    let is_local_fs_connector = is_local_fs_connector(connector);
60
61    if !is_local_fs_connector
62        || is_local_fs_connector_enabled(session.env().frontend_config(), connector)
63    {
64        return Ok(());
65    }
66
67    Err(RwError::from(ProtocolError(format!(
68        "local filesystem connector '{}' is disabled. Set `frontend.unsafe_enable_local_fs_connector = true` in `risingwave.toml` to enable it.",
69        connector
70    ))))
71}
72
73fn is_local_fs_connector_enabled(frontend_config: &FrontendConfig, connector: &str) -> bool {
74    !is_local_fs_connector(connector) || frontend_config.unsafe_enable_local_fs_connector
75}
76
77fn is_local_fs_connector(connector: &str) -> bool {
78    connector.eq_ignore_ascii_case(POSIX_FS_CONNECTOR)
79        || connector.eq_ignore_ascii_case(BATCH_POSIX_FS_CONNECTOR)
80        || connector.eq_ignore_ascii_case(FS_SINK)
81}
82
83pin_project! {
84    /// Wrapper struct that converts a stream of DataChunk to a stream of RowSet based on formatting
85    /// parameters.
86    ///
87    /// This is essentially `StreamExt::map(self, move |res| res.map(|chunk| to_pg_rows(chunk,
88    /// format)))` but we need a nameable type as part of [`super::PgResponseStream`], but we cannot
89    /// name the type of a closure.
90    pub struct DataChunkToRowSetAdapter<VS>
91    where
92        VS: Stream<Item = Result<DataChunk, BoxedError>>,
93    {
94        #[pin]
95        chunk_stream: VS,
96        column_types: Vec<DataType>,
97        pub formats: Vec<Format>,
98        session_data: StaticSessionData,
99    }
100}
101
102// Static session data frozen at the time of the creation of the stream
103pub struct StaticSessionData {
104    pub timezone: String,
105}
106
107impl<VS> DataChunkToRowSetAdapter<VS>
108where
109    VS: Stream<Item = Result<DataChunk, BoxedError>>,
110{
111    pub fn new(
112        chunk_stream: VS,
113        column_types: Vec<DataType>,
114        formats: Vec<Format>,
115        session: Arc<SessionImpl>,
116    ) -> Self {
117        let session_data = StaticSessionData {
118            timezone: session.config().timezone(),
119        };
120        Self {
121            chunk_stream,
122            column_types,
123            formats,
124            session_data,
125        }
126    }
127}
128
129impl<VS> Stream for DataChunkToRowSetAdapter<VS>
130where
131    VS: Stream<Item = Result<DataChunk, BoxedError>>,
132{
133    type Item = RowSetResult;
134
135    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136        let mut this = self.project();
137        match this.chunk_stream.as_mut().poll_next(cx) {
138            Poll::Pending => Poll::Pending,
139            Poll::Ready(chunk) => match chunk {
140                Some(chunk_result) => match chunk_result {
141                    Ok(chunk) => Poll::Ready(Some(
142                        to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
143                            .map_err(|err| err.into()),
144                    )),
145                    Err(err) => Poll::Ready(Some(Err(err))),
146                },
147                None => Poll::Ready(None),
148            },
149        }
150    }
151}
152
153/// Format scalars according to postgres convention.
154pub fn pg_value_format(
155    data_type: &DataType,
156    d: ScalarRefImpl<'_>,
157    format: Format,
158    session_data: &StaticSessionData,
159) -> RwResult<Bytes> {
160    // format == false means TEXT format
161    // format == true means BINARY format
162    match format {
163        Format::Text => {
164            if *data_type == DataType::Timestamptz {
165                Ok(timestamptz_to_string_with_session_data(d, session_data))
166            } else {
167                Ok(d.text_format(data_type).into())
168            }
169        }
170        Format::Binary => Ok(d
171            .binary_format(data_type)
172            .context("failed to format binary value")?),
173    }
174}
175
176fn timestamptz_to_string_with_session_data(
177    d: ScalarRefImpl<'_>,
178    session_data: &StaticSessionData,
179) -> Bytes {
180    let tz = d.into_timestamptz();
181    let time_zone = Timestamptz::lookup_time_zone(&session_data.timezone).unwrap();
182    let instant_local = tz.to_datetime_in_zone(time_zone);
183    let mut result_string = BytesMut::new();
184    write_date_time_tz(instant_local, &mut result_string).unwrap();
185    result_string.into()
186}
187
188fn to_pg_rows(
189    column_types: &[DataType],
190    chunk: DataChunk,
191    formats: &[Format],
192    session_data: &StaticSessionData,
193) -> RwResult<Vec<Row>> {
194    assert_eq!(chunk.dimension(), column_types.len());
195    if cfg!(debug_assertions) {
196        let chunk_data_types = chunk.data_types();
197        for (ty1, ty2) in chunk_data_types.iter().zip_eq_fast(column_types) {
198            debug_assert!(
199                ty1.equals_datatype(ty2),
200                "chunk_data_types: {chunk_data_types:?}, column_types: {column_types:?}"
201            )
202        }
203    }
204
205    chunk
206        .rows()
207        .map(|r| {
208            let format_iter = FormatIterator::new(formats, chunk.dimension())
209                .map_err(ErrorCode::InternalError)?;
210            let row = r
211                .iter()
212                .zip_eq_fast(column_types)
213                .zip_eq_fast(format_iter)
214                .map(|((data, t), format)| match data {
215                    Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
216                    None => Ok(None),
217                })
218                .try_collect()?;
219            Ok(Row::new(row))
220        })
221        .try_collect()
222}
223
224/// Convert from [`Field`] to [`PgFieldDescriptor`].
225pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {
226    PgFieldDescriptor::new(
227        f.name.clone(),
228        f.data_type().to_oid(),
229        f.data_type().type_len(),
230    )
231}
232
233#[easy_ext::ext(SourceSchemaCompatExt)]
234impl CompatibleFormatEncode {
235    /// Convert `self` to [`FormatEncodeOptions`] and warn the user if the syntax is deprecated.
236    pub fn into_v2_with_warning(self) -> FormatEncodeOptions {
237        match self {
238            CompatibleFormatEncode::RowFormat(inner) => {
239                // TODO: should be warning
240                current::notice_to_user(
241                    "RisingWave will stop supporting the syntax \"ROW FORMAT\" in future versions, which will be changed to \"FORMAT ... ENCODE ...\" syntax.",
242                );
243                inner.into_format_encode_v2()
244            }
245            CompatibleFormatEncode::V2(inner) => inner,
246        }
247    }
248}
249
250pub fn gen_query_from_table_name(from_name: ObjectName) -> Query {
251    let table_factor = TableFactor::Table {
252        name: from_name,
253        alias: None,
254        as_of: None,
255    };
256    let from = vec![TableWithJoins {
257        relation: table_factor,
258        joins: vec![],
259    }];
260    let select = Select {
261        from,
262        projection: vec![SelectItem::Wildcard(None)],
263        ..Default::default()
264    };
265    let body = SetExpr::Select(Box::new(select));
266    Query {
267        with: None,
268        body,
269        order_by: vec![],
270        limit: None,
271        offset: None,
272        fetch: None,
273    }
274}
275
276pub fn convert_unix_millis_to_logstore_u64(unix_millis: u64) -> u64 {
277    Epoch::from_unix_millis(unix_millis).0
278}
279
280pub fn convert_logstore_u64_to_unix_millis(logstore_u64: u64) -> u64 {
281    Epoch::from(logstore_u64).as_unix_millis()
282}
283
284pub fn convert_interval_to_u64_seconds(interval: &String) -> RwResult<u64> {
285    let seconds = (Interval::from_str(interval)
286        .map_err(|err| {
287            ErrorCode::InternalError(format!(
288                "Convert interval to u64 error, please check format, error: {:?}",
289                err.to_report_string()
290            ))
291        })?
292        .epoch_in_micros()
293        / 1000000) as u64;
294    Ok(seconds)
295}
296
297pub fn ensure_connection_type_allowed(
298    connection_type: PbConnectionType,
299    allowed_types: &HashSet<PbConnectionType>,
300) -> RwResult<()> {
301    if !allowed_types.contains(&connection_type) {
302        return Err(RwError::from(ProtocolError(format!(
303            "connection type {:?} is not allowed, allowed types: {:?}",
304            connection_type, allowed_types
305        ))));
306    }
307    Ok(())
308}
309
310fn connection_type_to_connector(connection_type: &PbConnectionType) -> &str {
311    match connection_type {
312        PbConnectionType::Kafka => KAFKA_CONNECTOR,
313        PbConnectionType::Iceberg => ICEBERG_CONNECTOR,
314        PbConnectionType::Elasticsearch => ES_SINK,
315        _ => unreachable!(),
316    }
317}
318
319pub fn check_connector_match_connection_type(
320    connector: &str,
321    connection_type: &PbConnectionType,
322) -> RwResult<()> {
323    if !connector.eq(connection_type_to_connector(connection_type)) {
324        return Err(RwError::from(ProtocolError(format!(
325            "connector {} and connection type {:?} are not compatible",
326            connector, connection_type
327        ))));
328    }
329    Ok(())
330}
331
332pub fn get_table_catalog_by_table_name(
333    session: &SessionImpl,
334    table_name: &ObjectName,
335) -> RwResult<(Arc<TableCatalog>, String)> {
336    let db_name = &session.database();
337    let (schema_name, real_table_name) =
338        Binder::resolve_schema_qualified_name(db_name, table_name)?;
339    let search_path = session.config().search_path();
340    let user_name = &session.user_name();
341
342    let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
343    let reader = session.env().catalog_reader().read_guard();
344    match reader.get_created_table_by_name(db_name, schema_path, &real_table_name) {
345        Ok((table, schema_name)) => Ok((table.clone(), schema_name.to_owned())),
346        Err(err) => {
347            if let Some(table) = session
348                .staging_catalog_manager()
349                .get_table(&real_table_name)
350            {
351                // During CREATE TABLE (iceberg engine), the table is only in staging, but we
352                // still need a stable schema name for internal sink planning.
353                let schema_name = reader
354                    .get_schema_by_id(table.database_id, table.schema_id)
355                    .map(|schema| schema.name.clone())?;
356                Ok((Arc::new(table.clone()), schema_name))
357            } else {
358                Err(err.into())
359            }
360        }
361    }
362}
363
364pub fn reject_internal_table_dependency(
365    table: &TableCatalog,
366    statement_name: &str,
367) -> RwResult<()> {
368    if table.is_internal_table() {
369        return Err(RwError::from(ErrorCode::InvalidInputSyntax(format!(
370            "{statement_name} does not support internal table \"{}\"",
371            table.name()
372        ))));
373    }
374    Ok(())
375}
376
377pub fn reject_internal_table_dependencies<'a, I>(
378    session: &SessionImpl,
379    dependent_relations: I,
380    statement_name: &str,
381) -> RwResult<()>
382where
383    I: IntoIterator<Item = &'a ObjectId>,
384{
385    let catalog_reader = session.env().catalog_reader().read_guard();
386    for object_id in dependent_relations {
387        if let Ok(table) = catalog_reader.get_any_table_by_id(object_id.as_table_id()) {
388            reject_internal_table_dependency(table.as_ref(), statement_name)?;
389        }
390    }
391    Ok(())
392}
393
394/// Execute an async operation with a notification if it takes too long.
395/// This is useful for operations that might be delayed due to high barrier latency.
396///
397/// The notification timeout duration is controlled by the `slow_ddl_notification_secs` session variable.
398///
399/// # Arguments
400/// * `operation_fut` - The async operation to execute
401/// * `session` - The session to send notifications to
402/// * `operation_name` - The name of the operation for the notification message (e.g., "DROP TABLE")
403///
404/// # Example
405/// ```ignore
406/// execute_with_long_running_notification(
407///     catalog_writer.drop_table(source_id, table_id, cascade),
408///     &session,
409///     "DROP TABLE",
410/// ).await?;
411/// ```
412#[derive(Clone, Copy)]
413pub enum LongRunningNotificationAction {
414    SuggestRecover,
415    DiagnoseBarrierLatency,
416    MonitorBackfillJob,
417}
418
419impl LongRunningNotificationAction {
420    fn build_message(self, operation_name: &str, notify_timeout_secs: u32) -> String {
421        match self {
422            LongRunningNotificationAction::SuggestRecover => format!(
423                "{} has taken more than {} secs, likely due to high barrier latency.\n\
424                You may trigger cluster recovery to let {} take effect immediately.\n\
425                Run RECOVER in a separate session to trigger recovery.\n\
426                See: https://docs.risingwave.com/sql/commands/sql-recover#recover",
427                operation_name, notify_timeout_secs, operation_name
428            ),
429            LongRunningNotificationAction::DiagnoseBarrierLatency => format!(
430                "{} has taken more than {} secs, likely due to high barrier latency.\n\
431                See: https://docs.risingwave.com/performance/metrics#barrier-monitoring for steps to diagnose high barrier latency.",
432                operation_name, notify_timeout_secs
433            ),
434            LongRunningNotificationAction::MonitorBackfillJob => format!(
435                "{} has taken more than {} secs, barrier latency might be high. Please check barrier latency metrics to confirm.\n\
436                You can also run SHOW JOBS to track the progress of the job.\n\
437                See: https://docs.risingwave.com/performance/metrics#barrier-monitoring and https://docs.risingwave.com/sql/commands/sql-show-jobs",
438                operation_name, notify_timeout_secs
439            ),
440        }
441    }
442}
443
444pub async fn execute_with_long_running_notification<F, T>(
445    operation_fut: F,
446    session: &SessionImpl,
447    operation_name: &str,
448    action: LongRunningNotificationAction,
449) -> RwResult<T>
450where
451    F: std::future::Future<Output = RwResult<T>>,
452{
453    let notify_timeout_secs = session.config().slow_ddl_notification_secs();
454
455    // If timeout is 0, disable notifications and just execute the operation
456    if notify_timeout_secs == 0 {
457        return operation_fut.await;
458    }
459
460    let notify_fut = sleep(Duration::from_secs(notify_timeout_secs as u64));
461    tokio::pin!(operation_fut);
462
463    select! {
464        _ = notify_fut => {
465            session.notice_to_user(action.build_message(operation_name, notify_timeout_secs));
466            operation_fut.await
467        }
468        result = &mut operation_fut => {
469            result
470        }
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use postgres_types::{ToSql, Type};
477    use risingwave_common::array::*;
478
479    use super::*;
480
481    #[test]
482    fn test_to_pg_field() {
483        let field = Field::with_name(DataType::Int32, "v1");
484        let pg_field = to_pg_field(&field);
485        assert_eq!(pg_field.get_name(), "v1");
486        assert_eq!(pg_field.get_type_oid(), DataType::Int32.to_oid());
487    }
488
489    #[test]
490    fn test_to_pg_rows() {
491        let chunk = DataChunk::from_pretty(
492            "i I f    T
493             1 6 6.01 aaa
494             2 . .    .
495             3 7 7.01 vvv
496             4 . .    .  ",
497        );
498        let static_session = StaticSessionData {
499            timezone: "UTC".into(),
500        };
501        let rows = to_pg_rows(
502            &[
503                DataType::Int32,
504                DataType::Int64,
505                DataType::Float32,
506                DataType::Varchar,
507            ],
508            chunk,
509            &[],
510            &static_session,
511        );
512        let expected: Vec<Vec<Option<Bytes>>> = vec![
513            vec![
514                Some("1".into()),
515                Some("6".into()),
516                Some("6.01".into()),
517                Some("aaa".into()),
518            ],
519            vec![Some("2".into()), None, None, None],
520            vec![
521                Some("3".into()),
522                Some("7".into()),
523                Some("7.01".into()),
524                Some("vvv".into()),
525            ],
526            vec![Some("4".into()), None, None, None],
527        ];
528        let vec = rows
529            .unwrap()
530            .into_iter()
531            .map(|r| r.values().iter().cloned().collect_vec())
532            .collect_vec();
533
534        assert_eq!(vec, expected);
535    }
536
537    #[test]
538    fn test_to_pg_rows_mix_format() {
539        let chunk = DataChunk::from_pretty(
540            "i I f    T
541             1 6 6.01 aaa
542            ",
543        );
544        let static_session = StaticSessionData {
545            timezone: "UTC".into(),
546        };
547        let rows = to_pg_rows(
548            &[
549                DataType::Int32,
550                DataType::Int64,
551                DataType::Float32,
552                DataType::Varchar,
553            ],
554            chunk,
555            &[Format::Binary, Format::Binary, Format::Binary, Format::Text],
556            &static_session,
557        );
558        let mut raw_params = vec![BytesMut::new(); 3];
559        1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
560        6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
561        6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
562        let raw_params = raw_params
563            .into_iter()
564            .map(|b| b.freeze())
565            .collect::<Vec<_>>();
566        let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
567            Some(raw_params[0].clone()),
568            Some(raw_params[1].clone()),
569            Some(raw_params[2].clone()),
570            Some("aaa".into()),
571        ]];
572        let vec = rows
573            .unwrap()
574            .into_iter()
575            .map(|r| r.values().iter().cloned().collect_vec())
576            .collect_vec();
577
578        assert_eq!(vec, expected);
579    }
580
581    #[test]
582    fn test_value_format() {
583        use DataType as T;
584        use ScalarRefImpl as S;
585        let static_session = StaticSessionData {
586            timezone: "UTC".into(),
587        };
588
589        let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
590        assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
591        assert_eq!(
592            &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
593            "NaN"
594        );
595        assert_eq!(
596            &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
597            "NaN"
598        );
599        assert_eq!(
600            &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
601            "Infinity"
602        );
603        assert_eq!(
604            &f(
605                &T::Float32,
606                S::Float32(f32::NEG_INFINITY.into()),
607                Format::Text
608            ),
609            "-Infinity"
610        );
611        assert_eq!(
612            &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
613            "Infinity"
614        );
615        assert_eq!(
616            &f(
617                &T::Float64,
618                S::Float64(f64::NEG_INFINITY.into()),
619                Format::Text
620            ),
621            "-Infinity"
622        );
623        assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
624        assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
625        assert_eq!(
626            &f(
627                &T::Timestamptz,
628                S::Timestamptz(Timestamptz::from_micros(-1)),
629                Format::Text
630            ),
631            "1969-12-31 23:59:59.999999+00:00"
632        );
633    }
634
635    #[test]
636    fn test_local_fs_connector_config_gate() {
637        let frontend_config = FrontendConfig::default();
638        let default_enabled = cfg!(debug_assertions);
639        assert_eq!(
640            is_local_fs_connector_enabled(&frontend_config, POSIX_FS_CONNECTOR),
641            default_enabled
642        );
643        assert_eq!(
644            is_local_fs_connector_enabled(&frontend_config, BATCH_POSIX_FS_CONNECTOR),
645            default_enabled
646        );
647        assert_eq!(
648            is_local_fs_connector_enabled(&frontend_config, FS_SINK),
649            default_enabled
650        );
651        assert!(is_local_fs_connector_enabled(
652            &frontend_config,
653            KAFKA_CONNECTOR
654        ));
655
656        let disabled_config = FrontendConfig {
657            unsafe_enable_local_fs_connector: false,
658            ..Default::default()
659        };
660        assert!(!is_local_fs_connector_enabled(
661            &disabled_config,
662            POSIX_FS_CONNECTOR
663        ));
664        assert!(!is_local_fs_connector_enabled(
665            &disabled_config,
666            BATCH_POSIX_FS_CONNECTOR
667        ));
668        assert!(!is_local_fs_connector_enabled(&disabled_config, FS_SINK));
669
670        let enabled_config = FrontendConfig {
671            unsafe_enable_local_fs_connector: true,
672            ..Default::default()
673        };
674        assert!(is_local_fs_connector_enabled(
675            &enabled_config,
676            POSIX_FS_CONNECTOR
677        ));
678        assert!(is_local_fs_connector_enabled(
679            &enabled_config,
680            BATCH_POSIX_FS_CONNECTOR
681        ));
682        assert!(is_local_fs_connector_enabled(&enabled_config, FS_SINK));
683    }
684}