risingwave_frontend/handler/
util.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::str::FromStr;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use bytes::{Bytes, BytesMut};
22use futures::Stream;
23use itertools::Itertools;
24use pgwire::pg_field_descriptor::PgFieldDescriptor;
25use pgwire::pg_response::RowSetResult;
26use pgwire::pg_server::BoxedError;
27use pgwire::types::{Format, FormatIterator, Row};
28use pin_project_lite::pin_project;
29use risingwave_common::array::DataChunk;
30use risingwave_common::catalog::Field;
31use risingwave_common::row::Row as _;
32use risingwave_common::types::{
33    DataType, Interval, ScalarRefImpl, Timestamptz, write_date_time_tz,
34};
35use risingwave_common::util::epoch::Epoch;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::sink::elasticsearch_opensearch::elasticsearch::ES_SINK;
38use risingwave_connector::source::KAFKA_CONNECTOR;
39use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR;
40use risingwave_pb::catalog::connection_params::PbConnectionType;
41use risingwave_sqlparser::ast::{
42    CompatibleFormatEncode, FormatEncodeOptions, ObjectName, Query, Select, SelectItem, SetExpr,
43    TableFactor, TableWithJoins,
44};
45use thiserror_ext::AsReport;
46use tokio::select;
47use tokio::time::{Duration, sleep};
48
49use crate::catalog::root_catalog::SchemaPath;
50use crate::error::ErrorCode::ProtocolError;
51use crate::error::{ErrorCode, Result as RwResult, RwError};
52use crate::session::{SessionImpl, current};
53use crate::{Binder, HashSet, TableCatalog};
54
55pin_project! {
56    /// Wrapper struct that converts a stream of DataChunk to a stream of RowSet based on formatting
57    /// parameters.
58    ///
59    /// This is essentially `StreamExt::map(self, move |res| res.map(|chunk| to_pg_rows(chunk,
60    /// format)))` but we need a nameable type as part of [`super::PgResponseStream`], but we cannot
61    /// name the type of a closure.
62    pub struct DataChunkToRowSetAdapter<VS>
63    where
64        VS: Stream<Item = Result<DataChunk, BoxedError>>,
65    {
66        #[pin]
67        chunk_stream: VS,
68        column_types: Vec<DataType>,
69        pub formats: Vec<Format>,
70        session_data: StaticSessionData,
71    }
72}
73
74// Static session data frozen at the time of the creation of the stream
75pub struct StaticSessionData {
76    pub timezone: String,
77}
78
79impl<VS> DataChunkToRowSetAdapter<VS>
80where
81    VS: Stream<Item = Result<DataChunk, BoxedError>>,
82{
83    pub fn new(
84        chunk_stream: VS,
85        column_types: Vec<DataType>,
86        formats: Vec<Format>,
87        session: Arc<SessionImpl>,
88    ) -> Self {
89        let session_data = StaticSessionData {
90            timezone: session.config().timezone(),
91        };
92        Self {
93            chunk_stream,
94            column_types,
95            formats,
96            session_data,
97        }
98    }
99}
100
101impl<VS> Stream for DataChunkToRowSetAdapter<VS>
102where
103    VS: Stream<Item = Result<DataChunk, BoxedError>>,
104{
105    type Item = RowSetResult;
106
107    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108        let mut this = self.project();
109        match this.chunk_stream.as_mut().poll_next(cx) {
110            Poll::Pending => Poll::Pending,
111            Poll::Ready(chunk) => match chunk {
112                Some(chunk_result) => match chunk_result {
113                    Ok(chunk) => Poll::Ready(Some(
114                        to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
115                            .map_err(|err| err.into()),
116                    )),
117                    Err(err) => Poll::Ready(Some(Err(err))),
118                },
119                None => Poll::Ready(None),
120            },
121        }
122    }
123}
124
125/// Format scalars according to postgres convention.
126pub fn pg_value_format(
127    data_type: &DataType,
128    d: ScalarRefImpl<'_>,
129    format: Format,
130    session_data: &StaticSessionData,
131) -> RwResult<Bytes> {
132    // format == false means TEXT format
133    // format == true means BINARY format
134    match format {
135        Format::Text => {
136            if *data_type == DataType::Timestamptz {
137                Ok(timestamptz_to_string_with_session_data(d, session_data))
138            } else {
139                Ok(d.text_format(data_type).into())
140            }
141        }
142        Format::Binary => Ok(d
143            .binary_format(data_type)
144            .context("failed to format binary value")?),
145    }
146}
147
148fn timestamptz_to_string_with_session_data(
149    d: ScalarRefImpl<'_>,
150    session_data: &StaticSessionData,
151) -> Bytes {
152    let tz = d.into_timestamptz();
153    let time_zone = Timestamptz::lookup_time_zone(&session_data.timezone).unwrap();
154    let instant_local = tz.to_datetime_in_zone(time_zone);
155    let mut result_string = BytesMut::new();
156    write_date_time_tz(instant_local, &mut result_string).unwrap();
157    result_string.into()
158}
159
160fn to_pg_rows(
161    column_types: &[DataType],
162    chunk: DataChunk,
163    formats: &[Format],
164    session_data: &StaticSessionData,
165) -> RwResult<Vec<Row>> {
166    assert_eq!(chunk.dimension(), column_types.len());
167    if cfg!(debug_assertions) {
168        let chunk_data_types = chunk.data_types();
169        for (ty1, ty2) in chunk_data_types.iter().zip_eq_fast(column_types) {
170            debug_assert!(
171                ty1.equals_datatype(ty2),
172                "chunk_data_types: {chunk_data_types:?}, column_types: {column_types:?}"
173            )
174        }
175    }
176
177    chunk
178        .rows()
179        .map(|r| {
180            let format_iter = FormatIterator::new(formats, chunk.dimension())
181                .map_err(ErrorCode::InternalError)?;
182            let row = r
183                .iter()
184                .zip_eq_fast(column_types)
185                .zip_eq_fast(format_iter)
186                .map(|((data, t), format)| match data {
187                    Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
188                    None => Ok(None),
189                })
190                .try_collect()?;
191            Ok(Row::new(row))
192        })
193        .try_collect()
194}
195
196/// Convert from [`Field`] to [`PgFieldDescriptor`].
197pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {
198    PgFieldDescriptor::new(
199        f.name.clone(),
200        f.data_type().to_oid(),
201        f.data_type().type_len(),
202    )
203}
204
205#[easy_ext::ext(SourceSchemaCompatExt)]
206impl CompatibleFormatEncode {
207    /// Convert `self` to [`FormatEncodeOptions`] and warn the user if the syntax is deprecated.
208    pub fn into_v2_with_warning(self) -> FormatEncodeOptions {
209        match self {
210            CompatibleFormatEncode::RowFormat(inner) => {
211                // TODO: should be warning
212                current::notice_to_user(
213                    "RisingWave will stop supporting the syntax \"ROW FORMAT\" in future versions, which will be changed to \"FORMAT ... ENCODE ...\" syntax.",
214                );
215                inner.into_format_encode_v2()
216            }
217            CompatibleFormatEncode::V2(inner) => inner,
218        }
219    }
220}
221
222pub fn gen_query_from_table_name(from_name: ObjectName) -> Query {
223    let table_factor = TableFactor::Table {
224        name: from_name,
225        alias: None,
226        as_of: None,
227    };
228    let from = vec![TableWithJoins {
229        relation: table_factor,
230        joins: vec![],
231    }];
232    let select = Select {
233        from,
234        projection: vec![SelectItem::Wildcard(None)],
235        ..Default::default()
236    };
237    let body = SetExpr::Select(Box::new(select));
238    Query {
239        with: None,
240        body,
241        order_by: vec![],
242        limit: None,
243        offset: None,
244        fetch: None,
245    }
246}
247
248pub fn convert_unix_millis_to_logstore_u64(unix_millis: u64) -> u64 {
249    Epoch::from_unix_millis(unix_millis).0
250}
251
252pub fn convert_logstore_u64_to_unix_millis(logstore_u64: u64) -> u64 {
253    Epoch::from(logstore_u64).as_unix_millis()
254}
255
256pub fn convert_interval_to_u64_seconds(interval: &String) -> RwResult<u64> {
257    let seconds = (Interval::from_str(interval)
258        .map_err(|err| {
259            ErrorCode::InternalError(format!(
260                "Convert interval to u64 error, please check format, error: {:?}",
261                err.to_report_string()
262            ))
263        })?
264        .epoch_in_micros()
265        / 1000000) as u64;
266    Ok(seconds)
267}
268
269pub fn ensure_connection_type_allowed(
270    connection_type: PbConnectionType,
271    allowed_types: &HashSet<PbConnectionType>,
272) -> RwResult<()> {
273    if !allowed_types.contains(&connection_type) {
274        return Err(RwError::from(ProtocolError(format!(
275            "connection type {:?} is not allowed, allowed types: {:?}",
276            connection_type, allowed_types
277        ))));
278    }
279    Ok(())
280}
281
282fn connection_type_to_connector(connection_type: &PbConnectionType) -> &str {
283    match connection_type {
284        PbConnectionType::Kafka => KAFKA_CONNECTOR,
285        PbConnectionType::Iceberg => ICEBERG_CONNECTOR,
286        PbConnectionType::Elasticsearch => ES_SINK,
287        _ => unreachable!(),
288    }
289}
290
291pub fn check_connector_match_connection_type(
292    connector: &str,
293    connection_type: &PbConnectionType,
294) -> RwResult<()> {
295    if !connector.eq(connection_type_to_connector(connection_type)) {
296        return Err(RwError::from(ProtocolError(format!(
297            "connector {} and connection type {:?} are not compatible",
298            connector, connection_type
299        ))));
300    }
301    Ok(())
302}
303
304pub fn get_table_catalog_by_table_name(
305    session: &SessionImpl,
306    table_name: &ObjectName,
307) -> RwResult<(Arc<TableCatalog>, String)> {
308    let db_name = &session.database();
309    let (schema_name, real_table_name) =
310        Binder::resolve_schema_qualified_name(db_name, table_name)?;
311    let search_path = session.config().search_path();
312    let user_name = &session.user_name();
313
314    let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
315    let reader = session.env().catalog_reader().read_guard();
316    match reader.get_created_table_by_name(db_name, schema_path, &real_table_name) {
317        Ok((table, schema_name)) => Ok((table.clone(), schema_name.to_owned())),
318        Err(err) => {
319            if let Some(table) = session
320                .staging_catalog_manager()
321                .get_table(&real_table_name)
322            {
323                // During CREATE TABLE (iceberg engine), the table is only in staging, but we
324                // still need a stable schema name for internal sink planning.
325                let schema_name = reader
326                    .get_schema_by_id(table.database_id, table.schema_id)
327                    .map(|schema| schema.name.clone())?;
328                Ok((Arc::new(table.clone()), schema_name))
329            } else {
330                Err(err.into())
331            }
332        }
333    }
334}
335
336/// Execute an async operation with a notification if it takes too long.
337/// This is useful for operations that might be delayed due to high barrier latency.
338///
339/// The notification timeout duration is controlled by the `slow_ddl_notification_secs` session variable.
340///
341/// # Arguments
342/// * `operation_fut` - The async operation to execute
343/// * `session` - The session to send notifications to
344/// * `operation_name` - The name of the operation for the notification message (e.g., "DROP TABLE")
345///
346/// # Example
347/// ```ignore
348/// execute_with_long_running_notification(
349///     catalog_writer.drop_table(source_id, table_id, cascade),
350///     &session,
351///     "DROP TABLE",
352/// ).await?;
353/// ```
354#[derive(Clone, Copy)]
355pub enum LongRunningNotificationAction {
356    SuggestRecover,
357    DiagnoseBarrierLatency,
358    MonitorBackfillJob,
359}
360
361impl LongRunningNotificationAction {
362    fn build_message(self, operation_name: &str, notify_timeout_secs: u32) -> String {
363        match self {
364            LongRunningNotificationAction::SuggestRecover => format!(
365                "{} has taken more than {} secs, likely due to high barrier latency.\n\
366                You may trigger cluster recovery to let {} take effect immediately.\n\
367                Run RECOVER in a separate session to trigger recovery.\n\
368                See: https://docs.risingwave.com/sql/commands/sql-recover#recover",
369                operation_name, notify_timeout_secs, operation_name
370            ),
371            LongRunningNotificationAction::DiagnoseBarrierLatency => format!(
372                "{} has taken more than {} secs, likely due to high barrier latency.\n\
373                See: https://docs.risingwave.com/performance/metrics#barrier-monitoring for steps to diagnose high barrier latency.",
374                operation_name, notify_timeout_secs
375            ),
376            LongRunningNotificationAction::MonitorBackfillJob => format!(
377                "{} has taken more than {} secs, barrier latency might be high. Please check barrier latency metrics to confirm.\n\
378                You can also run SHOW JOBS to track the progress of the job.\n\
379                See: https://docs.risingwave.com/performance/metrics#barrier-monitoring and https://docs.risingwave.com/sql/commands/sql-show-jobs",
380                operation_name, notify_timeout_secs
381            ),
382        }
383    }
384}
385
386pub async fn execute_with_long_running_notification<F, T>(
387    operation_fut: F,
388    session: &SessionImpl,
389    operation_name: &str,
390    action: LongRunningNotificationAction,
391) -> RwResult<T>
392where
393    F: std::future::Future<Output = RwResult<T>>,
394{
395    let notify_timeout_secs = session.config().slow_ddl_notification_secs();
396
397    // If timeout is 0, disable notifications and just execute the operation
398    if notify_timeout_secs == 0 {
399        return operation_fut.await;
400    }
401
402    let notify_fut = sleep(Duration::from_secs(notify_timeout_secs as u64));
403    tokio::pin!(operation_fut);
404
405    select! {
406        _ = notify_fut => {
407            session.notice_to_user(action.build_message(operation_name, notify_timeout_secs));
408            operation_fut.await
409        }
410        result = &mut operation_fut => {
411            result
412        }
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use postgres_types::{ToSql, Type};
419    use risingwave_common::array::*;
420
421    use super::*;
422
423    #[test]
424    fn test_to_pg_field() {
425        let field = Field::with_name(DataType::Int32, "v1");
426        let pg_field = to_pg_field(&field);
427        assert_eq!(pg_field.get_name(), "v1");
428        assert_eq!(pg_field.get_type_oid(), DataType::Int32.to_oid());
429    }
430
431    #[test]
432    fn test_to_pg_rows() {
433        let chunk = DataChunk::from_pretty(
434            "i I f    T
435             1 6 6.01 aaa
436             2 . .    .
437             3 7 7.01 vvv
438             4 . .    .  ",
439        );
440        let static_session = StaticSessionData {
441            timezone: "UTC".into(),
442        };
443        let rows = to_pg_rows(
444            &[
445                DataType::Int32,
446                DataType::Int64,
447                DataType::Float32,
448                DataType::Varchar,
449            ],
450            chunk,
451            &[],
452            &static_session,
453        );
454        let expected: Vec<Vec<Option<Bytes>>> = vec![
455            vec![
456                Some("1".into()),
457                Some("6".into()),
458                Some("6.01".into()),
459                Some("aaa".into()),
460            ],
461            vec![Some("2".into()), None, None, None],
462            vec![
463                Some("3".into()),
464                Some("7".into()),
465                Some("7.01".into()),
466                Some("vvv".into()),
467            ],
468            vec![Some("4".into()), None, None, None],
469        ];
470        let vec = rows
471            .unwrap()
472            .into_iter()
473            .map(|r| r.values().iter().cloned().collect_vec())
474            .collect_vec();
475
476        assert_eq!(vec, expected);
477    }
478
479    #[test]
480    fn test_to_pg_rows_mix_format() {
481        let chunk = DataChunk::from_pretty(
482            "i I f    T
483             1 6 6.01 aaa
484            ",
485        );
486        let static_session = StaticSessionData {
487            timezone: "UTC".into(),
488        };
489        let rows = to_pg_rows(
490            &[
491                DataType::Int32,
492                DataType::Int64,
493                DataType::Float32,
494                DataType::Varchar,
495            ],
496            chunk,
497            &[Format::Binary, Format::Binary, Format::Binary, Format::Text],
498            &static_session,
499        );
500        let mut raw_params = vec![BytesMut::new(); 3];
501        1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
502        6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
503        6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
504        let raw_params = raw_params
505            .into_iter()
506            .map(|b| b.freeze())
507            .collect::<Vec<_>>();
508        let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
509            Some(raw_params[0].clone()),
510            Some(raw_params[1].clone()),
511            Some(raw_params[2].clone()),
512            Some("aaa".into()),
513        ]];
514        let vec = rows
515            .unwrap()
516            .into_iter()
517            .map(|r| r.values().iter().cloned().collect_vec())
518            .collect_vec();
519
520        assert_eq!(vec, expected);
521    }
522
523    #[test]
524    fn test_value_format() {
525        use {DataType as T, ScalarRefImpl as S};
526        let static_session = StaticSessionData {
527            timezone: "UTC".into(),
528        };
529
530        let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
531        assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
532        assert_eq!(
533            &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
534            "NaN"
535        );
536        assert_eq!(
537            &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
538            "NaN"
539        );
540        assert_eq!(
541            &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
542            "Infinity"
543        );
544        assert_eq!(
545            &f(
546                &T::Float32,
547                S::Float32(f32::NEG_INFINITY.into()),
548                Format::Text
549            ),
550            "-Infinity"
551        );
552        assert_eq!(
553            &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
554            "Infinity"
555        );
556        assert_eq!(
557            &f(
558                &T::Float64,
559                S::Float64(f64::NEG_INFINITY.into()),
560                Format::Text
561            ),
562            "-Infinity"
563        );
564        assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
565        assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
566        assert_eq!(
567            &f(
568                &T::Timestamptz,
569                S::Timestamptz(Timestamptz::from_micros(-1)),
570                Format::Text
571            ),
572            "1969-12-31 23:59:59.999999+00:00"
573        );
574    }
575}