1use core::str::FromStr;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use bytes::{Bytes, BytesMut};
22use futures::Stream;
23use itertools::Itertools;
24use pgwire::pg_field_descriptor::PgFieldDescriptor;
25use pgwire::pg_response::RowSetResult;
26use pgwire::pg_server::BoxedError;
27use pgwire::types::{Format, FormatIterator, Row};
28use pin_project_lite::pin_project;
29use risingwave_common::array::DataChunk;
30use risingwave_common::catalog::Field;
31use risingwave_common::row::Row as _;
32use risingwave_common::types::{
33 DataType, Interval, ScalarRefImpl, Timestamptz, write_date_time_tz,
34};
35use risingwave_common::util::epoch::Epoch;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::sink::elasticsearch_opensearch::elasticsearch::ES_SINK;
38use risingwave_connector::source::KAFKA_CONNECTOR;
39use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR;
40use risingwave_pb::catalog::connection_params::PbConnectionType;
41use risingwave_sqlparser::ast::{
42 CompatibleFormatEncode, FormatEncodeOptions, ObjectName, Query, Select, SelectItem, SetExpr,
43 TableFactor, TableWithJoins,
44};
45use thiserror_ext::AsReport;
46use tokio::select;
47use tokio::time::{Duration, sleep};
48
49use crate::catalog::root_catalog::SchemaPath;
50use crate::error::ErrorCode::ProtocolError;
51use crate::error::{ErrorCode, Result as RwResult, RwError};
52use crate::session::{SessionImpl, current};
53use crate::{Binder, HashSet, TableCatalog};
54
55pin_project! {
56 pub struct DataChunkToRowSetAdapter<VS>
63 where
64 VS: Stream<Item = Result<DataChunk, BoxedError>>,
65 {
66 #[pin]
67 chunk_stream: VS,
68 column_types: Vec<DataType>,
69 pub formats: Vec<Format>,
70 session_data: StaticSessionData,
71 }
72}
73
74pub struct StaticSessionData {
76 pub timezone: String,
77}
78
79impl<VS> DataChunkToRowSetAdapter<VS>
80where
81 VS: Stream<Item = Result<DataChunk, BoxedError>>,
82{
83 pub fn new(
84 chunk_stream: VS,
85 column_types: Vec<DataType>,
86 formats: Vec<Format>,
87 session: Arc<SessionImpl>,
88 ) -> Self {
89 let session_data = StaticSessionData {
90 timezone: session.config().timezone(),
91 };
92 Self {
93 chunk_stream,
94 column_types,
95 formats,
96 session_data,
97 }
98 }
99}
100
101impl<VS> Stream for DataChunkToRowSetAdapter<VS>
102where
103 VS: Stream<Item = Result<DataChunk, BoxedError>>,
104{
105 type Item = RowSetResult;
106
107 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108 let mut this = self.project();
109 match this.chunk_stream.as_mut().poll_next(cx) {
110 Poll::Pending => Poll::Pending,
111 Poll::Ready(chunk) => match chunk {
112 Some(chunk_result) => match chunk_result {
113 Ok(chunk) => Poll::Ready(Some(
114 to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
115 .map_err(|err| err.into()),
116 )),
117 Err(err) => Poll::Ready(Some(Err(err))),
118 },
119 None => Poll::Ready(None),
120 },
121 }
122 }
123}
124
125pub fn pg_value_format(
127 data_type: &DataType,
128 d: ScalarRefImpl<'_>,
129 format: Format,
130 session_data: &StaticSessionData,
131) -> RwResult<Bytes> {
132 match format {
135 Format::Text => {
136 if *data_type == DataType::Timestamptz {
137 Ok(timestamptz_to_string_with_session_data(d, session_data))
138 } else {
139 Ok(d.text_format(data_type).into())
140 }
141 }
142 Format::Binary => Ok(d
143 .binary_format(data_type)
144 .context("failed to format binary value")?),
145 }
146}
147
148fn timestamptz_to_string_with_session_data(
149 d: ScalarRefImpl<'_>,
150 session_data: &StaticSessionData,
151) -> Bytes {
152 let tz = d.into_timestamptz();
153 let time_zone = Timestamptz::lookup_time_zone(&session_data.timezone).unwrap();
154 let instant_local = tz.to_datetime_in_zone(time_zone);
155 let mut result_string = BytesMut::new();
156 write_date_time_tz(instant_local, &mut result_string).unwrap();
157 result_string.into()
158}
159
160fn to_pg_rows(
161 column_types: &[DataType],
162 chunk: DataChunk,
163 formats: &[Format],
164 session_data: &StaticSessionData,
165) -> RwResult<Vec<Row>> {
166 assert_eq!(chunk.dimension(), column_types.len());
167 if cfg!(debug_assertions) {
168 let chunk_data_types = chunk.data_types();
169 for (ty1, ty2) in chunk_data_types.iter().zip_eq_fast(column_types) {
170 debug_assert!(
171 ty1.equals_datatype(ty2),
172 "chunk_data_types: {chunk_data_types:?}, column_types: {column_types:?}"
173 )
174 }
175 }
176
177 chunk
178 .rows()
179 .map(|r| {
180 let format_iter = FormatIterator::new(formats, chunk.dimension())
181 .map_err(ErrorCode::InternalError)?;
182 let row = r
183 .iter()
184 .zip_eq_fast(column_types)
185 .zip_eq_fast(format_iter)
186 .map(|((data, t), format)| match data {
187 Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
188 None => Ok(None),
189 })
190 .try_collect()?;
191 Ok(Row::new(row))
192 })
193 .try_collect()
194}
195
196pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {
198 PgFieldDescriptor::new(
199 f.name.clone(),
200 f.data_type().to_oid(),
201 f.data_type().type_len(),
202 )
203}
204
205#[easy_ext::ext(SourceSchemaCompatExt)]
206impl CompatibleFormatEncode {
207 pub fn into_v2_with_warning(self) -> FormatEncodeOptions {
209 match self {
210 CompatibleFormatEncode::RowFormat(inner) => {
211 current::notice_to_user(
213 "RisingWave will stop supporting the syntax \"ROW FORMAT\" in future versions, which will be changed to \"FORMAT ... ENCODE ...\" syntax.",
214 );
215 inner.into_format_encode_v2()
216 }
217 CompatibleFormatEncode::V2(inner) => inner,
218 }
219 }
220}
221
222pub fn gen_query_from_table_name(from_name: ObjectName) -> Query {
223 let table_factor = TableFactor::Table {
224 name: from_name,
225 alias: None,
226 as_of: None,
227 };
228 let from = vec![TableWithJoins {
229 relation: table_factor,
230 joins: vec![],
231 }];
232 let select = Select {
233 from,
234 projection: vec![SelectItem::Wildcard(None)],
235 ..Default::default()
236 };
237 let body = SetExpr::Select(Box::new(select));
238 Query {
239 with: None,
240 body,
241 order_by: vec![],
242 limit: None,
243 offset: None,
244 fetch: None,
245 }
246}
247
248pub fn convert_unix_millis_to_logstore_u64(unix_millis: u64) -> u64 {
249 Epoch::from_unix_millis(unix_millis).0
250}
251
252pub fn convert_logstore_u64_to_unix_millis(logstore_u64: u64) -> u64 {
253 Epoch::from(logstore_u64).as_unix_millis()
254}
255
256pub fn convert_interval_to_u64_seconds(interval: &String) -> RwResult<u64> {
257 let seconds = (Interval::from_str(interval)
258 .map_err(|err| {
259 ErrorCode::InternalError(format!(
260 "Convert interval to u64 error, please check format, error: {:?}",
261 err.to_report_string()
262 ))
263 })?
264 .epoch_in_micros()
265 / 1000000) as u64;
266 Ok(seconds)
267}
268
269pub fn ensure_connection_type_allowed(
270 connection_type: PbConnectionType,
271 allowed_types: &HashSet<PbConnectionType>,
272) -> RwResult<()> {
273 if !allowed_types.contains(&connection_type) {
274 return Err(RwError::from(ProtocolError(format!(
275 "connection type {:?} is not allowed, allowed types: {:?}",
276 connection_type, allowed_types
277 ))));
278 }
279 Ok(())
280}
281
282fn connection_type_to_connector(connection_type: &PbConnectionType) -> &str {
283 match connection_type {
284 PbConnectionType::Kafka => KAFKA_CONNECTOR,
285 PbConnectionType::Iceberg => ICEBERG_CONNECTOR,
286 PbConnectionType::Elasticsearch => ES_SINK,
287 _ => unreachable!(),
288 }
289}
290
291pub fn check_connector_match_connection_type(
292 connector: &str,
293 connection_type: &PbConnectionType,
294) -> RwResult<()> {
295 if !connector.eq(connection_type_to_connector(connection_type)) {
296 return Err(RwError::from(ProtocolError(format!(
297 "connector {} and connection type {:?} are not compatible",
298 connector, connection_type
299 ))));
300 }
301 Ok(())
302}
303
304pub fn get_table_catalog_by_table_name(
305 session: &SessionImpl,
306 table_name: &ObjectName,
307) -> RwResult<(Arc<TableCatalog>, String)> {
308 let db_name = &session.database();
309 let (schema_name, real_table_name) =
310 Binder::resolve_schema_qualified_name(db_name, table_name)?;
311 let search_path = session.config().search_path();
312 let user_name = &session.user_name();
313
314 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
315 let reader = session.env().catalog_reader().read_guard();
316 match reader.get_created_table_by_name(db_name, schema_path, &real_table_name) {
317 Ok((table, schema_name)) => Ok((table.clone(), schema_name.to_owned())),
318 Err(err) => {
319 if let Some(table) = session
320 .staging_catalog_manager()
321 .get_table(&real_table_name)
322 {
323 let schema_name = reader
326 .get_schema_by_id(table.database_id, table.schema_id)
327 .map(|schema| schema.name.clone())?;
328 Ok((Arc::new(table.clone()), schema_name))
329 } else {
330 Err(err.into())
331 }
332 }
333 }
334}
335
336#[derive(Clone, Copy)]
355pub enum LongRunningNotificationAction {
356 SuggestRecover,
357 DiagnoseBarrierLatency,
358 MonitorBackfillJob,
359}
360
361impl LongRunningNotificationAction {
362 fn build_message(self, operation_name: &str, notify_timeout_secs: u32) -> String {
363 match self {
364 LongRunningNotificationAction::SuggestRecover => format!(
365 "{} has taken more than {} secs, likely due to high barrier latency.\n\
366 You may trigger cluster recovery to let {} take effect immediately.\n\
367 Run RECOVER in a separate session to trigger recovery.\n\
368 See: https://docs.risingwave.com/sql/commands/sql-recover#recover",
369 operation_name, notify_timeout_secs, operation_name
370 ),
371 LongRunningNotificationAction::DiagnoseBarrierLatency => format!(
372 "{} has taken more than {} secs, likely due to high barrier latency.\n\
373 See: https://docs.risingwave.com/performance/metrics#barrier-monitoring for steps to diagnose high barrier latency.",
374 operation_name, notify_timeout_secs
375 ),
376 LongRunningNotificationAction::MonitorBackfillJob => format!(
377 "{} has taken more than {} secs, barrier latency might be high. Please check barrier latency metrics to confirm.\n\
378 You can also run SHOW JOBS to track the progress of the job.\n\
379 See: https://docs.risingwave.com/performance/metrics#barrier-monitoring and https://docs.risingwave.com/sql/commands/sql-show-jobs",
380 operation_name, notify_timeout_secs
381 ),
382 }
383 }
384}
385
386pub async fn execute_with_long_running_notification<F, T>(
387 operation_fut: F,
388 session: &SessionImpl,
389 operation_name: &str,
390 action: LongRunningNotificationAction,
391) -> RwResult<T>
392where
393 F: std::future::Future<Output = RwResult<T>>,
394{
395 let notify_timeout_secs = session.config().slow_ddl_notification_secs();
396
397 if notify_timeout_secs == 0 {
399 return operation_fut.await;
400 }
401
402 let notify_fut = sleep(Duration::from_secs(notify_timeout_secs as u64));
403 tokio::pin!(operation_fut);
404
405 select! {
406 _ = notify_fut => {
407 session.notice_to_user(action.build_message(operation_name, notify_timeout_secs));
408 operation_fut.await
409 }
410 result = &mut operation_fut => {
411 result
412 }
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use postgres_types::{ToSql, Type};
419 use risingwave_common::array::*;
420
421 use super::*;
422
423 #[test]
424 fn test_to_pg_field() {
425 let field = Field::with_name(DataType::Int32, "v1");
426 let pg_field = to_pg_field(&field);
427 assert_eq!(pg_field.get_name(), "v1");
428 assert_eq!(pg_field.get_type_oid(), DataType::Int32.to_oid());
429 }
430
431 #[test]
432 fn test_to_pg_rows() {
433 let chunk = DataChunk::from_pretty(
434 "i I f T
435 1 6 6.01 aaa
436 2 . . .
437 3 7 7.01 vvv
438 4 . . . ",
439 );
440 let static_session = StaticSessionData {
441 timezone: "UTC".into(),
442 };
443 let rows = to_pg_rows(
444 &[
445 DataType::Int32,
446 DataType::Int64,
447 DataType::Float32,
448 DataType::Varchar,
449 ],
450 chunk,
451 &[],
452 &static_session,
453 );
454 let expected: Vec<Vec<Option<Bytes>>> = vec![
455 vec![
456 Some("1".into()),
457 Some("6".into()),
458 Some("6.01".into()),
459 Some("aaa".into()),
460 ],
461 vec![Some("2".into()), None, None, None],
462 vec![
463 Some("3".into()),
464 Some("7".into()),
465 Some("7.01".into()),
466 Some("vvv".into()),
467 ],
468 vec![Some("4".into()), None, None, None],
469 ];
470 let vec = rows
471 .unwrap()
472 .into_iter()
473 .map(|r| r.values().iter().cloned().collect_vec())
474 .collect_vec();
475
476 assert_eq!(vec, expected);
477 }
478
479 #[test]
480 fn test_to_pg_rows_mix_format() {
481 let chunk = DataChunk::from_pretty(
482 "i I f T
483 1 6 6.01 aaa
484 ",
485 );
486 let static_session = StaticSessionData {
487 timezone: "UTC".into(),
488 };
489 let rows = to_pg_rows(
490 &[
491 DataType::Int32,
492 DataType::Int64,
493 DataType::Float32,
494 DataType::Varchar,
495 ],
496 chunk,
497 &[Format::Binary, Format::Binary, Format::Binary, Format::Text],
498 &static_session,
499 );
500 let mut raw_params = vec![BytesMut::new(); 3];
501 1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
502 6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
503 6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
504 let raw_params = raw_params
505 .into_iter()
506 .map(|b| b.freeze())
507 .collect::<Vec<_>>();
508 let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
509 Some(raw_params[0].clone()),
510 Some(raw_params[1].clone()),
511 Some(raw_params[2].clone()),
512 Some("aaa".into()),
513 ]];
514 let vec = rows
515 .unwrap()
516 .into_iter()
517 .map(|r| r.values().iter().cloned().collect_vec())
518 .collect_vec();
519
520 assert_eq!(vec, expected);
521 }
522
523 #[test]
524 fn test_value_format() {
525 use {DataType as T, ScalarRefImpl as S};
526 let static_session = StaticSessionData {
527 timezone: "UTC".into(),
528 };
529
530 let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
531 assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
532 assert_eq!(
533 &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
534 "NaN"
535 );
536 assert_eq!(
537 &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
538 "NaN"
539 );
540 assert_eq!(
541 &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
542 "Infinity"
543 );
544 assert_eq!(
545 &f(
546 &T::Float32,
547 S::Float32(f32::NEG_INFINITY.into()),
548 Format::Text
549 ),
550 "-Infinity"
551 );
552 assert_eq!(
553 &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
554 "Infinity"
555 );
556 assert_eq!(
557 &f(
558 &T::Float64,
559 S::Float64(f64::NEG_INFINITY.into()),
560 Format::Text
561 ),
562 "-Infinity"
563 );
564 assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
565 assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
566 assert_eq!(
567 &f(
568 &T::Timestamptz,
569 S::Timestamptz(Timestamptz::from_micros(-1)),
570 Format::Text
571 ),
572 "1969-12-31 23:59:59.999999+00:00"
573 );
574 }
575}