1use core::str::FromStr;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use bytes::{Bytes, BytesMut};
22use futures::Stream;
23use itertools::Itertools;
24use pgwire::pg_field_descriptor::PgFieldDescriptor;
25use pgwire::pg_response::RowSetResult;
26use pgwire::pg_server::BoxedError;
27use pgwire::types::{Format, FormatIterator, Row};
28use pin_project_lite::pin_project;
29use risingwave_common::array::DataChunk;
30use risingwave_common::catalog::Field;
31use risingwave_common::row::Row as _;
32use risingwave_common::types::{
33 DataType, Interval, ScalarRefImpl, Timestamptz, write_date_time_tz,
34};
35use risingwave_common::util::epoch::Epoch;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::sink::elasticsearch_opensearch::elasticsearch::ES_SINK;
38use risingwave_connector::source::KAFKA_CONNECTOR;
39use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR;
40use risingwave_pb::catalog::connection_params::PbConnectionType;
41use risingwave_sqlparser::ast::{
42 CompatibleFormatEncode, FormatEncodeOptions, ObjectName, Query, Select, SelectItem, SetExpr,
43 TableFactor, TableWithJoins,
44};
45use thiserror_ext::AsReport;
46
47use crate::catalog::root_catalog::SchemaPath;
48use crate::error::ErrorCode::ProtocolError;
49use crate::error::{ErrorCode, Result as RwResult, RwError};
50use crate::session::{SessionImpl, current};
51use crate::{Binder, HashSet, TableCatalog};
52
53pin_project! {
54 pub struct DataChunkToRowSetAdapter<VS>
61 where
62 VS: Stream<Item = Result<DataChunk, BoxedError>>,
63 {
64 #[pin]
65 chunk_stream: VS,
66 column_types: Vec<DataType>,
67 pub formats: Vec<Format>,
68 session_data: StaticSessionData,
69 }
70}
71
72pub struct StaticSessionData {
74 pub timezone: String,
75}
76
77impl<VS> DataChunkToRowSetAdapter<VS>
78where
79 VS: Stream<Item = Result<DataChunk, BoxedError>>,
80{
81 pub fn new(
82 chunk_stream: VS,
83 column_types: Vec<DataType>,
84 formats: Vec<Format>,
85 session: Arc<SessionImpl>,
86 ) -> Self {
87 let session_data = StaticSessionData {
88 timezone: session.config().timezone(),
89 };
90 Self {
91 chunk_stream,
92 column_types,
93 formats,
94 session_data,
95 }
96 }
97}
98
99impl<VS> Stream for DataChunkToRowSetAdapter<VS>
100where
101 VS: Stream<Item = Result<DataChunk, BoxedError>>,
102{
103 type Item = RowSetResult;
104
105 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106 let mut this = self.project();
107 match this.chunk_stream.as_mut().poll_next(cx) {
108 Poll::Pending => Poll::Pending,
109 Poll::Ready(chunk) => match chunk {
110 Some(chunk_result) => match chunk_result {
111 Ok(chunk) => Poll::Ready(Some(
112 to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
113 .map_err(|err| err.into()),
114 )),
115 Err(err) => Poll::Ready(Some(Err(err))),
116 },
117 None => Poll::Ready(None),
118 },
119 }
120 }
121}
122
123pub fn pg_value_format(
125 data_type: &DataType,
126 d: ScalarRefImpl<'_>,
127 format: Format,
128 session_data: &StaticSessionData,
129) -> RwResult<Bytes> {
130 match format {
133 Format::Text => {
134 if *data_type == DataType::Timestamptz {
135 Ok(timestamptz_to_string_with_session_data(d, session_data))
136 } else {
137 Ok(d.text_format(data_type).into())
138 }
139 }
140 Format::Binary => Ok(d
141 .binary_format(data_type)
142 .context("failed to format binary value")?),
143 }
144}
145
146fn timestamptz_to_string_with_session_data(
147 d: ScalarRefImpl<'_>,
148 session_data: &StaticSessionData,
149) -> Bytes {
150 let tz = d.into_timestamptz();
151 let time_zone = Timestamptz::lookup_time_zone(&session_data.timezone).unwrap();
152 let instant_local = tz.to_datetime_in_zone(time_zone);
153 let mut result_string = BytesMut::new();
154 write_date_time_tz(instant_local, &mut result_string).unwrap();
155 result_string.into()
156}
157
158fn to_pg_rows(
159 column_types: &[DataType],
160 chunk: DataChunk,
161 formats: &[Format],
162 session_data: &StaticSessionData,
163) -> RwResult<Vec<Row>> {
164 assert_eq!(chunk.dimension(), column_types.len());
165 if cfg!(debug_assertions) {
166 let chunk_data_types = chunk.data_types();
167 for (ty1, ty2) in chunk_data_types.iter().zip_eq_fast(column_types) {
168 debug_assert!(
169 ty1.equals_datatype(ty2),
170 "chunk_data_types: {chunk_data_types:?}, column_types: {column_types:?}"
171 )
172 }
173 }
174
175 chunk
176 .rows()
177 .map(|r| {
178 let format_iter = FormatIterator::new(formats, chunk.dimension())
179 .map_err(ErrorCode::InternalError)?;
180 let row = r
181 .iter()
182 .zip_eq_fast(column_types)
183 .zip_eq_fast(format_iter)
184 .map(|((data, t), format)| match data {
185 Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
186 None => Ok(None),
187 })
188 .try_collect()?;
189 Ok(Row::new(row))
190 })
191 .try_collect()
192}
193
194pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {
196 PgFieldDescriptor::new(
197 f.name.clone(),
198 f.data_type().to_oid(),
199 f.data_type().type_len(),
200 )
201}
202
203#[easy_ext::ext(SourceSchemaCompatExt)]
204impl CompatibleFormatEncode {
205 pub fn into_v2_with_warning(self) -> FormatEncodeOptions {
207 match self {
208 CompatibleFormatEncode::RowFormat(inner) => {
209 current::notice_to_user(
211 "RisingWave will stop supporting the syntax \"ROW FORMAT\" in future versions, which will be changed to \"FORMAT ... ENCODE ...\" syntax.",
212 );
213 inner.into_format_encode_v2()
214 }
215 CompatibleFormatEncode::V2(inner) => inner,
216 }
217 }
218}
219
220pub fn gen_query_from_table_name(from_name: ObjectName) -> Query {
221 let table_factor = TableFactor::Table {
222 name: from_name,
223 alias: None,
224 as_of: None,
225 };
226 let from = vec![TableWithJoins {
227 relation: table_factor,
228 joins: vec![],
229 }];
230 let select = Select {
231 from,
232 projection: vec![SelectItem::Wildcard(None)],
233 ..Default::default()
234 };
235 let body = SetExpr::Select(Box::new(select));
236 Query {
237 with: None,
238 body,
239 order_by: vec![],
240 limit: None,
241 offset: None,
242 fetch: None,
243 }
244}
245
246pub fn convert_unix_millis_to_logstore_u64(unix_millis: u64) -> u64 {
247 Epoch::from_unix_millis(unix_millis).0
248}
249
250pub fn convert_logstore_u64_to_unix_millis(logstore_u64: u64) -> u64 {
251 Epoch::from(logstore_u64).as_unix_millis()
252}
253
254pub fn convert_interval_to_u64_seconds(interval: &String) -> RwResult<u64> {
255 let seconds = (Interval::from_str(interval)
256 .map_err(|err| {
257 ErrorCode::InternalError(format!(
258 "Convert interval to u64 error, please check format, error: {:?}",
259 err.to_report_string()
260 ))
261 })?
262 .epoch_in_micros()
263 / 1000000) as u64;
264 Ok(seconds)
265}
266
267pub fn ensure_connection_type_allowed(
268 connection_type: PbConnectionType,
269 allowed_types: &HashSet<PbConnectionType>,
270) -> RwResult<()> {
271 if !allowed_types.contains(&connection_type) {
272 return Err(RwError::from(ProtocolError(format!(
273 "connection type {:?} is not allowed, allowed types: {:?}",
274 connection_type, allowed_types
275 ))));
276 }
277 Ok(())
278}
279
280fn connection_type_to_connector(connection_type: &PbConnectionType) -> &str {
281 match connection_type {
282 PbConnectionType::Kafka => KAFKA_CONNECTOR,
283 PbConnectionType::Iceberg => ICEBERG_CONNECTOR,
284 PbConnectionType::Elasticsearch => ES_SINK,
285 _ => unreachable!(),
286 }
287}
288
289pub fn check_connector_match_connection_type(
290 connector: &str,
291 connection_type: &PbConnectionType,
292) -> RwResult<()> {
293 if !connector.eq(connection_type_to_connector(connection_type)) {
294 return Err(RwError::from(ProtocolError(format!(
295 "connector {} and connection type {:?} are not compatible",
296 connector, connection_type
297 ))));
298 }
299 Ok(())
300}
301
302pub fn get_table_catalog_by_table_name(
303 session: &SessionImpl,
304 table_name: &ObjectName,
305) -> RwResult<(Arc<TableCatalog>, String)> {
306 let db_name = &session.database();
307 let (schema_name, real_table_name) =
308 Binder::resolve_schema_qualified_name(db_name, table_name.clone())?;
309 let search_path = session.config().search_path();
310 let user_name = &session.user_name();
311
312 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
313 let reader = session.env().catalog_reader().read_guard();
314 let (table, schema_name) =
315 reader.get_created_table_by_name(db_name, schema_path, &real_table_name)?;
316
317 Ok((table.clone(), schema_name.to_owned()))
318}
319
320#[cfg(test)]
321mod tests {
322 use postgres_types::{ToSql, Type};
323 use risingwave_common::array::*;
324
325 use super::*;
326
327 #[test]
328 fn test_to_pg_field() {
329 let field = Field::with_name(DataType::Int32, "v1");
330 let pg_field = to_pg_field(&field);
331 assert_eq!(pg_field.get_name(), "v1");
332 assert_eq!(pg_field.get_type_oid(), DataType::Int32.to_oid());
333 }
334
335 #[test]
336 fn test_to_pg_rows() {
337 let chunk = DataChunk::from_pretty(
338 "i I f T
339 1 6 6.01 aaa
340 2 . . .
341 3 7 7.01 vvv
342 4 . . . ",
343 );
344 let static_session = StaticSessionData {
345 timezone: "UTC".into(),
346 };
347 let rows = to_pg_rows(
348 &[
349 DataType::Int32,
350 DataType::Int64,
351 DataType::Float32,
352 DataType::Varchar,
353 ],
354 chunk,
355 &[],
356 &static_session,
357 );
358 let expected: Vec<Vec<Option<Bytes>>> = vec![
359 vec![
360 Some("1".into()),
361 Some("6".into()),
362 Some("6.01".into()),
363 Some("aaa".into()),
364 ],
365 vec![Some("2".into()), None, None, None],
366 vec![
367 Some("3".into()),
368 Some("7".into()),
369 Some("7.01".into()),
370 Some("vvv".into()),
371 ],
372 vec![Some("4".into()), None, None, None],
373 ];
374 let vec = rows
375 .unwrap()
376 .into_iter()
377 .map(|r| r.values().iter().cloned().collect_vec())
378 .collect_vec();
379
380 assert_eq!(vec, expected);
381 }
382
383 #[test]
384 fn test_to_pg_rows_mix_format() {
385 let chunk = DataChunk::from_pretty(
386 "i I f T
387 1 6 6.01 aaa
388 ",
389 );
390 let static_session = StaticSessionData {
391 timezone: "UTC".into(),
392 };
393 let rows = to_pg_rows(
394 &[
395 DataType::Int32,
396 DataType::Int64,
397 DataType::Float32,
398 DataType::Varchar,
399 ],
400 chunk,
401 &[Format::Binary, Format::Binary, Format::Binary, Format::Text],
402 &static_session,
403 );
404 let mut raw_params = vec![BytesMut::new(); 3];
405 1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
406 6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
407 6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
408 let raw_params = raw_params
409 .into_iter()
410 .map(|b| b.freeze())
411 .collect::<Vec<_>>();
412 let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
413 Some(raw_params[0].clone()),
414 Some(raw_params[1].clone()),
415 Some(raw_params[2].clone()),
416 Some("aaa".into()),
417 ]];
418 let vec = rows
419 .unwrap()
420 .into_iter()
421 .map(|r| r.values().iter().cloned().collect_vec())
422 .collect_vec();
423
424 assert_eq!(vec, expected);
425 }
426
427 #[test]
428 fn test_value_format() {
429 use {DataType as T, ScalarRefImpl as S};
430 let static_session = StaticSessionData {
431 timezone: "UTC".into(),
432 };
433
434 let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
435 assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
436 assert_eq!(
437 &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
438 "NaN"
439 );
440 assert_eq!(
441 &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
442 "NaN"
443 );
444 assert_eq!(
445 &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
446 "Infinity"
447 );
448 assert_eq!(
449 &f(
450 &T::Float32,
451 S::Float32(f32::NEG_INFINITY.into()),
452 Format::Text
453 ),
454 "-Infinity"
455 );
456 assert_eq!(
457 &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
458 "Infinity"
459 );
460 assert_eq!(
461 &f(
462 &T::Float64,
463 S::Float64(f64::NEG_INFINITY.into()),
464 Format::Text
465 ),
466 "-Infinity"
467 );
468 assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
469 assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
470 assert_eq!(
471 &f(
472 &T::Timestamptz,
473 S::Timestamptz(Timestamptz::from_micros(-1)),
474 Format::Text
475 ),
476 "1969-12-31 23:59:59.999999+00:00"
477 );
478 }
479}