1use core::str::FromStr;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use bytes::{Bytes, BytesMut};
22use futures::Stream;
23use itertools::Itertools;
24use pgwire::pg_field_descriptor::PgFieldDescriptor;
25use pgwire::pg_response::RowSetResult;
26use pgwire::pg_server::BoxedError;
27use pgwire::types::{Format, FormatIterator, Row};
28use pin_project_lite::pin_project;
29use risingwave_common::array::DataChunk;
30use risingwave_common::catalog::Field;
31use risingwave_common::row::Row as _;
32use risingwave_common::types::{
33 DataType, Interval, ScalarRefImpl, Timestamptz, write_date_time_tz,
34};
35use risingwave_common::util::epoch::Epoch;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::sink::elasticsearch_opensearch::elasticsearch::ES_SINK;
38use risingwave_connector::source::KAFKA_CONNECTOR;
39use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR;
40use risingwave_pb::catalog::connection_params::PbConnectionType;
41use risingwave_sqlparser::ast::{
42 CompatibleFormatEncode, Expr, FormatEncodeOptions, Ident, ObjectName, OrderByExpr, Query,
43 Select, SelectItem, SetExpr, TableFactor, TableWithJoins,
44};
45use thiserror_ext::AsReport;
46
47use crate::catalog::root_catalog::SchemaPath;
48use crate::error::ErrorCode::ProtocolError;
49use crate::error::{ErrorCode, Result as RwResult, RwError};
50use crate::session::{SessionImpl, current};
51use crate::{Binder, HashSet, TableCatalog};
52
53pin_project! {
54 pub struct DataChunkToRowSetAdapter<VS>
61 where
62 VS: Stream<Item = Result<DataChunk, BoxedError>>,
63 {
64 #[pin]
65 chunk_stream: VS,
66 column_types: Vec<DataType>,
67 pub formats: Vec<Format>,
68 session_data: StaticSessionData,
69 }
70}
71
72pub struct StaticSessionData {
74 pub timezone: String,
75}
76
77impl<VS> DataChunkToRowSetAdapter<VS>
78where
79 VS: Stream<Item = Result<DataChunk, BoxedError>>,
80{
81 pub fn new(
82 chunk_stream: VS,
83 column_types: Vec<DataType>,
84 formats: Vec<Format>,
85 session: Arc<SessionImpl>,
86 ) -> Self {
87 let session_data = StaticSessionData {
88 timezone: session.config().timezone(),
89 };
90 Self {
91 chunk_stream,
92 column_types,
93 formats,
94 session_data,
95 }
96 }
97}
98
99impl<VS> Stream for DataChunkToRowSetAdapter<VS>
100where
101 VS: Stream<Item = Result<DataChunk, BoxedError>>,
102{
103 type Item = RowSetResult;
104
105 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106 let mut this = self.project();
107 match this.chunk_stream.as_mut().poll_next(cx) {
108 Poll::Pending => Poll::Pending,
109 Poll::Ready(chunk) => match chunk {
110 Some(chunk_result) => match chunk_result {
111 Ok(chunk) => Poll::Ready(Some(
112 to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
113 .map_err(|err| err.into()),
114 )),
115 Err(err) => Poll::Ready(Some(Err(err))),
116 },
117 None => Poll::Ready(None),
118 },
119 }
120 }
121}
122
123pub fn pg_value_format(
125 data_type: &DataType,
126 d: ScalarRefImpl<'_>,
127 format: Format,
128 session_data: &StaticSessionData,
129) -> RwResult<Bytes> {
130 match format {
133 Format::Text => {
134 if *data_type == DataType::Timestamptz {
135 Ok(timestamptz_to_string_with_session_data(d, session_data))
136 } else {
137 Ok(d.text_format(data_type).into())
138 }
139 }
140 Format::Binary => Ok(d
141 .binary_format(data_type)
142 .context("failed to format binary value")?),
143 }
144}
145
146fn timestamptz_to_string_with_session_data(
147 d: ScalarRefImpl<'_>,
148 session_data: &StaticSessionData,
149) -> Bytes {
150 let tz = d.into_timestamptz();
151 let time_zone = Timestamptz::lookup_time_zone(&session_data.timezone).unwrap();
152 let instant_local = tz.to_datetime_in_zone(time_zone);
153 let mut result_string = BytesMut::new();
154 write_date_time_tz(instant_local, &mut result_string).unwrap();
155 result_string.into()
156}
157
158fn to_pg_rows(
159 column_types: &[DataType],
160 chunk: DataChunk,
161 formats: &[Format],
162 session_data: &StaticSessionData,
163) -> RwResult<Vec<Row>> {
164 assert_eq!(chunk.dimension(), column_types.len());
165 if cfg!(debug_assertions) {
166 let chunk_data_types = chunk.data_types();
167 for (ty1, ty2) in chunk_data_types.iter().zip_eq_fast(column_types) {
168 debug_assert!(
169 ty1.equals_datatype(ty2),
170 "chunk_data_types: {chunk_data_types:?}, column_types: {column_types:?}"
171 )
172 }
173 }
174
175 chunk
176 .rows()
177 .map(|r| {
178 let format_iter = FormatIterator::new(formats, chunk.dimension())
179 .map_err(ErrorCode::InternalError)?;
180 let row = r
181 .iter()
182 .zip_eq_fast(column_types)
183 .zip_eq_fast(format_iter)
184 .map(|((data, t), format)| match data {
185 Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
186 None => Ok(None),
187 })
188 .try_collect()?;
189 Ok(Row::new(row))
190 })
191 .try_collect()
192}
193
194pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {
196 PgFieldDescriptor::new(
197 f.name.clone(),
198 f.data_type().to_oid(),
199 f.data_type().type_len(),
200 )
201}
202
203#[easy_ext::ext(SourceSchemaCompatExt)]
204impl CompatibleFormatEncode {
205 pub fn into_v2_with_warning(self) -> FormatEncodeOptions {
207 match self {
208 CompatibleFormatEncode::RowFormat(inner) => {
209 current::notice_to_user(
211 "RisingWave will stop supporting the syntax \"ROW FORMAT\" in future versions, which will be changed to \"FORMAT ... ENCODE ...\" syntax.",
212 );
213 inner.into_format_encode_v2()
214 }
215 CompatibleFormatEncode::V2(inner) => inner,
216 }
217 }
218}
219
220pub fn gen_query_from_table_name(from_name: ObjectName) -> Query {
221 let table_factor = TableFactor::Table {
222 name: from_name,
223 alias: None,
224 as_of: None,
225 };
226 let from = vec![TableWithJoins {
227 relation: table_factor,
228 joins: vec![],
229 }];
230 let select = Select {
231 from,
232 projection: vec![SelectItem::Wildcard(None)],
233 ..Default::default()
234 };
235 let body = SetExpr::Select(Box::new(select));
236 Query {
237 with: None,
238 body,
239 order_by: vec![],
240 limit: None,
241 offset: None,
242 fetch: None,
243 }
244}
245
246pub fn gen_query_from_table_name_order_by(from_name: ObjectName, pk_names: Vec<String>) -> Query {
247 let mut query = gen_query_from_table_name(from_name);
248 query.order_by = pk_names
249 .into_iter()
250 .map(|pk| {
251 let expr = Expr::Identifier(Ident::with_quote_unchecked('"', pk));
252 OrderByExpr {
253 expr,
254 asc: None,
255 nulls_first: None,
256 }
257 })
258 .collect();
259 query
260}
261
262pub fn convert_unix_millis_to_logstore_u64(unix_millis: u64) -> u64 {
263 Epoch::from_unix_millis(unix_millis).0
264}
265
266pub fn convert_logstore_u64_to_unix_millis(logstore_u64: u64) -> u64 {
267 Epoch::from(logstore_u64).as_unix_millis()
268}
269
270pub fn convert_interval_to_u64_seconds(interval: &String) -> RwResult<u64> {
271 let seconds = (Interval::from_str(interval)
272 .map_err(|err| {
273 ErrorCode::InternalError(format!(
274 "Convert interval to u64 error, please check format, error: {:?}",
275 err.to_report_string()
276 ))
277 })?
278 .epoch_in_micros()
279 / 1000000) as u64;
280 Ok(seconds)
281}
282
283pub fn ensure_connection_type_allowed(
284 connection_type: PbConnectionType,
285 allowed_types: &HashSet<PbConnectionType>,
286) -> RwResult<()> {
287 if !allowed_types.contains(&connection_type) {
288 return Err(RwError::from(ProtocolError(format!(
289 "connection type {:?} is not allowed, allowed types: {:?}",
290 connection_type, allowed_types
291 ))));
292 }
293 Ok(())
294}
295
296fn connection_type_to_connector(connection_type: &PbConnectionType) -> &str {
297 match connection_type {
298 PbConnectionType::Kafka => KAFKA_CONNECTOR,
299 PbConnectionType::Iceberg => ICEBERG_CONNECTOR,
300 PbConnectionType::Elasticsearch => ES_SINK,
301 _ => unreachable!(),
302 }
303}
304
305pub fn check_connector_match_connection_type(
306 connector: &str,
307 connection_type: &PbConnectionType,
308) -> RwResult<()> {
309 if !connector.eq(connection_type_to_connector(connection_type)) {
310 return Err(RwError::from(ProtocolError(format!(
311 "connector {} and connection type {:?} are not compatible",
312 connector, connection_type
313 ))));
314 }
315 Ok(())
316}
317
318pub fn get_table_catalog_by_table_name(
319 session: &SessionImpl,
320 table_name: &ObjectName,
321) -> RwResult<(Arc<TableCatalog>, String)> {
322 let db_name = &session.database();
323 let (schema_name, real_table_name) =
324 Binder::resolve_schema_qualified_name(db_name, table_name.clone())?;
325 let search_path = session.config().search_path();
326 let user_name = &session.user_name();
327
328 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
329 let reader = session.env().catalog_reader().read_guard();
330 let (table, schema_name) =
331 reader.get_created_table_by_name(db_name, schema_path, &real_table_name)?;
332
333 Ok((table.clone(), schema_name.to_owned()))
334}
335
336#[cfg(test)]
337mod tests {
338 use postgres_types::{ToSql, Type};
339 use risingwave_common::array::*;
340
341 use super::*;
342
343 #[test]
344 fn test_to_pg_field() {
345 let field = Field::with_name(DataType::Int32, "v1");
346 let pg_field = to_pg_field(&field);
347 assert_eq!(pg_field.get_name(), "v1");
348 assert_eq!(pg_field.get_type_oid(), DataType::Int32.to_oid());
349 }
350
351 #[test]
352 fn test_to_pg_rows() {
353 let chunk = DataChunk::from_pretty(
354 "i I f T
355 1 6 6.01 aaa
356 2 . . .
357 3 7 7.01 vvv
358 4 . . . ",
359 );
360 let static_session = StaticSessionData {
361 timezone: "UTC".into(),
362 };
363 let rows = to_pg_rows(
364 &[
365 DataType::Int32,
366 DataType::Int64,
367 DataType::Float32,
368 DataType::Varchar,
369 ],
370 chunk,
371 &[],
372 &static_session,
373 );
374 let expected: Vec<Vec<Option<Bytes>>> = vec![
375 vec![
376 Some("1".into()),
377 Some("6".into()),
378 Some("6.01".into()),
379 Some("aaa".into()),
380 ],
381 vec![Some("2".into()), None, None, None],
382 vec![
383 Some("3".into()),
384 Some("7".into()),
385 Some("7.01".into()),
386 Some("vvv".into()),
387 ],
388 vec![Some("4".into()), None, None, None],
389 ];
390 let vec = rows
391 .unwrap()
392 .into_iter()
393 .map(|r| r.values().iter().cloned().collect_vec())
394 .collect_vec();
395
396 assert_eq!(vec, expected);
397 }
398
399 #[test]
400 fn test_to_pg_rows_mix_format() {
401 let chunk = DataChunk::from_pretty(
402 "i I f T
403 1 6 6.01 aaa
404 ",
405 );
406 let static_session = StaticSessionData {
407 timezone: "UTC".into(),
408 };
409 let rows = to_pg_rows(
410 &[
411 DataType::Int32,
412 DataType::Int64,
413 DataType::Float32,
414 DataType::Varchar,
415 ],
416 chunk,
417 &[Format::Binary, Format::Binary, Format::Binary, Format::Text],
418 &static_session,
419 );
420 let mut raw_params = vec![BytesMut::new(); 3];
421 1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
422 6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
423 6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
424 let raw_params = raw_params
425 .into_iter()
426 .map(|b| b.freeze())
427 .collect::<Vec<_>>();
428 let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
429 Some(raw_params[0].clone()),
430 Some(raw_params[1].clone()),
431 Some(raw_params[2].clone()),
432 Some("aaa".into()),
433 ]];
434 let vec = rows
435 .unwrap()
436 .into_iter()
437 .map(|r| r.values().iter().cloned().collect_vec())
438 .collect_vec();
439
440 assert_eq!(vec, expected);
441 }
442
443 #[test]
444 fn test_value_format() {
445 use {DataType as T, ScalarRefImpl as S};
446 let static_session = StaticSessionData {
447 timezone: "UTC".into(),
448 };
449
450 let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
451 assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
452 assert_eq!(
453 &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
454 "NaN"
455 );
456 assert_eq!(
457 &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
458 "NaN"
459 );
460 assert_eq!(
461 &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
462 "Infinity"
463 );
464 assert_eq!(
465 &f(
466 &T::Float32,
467 S::Float32(f32::NEG_INFINITY.into()),
468 Format::Text
469 ),
470 "-Infinity"
471 );
472 assert_eq!(
473 &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
474 "Infinity"
475 );
476 assert_eq!(
477 &f(
478 &T::Float64,
479 S::Float64(f64::NEG_INFINITY.into()),
480 Format::Text
481 ),
482 "-Infinity"
483 );
484 assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
485 assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
486 assert_eq!(
487 &f(
488 &T::Timestamptz,
489 S::Timestamptz(Timestamptz::from_micros(-1)),
490 Format::Text
491 ),
492 "1969-12-31 23:59:59.999999+00:00"
493 );
494 }
495}