1use std::collections::HashMap;
16use std::io::{Error, ErrorKind};
17use std::net::{IpAddr, Ipv4Addr, SocketAddr};
18use std::sync::atomic::{AtomicI32, Ordering};
19use std::sync::{Arc, Weak};
20use std::time::{Duration, Instant};
21
22use anyhow::anyhow;
23use bytes::Bytes;
24use either::Either;
25use itertools::Itertools;
26use parking_lot::{Mutex, RwLock, RwLockReadGuard};
27use pgwire::error::{PsqlError, PsqlResult};
28use pgwire::net::{Address, AddressRef};
29use pgwire::pg_field_descriptor::PgFieldDescriptor;
30use pgwire::pg_message::TransactionStatus;
31use pgwire::pg_response::{PgResponse, StatementType};
32use pgwire::pg_server::{
33 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
34 UserAuthenticator,
35};
36use pgwire::types::{Format, FormatIterator};
37use rand::RngCore;
38use risingwave_batch::monitor::{BatchSpillMetrics, GLOBAL_BATCH_SPILL_METRICS};
39use risingwave_batch::spill::spill_op::SpillOp;
40use risingwave_batch::task::{ShutdownSender, ShutdownToken};
41use risingwave_batch::worker_manager::worker_node_manager::{
42 WorkerNodeManager, WorkerNodeManagerRef,
43};
44use risingwave_common::acl::AclMode;
45#[cfg(test)]
46use risingwave_common::catalog::{
47 DEFAULT_DATABASE_NAME, DEFAULT_SUPER_USER, DEFAULT_SUPER_USER_ID,
48};
49use risingwave_common::config::{
50 BatchConfig, FrontendConfig, MetaConfig, MetricLevel, StreamingConfig, UdfConfig, load_config,
51};
52use risingwave_common::memory::MemoryContext;
53use risingwave_common::secret::LocalSecretManager;
54use risingwave_common::session_config::{ConfigReporter, SessionConfig, VisibilityMode};
55use risingwave_common::system_param::local_manager::{
56 LocalSystemParamsManager, LocalSystemParamsManagerRef,
57};
58use risingwave_common::telemetry::manager::TelemetryManager;
59use risingwave_common::telemetry::telemetry_env_enabled;
60use risingwave_common::types::DataType;
61use risingwave_common::util::addr::HostAddr;
62use risingwave_common::util::cluster_limit;
63use risingwave_common::util::cluster_limit::ActorCountPerParallelism;
64use risingwave_common::util::iter_util::ZipEqFast;
65use risingwave_common::util::pretty_bytes::convert;
66use risingwave_common::util::runtime::BackgroundShutdownRuntime;
67use risingwave_common::{GIT_SHA, RW_VERSION};
68use risingwave_common_heap_profiling::HeapProfiler;
69use risingwave_common_service::{MetricsManager, ObserverManager};
70use risingwave_connector::source::monitor::{GLOBAL_SOURCE_METRICS, SourceMetrics};
71use risingwave_pb::common::WorkerType;
72use risingwave_pb::common::worker_node::Property as AddWorkerNodeProperty;
73use risingwave_pb::frontend_service::frontend_service_server::FrontendServiceServer;
74use risingwave_pb::health::health_server::HealthServer;
75use risingwave_pb::user::auth_info::EncryptionType;
76use risingwave_pb::user::grant_privilege::Object;
77use risingwave_rpc_client::{ComputeClientPool, ComputeClientPoolRef, MetaClient};
78use risingwave_sqlparser::ast::{ObjectName, Statement};
79use risingwave_sqlparser::parser::Parser;
80use thiserror::Error;
81use tokio::runtime::Builder;
82use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
83use tokio::sync::oneshot::Sender;
84use tokio::sync::watch;
85use tokio::task::JoinHandle;
86use tracing::info;
87use tracing::log::error;
88
89use self::cursor_manager::CursorManager;
90use crate::binder::{Binder, BoundStatement, ResolveQualifiedNameError};
91use crate::catalog::catalog_service::{CatalogReader, CatalogWriter, CatalogWriterImpl};
92use crate::catalog::connection_catalog::ConnectionCatalog;
93use crate::catalog::root_catalog::{Catalog, SchemaPath};
94use crate::catalog::secret_catalog::SecretCatalog;
95use crate::catalog::source_catalog::SourceCatalog;
96use crate::catalog::subscription_catalog::SubscriptionCatalog;
97use crate::catalog::{
98 CatalogError, DatabaseId, OwnedByUserCatalog, SchemaId, TableId, check_schema_writable,
99};
100use crate::error::{ErrorCode, Result, RwError};
101use crate::handler::describe::infer_describe;
102use crate::handler::extended_handle::{
103 Portal, PrepareStatement, handle_bind, handle_execute, handle_parse,
104};
105use crate::handler::privilege::ObjectCheckItem;
106use crate::handler::show::{infer_show_create_object, infer_show_object};
107use crate::handler::util::to_pg_field;
108use crate::handler::variable::infer_show_variable;
109use crate::handler::{RwPgResponse, handle};
110use crate::health_service::HealthServiceImpl;
111use crate::meta_client::{FrontendMetaClient, FrontendMetaClientImpl};
112use crate::monitor::{CursorMetrics, FrontendMetrics, GLOBAL_FRONTEND_METRICS};
113use crate::observer::FrontendObserverNode;
114use crate::rpc::FrontendServiceImpl;
115use crate::scheduler::streaming_manager::{StreamingJobTracker, StreamingJobTrackerRef};
116use crate::scheduler::{
117 DistributedQueryMetrics, GLOBAL_DISTRIBUTED_QUERY_METRICS, HummockSnapshotManager,
118 HummockSnapshotManagerRef, QueryManager,
119};
120use crate::telemetry::FrontendTelemetryCreator;
121use crate::user::UserId;
122use crate::user::user_authentication::md5_hash_with_salt;
123use crate::user::user_manager::UserInfoManager;
124use crate::user::user_service::{UserInfoReader, UserInfoWriter, UserInfoWriterImpl};
125use crate::{FrontendOpts, PgResponseStream, TableCatalog};
126
127pub(crate) mod current;
128pub(crate) mod cursor_manager;
129pub(crate) mod transaction;
130
131#[derive(Clone)]
133pub(crate) struct FrontendEnv {
134 meta_client: Arc<dyn FrontendMetaClient>,
137 catalog_writer: Arc<dyn CatalogWriter>,
138 catalog_reader: CatalogReader,
139 user_info_writer: Arc<dyn UserInfoWriter>,
140 user_info_reader: UserInfoReader,
141 worker_node_manager: WorkerNodeManagerRef,
142 query_manager: QueryManager,
143 hummock_snapshot_manager: HummockSnapshotManagerRef,
144 system_params_manager: LocalSystemParamsManagerRef,
145 session_params: Arc<RwLock<SessionConfig>>,
146
147 server_addr: HostAddr,
148 client_pool: ComputeClientPoolRef,
149
150 sessions_map: SessionMapRef,
154
155 pub frontend_metrics: Arc<FrontendMetrics>,
156
157 pub cursor_metrics: Arc<CursorMetrics>,
158
159 source_metrics: Arc<SourceMetrics>,
160
161 spill_metrics: Arc<BatchSpillMetrics>,
163
164 batch_config: BatchConfig,
165 frontend_config: FrontendConfig,
166 #[expect(dead_code)]
167 meta_config: MetaConfig,
168 streaming_config: StreamingConfig,
169 udf_config: UdfConfig,
170
171 creating_streaming_job_tracker: StreamingJobTrackerRef,
174
175 compute_runtime: Arc<BackgroundShutdownRuntime>,
178
179 mem_context: MemoryContext,
181
182 serverless_backfill_controller_addr: String,
184}
185
186pub type SessionMapRef = Arc<RwLock<HashMap<(i32, i32), Arc<SessionImpl>>>>;
188
189const FRONTEND_BATCH_MEMORY_PROPORTION: f64 = 0.5;
191
192impl FrontendEnv {
193 pub fn mock() -> Self {
194 use crate::test_utils::{MockCatalogWriter, MockFrontendMetaClient, MockUserInfoWriter};
195
196 let catalog = Arc::new(RwLock::new(Catalog::default()));
197 let meta_client = Arc::new(MockFrontendMetaClient {});
198 let hummock_snapshot_manager = Arc::new(HummockSnapshotManager::new(meta_client.clone()));
199 let catalog_writer = Arc::new(MockCatalogWriter::new(
200 catalog.clone(),
201 hummock_snapshot_manager.clone(),
202 ));
203 let catalog_reader = CatalogReader::new(catalog);
204 let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
205 let user_info_writer = Arc::new(MockUserInfoWriter::new(user_info_manager.clone()));
206 let user_info_reader = UserInfoReader::new(user_info_manager);
207 let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
208 let system_params_manager = Arc::new(LocalSystemParamsManager::for_test());
209 let compute_client_pool = Arc::new(ComputeClientPool::for_test());
210 let query_manager = QueryManager::new(
211 worker_node_manager.clone(),
212 compute_client_pool,
213 catalog_reader.clone(),
214 Arc::new(DistributedQueryMetrics::for_test()),
215 None,
216 None,
217 );
218 let server_addr = HostAddr::try_from("127.0.0.1:4565").unwrap();
219 let client_pool = Arc::new(ComputeClientPool::for_test());
220 let creating_streaming_tracker = StreamingJobTracker::new(meta_client.clone());
221 let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(
222 Builder::new_multi_thread()
223 .worker_threads(
224 load_config("", FrontendOpts::default())
225 .batch
226 .frontend_compute_runtime_worker_threads,
227 )
228 .thread_name("rw-batch-local")
229 .enable_all()
230 .build()
231 .unwrap(),
232 ));
233 let sessions_map = Arc::new(RwLock::new(HashMap::new()));
234 Self {
235 meta_client,
236 catalog_writer,
237 catalog_reader,
238 user_info_writer,
239 user_info_reader,
240 worker_node_manager,
241 query_manager,
242 hummock_snapshot_manager,
243 system_params_manager,
244 session_params: Default::default(),
245 server_addr,
246 client_pool,
247 sessions_map: sessions_map.clone(),
248 frontend_metrics: Arc::new(FrontendMetrics::for_test()),
249 cursor_metrics: Arc::new(CursorMetrics::for_test()),
250 batch_config: BatchConfig::default(),
251 frontend_config: FrontendConfig::default(),
252 meta_config: MetaConfig::default(),
253 streaming_config: StreamingConfig::default(),
254 udf_config: UdfConfig::default(),
255 source_metrics: Arc::new(SourceMetrics::default()),
256 spill_metrics: BatchSpillMetrics::for_test(),
257 creating_streaming_job_tracker: Arc::new(creating_streaming_tracker),
258 compute_runtime,
259 mem_context: MemoryContext::none(),
260 serverless_backfill_controller_addr: Default::default(),
261 }
262 }
263
264 pub async fn init(opts: FrontendOpts) -> Result<(Self, Vec<JoinHandle<()>>, Vec<Sender<()>>)> {
265 let config = load_config(&opts.config_path, &opts);
266 info!("Starting frontend node");
267 info!("> config: {:?}", config);
268 info!(
269 "> debug assertions: {}",
270 if cfg!(debug_assertions) { "on" } else { "off" }
271 );
272 info!("> version: {} ({})", RW_VERSION, GIT_SHA);
273
274 let frontend_address: HostAddr = opts
275 .advertise_addr
276 .as_ref()
277 .unwrap_or_else(|| {
278 tracing::warn!("advertise addr is not specified, defaulting to listen_addr");
279 &opts.listen_addr
280 })
281 .parse()
282 .unwrap();
283 info!("advertise addr is {}", frontend_address);
284
285 let rpc_addr: HostAddr = opts.frontend_rpc_listener_addr.parse().unwrap();
286 let internal_rpc_host_addr = HostAddr {
287 host: frontend_address.host.clone(),
289 port: rpc_addr.port,
290 };
291 let (meta_client, system_params_reader) = MetaClient::register_new(
293 opts.meta_addr,
294 WorkerType::Frontend,
295 &frontend_address,
296 AddWorkerNodeProperty {
297 internal_rpc_host_addr: internal_rpc_host_addr.to_string(),
298 ..Default::default()
299 },
300 &config.meta,
301 )
302 .await;
303
304 let worker_id = meta_client.worker_id();
305 info!("Assigned worker node id {}", worker_id);
306
307 let (heartbeat_join_handle, heartbeat_shutdown_sender) = MetaClient::start_heartbeat_loop(
308 meta_client.clone(),
309 Duration::from_millis(config.server.heartbeat_interval_ms as u64),
310 );
311 let mut join_handles = vec![heartbeat_join_handle];
312 let mut shutdown_senders = vec![heartbeat_shutdown_sender];
313
314 let frontend_meta_client = Arc::new(FrontendMetaClientImpl(meta_client.clone()));
315 let hummock_snapshot_manager =
316 Arc::new(HummockSnapshotManager::new(frontend_meta_client.clone()));
317
318 let (catalog_updated_tx, catalog_updated_rx) = watch::channel(0);
319 let catalog = Arc::new(RwLock::new(Catalog::default()));
320 let catalog_writer = Arc::new(CatalogWriterImpl::new(
321 meta_client.clone(),
322 catalog_updated_rx,
323 hummock_snapshot_manager.clone(),
324 ));
325 let catalog_reader = CatalogReader::new(catalog.clone());
326
327 let worker_node_manager = Arc::new(WorkerNodeManager::new());
328
329 let compute_client_pool = Arc::new(ComputeClientPool::new(
330 config.batch_exchange_connection_pool_size(),
331 config.batch.developer.compute_client_config.clone(),
332 ));
333 let query_manager = QueryManager::new(
334 worker_node_manager.clone(),
335 compute_client_pool.clone(),
336 catalog_reader.clone(),
337 Arc::new(GLOBAL_DISTRIBUTED_QUERY_METRICS.clone()),
338 config.batch.distributed_query_limit,
339 config.batch.max_batch_queries_per_frontend_node,
340 );
341
342 let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
343 let (user_info_updated_tx, user_info_updated_rx) = watch::channel(0);
344 let user_info_reader = UserInfoReader::new(user_info_manager.clone());
345 let user_info_writer = Arc::new(UserInfoWriterImpl::new(
346 meta_client.clone(),
347 user_info_updated_rx,
348 ));
349
350 let system_params_manager =
351 Arc::new(LocalSystemParamsManager::new(system_params_reader.clone()));
352
353 LocalSecretManager::init(
354 opts.temp_secret_file_dir,
355 meta_client.cluster_id().to_owned(),
356 worker_id,
357 );
358
359 let session_params = Arc::new(RwLock::new(SessionConfig::default()));
361 let sessions_map: SessionMapRef = Arc::new(RwLock::new(HashMap::new()));
362 let cursor_metrics = Arc::new(CursorMetrics::init(sessions_map.clone()));
363
364 let frontend_observer_node = FrontendObserverNode::new(
365 worker_node_manager.clone(),
366 catalog,
367 catalog_updated_tx,
368 user_info_manager,
369 user_info_updated_tx,
370 hummock_snapshot_manager.clone(),
371 system_params_manager.clone(),
372 session_params.clone(),
373 compute_client_pool.clone(),
374 );
375 let observer_manager =
376 ObserverManager::new_with_meta_client(meta_client.clone(), frontend_observer_node)
377 .await;
378 let observer_join_handle = observer_manager.start().await;
379 join_handles.push(observer_join_handle);
380
381 meta_client.activate(&frontend_address).await?;
382
383 let frontend_metrics = Arc::new(GLOBAL_FRONTEND_METRICS.clone());
384 let source_metrics = Arc::new(GLOBAL_SOURCE_METRICS.clone());
385 let spill_metrics = Arc::new(GLOBAL_BATCH_SPILL_METRICS.clone());
386
387 if config.server.metrics_level > MetricLevel::Disabled {
388 MetricsManager::boot_metrics_service(opts.prometheus_listener_addr.clone());
389 }
390
391 let health_srv = HealthServiceImpl::new();
392 let frontend_srv = FrontendServiceImpl::new();
393 let frontend_rpc_addr = opts.frontend_rpc_listener_addr.parse().unwrap();
394
395 let telemetry_manager = TelemetryManager::new(
396 Arc::new(meta_client.clone()),
397 Arc::new(FrontendTelemetryCreator::new()),
398 );
399
400 if config.server.telemetry_enabled && telemetry_env_enabled() {
403 let (join_handle, shutdown_sender) = telemetry_manager.start().await;
404 join_handles.push(join_handle);
405 shutdown_senders.push(shutdown_sender);
406 } else {
407 tracing::info!("Telemetry didn't start due to config");
408 }
409
410 tokio::spawn(async move {
411 tonic::transport::Server::builder()
412 .add_service(HealthServer::new(health_srv))
413 .add_service(FrontendServiceServer::new(frontend_srv))
414 .serve(frontend_rpc_addr)
415 .await
416 .unwrap();
417 });
418 info!(
419 "Health Check RPC Listener is set up on {}",
420 opts.frontend_rpc_listener_addr.clone()
421 );
422
423 let creating_streaming_job_tracker =
424 Arc::new(StreamingJobTracker::new(frontend_meta_client.clone()));
425
426 let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(
427 Builder::new_multi_thread()
428 .worker_threads(config.batch.frontend_compute_runtime_worker_threads)
429 .thread_name("rw-batch-local")
430 .enable_all()
431 .build()
432 .unwrap(),
433 ));
434
435 let sessions = sessions_map.clone();
436 let join_handle = tokio::spawn(async move {
438 let mut check_idle_txn_interval =
439 tokio::time::interval(core::time::Duration::from_secs(10));
440 check_idle_txn_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
441 check_idle_txn_interval.reset();
442 loop {
443 check_idle_txn_interval.tick().await;
444 sessions.read().values().for_each(|session| {
445 let _ = session.check_idle_in_transaction_timeout();
446 })
447 }
448 });
449 join_handles.push(join_handle);
450
451 #[cfg(not(madsim))]
453 if config.batch.enable_spill {
454 SpillOp::clean_spill_directory()
455 .await
456 .map_err(|err| anyhow!(err))?;
457 }
458
459 let total_memory_bytes = opts.frontend_total_memory_bytes;
460 let heap_profiler =
461 HeapProfiler::new(total_memory_bytes, config.server.heap_profiling.clone());
462 heap_profiler.start();
464
465 let batch_memory_limit = total_memory_bytes as f64 * FRONTEND_BATCH_MEMORY_PROPORTION;
466 let mem_context = MemoryContext::root(
467 frontend_metrics.batch_total_mem.clone(),
468 batch_memory_limit as u64,
469 );
470
471 info!(
472 "Frontend total_memory: {} batch_memory: {}",
473 convert(total_memory_bytes as _),
474 convert(batch_memory_limit as _),
475 );
476
477 Ok((
478 Self {
479 catalog_reader,
480 catalog_writer,
481 user_info_reader,
482 user_info_writer,
483 worker_node_manager,
484 meta_client: frontend_meta_client,
485 query_manager,
486 hummock_snapshot_manager,
487 system_params_manager,
488 session_params,
489 server_addr: frontend_address,
490 client_pool: compute_client_pool,
491 frontend_metrics,
492 cursor_metrics,
493 spill_metrics,
494 sessions_map,
495 batch_config: config.batch,
496 frontend_config: config.frontend,
497 meta_config: config.meta,
498 streaming_config: config.streaming,
499 serverless_backfill_controller_addr: opts.serverless_backfill_controller_addr,
500 udf_config: config.udf,
501 source_metrics,
502 creating_streaming_job_tracker,
503 compute_runtime,
504 mem_context,
505 },
506 join_handles,
507 shutdown_senders,
508 ))
509 }
510
511 fn catalog_writer(&self, _guard: transaction::WriteGuard) -> &dyn CatalogWriter {
516 &*self.catalog_writer
517 }
518
519 pub fn catalog_reader(&self) -> &CatalogReader {
521 &self.catalog_reader
522 }
523
524 fn user_info_writer(&self, _guard: transaction::WriteGuard) -> &dyn UserInfoWriter {
529 &*self.user_info_writer
530 }
531
532 pub fn user_info_reader(&self) -> &UserInfoReader {
534 &self.user_info_reader
535 }
536
537 pub fn worker_node_manager_ref(&self) -> WorkerNodeManagerRef {
538 self.worker_node_manager.clone()
539 }
540
541 pub fn meta_client(&self) -> &dyn FrontendMetaClient {
542 &*self.meta_client
543 }
544
545 pub fn meta_client_ref(&self) -> Arc<dyn FrontendMetaClient> {
546 self.meta_client.clone()
547 }
548
549 pub fn query_manager(&self) -> &QueryManager {
550 &self.query_manager
551 }
552
553 pub fn hummock_snapshot_manager(&self) -> &HummockSnapshotManagerRef {
554 &self.hummock_snapshot_manager
555 }
556
557 pub fn system_params_manager(&self) -> &LocalSystemParamsManagerRef {
558 &self.system_params_manager
559 }
560
561 pub fn session_params_snapshot(&self) -> SessionConfig {
562 self.session_params.read_recursive().clone()
563 }
564
565 pub fn sbc_address(&self) -> &String {
566 &self.serverless_backfill_controller_addr
567 }
568
569 pub fn server_address(&self) -> &HostAddr {
570 &self.server_addr
571 }
572
573 pub fn client_pool(&self) -> ComputeClientPoolRef {
574 self.client_pool.clone()
575 }
576
577 pub fn batch_config(&self) -> &BatchConfig {
578 &self.batch_config
579 }
580
581 pub fn frontend_config(&self) -> &FrontendConfig {
582 &self.frontend_config
583 }
584
585 pub fn streaming_config(&self) -> &StreamingConfig {
586 &self.streaming_config
587 }
588
589 pub fn udf_config(&self) -> &UdfConfig {
590 &self.udf_config
591 }
592
593 pub fn source_metrics(&self) -> Arc<SourceMetrics> {
594 self.source_metrics.clone()
595 }
596
597 pub fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
598 self.spill_metrics.clone()
599 }
600
601 pub fn creating_streaming_job_tracker(&self) -> &StreamingJobTrackerRef {
602 &self.creating_streaming_job_tracker
603 }
604
605 pub fn sessions_map(&self) -> &SessionMapRef {
606 &self.sessions_map
607 }
608
609 pub fn compute_runtime(&self) -> Arc<BackgroundShutdownRuntime> {
610 self.compute_runtime.clone()
611 }
612
613 pub fn cancel_queries_in_session(&self, session_id: SessionId) -> bool {
616 let guard = self.sessions_map.read();
617 if let Some(session) = guard.get(&session_id) {
618 session.cancel_current_query();
619 true
620 } else {
621 info!("Current session finished, ignoring cancel query request");
622 false
623 }
624 }
625
626 pub fn cancel_creating_jobs_in_session(&self, session_id: SessionId) -> bool {
629 let guard = self.sessions_map.read();
630 if let Some(session) = guard.get(&session_id) {
631 session.cancel_current_creating_job();
632 true
633 } else {
634 info!("Current session finished, ignoring cancel creating request");
635 false
636 }
637 }
638
639 pub fn mem_context(&self) -> MemoryContext {
640 self.mem_context.clone()
641 }
642}
643
644#[derive(Clone)]
645pub struct AuthContext {
646 pub database: String,
647 pub user_name: String,
648 pub user_id: UserId,
649}
650
651impl AuthContext {
652 pub fn new(database: String, user_name: String, user_id: UserId) -> Self {
653 Self {
654 database,
655 user_name,
656 user_id,
657 }
658 }
659}
660pub struct SessionImpl {
661 env: FrontendEnv,
662 auth_context: Arc<RwLock<AuthContext>>,
663 user_authenticator: UserAuthenticator,
665 config_map: Arc<RwLock<SessionConfig>>,
667
668 notice_tx: UnboundedSender<String>,
670 notice_rx: Mutex<UnboundedReceiver<String>>,
672
673 id: (i32, i32),
675
676 peer_addr: AddressRef,
678
679 txn: Arc<Mutex<transaction::State>>,
683
684 current_query_cancel_flag: Mutex<Option<ShutdownSender>>,
688
689 exec_context: Mutex<Option<Weak<ExecContext>>>,
691
692 last_idle_instant: Arc<Mutex<Option<Instant>>>,
694
695 cursor_manager: Arc<CursorManager>,
696
697 temporary_source_manager: Arc<Mutex<TemporarySourceManager>>,
699}
700
701#[derive(Default, Clone)]
710pub struct TemporarySourceManager {
711 sources: HashMap<String, SourceCatalog>,
712}
713
714impl TemporarySourceManager {
715 pub fn new() -> Self {
716 Self {
717 sources: HashMap::new(),
718 }
719 }
720
721 pub fn create_source(&mut self, name: String, source: SourceCatalog) {
722 self.sources.insert(name, source);
723 }
724
725 pub fn drop_source(&mut self, name: &str) {
726 self.sources.remove(name);
727 }
728
729 pub fn get_source(&self, name: &str) -> Option<&SourceCatalog> {
730 self.sources.get(name)
731 }
732
733 pub fn keys(&self) -> Vec<String> {
734 self.sources.keys().cloned().collect()
735 }
736}
737
738#[derive(Error, Debug)]
739pub enum CheckRelationError {
740 #[error("{0}")]
741 Resolve(#[from] ResolveQualifiedNameError),
742 #[error("{0}")]
743 Catalog(#[from] CatalogError),
744}
745
746impl From<CheckRelationError> for RwError {
747 fn from(e: CheckRelationError) -> Self {
748 match e {
749 CheckRelationError::Resolve(e) => e.into(),
750 CheckRelationError::Catalog(e) => e.into(),
751 }
752 }
753}
754
755impl SessionImpl {
756 pub(crate) fn new(
757 env: FrontendEnv,
758 auth_context: AuthContext,
759 user_authenticator: UserAuthenticator,
760 id: SessionId,
761 peer_addr: AddressRef,
762 session_config: SessionConfig,
763 ) -> Self {
764 let cursor_metrics = env.cursor_metrics.clone();
765 let (notice_tx, notice_rx) = mpsc::unbounded_channel();
766
767 Self {
768 env,
769 auth_context: Arc::new(RwLock::new(auth_context)),
770 user_authenticator,
771 config_map: Arc::new(RwLock::new(session_config)),
772 id,
773 peer_addr,
774 txn: Default::default(),
775 current_query_cancel_flag: Mutex::new(None),
776 notice_tx,
777 notice_rx: Mutex::new(notice_rx),
778 exec_context: Mutex::new(None),
779 last_idle_instant: Default::default(),
780 cursor_manager: Arc::new(CursorManager::new(cursor_metrics)),
781 temporary_source_manager: Default::default(),
782 }
783 }
784
785 #[cfg(test)]
786 pub fn mock() -> Self {
787 let env = FrontendEnv::mock();
788 let (notice_tx, notice_rx) = mpsc::unbounded_channel();
789
790 Self {
791 env: FrontendEnv::mock(),
792 auth_context: Arc::new(RwLock::new(AuthContext::new(
793 DEFAULT_DATABASE_NAME.to_owned(),
794 DEFAULT_SUPER_USER.to_owned(),
795 DEFAULT_SUPER_USER_ID,
796 ))),
797 user_authenticator: UserAuthenticator::None,
798 config_map: Default::default(),
799 id: (0, 0),
801 txn: Default::default(),
802 current_query_cancel_flag: Mutex::new(None),
803 notice_tx,
804 notice_rx: Mutex::new(notice_rx),
805 exec_context: Mutex::new(None),
806 peer_addr: Address::Tcp(SocketAddr::new(
807 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
808 8080,
809 ))
810 .into(),
811 last_idle_instant: Default::default(),
812 cursor_manager: Arc::new(CursorManager::new(env.cursor_metrics.clone())),
813 temporary_source_manager: Default::default(),
814 }
815 }
816
817 pub(crate) fn env(&self) -> &FrontendEnv {
818 &self.env
819 }
820
821 pub fn auth_context(&self) -> Arc<AuthContext> {
822 let ctx = self.auth_context.read();
823 Arc::new(ctx.clone())
824 }
825
826 pub fn database(&self) -> String {
827 self.auth_context.read().database.clone()
828 }
829
830 pub fn database_id(&self) -> DatabaseId {
831 let db_name = self.database();
832 self.env
833 .catalog_reader()
834 .read_guard()
835 .get_database_by_name(&db_name)
836 .map(|db| db.id())
837 .expect("session database not found")
838 }
839
840 pub fn user_name(&self) -> String {
841 self.auth_context.read().user_name.clone()
842 }
843
844 pub fn user_id(&self) -> UserId {
845 self.auth_context.read().user_id
846 }
847
848 pub fn update_database(&self, database: String) {
849 self.auth_context.write().database = database;
850 }
851
852 pub fn shared_config(&self) -> Arc<RwLock<SessionConfig>> {
853 Arc::clone(&self.config_map)
854 }
855
856 pub fn config(&self) -> RwLockReadGuard<'_, SessionConfig> {
857 self.config_map.read()
858 }
859
860 pub fn set_config(&self, key: &str, value: String) -> Result<String> {
861 self.config_map
862 .write()
863 .set(key, value, &mut ())
864 .map_err(Into::into)
865 }
866
867 pub fn reset_config(&self, key: &str) -> Result<String> {
868 self.config_map
869 .write()
870 .reset(key, &mut ())
871 .map_err(Into::into)
872 }
873
874 pub fn set_config_report(
875 &self,
876 key: &str,
877 value: Option<String>,
878 mut reporter: impl ConfigReporter,
879 ) -> Result<String> {
880 if let Some(value) = value {
881 self.config_map
882 .write()
883 .set(key, value, &mut reporter)
884 .map_err(Into::into)
885 } else {
886 self.config_map
887 .write()
888 .reset(key, &mut reporter)
889 .map_err(Into::into)
890 }
891 }
892
893 pub fn session_id(&self) -> SessionId {
894 self.id
895 }
896
897 pub fn running_sql(&self) -> Option<Arc<str>> {
898 self.exec_context
899 .lock()
900 .as_ref()
901 .and_then(|weak| weak.upgrade())
902 .map(|context| context.running_sql.clone())
903 }
904
905 pub fn get_cursor_manager(&self) -> Arc<CursorManager> {
906 self.cursor_manager.clone()
907 }
908
909 pub fn peer_addr(&self) -> &Address {
910 &self.peer_addr
911 }
912
913 pub fn elapse_since_running_sql(&self) -> Option<u128> {
914 self.exec_context
915 .lock()
916 .as_ref()
917 .and_then(|weak| weak.upgrade())
918 .map(|context| context.last_instant.elapsed().as_millis())
919 }
920
921 pub fn elapse_since_last_idle_instant(&self) -> Option<u128> {
922 self.last_idle_instant
923 .lock()
924 .as_ref()
925 .map(|x| x.elapsed().as_millis())
926 }
927
928 pub fn check_relation_name_duplicated(
929 &self,
930 name: ObjectName,
931 stmt_type: StatementType,
932 if_not_exists: bool,
933 ) -> std::result::Result<Either<(), RwPgResponse>, CheckRelationError> {
934 let db_name = &self.database();
935 let catalog_reader = self.env().catalog_reader().read_guard();
936 let (schema_name, relation_name) = {
937 let (schema_name, relation_name) =
938 Binder::resolve_schema_qualified_name(db_name, name)?;
939 let search_path = self.config().search_path();
940 let user_name = &self.user_name();
941 let schema_name = match schema_name {
942 Some(schema_name) => schema_name,
943 None => catalog_reader
944 .first_valid_schema(db_name, &search_path, user_name)?
945 .name(),
946 };
947 (schema_name, relation_name)
948 };
949 match catalog_reader.check_relation_name_duplicated(db_name, &schema_name, &relation_name) {
950 Err(CatalogError::Duplicated(_, name, is_creating)) if if_not_exists => {
951 let is_creating_str = if is_creating {
952 " but still creating"
953 } else {
954 ""
955 };
956 Ok(Either::Right(
957 PgResponse::builder(stmt_type)
958 .notice(format!(
959 "relation \"{}\" already exists{}, skipping",
960 name, is_creating_str
961 ))
962 .into(),
963 ))
964 }
965 Err(e) => Err(e.into()),
966 Ok(_) => Ok(Either::Left(())),
967 }
968 }
969
970 pub fn check_secret_name_duplicated(&self, name: ObjectName) -> Result<()> {
971 let db_name = &self.database();
972 let catalog_reader = self.env().catalog_reader().read_guard();
973 let (schema_name, secret_name) = {
974 let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
975 let search_path = self.config().search_path();
976 let user_name = &self.user_name();
977 let schema_name = match schema_name {
978 Some(schema_name) => schema_name,
979 None => catalog_reader
980 .first_valid_schema(db_name, &search_path, user_name)?
981 .name(),
982 };
983 (schema_name, secret_name)
984 };
985 catalog_reader
986 .check_secret_name_duplicated(db_name, &schema_name, &secret_name)
987 .map_err(RwError::from)
988 }
989
990 pub fn check_connection_name_duplicated(&self, name: ObjectName) -> Result<()> {
991 let db_name = &self.database();
992 let catalog_reader = self.env().catalog_reader().read_guard();
993 let (schema_name, connection_name) = {
994 let (schema_name, connection_name) =
995 Binder::resolve_schema_qualified_name(db_name, name)?;
996 let search_path = self.config().search_path();
997 let user_name = &self.user_name();
998 let schema_name = match schema_name {
999 Some(schema_name) => schema_name,
1000 None => catalog_reader
1001 .first_valid_schema(db_name, &search_path, user_name)?
1002 .name(),
1003 };
1004 (schema_name, connection_name)
1005 };
1006 catalog_reader
1007 .check_connection_name_duplicated(db_name, &schema_name, &connection_name)
1008 .map_err(RwError::from)
1009 }
1010
1011 pub fn check_function_name_duplicated(
1012 &self,
1013 stmt_type: StatementType,
1014 name: ObjectName,
1015 arg_types: &[DataType],
1016 if_not_exists: bool,
1017 ) -> Result<Either<(), RwPgResponse>> {
1018 let db_name = &self.database();
1019 let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
1020 let (database_id, schema_id) = self.get_database_and_schema_id_for_create(schema_name)?;
1021
1022 let catalog_reader = self.env().catalog_reader().read_guard();
1023 if catalog_reader
1024 .get_schema_by_id(&database_id, &schema_id)?
1025 .get_function_by_name_args(&function_name, arg_types)
1026 .is_some()
1027 {
1028 let full_name = format!(
1029 "{function_name}({})",
1030 arg_types.iter().map(|t| t.to_string()).join(",")
1031 );
1032 if if_not_exists {
1033 Ok(Either::Right(
1034 PgResponse::builder(stmt_type)
1035 .notice(format!(
1036 "function \"{}\" already exists, skipping",
1037 full_name
1038 ))
1039 .into(),
1040 ))
1041 } else {
1042 Err(CatalogError::duplicated("function", full_name).into())
1043 }
1044 } else {
1045 Ok(Either::Left(()))
1046 }
1047 }
1048
1049 pub fn get_database_and_schema_id_for_create(
1051 &self,
1052 schema_name: Option<String>,
1053 ) -> Result<(DatabaseId, SchemaId)> {
1054 let db_name = &self.database();
1055
1056 let search_path = self.config().search_path();
1057 let user_name = &self.user_name();
1058
1059 let catalog_reader = self.env().catalog_reader().read_guard();
1060 let schema = match schema_name {
1061 Some(schema_name) => catalog_reader.get_schema_by_name(db_name, &schema_name)?,
1062 None => catalog_reader.first_valid_schema(db_name, &search_path, user_name)?,
1063 };
1064
1065 check_schema_writable(&schema.name())?;
1066 self.check_privileges(&[ObjectCheckItem::new(
1067 schema.owner(),
1068 AclMode::Create,
1069 Object::SchemaId(schema.id()),
1070 )])?;
1071
1072 let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1073 Ok((db_id, schema.id()))
1074 }
1075
1076 pub fn get_connection_by_name(
1077 &self,
1078 schema_name: Option<String>,
1079 connection_name: &str,
1080 ) -> Result<Arc<ConnectionCatalog>> {
1081 let db_name = &self.database();
1082 let search_path = self.config().search_path();
1083 let user_name = &self.user_name();
1084
1085 let catalog_reader = self.env().catalog_reader().read_guard();
1086 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1087 let (connection, _) =
1088 catalog_reader.get_connection_by_name(db_name, schema_path, connection_name)?;
1089
1090 self.check_privileges(&[ObjectCheckItem::new(
1091 connection.owner(),
1092 AclMode::Usage,
1093 Object::ConnectionId(connection.id),
1094 )])?;
1095
1096 Ok(connection.clone())
1097 }
1098
1099 pub fn get_subscription_by_schema_id_name(
1100 &self,
1101 schema_id: SchemaId,
1102 subscription_name: &str,
1103 ) -> Result<Arc<SubscriptionCatalog>> {
1104 let db_name = &self.database();
1105
1106 let catalog_reader = self.env().catalog_reader().read_guard();
1107 let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1108 let schema = catalog_reader.get_schema_by_id(&db_id, &schema_id)?;
1109 let subscription = schema
1110 .get_subscription_by_name(subscription_name)
1111 .ok_or_else(|| {
1112 RwError::from(ErrorCode::ItemNotFound(format!(
1113 "subscription {} not found",
1114 subscription_name
1115 )))
1116 })?;
1117 Ok(subscription.clone())
1118 }
1119
1120 pub fn get_subscription_by_name(
1121 &self,
1122 schema_name: Option<String>,
1123 subscription_name: &str,
1124 ) -> Result<Arc<SubscriptionCatalog>> {
1125 let db_name = &self.database();
1126 let search_path = self.config().search_path();
1127 let user_name = &self.user_name();
1128
1129 let catalog_reader = self.env().catalog_reader().read_guard();
1130 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1131 let (subscription, _) =
1132 catalog_reader.get_subscription_by_name(db_name, schema_path, subscription_name)?;
1133 Ok(subscription.clone())
1134 }
1135
1136 pub fn get_table_by_id(&self, table_id: &TableId) -> Result<Arc<TableCatalog>> {
1137 let catalog_reader = self.env().catalog_reader().read_guard();
1138 Ok(catalog_reader.get_any_table_by_id(table_id)?.clone())
1139 }
1140
1141 pub fn get_table_by_name(
1142 &self,
1143 table_name: &str,
1144 db_id: u32,
1145 schema_id: u32,
1146 ) -> Result<Arc<TableCatalog>> {
1147 let catalog_reader = self.env().catalog_reader().read_guard();
1148 let table = catalog_reader
1149 .get_schema_by_id(&DatabaseId::from(db_id), &SchemaId::from(schema_id))?
1150 .get_created_table_by_name(table_name)
1151 .ok_or_else(|| {
1152 Error::new(
1153 ErrorKind::InvalidInput,
1154 format!("table \"{}\" does not exist", table_name),
1155 )
1156 })?;
1157
1158 self.check_privileges(&[ObjectCheckItem::new(
1159 table.owner(),
1160 AclMode::Select,
1161 Object::TableId(table.id.table_id()),
1162 )])?;
1163
1164 Ok(table.clone())
1165 }
1166
1167 pub fn get_secret_by_name(
1168 &self,
1169 schema_name: Option<String>,
1170 secret_name: &str,
1171 ) -> Result<Arc<SecretCatalog>> {
1172 let db_name = &self.database();
1173 let search_path = self.config().search_path();
1174 let user_name = &self.user_name();
1175
1176 let catalog_reader = self.env().catalog_reader().read_guard();
1177 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1178 let (secret, _) = catalog_reader.get_secret_by_name(db_name, schema_path, secret_name)?;
1179
1180 self.check_privileges(&[ObjectCheckItem::new(
1181 secret.owner(),
1182 AclMode::Create,
1183 Object::SecretId(secret.id.secret_id()),
1184 )])?;
1185
1186 Ok(secret.clone())
1187 }
1188
1189 pub fn list_change_log_epochs(
1190 &self,
1191 table_id: u32,
1192 min_epoch: u64,
1193 max_count: u32,
1194 ) -> Result<Vec<u64>> {
1195 Ok(self
1196 .env
1197 .hummock_snapshot_manager()
1198 .acquire()
1199 .list_change_log_epochs(table_id, min_epoch, max_count))
1200 }
1201
1202 pub fn clear_cancel_query_flag(&self) {
1203 let mut flag = self.current_query_cancel_flag.lock();
1204 *flag = None;
1205 }
1206
1207 pub fn reset_cancel_query_flag(&self) -> ShutdownToken {
1208 let mut flag = self.current_query_cancel_flag.lock();
1209 let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
1210 *flag = Some(shutdown_tx);
1211 shutdown_rx
1212 }
1213
1214 pub fn cancel_current_query(&self) {
1215 let mut flag_guard = self.current_query_cancel_flag.lock();
1216 if let Some(sender) = flag_guard.take() {
1217 info!("Trying to cancel query in local mode.");
1218 sender.cancel();
1220 info!("Cancel query request sent.");
1221 } else {
1222 info!("Trying to cancel query in distributed mode.");
1223 self.env.query_manager().cancel_queries_in_session(self.id)
1224 }
1225 }
1226
1227 pub fn cancel_current_creating_job(&self) {
1228 self.env.creating_streaming_job_tracker.abort_jobs(self.id);
1229 }
1230
1231 pub async fn run_statement(
1234 self: Arc<Self>,
1235 sql: Arc<str>,
1236 formats: Vec<Format>,
1237 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1238 let mut stmts = Parser::parse_sql(&sql)?;
1240 if stmts.is_empty() {
1241 return Ok(PgResponse::empty_result(
1242 pgwire::pg_response::StatementType::EMPTY,
1243 ));
1244 }
1245 if stmts.len() > 1 {
1246 return Ok(
1247 PgResponse::builder(pgwire::pg_response::StatementType::EMPTY)
1248 .notice("cannot insert multiple commands into statement")
1249 .into(),
1250 );
1251 }
1252 let stmt = stmts.swap_remove(0);
1253 let rsp = handle(self, stmt, sql.clone(), formats).await?;
1254 Ok(rsp)
1255 }
1256
1257 pub fn notice_to_user(&self, str: impl Into<String>) {
1258 let notice = str.into();
1259 tracing::trace!(notice, "notice to user");
1260 self.notice_tx
1261 .send(notice)
1262 .expect("notice channel should not be closed");
1263 }
1264
1265 pub fn is_barrier_read(&self) -> bool {
1266 match self.config().visibility_mode() {
1267 VisibilityMode::Default => self.env.batch_config.enable_barrier_read,
1268 VisibilityMode::All => true,
1269 VisibilityMode::Checkpoint => false,
1270 }
1271 }
1272
1273 pub fn statement_timeout(&self) -> Duration {
1274 if self.config().statement_timeout() == 0 {
1275 Duration::from_secs(self.env.batch_config.statement_timeout_in_sec as u64)
1276 } else {
1277 Duration::from_secs(self.config().statement_timeout() as u64)
1278 }
1279 }
1280
1281 pub fn create_temporary_source(&self, source: SourceCatalog) {
1282 self.temporary_source_manager
1283 .lock()
1284 .create_source(source.name.clone(), source);
1285 }
1286
1287 pub fn get_temporary_source(&self, name: &str) -> Option<SourceCatalog> {
1288 self.temporary_source_manager
1289 .lock()
1290 .get_source(name)
1291 .cloned()
1292 }
1293
1294 pub fn drop_temporary_source(&self, name: &str) {
1295 self.temporary_source_manager.lock().drop_source(name);
1296 }
1297
1298 pub fn temporary_source_manager(&self) -> TemporarySourceManager {
1299 self.temporary_source_manager.lock().clone()
1300 }
1301
1302 pub async fn check_cluster_limits(&self) -> Result<()> {
1303 if self.config().bypass_cluster_limits() {
1304 return Ok(());
1305 }
1306
1307 let gen_message = |ActorCountPerParallelism {
1308 worker_id_to_actor_count,
1309 hard_limit,
1310 soft_limit,
1311 }: ActorCountPerParallelism,
1312 exceed_hard_limit: bool|
1313 -> String {
1314 let (limit_type, action) = if exceed_hard_limit {
1315 ("critical", "Scale the cluster immediately to proceed.")
1316 } else {
1317 (
1318 "recommended",
1319 "Consider scaling the cluster for optimal performance.",
1320 )
1321 };
1322 format!(
1323 r#"Actor count per parallelism exceeds the {limit_type} limit.
1324
1325Depending on your workload, this may overload the cluster and cause performance/stability issues. {action}
1326
1327HINT:
1328- For best practices on managing streaming jobs: https://docs.risingwave.com/operate/manage-a-large-number-of-streaming-jobs
1329- To bypass the check (if the cluster load is acceptable): `[ALTER SYSTEM] SET bypass_cluster_limits TO true`.
1330 See https://docs.risingwave.com/operate/view-configure-runtime-parameters#how-to-configure-runtime-parameters
1331- Contact us via slack or https://risingwave.com/contact-us/ for further enquiry.
1332
1333DETAILS:
1334- hard limit: {hard_limit}
1335- soft limit: {soft_limit}
1336- worker_id_to_actor_count: {worker_id_to_actor_count:?}"#,
1337 )
1338 };
1339
1340 let limits = self.env().meta_client().get_cluster_limits().await?;
1341 for limit in limits {
1342 match limit {
1343 cluster_limit::ClusterLimit::ActorCount(l) => {
1344 if l.exceed_hard_limit() {
1345 return Err(RwError::from(ErrorCode::ProtocolError(gen_message(
1346 l, true,
1347 ))));
1348 } else if l.exceed_soft_limit() {
1349 self.notice_to_user(gen_message(l, false));
1350 }
1351 }
1352 }
1353 }
1354 Ok(())
1355 }
1356}
1357
1358pub static SESSION_MANAGER: std::sync::OnceLock<Arc<SessionManagerImpl>> =
1359 std::sync::OnceLock::new();
1360
1361pub struct SessionManagerImpl {
1362 env: FrontendEnv,
1363 _join_handles: Vec<JoinHandle<()>>,
1364 _shutdown_senders: Vec<Sender<()>>,
1365 number: AtomicI32,
1366}
1367
1368impl SessionManager for SessionManagerImpl {
1369 type Session = SessionImpl;
1370
1371 fn create_dummy_session(
1372 &self,
1373 database_id: u32,
1374 user_id: u32,
1375 ) -> std::result::Result<Arc<Self::Session>, BoxedError> {
1376 let dummy_addr = Address::Tcp(SocketAddr::new(
1377 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
1378 5691, ));
1380 let user_reader = self.env.user_info_reader();
1381 let reader = user_reader.read_guard();
1382 if let Some(user_name) = reader.get_user_name_by_id(user_id) {
1383 self.connect_inner(database_id, user_name.as_str(), Arc::new(dummy_addr))
1384 } else {
1385 Err(Box::new(Error::new(
1386 ErrorKind::InvalidInput,
1387 format!("Role id {} does not exist", user_id),
1388 )))
1389 }
1390 }
1391
1392 fn connect(
1393 &self,
1394 database: &str,
1395 user_name: &str,
1396 peer_addr: AddressRef,
1397 ) -> std::result::Result<Arc<Self::Session>, BoxedError> {
1398 let catalog_reader = self.env.catalog_reader();
1399 let reader = catalog_reader.read_guard();
1400 let database_id = reader
1401 .get_database_by_name(database)
1402 .map_err(|_| {
1403 Box::new(Error::new(
1404 ErrorKind::InvalidInput,
1405 format!("database \"{}\" does not exist", database),
1406 ))
1407 })?
1408 .id();
1409
1410 self.connect_inner(database_id, user_name, peer_addr)
1411 }
1412
1413 fn cancel_queries_in_session(&self, session_id: SessionId) {
1415 self.env.cancel_queries_in_session(session_id);
1416 }
1417
1418 fn cancel_creating_jobs_in_session(&self, session_id: SessionId) {
1419 self.env.cancel_creating_jobs_in_session(session_id);
1420 }
1421
1422 fn end_session(&self, session: &Self::Session) {
1423 self.delete_session(&session.session_id());
1424 }
1425
1426 async fn shutdown(&self) {
1427 self.env.sessions_map().write().clear();
1429 self.env.meta_client().try_unregister().await;
1431 }
1432}
1433
1434impl SessionManagerImpl {
1435 pub async fn new(opts: FrontendOpts) -> Result<Self> {
1436 let (env, join_handles, shutdown_senders) = FrontendEnv::init(opts).await?;
1438 Ok(Self {
1439 env,
1440 _join_handles: join_handles,
1441 _shutdown_senders: shutdown_senders,
1442 number: AtomicI32::new(0),
1443 })
1444 }
1445
1446 pub(crate) fn env(&self) -> &FrontendEnv {
1447 &self.env
1448 }
1449
1450 fn insert_session(&self, session: Arc<SessionImpl>) {
1451 let active_sessions = {
1452 let mut write_guard = self.env.sessions_map.write();
1453 write_guard.insert(session.id(), session);
1454 write_guard.len()
1455 };
1456 self.env
1457 .frontend_metrics
1458 .active_sessions
1459 .set(active_sessions as i64);
1460 }
1461
1462 fn delete_session(&self, session_id: &SessionId) {
1463 let active_sessions = {
1464 let mut write_guard = self.env.sessions_map.write();
1465 write_guard.remove(session_id);
1466 write_guard.len()
1467 };
1468 self.env
1469 .frontend_metrics
1470 .active_sessions
1471 .set(active_sessions as i64);
1472 }
1473
1474 fn connect_inner(
1475 &self,
1476 database_id: u32,
1477 user_name: &str,
1478 peer_addr: AddressRef,
1479 ) -> std::result::Result<Arc<SessionImpl>, BoxedError> {
1480 let catalog_reader = self.env.catalog_reader();
1481 let reader = catalog_reader.read_guard();
1482 let database_name = reader
1483 .get_database_by_id(&database_id)
1484 .map_err(|_| {
1485 Box::new(Error::new(
1486 ErrorKind::InvalidInput,
1487 format!("database \"{}\" does not exist", database_id),
1488 ))
1489 })?
1490 .name();
1491
1492 let user_reader = self.env.user_info_reader();
1493 let reader = user_reader.read_guard();
1494 if let Some(user) = reader.get_user_by_name(user_name) {
1495 if !user.can_login {
1496 return Err(Box::new(Error::new(
1497 ErrorKind::InvalidInput,
1498 format!("User {} is not allowed to login", user_name),
1499 )));
1500 }
1501 let has_privilege =
1502 user.has_privilege(&Object::DatabaseId(database_id), AclMode::Connect);
1503 if !user.is_super && !has_privilege {
1504 return Err(Box::new(Error::new(
1505 ErrorKind::PermissionDenied,
1506 "User does not have CONNECT privilege.",
1507 )));
1508 }
1509 let user_authenticator = match &user.auth_info {
1510 None => UserAuthenticator::None,
1511 Some(auth_info) => {
1512 if auth_info.encryption_type == EncryptionType::Plaintext as i32 {
1513 UserAuthenticator::ClearText(auth_info.encrypted_value.clone())
1514 } else if auth_info.encryption_type == EncryptionType::Md5 as i32 {
1515 let mut salt = [0; 4];
1516 let mut rng = rand::rng();
1517 rng.fill_bytes(&mut salt);
1518 UserAuthenticator::Md5WithSalt {
1519 encrypted_password: md5_hash_with_salt(
1520 &auth_info.encrypted_value,
1521 &salt,
1522 ),
1523 salt,
1524 }
1525 } else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
1526 UserAuthenticator::OAuth(auth_info.metadata.clone())
1527 } else {
1528 return Err(Box::new(Error::new(
1529 ErrorKind::Unsupported,
1530 format!("Unsupported auth type: {}", auth_info.encryption_type),
1531 )));
1532 }
1533 }
1534 };
1535
1536 let secret_key = self.number.fetch_add(1, Ordering::Relaxed);
1538 let id = (secret_key, secret_key);
1540 let session_config = self.env.session_params_snapshot();
1542
1543 let session_impl: Arc<SessionImpl> = SessionImpl::new(
1544 self.env.clone(),
1545 AuthContext::new(database_name.to_owned(), user_name.to_owned(), user.id),
1546 user_authenticator,
1547 id,
1548 peer_addr,
1549 session_config,
1550 )
1551 .into();
1552 self.insert_session(session_impl.clone());
1553
1554 Ok(session_impl)
1555 } else {
1556 Err(Box::new(Error::new(
1557 ErrorKind::InvalidInput,
1558 format!("Role {} does not exist", user_name),
1559 )))
1560 }
1561 }
1562}
1563
1564impl Session for SessionImpl {
1565 type Portal = Portal;
1566 type PreparedStatement = PrepareStatement;
1567 type ValuesStream = PgResponseStream;
1568
1569 async fn run_one_query(
1572 self: Arc<Self>,
1573 stmt: Statement,
1574 format: Format,
1575 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1576 let string = stmt.to_string();
1577 let sql_str = string.as_str();
1578 let sql: Arc<str> = Arc::from(sql_str);
1579 drop(string);
1581 let rsp = handle(self, stmt, sql, vec![format]).await?;
1582 Ok(rsp)
1583 }
1584
1585 fn user_authenticator(&self) -> &UserAuthenticator {
1586 &self.user_authenticator
1587 }
1588
1589 fn id(&self) -> SessionId {
1590 self.id
1591 }
1592
1593 async fn parse(
1594 self: Arc<Self>,
1595 statement: Option<Statement>,
1596 params_types: Vec<Option<DataType>>,
1597 ) -> std::result::Result<PrepareStatement, BoxedError> {
1598 Ok(if let Some(statement) = statement {
1599 handle_parse(self, statement, params_types).await?
1600 } else {
1601 PrepareStatement::Empty
1602 })
1603 }
1604
1605 fn bind(
1606 self: Arc<Self>,
1607 prepare_statement: PrepareStatement,
1608 params: Vec<Option<Bytes>>,
1609 param_formats: Vec<Format>,
1610 result_formats: Vec<Format>,
1611 ) -> std::result::Result<Portal, BoxedError> {
1612 Ok(handle_bind(
1613 prepare_statement,
1614 params,
1615 param_formats,
1616 result_formats,
1617 )?)
1618 }
1619
1620 async fn execute(
1621 self: Arc<Self>,
1622 portal: Portal,
1623 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1624 let rsp = handle_execute(self, portal).await?;
1625 Ok(rsp)
1626 }
1627
1628 fn describe_statement(
1629 self: Arc<Self>,
1630 prepare_statement: PrepareStatement,
1631 ) -> std::result::Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
1632 Ok(match prepare_statement {
1633 PrepareStatement::Empty => (vec![], vec![]),
1634 PrepareStatement::Prepared(prepare_statement) => (
1635 prepare_statement.bound_result.param_types,
1636 infer(
1637 Some(prepare_statement.bound_result.bound),
1638 prepare_statement.statement,
1639 )?,
1640 ),
1641 PrepareStatement::PureStatement(statement) => (vec![], infer(None, statement)?),
1642 })
1643 }
1644
1645 fn describe_portal(
1646 self: Arc<Self>,
1647 portal: Portal,
1648 ) -> std::result::Result<Vec<PgFieldDescriptor>, BoxedError> {
1649 match portal {
1650 Portal::Empty => Ok(vec![]),
1651 Portal::Portal(portal) => {
1652 let mut columns = infer(Some(portal.bound_result.bound), portal.statement)?;
1653 let formats = FormatIterator::new(&portal.result_formats, columns.len())?;
1654 columns.iter_mut().zip_eq_fast(formats).for_each(|(c, f)| {
1655 if f == Format::Binary {
1656 c.set_to_binary()
1657 }
1658 });
1659 Ok(columns)
1660 }
1661 Portal::PureStatement(statement) => Ok(infer(None, statement)?),
1662 }
1663 }
1664
1665 fn set_config(&self, key: &str, value: String) -> std::result::Result<String, BoxedError> {
1666 Self::set_config(self, key, value).map_err(Into::into)
1667 }
1668
1669 async fn next_notice(self: &Arc<Self>) -> String {
1670 std::future::poll_fn(|cx| self.clone().notice_rx.lock().poll_recv(cx))
1671 .await
1672 .expect("notice channel should not be closed")
1673 }
1674
1675 fn transaction_status(&self) -> TransactionStatus {
1676 match &*self.txn.lock() {
1677 transaction::State::Initial | transaction::State::Implicit(_) => {
1678 TransactionStatus::Idle
1679 }
1680 transaction::State::Explicit(_) => TransactionStatus::InTransaction,
1681 }
1683 }
1684
1685 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
1687 let exec_context = Arc::new(ExecContext {
1688 running_sql: sql,
1689 last_instant: Instant::now(),
1690 last_idle_instant: self.last_idle_instant.clone(),
1691 });
1692 *self.exec_context.lock() = Some(Arc::downgrade(&exec_context));
1693 *self.last_idle_instant.lock() = None;
1695 ExecContextGuard::new(exec_context)
1696 }
1697
1698 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
1701 if matches!(self.transaction_status(), TransactionStatus::InTransaction) {
1703 let idle_in_transaction_session_timeout =
1704 self.config().idle_in_transaction_session_timeout() as u128;
1705 if idle_in_transaction_session_timeout != 0 {
1707 let guard = self.exec_context.lock();
1709 if guard.as_ref().and_then(|weak| weak.upgrade()).is_none() {
1711 if let Some(elapse_since_last_idle_instant) =
1713 self.elapse_since_last_idle_instant()
1714 {
1715 if elapse_since_last_idle_instant > idle_in_transaction_session_timeout {
1716 return Err(PsqlError::IdleInTxnTimeout);
1717 }
1718 }
1719 }
1720 }
1721 }
1722 Ok(())
1723 }
1724}
1725
1726fn infer(bound: Option<BoundStatement>, stmt: Statement) -> Result<Vec<PgFieldDescriptor>> {
1728 match stmt {
1729 Statement::Query(_)
1730 | Statement::Insert { .. }
1731 | Statement::Delete { .. }
1732 | Statement::Update { .. }
1733 | Statement::FetchCursor { .. } => Ok(bound
1734 .unwrap()
1735 .output_fields()
1736 .iter()
1737 .map(to_pg_field)
1738 .collect()),
1739 Statement::ShowObjects {
1740 object: show_object,
1741 ..
1742 } => Ok(infer_show_object(&show_object)),
1743 Statement::ShowCreateObject { .. } => Ok(infer_show_create_object()),
1744 Statement::ShowTransactionIsolationLevel => {
1745 let name = "transaction_isolation";
1746 Ok(infer_show_variable(name))
1747 }
1748 Statement::ShowVariable { variable } => {
1749 let name = &variable[0].real_value().to_lowercase();
1750 Ok(infer_show_variable(name))
1751 }
1752 Statement::Describe { name: _ } => Ok(infer_describe()),
1753 Statement::Explain { .. } => Ok(vec![PgFieldDescriptor::new(
1754 "QUERY PLAN".to_owned(),
1755 DataType::Varchar.to_oid(),
1756 DataType::Varchar.type_len(),
1757 )]),
1758 _ => Ok(vec![]),
1759 }
1760}