risingwave_frontend/handler/
util.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::str::FromStr;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use bytes::{Bytes, BytesMut};
22use futures::Stream;
23use itertools::Itertools;
24use pgwire::pg_field_descriptor::PgFieldDescriptor;
25use pgwire::pg_response::RowSetResult;
26use pgwire::pg_server::BoxedError;
27use pgwire::types::{Format, FormatIterator, Row};
28use pin_project_lite::pin_project;
29use risingwave_common::array::DataChunk;
30use risingwave_common::catalog::Field;
31use risingwave_common::row::Row as _;
32use risingwave_common::types::{
33    DataType, Interval, ScalarRefImpl, Timestamptz, write_date_time_tz,
34};
35use risingwave_common::util::epoch::Epoch;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::sink::elasticsearch_opensearch::elasticsearch::ES_SINK;
38use risingwave_connector::source::KAFKA_CONNECTOR;
39use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR;
40use risingwave_pb::catalog::connection_params::PbConnectionType;
41use risingwave_sqlparser::ast::{
42    CompatibleFormatEncode, FormatEncodeOptions, ObjectName, Query, Select, SelectItem, SetExpr,
43    TableFactor, TableWithJoins,
44};
45use thiserror_ext::AsReport;
46
47use crate::catalog::root_catalog::SchemaPath;
48use crate::error::ErrorCode::ProtocolError;
49use crate::error::{ErrorCode, Result as RwResult, RwError};
50use crate::session::{SessionImpl, current};
51use crate::{Binder, HashSet, TableCatalog};
52
53pin_project! {
54    /// Wrapper struct that converts a stream of DataChunk to a stream of RowSet based on formatting
55    /// parameters.
56    ///
57    /// This is essentially `StreamExt::map(self, move |res| res.map(|chunk| to_pg_rows(chunk,
58    /// format)))` but we need a nameable type as part of [`super::PgResponseStream`], but we cannot
59    /// name the type of a closure.
60    pub struct DataChunkToRowSetAdapter<VS>
61    where
62        VS: Stream<Item = Result<DataChunk, BoxedError>>,
63    {
64        #[pin]
65        chunk_stream: VS,
66        column_types: Vec<DataType>,
67        pub formats: Vec<Format>,
68        session_data: StaticSessionData,
69    }
70}
71
72// Static session data frozen at the time of the creation of the stream
73pub struct StaticSessionData {
74    pub timezone: String,
75}
76
77impl<VS> DataChunkToRowSetAdapter<VS>
78where
79    VS: Stream<Item = Result<DataChunk, BoxedError>>,
80{
81    pub fn new(
82        chunk_stream: VS,
83        column_types: Vec<DataType>,
84        formats: Vec<Format>,
85        session: Arc<SessionImpl>,
86    ) -> Self {
87        let session_data = StaticSessionData {
88            timezone: session.config().timezone(),
89        };
90        Self {
91            chunk_stream,
92            column_types,
93            formats,
94            session_data,
95        }
96    }
97}
98
99impl<VS> Stream for DataChunkToRowSetAdapter<VS>
100where
101    VS: Stream<Item = Result<DataChunk, BoxedError>>,
102{
103    type Item = RowSetResult;
104
105    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        let mut this = self.project();
107        match this.chunk_stream.as_mut().poll_next(cx) {
108            Poll::Pending => Poll::Pending,
109            Poll::Ready(chunk) => match chunk {
110                Some(chunk_result) => match chunk_result {
111                    Ok(chunk) => Poll::Ready(Some(
112                        to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
113                            .map_err(|err| err.into()),
114                    )),
115                    Err(err) => Poll::Ready(Some(Err(err))),
116                },
117                None => Poll::Ready(None),
118            },
119        }
120    }
121}
122
123/// Format scalars according to postgres convention.
124pub fn pg_value_format(
125    data_type: &DataType,
126    d: ScalarRefImpl<'_>,
127    format: Format,
128    session_data: &StaticSessionData,
129) -> RwResult<Bytes> {
130    // format == false means TEXT format
131    // format == true means BINARY format
132    match format {
133        Format::Text => {
134            if *data_type == DataType::Timestamptz {
135                Ok(timestamptz_to_string_with_session_data(d, session_data))
136            } else {
137                Ok(d.text_format(data_type).into())
138            }
139        }
140        Format::Binary => Ok(d
141            .binary_format(data_type)
142            .context("failed to format binary value")?),
143    }
144}
145
146fn timestamptz_to_string_with_session_data(
147    d: ScalarRefImpl<'_>,
148    session_data: &StaticSessionData,
149) -> Bytes {
150    let tz = d.into_timestamptz();
151    let time_zone = Timestamptz::lookup_time_zone(&session_data.timezone).unwrap();
152    let instant_local = tz.to_datetime_in_zone(time_zone);
153    let mut result_string = BytesMut::new();
154    write_date_time_tz(instant_local, &mut result_string).unwrap();
155    result_string.into()
156}
157
158fn to_pg_rows(
159    column_types: &[DataType],
160    chunk: DataChunk,
161    formats: &[Format],
162    session_data: &StaticSessionData,
163) -> RwResult<Vec<Row>> {
164    assert_eq!(chunk.dimension(), column_types.len());
165    if cfg!(debug_assertions) {
166        let chunk_data_types = chunk.data_types();
167        for (ty1, ty2) in chunk_data_types.iter().zip_eq_fast(column_types) {
168            debug_assert!(
169                ty1.equals_datatype(ty2),
170                "chunk_data_types: {chunk_data_types:?}, column_types: {column_types:?}"
171            )
172        }
173    }
174
175    chunk
176        .rows()
177        .map(|r| {
178            let format_iter = FormatIterator::new(formats, chunk.dimension())
179                .map_err(ErrorCode::InternalError)?;
180            let row = r
181                .iter()
182                .zip_eq_fast(column_types)
183                .zip_eq_fast(format_iter)
184                .map(|((data, t), format)| match data {
185                    Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
186                    None => Ok(None),
187                })
188                .try_collect()?;
189            Ok(Row::new(row))
190        })
191        .try_collect()
192}
193
194/// Convert from [`Field`] to [`PgFieldDescriptor`].
195pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {
196    PgFieldDescriptor::new(
197        f.name.clone(),
198        f.data_type().to_oid(),
199        f.data_type().type_len(),
200    )
201}
202
203#[easy_ext::ext(SourceSchemaCompatExt)]
204impl CompatibleFormatEncode {
205    /// Convert `self` to [`FormatEncodeOptions`] and warn the user if the syntax is deprecated.
206    pub fn into_v2_with_warning(self) -> FormatEncodeOptions {
207        match self {
208            CompatibleFormatEncode::RowFormat(inner) => {
209                // TODO: should be warning
210                current::notice_to_user(
211                    "RisingWave will stop supporting the syntax \"ROW FORMAT\" in future versions, which will be changed to \"FORMAT ... ENCODE ...\" syntax.",
212                );
213                inner.into_format_encode_v2()
214            }
215            CompatibleFormatEncode::V2(inner) => inner,
216        }
217    }
218}
219
220pub fn gen_query_from_table_name(from_name: ObjectName) -> Query {
221    let table_factor = TableFactor::Table {
222        name: from_name,
223        alias: None,
224        as_of: None,
225    };
226    let from = vec![TableWithJoins {
227        relation: table_factor,
228        joins: vec![],
229    }];
230    let select = Select {
231        from,
232        projection: vec![SelectItem::Wildcard(None)],
233        ..Default::default()
234    };
235    let body = SetExpr::Select(Box::new(select));
236    Query {
237        with: None,
238        body,
239        order_by: vec![],
240        limit: None,
241        offset: None,
242        fetch: None,
243    }
244}
245
246pub fn convert_unix_millis_to_logstore_u64(unix_millis: u64) -> u64 {
247    Epoch::from_unix_millis(unix_millis).0
248}
249
250pub fn convert_logstore_u64_to_unix_millis(logstore_u64: u64) -> u64 {
251    Epoch::from(logstore_u64).as_unix_millis()
252}
253
254pub fn convert_interval_to_u64_seconds(interval: &String) -> RwResult<u64> {
255    let seconds = (Interval::from_str(interval)
256        .map_err(|err| {
257            ErrorCode::InternalError(format!(
258                "Convert interval to u64 error, please check format, error: {:?}",
259                err.to_report_string()
260            ))
261        })?
262        .epoch_in_micros()
263        / 1000000) as u64;
264    Ok(seconds)
265}
266
267pub fn ensure_connection_type_allowed(
268    connection_type: PbConnectionType,
269    allowed_types: &HashSet<PbConnectionType>,
270) -> RwResult<()> {
271    if !allowed_types.contains(&connection_type) {
272        return Err(RwError::from(ProtocolError(format!(
273            "connection type {:?} is not allowed, allowed types: {:?}",
274            connection_type, allowed_types
275        ))));
276    }
277    Ok(())
278}
279
280fn connection_type_to_connector(connection_type: &PbConnectionType) -> &str {
281    match connection_type {
282        PbConnectionType::Kafka => KAFKA_CONNECTOR,
283        PbConnectionType::Iceberg => ICEBERG_CONNECTOR,
284        PbConnectionType::Elasticsearch => ES_SINK,
285        _ => unreachable!(),
286    }
287}
288
289pub fn check_connector_match_connection_type(
290    connector: &str,
291    connection_type: &PbConnectionType,
292) -> RwResult<()> {
293    if !connector.eq(connection_type_to_connector(connection_type)) {
294        return Err(RwError::from(ProtocolError(format!(
295            "connector {} and connection type {:?} are not compatible",
296            connector, connection_type
297        ))));
298    }
299    Ok(())
300}
301
302pub fn get_table_catalog_by_table_name(
303    session: &SessionImpl,
304    table_name: &ObjectName,
305) -> RwResult<(Arc<TableCatalog>, String)> {
306    let db_name = &session.database();
307    let (schema_name, real_table_name) =
308        Binder::resolve_schema_qualified_name(db_name, table_name.clone())?;
309    let search_path = session.config().search_path();
310    let user_name = &session.user_name();
311
312    let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
313    let reader = session.env().catalog_reader().read_guard();
314    let (table, schema_name) =
315        reader.get_created_table_by_name(db_name, schema_path, &real_table_name)?;
316
317    Ok((table.clone(), schema_name.to_owned()))
318}
319
320#[cfg(test)]
321mod tests {
322    use postgres_types::{ToSql, Type};
323    use risingwave_common::array::*;
324
325    use super::*;
326
327    #[test]
328    fn test_to_pg_field() {
329        let field = Field::with_name(DataType::Int32, "v1");
330        let pg_field = to_pg_field(&field);
331        assert_eq!(pg_field.get_name(), "v1");
332        assert_eq!(pg_field.get_type_oid(), DataType::Int32.to_oid());
333    }
334
335    #[test]
336    fn test_to_pg_rows() {
337        let chunk = DataChunk::from_pretty(
338            "i I f    T
339             1 6 6.01 aaa
340             2 . .    .
341             3 7 7.01 vvv
342             4 . .    .  ",
343        );
344        let static_session = StaticSessionData {
345            timezone: "UTC".into(),
346        };
347        let rows = to_pg_rows(
348            &[
349                DataType::Int32,
350                DataType::Int64,
351                DataType::Float32,
352                DataType::Varchar,
353            ],
354            chunk,
355            &[],
356            &static_session,
357        );
358        let expected: Vec<Vec<Option<Bytes>>> = vec![
359            vec![
360                Some("1".into()),
361                Some("6".into()),
362                Some("6.01".into()),
363                Some("aaa".into()),
364            ],
365            vec![Some("2".into()), None, None, None],
366            vec![
367                Some("3".into()),
368                Some("7".into()),
369                Some("7.01".into()),
370                Some("vvv".into()),
371            ],
372            vec![Some("4".into()), None, None, None],
373        ];
374        let vec = rows
375            .unwrap()
376            .into_iter()
377            .map(|r| r.values().iter().cloned().collect_vec())
378            .collect_vec();
379
380        assert_eq!(vec, expected);
381    }
382
383    #[test]
384    fn test_to_pg_rows_mix_format() {
385        let chunk = DataChunk::from_pretty(
386            "i I f    T
387             1 6 6.01 aaa
388            ",
389        );
390        let static_session = StaticSessionData {
391            timezone: "UTC".into(),
392        };
393        let rows = to_pg_rows(
394            &[
395                DataType::Int32,
396                DataType::Int64,
397                DataType::Float32,
398                DataType::Varchar,
399            ],
400            chunk,
401            &[Format::Binary, Format::Binary, Format::Binary, Format::Text],
402            &static_session,
403        );
404        let mut raw_params = vec![BytesMut::new(); 3];
405        1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
406        6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
407        6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
408        let raw_params = raw_params
409            .into_iter()
410            .map(|b| b.freeze())
411            .collect::<Vec<_>>();
412        let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
413            Some(raw_params[0].clone()),
414            Some(raw_params[1].clone()),
415            Some(raw_params[2].clone()),
416            Some("aaa".into()),
417        ]];
418        let vec = rows
419            .unwrap()
420            .into_iter()
421            .map(|r| r.values().iter().cloned().collect_vec())
422            .collect_vec();
423
424        assert_eq!(vec, expected);
425    }
426
427    #[test]
428    fn test_value_format() {
429        use {DataType as T, ScalarRefImpl as S};
430        let static_session = StaticSessionData {
431            timezone: "UTC".into(),
432        };
433
434        let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
435        assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
436        assert_eq!(
437            &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
438            "NaN"
439        );
440        assert_eq!(
441            &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
442            "NaN"
443        );
444        assert_eq!(
445            &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
446            "Infinity"
447        );
448        assert_eq!(
449            &f(
450                &T::Float32,
451                S::Float32(f32::NEG_INFINITY.into()),
452                Format::Text
453            ),
454            "-Infinity"
455        );
456        assert_eq!(
457            &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
458            "Infinity"
459        );
460        assert_eq!(
461            &f(
462                &T::Float64,
463                S::Float64(f64::NEG_INFINITY.into()),
464                Format::Text
465            ),
466            "-Infinity"
467        );
468        assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
469        assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
470        assert_eq!(
471            &f(
472                &T::Timestamptz,
473                S::Timestamptz(Timestamptz::from_micros(-1)),
474                Format::Text
475            ),
476            "1969-12-31 23:59:59.999999+00:00"
477        );
478    }
479}