risingwave_frontend/handler/
util.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::str::FromStr;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use bytes::{Bytes, BytesMut};
22use futures::Stream;
23use itertools::Itertools;
24use pgwire::pg_field_descriptor::PgFieldDescriptor;
25use pgwire::pg_response::RowSetResult;
26use pgwire::pg_server::BoxedError;
27use pgwire::types::{Format, FormatIterator, Row};
28use pin_project_lite::pin_project;
29use risingwave_common::array::DataChunk;
30use risingwave_common::catalog::Field;
31use risingwave_common::id::ObjectId;
32use risingwave_common::row::Row as _;
33use risingwave_common::types::{
34    DataType, Interval, ScalarRefImpl, Timestamptz, write_date_time_tz,
35};
36use risingwave_common::util::epoch::Epoch;
37use risingwave_common::util::iter_util::ZipEqFast;
38use risingwave_connector::sink::elasticsearch_opensearch::elasticsearch::ES_SINK;
39use risingwave_connector::source::KAFKA_CONNECTOR;
40use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR;
41use risingwave_pb::catalog::connection_params::PbConnectionType;
42use risingwave_sqlparser::ast::{
43    CompatibleFormatEncode, FormatEncodeOptions, ObjectName, Query, Select, SelectItem, SetExpr,
44    TableFactor, TableWithJoins,
45};
46use thiserror_ext::AsReport;
47use tokio::select;
48use tokio::time::{Duration, sleep};
49
50use crate::catalog::root_catalog::SchemaPath;
51use crate::error::ErrorCode::ProtocolError;
52use crate::error::{ErrorCode, Result as RwResult, RwError};
53use crate::session::{SessionImpl, current};
54use crate::{Binder, HashSet, TableCatalog};
55
56pin_project! {
57    /// Wrapper struct that converts a stream of DataChunk to a stream of RowSet based on formatting
58    /// parameters.
59    ///
60    /// This is essentially `StreamExt::map(self, move |res| res.map(|chunk| to_pg_rows(chunk,
61    /// format)))` but we need a nameable type as part of [`super::PgResponseStream`], but we cannot
62    /// name the type of a closure.
63    pub struct DataChunkToRowSetAdapter<VS>
64    where
65        VS: Stream<Item = Result<DataChunk, BoxedError>>,
66    {
67        #[pin]
68        chunk_stream: VS,
69        column_types: Vec<DataType>,
70        pub formats: Vec<Format>,
71        session_data: StaticSessionData,
72    }
73}
74
75// Static session data frozen at the time of the creation of the stream
76pub struct StaticSessionData {
77    pub timezone: String,
78}
79
80impl<VS> DataChunkToRowSetAdapter<VS>
81where
82    VS: Stream<Item = Result<DataChunk, BoxedError>>,
83{
84    pub fn new(
85        chunk_stream: VS,
86        column_types: Vec<DataType>,
87        formats: Vec<Format>,
88        session: Arc<SessionImpl>,
89    ) -> Self {
90        let session_data = StaticSessionData {
91            timezone: session.config().timezone(),
92        };
93        Self {
94            chunk_stream,
95            column_types,
96            formats,
97            session_data,
98        }
99    }
100}
101
102impl<VS> Stream for DataChunkToRowSetAdapter<VS>
103where
104    VS: Stream<Item = Result<DataChunk, BoxedError>>,
105{
106    type Item = RowSetResult;
107
108    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
109        let mut this = self.project();
110        match this.chunk_stream.as_mut().poll_next(cx) {
111            Poll::Pending => Poll::Pending,
112            Poll::Ready(chunk) => match chunk {
113                Some(chunk_result) => match chunk_result {
114                    Ok(chunk) => Poll::Ready(Some(
115                        to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
116                            .map_err(|err| err.into()),
117                    )),
118                    Err(err) => Poll::Ready(Some(Err(err))),
119                },
120                None => Poll::Ready(None),
121            },
122        }
123    }
124}
125
126/// Format scalars according to postgres convention.
127pub fn pg_value_format(
128    data_type: &DataType,
129    d: ScalarRefImpl<'_>,
130    format: Format,
131    session_data: &StaticSessionData,
132) -> RwResult<Bytes> {
133    // format == false means TEXT format
134    // format == true means BINARY format
135    match format {
136        Format::Text => {
137            if *data_type == DataType::Timestamptz {
138                Ok(timestamptz_to_string_with_session_data(d, session_data))
139            } else {
140                Ok(d.text_format(data_type).into())
141            }
142        }
143        Format::Binary => Ok(d
144            .binary_format(data_type)
145            .context("failed to format binary value")?),
146    }
147}
148
149fn timestamptz_to_string_with_session_data(
150    d: ScalarRefImpl<'_>,
151    session_data: &StaticSessionData,
152) -> Bytes {
153    let tz = d.into_timestamptz();
154    let time_zone = Timestamptz::lookup_time_zone(&session_data.timezone).unwrap();
155    let instant_local = tz.to_datetime_in_zone(time_zone);
156    let mut result_string = BytesMut::new();
157    write_date_time_tz(instant_local, &mut result_string).unwrap();
158    result_string.into()
159}
160
161fn to_pg_rows(
162    column_types: &[DataType],
163    chunk: DataChunk,
164    formats: &[Format],
165    session_data: &StaticSessionData,
166) -> RwResult<Vec<Row>> {
167    assert_eq!(chunk.dimension(), column_types.len());
168    if cfg!(debug_assertions) {
169        let chunk_data_types = chunk.data_types();
170        for (ty1, ty2) in chunk_data_types.iter().zip_eq_fast(column_types) {
171            debug_assert!(
172                ty1.equals_datatype(ty2),
173                "chunk_data_types: {chunk_data_types:?}, column_types: {column_types:?}"
174            )
175        }
176    }
177
178    chunk
179        .rows()
180        .map(|r| {
181            let format_iter = FormatIterator::new(formats, chunk.dimension())
182                .map_err(ErrorCode::InternalError)?;
183            let row = r
184                .iter()
185                .zip_eq_fast(column_types)
186                .zip_eq_fast(format_iter)
187                .map(|((data, t), format)| match data {
188                    Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
189                    None => Ok(None),
190                })
191                .try_collect()?;
192            Ok(Row::new(row))
193        })
194        .try_collect()
195}
196
197/// Convert from [`Field`] to [`PgFieldDescriptor`].
198pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {
199    PgFieldDescriptor::new(
200        f.name.clone(),
201        f.data_type().to_oid(),
202        f.data_type().type_len(),
203    )
204}
205
206#[easy_ext::ext(SourceSchemaCompatExt)]
207impl CompatibleFormatEncode {
208    /// Convert `self` to [`FormatEncodeOptions`] and warn the user if the syntax is deprecated.
209    pub fn into_v2_with_warning(self) -> FormatEncodeOptions {
210        match self {
211            CompatibleFormatEncode::RowFormat(inner) => {
212                // TODO: should be warning
213                current::notice_to_user(
214                    "RisingWave will stop supporting the syntax \"ROW FORMAT\" in future versions, which will be changed to \"FORMAT ... ENCODE ...\" syntax.",
215                );
216                inner.into_format_encode_v2()
217            }
218            CompatibleFormatEncode::V2(inner) => inner,
219        }
220    }
221}
222
223pub fn gen_query_from_table_name(from_name: ObjectName) -> Query {
224    let table_factor = TableFactor::Table {
225        name: from_name,
226        alias: None,
227        as_of: None,
228    };
229    let from = vec![TableWithJoins {
230        relation: table_factor,
231        joins: vec![],
232    }];
233    let select = Select {
234        from,
235        projection: vec![SelectItem::Wildcard(None)],
236        ..Default::default()
237    };
238    let body = SetExpr::Select(Box::new(select));
239    Query {
240        with: None,
241        body,
242        order_by: vec![],
243        limit: None,
244        offset: None,
245        fetch: None,
246    }
247}
248
249pub fn convert_unix_millis_to_logstore_u64(unix_millis: u64) -> u64 {
250    Epoch::from_unix_millis(unix_millis).0
251}
252
253pub fn convert_logstore_u64_to_unix_millis(logstore_u64: u64) -> u64 {
254    Epoch::from(logstore_u64).as_unix_millis()
255}
256
257pub fn convert_interval_to_u64_seconds(interval: &String) -> RwResult<u64> {
258    let seconds = (Interval::from_str(interval)
259        .map_err(|err| {
260            ErrorCode::InternalError(format!(
261                "Convert interval to u64 error, please check format, error: {:?}",
262                err.to_report_string()
263            ))
264        })?
265        .epoch_in_micros()
266        / 1000000) as u64;
267    Ok(seconds)
268}
269
270pub fn ensure_connection_type_allowed(
271    connection_type: PbConnectionType,
272    allowed_types: &HashSet<PbConnectionType>,
273) -> RwResult<()> {
274    if !allowed_types.contains(&connection_type) {
275        return Err(RwError::from(ProtocolError(format!(
276            "connection type {:?} is not allowed, allowed types: {:?}",
277            connection_type, allowed_types
278        ))));
279    }
280    Ok(())
281}
282
283fn connection_type_to_connector(connection_type: &PbConnectionType) -> &str {
284    match connection_type {
285        PbConnectionType::Kafka => KAFKA_CONNECTOR,
286        PbConnectionType::Iceberg => ICEBERG_CONNECTOR,
287        PbConnectionType::Elasticsearch => ES_SINK,
288        _ => unreachable!(),
289    }
290}
291
292pub fn check_connector_match_connection_type(
293    connector: &str,
294    connection_type: &PbConnectionType,
295) -> RwResult<()> {
296    if !connector.eq(connection_type_to_connector(connection_type)) {
297        return Err(RwError::from(ProtocolError(format!(
298            "connector {} and connection type {:?} are not compatible",
299            connector, connection_type
300        ))));
301    }
302    Ok(())
303}
304
305pub fn get_table_catalog_by_table_name(
306    session: &SessionImpl,
307    table_name: &ObjectName,
308) -> RwResult<(Arc<TableCatalog>, String)> {
309    let db_name = &session.database();
310    let (schema_name, real_table_name) =
311        Binder::resolve_schema_qualified_name(db_name, table_name)?;
312    let search_path = session.config().search_path();
313    let user_name = &session.user_name();
314
315    let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
316    let reader = session.env().catalog_reader().read_guard();
317    match reader.get_created_table_by_name(db_name, schema_path, &real_table_name) {
318        Ok((table, schema_name)) => Ok((table.clone(), schema_name.to_owned())),
319        Err(err) => {
320            if let Some(table) = session
321                .staging_catalog_manager()
322                .get_table(&real_table_name)
323            {
324                // During CREATE TABLE (iceberg engine), the table is only in staging, but we
325                // still need a stable schema name for internal sink planning.
326                let schema_name = reader
327                    .get_schema_by_id(table.database_id, table.schema_id)
328                    .map(|schema| schema.name.clone())?;
329                Ok((Arc::new(table.clone()), schema_name))
330            } else {
331                Err(err.into())
332            }
333        }
334    }
335}
336
337pub fn reject_internal_table_dependency(
338    table: &TableCatalog,
339    statement_name: &str,
340) -> RwResult<()> {
341    if table.is_internal_table() {
342        return Err(RwError::from(ErrorCode::InvalidInputSyntax(format!(
343            "{statement_name} does not support internal table \"{}\"",
344            table.name()
345        ))));
346    }
347    Ok(())
348}
349
350pub fn reject_internal_table_dependencies<'a, I>(
351    session: &SessionImpl,
352    dependent_relations: I,
353    statement_name: &str,
354) -> RwResult<()>
355where
356    I: IntoIterator<Item = &'a ObjectId>,
357{
358    let catalog_reader = session.env().catalog_reader().read_guard();
359    for object_id in dependent_relations {
360        if let Ok(table) = catalog_reader.get_any_table_by_id(object_id.as_table_id()) {
361            reject_internal_table_dependency(table.as_ref(), statement_name)?;
362        }
363    }
364    Ok(())
365}
366
367/// Execute an async operation with a notification if it takes too long.
368/// This is useful for operations that might be delayed due to high barrier latency.
369///
370/// The notification timeout duration is controlled by the `slow_ddl_notification_secs` session variable.
371///
372/// # Arguments
373/// * `operation_fut` - The async operation to execute
374/// * `session` - The session to send notifications to
375/// * `operation_name` - The name of the operation for the notification message (e.g., "DROP TABLE")
376///
377/// # Example
378/// ```ignore
379/// execute_with_long_running_notification(
380///     catalog_writer.drop_table(source_id, table_id, cascade),
381///     &session,
382///     "DROP TABLE",
383/// ).await?;
384/// ```
385#[derive(Clone, Copy)]
386pub enum LongRunningNotificationAction {
387    SuggestRecover,
388    DiagnoseBarrierLatency,
389    MonitorBackfillJob,
390}
391
392impl LongRunningNotificationAction {
393    fn build_message(self, operation_name: &str, notify_timeout_secs: u32) -> String {
394        match self {
395            LongRunningNotificationAction::SuggestRecover => format!(
396                "{} has taken more than {} secs, likely due to high barrier latency.\n\
397                You may trigger cluster recovery to let {} take effect immediately.\n\
398                Run RECOVER in a separate session to trigger recovery.\n\
399                See: https://docs.risingwave.com/sql/commands/sql-recover#recover",
400                operation_name, notify_timeout_secs, operation_name
401            ),
402            LongRunningNotificationAction::DiagnoseBarrierLatency => format!(
403                "{} has taken more than {} secs, likely due to high barrier latency.\n\
404                See: https://docs.risingwave.com/performance/metrics#barrier-monitoring for steps to diagnose high barrier latency.",
405                operation_name, notify_timeout_secs
406            ),
407            LongRunningNotificationAction::MonitorBackfillJob => format!(
408                "{} has taken more than {} secs, barrier latency might be high. Please check barrier latency metrics to confirm.\n\
409                You can also run SHOW JOBS to track the progress of the job.\n\
410                See: https://docs.risingwave.com/performance/metrics#barrier-monitoring and https://docs.risingwave.com/sql/commands/sql-show-jobs",
411                operation_name, notify_timeout_secs
412            ),
413        }
414    }
415}
416
417pub async fn execute_with_long_running_notification<F, T>(
418    operation_fut: F,
419    session: &SessionImpl,
420    operation_name: &str,
421    action: LongRunningNotificationAction,
422) -> RwResult<T>
423where
424    F: std::future::Future<Output = RwResult<T>>,
425{
426    let notify_timeout_secs = session.config().slow_ddl_notification_secs();
427
428    // If timeout is 0, disable notifications and just execute the operation
429    if notify_timeout_secs == 0 {
430        return operation_fut.await;
431    }
432
433    let notify_fut = sleep(Duration::from_secs(notify_timeout_secs as u64));
434    tokio::pin!(operation_fut);
435
436    select! {
437        _ = notify_fut => {
438            session.notice_to_user(action.build_message(operation_name, notify_timeout_secs));
439            operation_fut.await
440        }
441        result = &mut operation_fut => {
442            result
443        }
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use postgres_types::{ToSql, Type};
450    use risingwave_common::array::*;
451
452    use super::*;
453
454    #[test]
455    fn test_to_pg_field() {
456        let field = Field::with_name(DataType::Int32, "v1");
457        let pg_field = to_pg_field(&field);
458        assert_eq!(pg_field.get_name(), "v1");
459        assert_eq!(pg_field.get_type_oid(), DataType::Int32.to_oid());
460    }
461
462    #[test]
463    fn test_to_pg_rows() {
464        let chunk = DataChunk::from_pretty(
465            "i I f    T
466             1 6 6.01 aaa
467             2 . .    .
468             3 7 7.01 vvv
469             4 . .    .  ",
470        );
471        let static_session = StaticSessionData {
472            timezone: "UTC".into(),
473        };
474        let rows = to_pg_rows(
475            &[
476                DataType::Int32,
477                DataType::Int64,
478                DataType::Float32,
479                DataType::Varchar,
480            ],
481            chunk,
482            &[],
483            &static_session,
484        );
485        let expected: Vec<Vec<Option<Bytes>>> = vec![
486            vec![
487                Some("1".into()),
488                Some("6".into()),
489                Some("6.01".into()),
490                Some("aaa".into()),
491            ],
492            vec![Some("2".into()), None, None, None],
493            vec![
494                Some("3".into()),
495                Some("7".into()),
496                Some("7.01".into()),
497                Some("vvv".into()),
498            ],
499            vec![Some("4".into()), None, None, None],
500        ];
501        let vec = rows
502            .unwrap()
503            .into_iter()
504            .map(|r| r.values().iter().cloned().collect_vec())
505            .collect_vec();
506
507        assert_eq!(vec, expected);
508    }
509
510    #[test]
511    fn test_to_pg_rows_mix_format() {
512        let chunk = DataChunk::from_pretty(
513            "i I f    T
514             1 6 6.01 aaa
515            ",
516        );
517        let static_session = StaticSessionData {
518            timezone: "UTC".into(),
519        };
520        let rows = to_pg_rows(
521            &[
522                DataType::Int32,
523                DataType::Int64,
524                DataType::Float32,
525                DataType::Varchar,
526            ],
527            chunk,
528            &[Format::Binary, Format::Binary, Format::Binary, Format::Text],
529            &static_session,
530        );
531        let mut raw_params = vec![BytesMut::new(); 3];
532        1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
533        6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
534        6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
535        let raw_params = raw_params
536            .into_iter()
537            .map(|b| b.freeze())
538            .collect::<Vec<_>>();
539        let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
540            Some(raw_params[0].clone()),
541            Some(raw_params[1].clone()),
542            Some(raw_params[2].clone()),
543            Some("aaa".into()),
544        ]];
545        let vec = rows
546            .unwrap()
547            .into_iter()
548            .map(|r| r.values().iter().cloned().collect_vec())
549            .collect_vec();
550
551        assert_eq!(vec, expected);
552    }
553
554    #[test]
555    fn test_value_format() {
556        use {DataType as T, ScalarRefImpl as S};
557        let static_session = StaticSessionData {
558            timezone: "UTC".into(),
559        };
560
561        let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
562        assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
563        assert_eq!(
564            &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
565            "NaN"
566        );
567        assert_eq!(
568            &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
569            "NaN"
570        );
571        assert_eq!(
572            &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
573            "Infinity"
574        );
575        assert_eq!(
576            &f(
577                &T::Float32,
578                S::Float32(f32::NEG_INFINITY.into()),
579                Format::Text
580            ),
581            "-Infinity"
582        );
583        assert_eq!(
584            &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
585            "Infinity"
586        );
587        assert_eq!(
588            &f(
589                &T::Float64,
590                S::Float64(f64::NEG_INFINITY.into()),
591                Format::Text
592            ),
593            "-Infinity"
594        );
595        assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
596        assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
597        assert_eq!(
598            &f(
599                &T::Timestamptz,
600                S::Timestamptz(Timestamptz::from_micros(-1)),
601                Format::Text
602            ),
603            "1969-12-31 23:59:59.999999+00:00"
604        );
605    }
606}