risingwave_frontend/handler/
util.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::str::FromStr;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use bytes::{Bytes, BytesMut};
22use futures::Stream;
23use itertools::Itertools;
24use pgwire::pg_field_descriptor::PgFieldDescriptor;
25use pgwire::pg_response::RowSetResult;
26use pgwire::pg_server::BoxedError;
27use pgwire::types::{Format, FormatIterator, Row};
28use pin_project_lite::pin_project;
29use risingwave_common::array::DataChunk;
30use risingwave_common::catalog::Field;
31use risingwave_common::row::Row as _;
32use risingwave_common::types::{
33    DataType, Interval, ScalarRefImpl, Timestamptz, write_date_time_tz,
34};
35use risingwave_common::util::epoch::Epoch;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::sink::elasticsearch_opensearch::elasticsearch::ES_SINK;
38use risingwave_connector::source::KAFKA_CONNECTOR;
39use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR;
40use risingwave_pb::catalog::connection_params::PbConnectionType;
41use risingwave_sqlparser::ast::{
42    CompatibleFormatEncode, Expr, FormatEncodeOptions, Ident, ObjectName, OrderByExpr, Query,
43    Select, SelectItem, SetExpr, TableFactor, TableWithJoins,
44};
45use thiserror_ext::AsReport;
46
47use crate::catalog::root_catalog::SchemaPath;
48use crate::error::ErrorCode::ProtocolError;
49use crate::error::{ErrorCode, Result as RwResult, RwError};
50use crate::session::{SessionImpl, current};
51use crate::{Binder, HashSet, TableCatalog};
52
53pin_project! {
54    /// Wrapper struct that converts a stream of DataChunk to a stream of RowSet based on formatting
55    /// parameters.
56    ///
57    /// This is essentially `StreamExt::map(self, move |res| res.map(|chunk| to_pg_rows(chunk,
58    /// format)))` but we need a nameable type as part of [`super::PgResponseStream`], but we cannot
59    /// name the type of a closure.
60    pub struct DataChunkToRowSetAdapter<VS>
61    where
62        VS: Stream<Item = Result<DataChunk, BoxedError>>,
63    {
64        #[pin]
65        chunk_stream: VS,
66        column_types: Vec<DataType>,
67        pub formats: Vec<Format>,
68        session_data: StaticSessionData,
69    }
70}
71
72// Static session data frozen at the time of the creation of the stream
73pub struct StaticSessionData {
74    pub timezone: String,
75}
76
77impl<VS> DataChunkToRowSetAdapter<VS>
78where
79    VS: Stream<Item = Result<DataChunk, BoxedError>>,
80{
81    pub fn new(
82        chunk_stream: VS,
83        column_types: Vec<DataType>,
84        formats: Vec<Format>,
85        session: Arc<SessionImpl>,
86    ) -> Self {
87        let session_data = StaticSessionData {
88            timezone: session.config().timezone(),
89        };
90        Self {
91            chunk_stream,
92            column_types,
93            formats,
94            session_data,
95        }
96    }
97}
98
99impl<VS> Stream for DataChunkToRowSetAdapter<VS>
100where
101    VS: Stream<Item = Result<DataChunk, BoxedError>>,
102{
103    type Item = RowSetResult;
104
105    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        let mut this = self.project();
107        match this.chunk_stream.as_mut().poll_next(cx) {
108            Poll::Pending => Poll::Pending,
109            Poll::Ready(chunk) => match chunk {
110                Some(chunk_result) => match chunk_result {
111                    Ok(chunk) => Poll::Ready(Some(
112                        to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
113                            .map_err(|err| err.into()),
114                    )),
115                    Err(err) => Poll::Ready(Some(Err(err))),
116                },
117                None => Poll::Ready(None),
118            },
119        }
120    }
121}
122
123/// Format scalars according to postgres convention.
124pub fn pg_value_format(
125    data_type: &DataType,
126    d: ScalarRefImpl<'_>,
127    format: Format,
128    session_data: &StaticSessionData,
129) -> RwResult<Bytes> {
130    // format == false means TEXT format
131    // format == true means BINARY format
132    match format {
133        Format::Text => {
134            if *data_type == DataType::Timestamptz {
135                Ok(timestamptz_to_string_with_session_data(d, session_data))
136            } else {
137                Ok(d.text_format(data_type).into())
138            }
139        }
140        Format::Binary => Ok(d
141            .binary_format(data_type)
142            .context("failed to format binary value")?),
143    }
144}
145
146fn timestamptz_to_string_with_session_data(
147    d: ScalarRefImpl<'_>,
148    session_data: &StaticSessionData,
149) -> Bytes {
150    let tz = d.into_timestamptz();
151    let time_zone = Timestamptz::lookup_time_zone(&session_data.timezone).unwrap();
152    let instant_local = tz.to_datetime_in_zone(time_zone);
153    let mut result_string = BytesMut::new();
154    write_date_time_tz(instant_local, &mut result_string).unwrap();
155    result_string.into()
156}
157
158fn to_pg_rows(
159    column_types: &[DataType],
160    chunk: DataChunk,
161    formats: &[Format],
162    session_data: &StaticSessionData,
163) -> RwResult<Vec<Row>> {
164    assert_eq!(chunk.dimension(), column_types.len());
165    if cfg!(debug_assertions) {
166        let chunk_data_types = chunk.data_types();
167        for (ty1, ty2) in chunk_data_types.iter().zip_eq_fast(column_types) {
168            debug_assert!(
169                ty1.equals_datatype(ty2),
170                "chunk_data_types: {chunk_data_types:?}, column_types: {column_types:?}"
171            )
172        }
173    }
174
175    chunk
176        .rows()
177        .map(|r| {
178            let format_iter = FormatIterator::new(formats, chunk.dimension())
179                .map_err(ErrorCode::InternalError)?;
180            let row = r
181                .iter()
182                .zip_eq_fast(column_types)
183                .zip_eq_fast(format_iter)
184                .map(|((data, t), format)| match data {
185                    Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
186                    None => Ok(None),
187                })
188                .try_collect()?;
189            Ok(Row::new(row))
190        })
191        .try_collect()
192}
193
194/// Convert from [`Field`] to [`PgFieldDescriptor`].
195pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {
196    PgFieldDescriptor::new(
197        f.name.clone(),
198        f.data_type().to_oid(),
199        f.data_type().type_len(),
200    )
201}
202
203#[easy_ext::ext(SourceSchemaCompatExt)]
204impl CompatibleFormatEncode {
205    /// Convert `self` to [`FormatEncodeOptions`] and warn the user if the syntax is deprecated.
206    pub fn into_v2_with_warning(self) -> FormatEncodeOptions {
207        match self {
208            CompatibleFormatEncode::RowFormat(inner) => {
209                // TODO: should be warning
210                current::notice_to_user(
211                    "RisingWave will stop supporting the syntax \"ROW FORMAT\" in future versions, which will be changed to \"FORMAT ... ENCODE ...\" syntax.",
212                );
213                inner.into_format_encode_v2()
214            }
215            CompatibleFormatEncode::V2(inner) => inner,
216        }
217    }
218}
219
220pub fn gen_query_from_table_name(from_name: ObjectName) -> Query {
221    let table_factor = TableFactor::Table {
222        name: from_name,
223        alias: None,
224        as_of: None,
225    };
226    let from = vec![TableWithJoins {
227        relation: table_factor,
228        joins: vec![],
229    }];
230    let select = Select {
231        from,
232        projection: vec![SelectItem::Wildcard(None)],
233        ..Default::default()
234    };
235    let body = SetExpr::Select(Box::new(select));
236    Query {
237        with: None,
238        body,
239        order_by: vec![],
240        limit: None,
241        offset: None,
242        fetch: None,
243    }
244}
245
246pub fn gen_query_from_table_name_order_by(from_name: ObjectName, pk_names: Vec<String>) -> Query {
247    let mut query = gen_query_from_table_name(from_name);
248    query.order_by = pk_names
249        .into_iter()
250        .map(|pk| {
251            let expr = Expr::Identifier(Ident::with_quote_unchecked('"', pk));
252            OrderByExpr {
253                expr,
254                asc: None,
255                nulls_first: None,
256            }
257        })
258        .collect();
259    query
260}
261
262pub fn convert_unix_millis_to_logstore_u64(unix_millis: u64) -> u64 {
263    Epoch::from_unix_millis(unix_millis).0
264}
265
266pub fn convert_logstore_u64_to_unix_millis(logstore_u64: u64) -> u64 {
267    Epoch::from(logstore_u64).as_unix_millis()
268}
269
270pub fn convert_interval_to_u64_seconds(interval: &String) -> RwResult<u64> {
271    let seconds = (Interval::from_str(interval)
272        .map_err(|err| {
273            ErrorCode::InternalError(format!(
274                "Convert interval to u64 error, please check format, error: {:?}",
275                err.to_report_string()
276            ))
277        })?
278        .epoch_in_micros()
279        / 1000000) as u64;
280    Ok(seconds)
281}
282
283pub fn ensure_connection_type_allowed(
284    connection_type: PbConnectionType,
285    allowed_types: &HashSet<PbConnectionType>,
286) -> RwResult<()> {
287    if !allowed_types.contains(&connection_type) {
288        return Err(RwError::from(ProtocolError(format!(
289            "connection type {:?} is not allowed, allowed types: {:?}",
290            connection_type, allowed_types
291        ))));
292    }
293    Ok(())
294}
295
296fn connection_type_to_connector(connection_type: &PbConnectionType) -> &str {
297    match connection_type {
298        PbConnectionType::Kafka => KAFKA_CONNECTOR,
299        PbConnectionType::Iceberg => ICEBERG_CONNECTOR,
300        PbConnectionType::Elasticsearch => ES_SINK,
301        _ => unreachable!(),
302    }
303}
304
305pub fn check_connector_match_connection_type(
306    connector: &str,
307    connection_type: &PbConnectionType,
308) -> RwResult<()> {
309    if !connector.eq(connection_type_to_connector(connection_type)) {
310        return Err(RwError::from(ProtocolError(format!(
311            "connector {} and connection type {:?} are not compatible",
312            connector, connection_type
313        ))));
314    }
315    Ok(())
316}
317
318pub fn get_table_catalog_by_table_name(
319    session: &SessionImpl,
320    table_name: &ObjectName,
321) -> RwResult<(Arc<TableCatalog>, String)> {
322    let db_name = &session.database();
323    let (schema_name, real_table_name) =
324        Binder::resolve_schema_qualified_name(db_name, table_name.clone())?;
325    let search_path = session.config().search_path();
326    let user_name = &session.user_name();
327
328    let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
329    let reader = session.env().catalog_reader().read_guard();
330    let (table, schema_name) =
331        reader.get_created_table_by_name(db_name, schema_path, &real_table_name)?;
332
333    Ok((table.clone(), schema_name.to_owned()))
334}
335
336#[cfg(test)]
337mod tests {
338    use postgres_types::{ToSql, Type};
339    use risingwave_common::array::*;
340
341    use super::*;
342
343    #[test]
344    fn test_to_pg_field() {
345        let field = Field::with_name(DataType::Int32, "v1");
346        let pg_field = to_pg_field(&field);
347        assert_eq!(pg_field.get_name(), "v1");
348        assert_eq!(pg_field.get_type_oid(), DataType::Int32.to_oid());
349    }
350
351    #[test]
352    fn test_to_pg_rows() {
353        let chunk = DataChunk::from_pretty(
354            "i I f    T
355             1 6 6.01 aaa
356             2 . .    .
357             3 7 7.01 vvv
358             4 . .    .  ",
359        );
360        let static_session = StaticSessionData {
361            timezone: "UTC".into(),
362        };
363        let rows = to_pg_rows(
364            &[
365                DataType::Int32,
366                DataType::Int64,
367                DataType::Float32,
368                DataType::Varchar,
369            ],
370            chunk,
371            &[],
372            &static_session,
373        );
374        let expected: Vec<Vec<Option<Bytes>>> = vec![
375            vec![
376                Some("1".into()),
377                Some("6".into()),
378                Some("6.01".into()),
379                Some("aaa".into()),
380            ],
381            vec![Some("2".into()), None, None, None],
382            vec![
383                Some("3".into()),
384                Some("7".into()),
385                Some("7.01".into()),
386                Some("vvv".into()),
387            ],
388            vec![Some("4".into()), None, None, None],
389        ];
390        let vec = rows
391            .unwrap()
392            .into_iter()
393            .map(|r| r.values().iter().cloned().collect_vec())
394            .collect_vec();
395
396        assert_eq!(vec, expected);
397    }
398
399    #[test]
400    fn test_to_pg_rows_mix_format() {
401        let chunk = DataChunk::from_pretty(
402            "i I f    T
403             1 6 6.01 aaa
404            ",
405        );
406        let static_session = StaticSessionData {
407            timezone: "UTC".into(),
408        };
409        let rows = to_pg_rows(
410            &[
411                DataType::Int32,
412                DataType::Int64,
413                DataType::Float32,
414                DataType::Varchar,
415            ],
416            chunk,
417            &[Format::Binary, Format::Binary, Format::Binary, Format::Text],
418            &static_session,
419        );
420        let mut raw_params = vec![BytesMut::new(); 3];
421        1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
422        6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
423        6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
424        let raw_params = raw_params
425            .into_iter()
426            .map(|b| b.freeze())
427            .collect::<Vec<_>>();
428        let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
429            Some(raw_params[0].clone()),
430            Some(raw_params[1].clone()),
431            Some(raw_params[2].clone()),
432            Some("aaa".into()),
433        ]];
434        let vec = rows
435            .unwrap()
436            .into_iter()
437            .map(|r| r.values().iter().cloned().collect_vec())
438            .collect_vec();
439
440        assert_eq!(vec, expected);
441    }
442
443    #[test]
444    fn test_value_format() {
445        use {DataType as T, ScalarRefImpl as S};
446        let static_session = StaticSessionData {
447            timezone: "UTC".into(),
448        };
449
450        let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
451        assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
452        assert_eq!(
453            &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
454            "NaN"
455        );
456        assert_eq!(
457            &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
458            "NaN"
459        );
460        assert_eq!(
461            &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
462            "Infinity"
463        );
464        assert_eq!(
465            &f(
466                &T::Float32,
467                S::Float32(f32::NEG_INFINITY.into()),
468                Format::Text
469            ),
470            "-Infinity"
471        );
472        assert_eq!(
473            &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
474            "Infinity"
475        );
476        assert_eq!(
477            &f(
478                &T::Float64,
479                S::Float64(f64::NEG_INFINITY.into()),
480                Format::Text
481            ),
482            "-Infinity"
483        );
484        assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
485        assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
486        assert_eq!(
487            &f(
488                &T::Timestamptz,
489                S::Timestamptz(Timestamptz::from_micros(-1)),
490                Format::Text
491            ),
492            "1969-12-31 23:59:59.999999+00:00"
493        );
494    }
495}