1use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::io::{Error, ErrorKind};
18use std::net::{IpAddr, Ipv4Addr, SocketAddr};
19use std::sync::atomic::{AtomicI32, Ordering};
20use std::sync::{Arc, Weak};
21use std::time::{Duration, Instant};
22
23use anyhow::anyhow;
24use bytes::Bytes;
25use either::Either;
26use itertools::Itertools;
27use parking_lot::{Mutex, RwLock, RwLockReadGuard};
28use pgwire::error::{PsqlError, PsqlResult};
29use pgwire::net::{Address, AddressRef};
30use pgwire::pg_field_descriptor::PgFieldDescriptor;
31use pgwire::pg_message::TransactionStatus;
32use pgwire::pg_response::{PgResponse, StatementType};
33use pgwire::pg_server::{
34 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
35 UserAuthenticator,
36};
37use pgwire::types::{Format, FormatIterator};
38use rand::RngCore;
39use risingwave_batch::monitor::{BatchSpillMetrics, GLOBAL_BATCH_SPILL_METRICS};
40use risingwave_batch::spill::spill_op::SpillOp;
41use risingwave_batch::task::{ShutdownSender, ShutdownToken};
42use risingwave_batch::worker_manager::worker_node_manager::{
43 WorkerNodeManager, WorkerNodeManagerRef,
44};
45use risingwave_common::acl::AclMode;
46#[cfg(test)]
47use risingwave_common::catalog::{
48 DEFAULT_DATABASE_NAME, DEFAULT_SUPER_USER, DEFAULT_SUPER_USER_ID,
49};
50use risingwave_common::config::{
51 BatchConfig, FrontendConfig, MetaConfig, MetricLevel, StreamingConfig, UdfConfig, load_config,
52};
53use risingwave_common::memory::MemoryContext;
54use risingwave_common::secret::LocalSecretManager;
55use risingwave_common::session_config::{ConfigReporter, SessionConfig, VisibilityMode};
56use risingwave_common::system_param::local_manager::{
57 LocalSystemParamsManager, LocalSystemParamsManagerRef,
58};
59use risingwave_common::telemetry::manager::TelemetryManager;
60use risingwave_common::telemetry::telemetry_env_enabled;
61use risingwave_common::types::DataType;
62use risingwave_common::util::addr::HostAddr;
63use risingwave_common::util::cluster_limit;
64use risingwave_common::util::cluster_limit::ActorCountPerParallelism;
65use risingwave_common::util::iter_util::ZipEqFast;
66use risingwave_common::util::pretty_bytes::convert;
67use risingwave_common::util::runtime::BackgroundShutdownRuntime;
68use risingwave_common::{GIT_SHA, RW_VERSION};
69use risingwave_common_heap_profiling::HeapProfiler;
70use risingwave_common_service::{MetricsManager, ObserverManager};
71use risingwave_connector::source::monitor::{GLOBAL_SOURCE_METRICS, SourceMetrics};
72use risingwave_pb::common::WorkerType;
73use risingwave_pb::common::worker_node::Property as AddWorkerNodeProperty;
74use risingwave_pb::frontend_service::frontend_service_server::FrontendServiceServer;
75use risingwave_pb::health::health_server::HealthServer;
76use risingwave_pb::user::auth_info::EncryptionType;
77use risingwave_pb::user::grant_privilege::Object;
78use risingwave_rpc_client::{
79 ComputeClientPool, ComputeClientPoolRef, FrontendClientPool, FrontendClientPoolRef, MetaClient,
80};
81use risingwave_sqlparser::ast::{ObjectName, Statement};
82use risingwave_sqlparser::parser::Parser;
83use thiserror::Error;
84use tokio::runtime::Builder;
85use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
86use tokio::sync::oneshot::Sender;
87use tokio::sync::watch;
88use tokio::task::JoinHandle;
89use tracing::info;
90use tracing::log::error;
91
92use self::cursor_manager::CursorManager;
93use crate::binder::{Binder, BoundStatement, ResolveQualifiedNameError};
94use crate::catalog::catalog_service::{CatalogReader, CatalogWriter, CatalogWriterImpl};
95use crate::catalog::connection_catalog::ConnectionCatalog;
96use crate::catalog::root_catalog::{Catalog, SchemaPath};
97use crate::catalog::secret_catalog::SecretCatalog;
98use crate::catalog::source_catalog::SourceCatalog;
99use crate::catalog::subscription_catalog::SubscriptionCatalog;
100use crate::catalog::{
101 CatalogError, DatabaseId, OwnedByUserCatalog, SchemaId, TableId, check_schema_writable,
102};
103use crate::error::{ErrorCode, Result, RwError};
104use crate::handler::describe::infer_describe;
105use crate::handler::extended_handle::{
106 Portal, PrepareStatement, handle_bind, handle_execute, handle_parse,
107};
108use crate::handler::privilege::ObjectCheckItem;
109use crate::handler::show::{infer_show_create_object, infer_show_object};
110use crate::handler::util::to_pg_field;
111use crate::handler::variable::infer_show_variable;
112use crate::handler::{RwPgResponse, handle};
113use crate::health_service::HealthServiceImpl;
114use crate::meta_client::{FrontendMetaClient, FrontendMetaClientImpl};
115use crate::monitor::{CursorMetrics, FrontendMetrics, GLOBAL_FRONTEND_METRICS};
116use crate::observer::FrontendObserverNode;
117use crate::rpc::FrontendServiceImpl;
118use crate::scheduler::streaming_manager::{StreamingJobTracker, StreamingJobTrackerRef};
119use crate::scheduler::{
120 DistributedQueryMetrics, GLOBAL_DISTRIBUTED_QUERY_METRICS, HummockSnapshotManager,
121 HummockSnapshotManagerRef, QueryManager,
122};
123use crate::telemetry::FrontendTelemetryCreator;
124use crate::user::UserId;
125use crate::user::user_authentication::md5_hash_with_salt;
126use crate::user::user_manager::UserInfoManager;
127use crate::user::user_service::{UserInfoReader, UserInfoWriter, UserInfoWriterImpl};
128use crate::{FrontendOpts, PgResponseStream, TableCatalog};
129
130pub(crate) mod current;
131pub(crate) mod cursor_manager;
132pub(crate) mod transaction;
133
134#[derive(Clone)]
136pub(crate) struct FrontendEnv {
137 meta_client: Arc<dyn FrontendMetaClient>,
140 catalog_writer: Arc<dyn CatalogWriter>,
141 catalog_reader: CatalogReader,
142 user_info_writer: Arc<dyn UserInfoWriter>,
143 user_info_reader: UserInfoReader,
144 worker_node_manager: WorkerNodeManagerRef,
145 query_manager: QueryManager,
146 hummock_snapshot_manager: HummockSnapshotManagerRef,
147 system_params_manager: LocalSystemParamsManagerRef,
148 session_params: Arc<RwLock<SessionConfig>>,
149
150 server_addr: HostAddr,
151 client_pool: ComputeClientPoolRef,
152 frontend_client_pool: FrontendClientPoolRef,
153
154 sessions_map: SessionMapRef,
158
159 pub frontend_metrics: Arc<FrontendMetrics>,
160
161 pub cursor_metrics: Arc<CursorMetrics>,
162
163 source_metrics: Arc<SourceMetrics>,
164
165 spill_metrics: Arc<BatchSpillMetrics>,
167
168 batch_config: BatchConfig,
169 frontend_config: FrontendConfig,
170 #[expect(dead_code)]
171 meta_config: MetaConfig,
172 streaming_config: StreamingConfig,
173 udf_config: UdfConfig,
174
175 creating_streaming_job_tracker: StreamingJobTrackerRef,
178
179 compute_runtime: Arc<BackgroundShutdownRuntime>,
182
183 mem_context: MemoryContext,
185
186 serverless_backfill_controller_addr: String,
188}
189
190pub type SessionMapRef = Arc<RwLock<HashMap<(i32, i32), Arc<SessionImpl>>>>;
192
193const FRONTEND_BATCH_MEMORY_PROPORTION: f64 = 0.5;
195
196impl FrontendEnv {
197 pub fn mock() -> Self {
198 use crate::test_utils::{MockCatalogWriter, MockFrontendMetaClient, MockUserInfoWriter};
199
200 let catalog = Arc::new(RwLock::new(Catalog::default()));
201 let meta_client = Arc::new(MockFrontendMetaClient {});
202 let hummock_snapshot_manager = Arc::new(HummockSnapshotManager::new(meta_client.clone()));
203 let catalog_writer = Arc::new(MockCatalogWriter::new(
204 catalog.clone(),
205 hummock_snapshot_manager.clone(),
206 ));
207 let catalog_reader = CatalogReader::new(catalog);
208 let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
209 let user_info_writer = Arc::new(MockUserInfoWriter::new(user_info_manager.clone()));
210 let user_info_reader = UserInfoReader::new(user_info_manager);
211 let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
212 let system_params_manager = Arc::new(LocalSystemParamsManager::for_test());
213 let compute_client_pool = Arc::new(ComputeClientPool::for_test());
214 let frontend_client_pool = Arc::new(FrontendClientPool::for_test());
215 let query_manager = QueryManager::new(
216 worker_node_manager.clone(),
217 compute_client_pool,
218 catalog_reader.clone(),
219 Arc::new(DistributedQueryMetrics::for_test()),
220 None,
221 None,
222 );
223 let server_addr = HostAddr::try_from("127.0.0.1:4565").unwrap();
224 let client_pool = Arc::new(ComputeClientPool::for_test());
225 let creating_streaming_tracker = StreamingJobTracker::new(meta_client.clone());
226 let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(
227 Builder::new_multi_thread()
228 .worker_threads(
229 load_config("", FrontendOpts::default())
230 .batch
231 .frontend_compute_runtime_worker_threads,
232 )
233 .thread_name("rw-batch-local")
234 .enable_all()
235 .build()
236 .unwrap(),
237 ));
238 let sessions_map = Arc::new(RwLock::new(HashMap::new()));
239 Self {
240 meta_client,
241 catalog_writer,
242 catalog_reader,
243 user_info_writer,
244 user_info_reader,
245 worker_node_manager,
246 query_manager,
247 hummock_snapshot_manager,
248 system_params_manager,
249 session_params: Default::default(),
250 server_addr,
251 client_pool,
252 frontend_client_pool,
253 sessions_map: sessions_map.clone(),
254 frontend_metrics: Arc::new(FrontendMetrics::for_test()),
255 cursor_metrics: Arc::new(CursorMetrics::for_test()),
256 batch_config: BatchConfig::default(),
257 frontend_config: FrontendConfig::default(),
258 meta_config: MetaConfig::default(),
259 streaming_config: StreamingConfig::default(),
260 udf_config: UdfConfig::default(),
261 source_metrics: Arc::new(SourceMetrics::default()),
262 spill_metrics: BatchSpillMetrics::for_test(),
263 creating_streaming_job_tracker: Arc::new(creating_streaming_tracker),
264 compute_runtime,
265 mem_context: MemoryContext::none(),
266 serverless_backfill_controller_addr: Default::default(),
267 }
268 }
269
270 pub async fn init(opts: FrontendOpts) -> Result<(Self, Vec<JoinHandle<()>>, Vec<Sender<()>>)> {
271 let config = load_config(&opts.config_path, &opts);
272 info!("Starting frontend node");
273 info!("> config: {:?}", config);
274 info!(
275 "> debug assertions: {}",
276 if cfg!(debug_assertions) { "on" } else { "off" }
277 );
278 info!("> version: {} ({})", RW_VERSION, GIT_SHA);
279
280 let frontend_address: HostAddr = opts
281 .advertise_addr
282 .as_ref()
283 .unwrap_or_else(|| {
284 tracing::warn!("advertise addr is not specified, defaulting to listen_addr");
285 &opts.listen_addr
286 })
287 .parse()
288 .unwrap();
289 info!("advertise addr is {}", frontend_address);
290
291 let rpc_addr: HostAddr = opts.frontend_rpc_listener_addr.parse().unwrap();
292 let internal_rpc_host_addr = HostAddr {
293 host: frontend_address.host.clone(),
295 port: rpc_addr.port,
296 };
297 let (meta_client, system_params_reader) = MetaClient::register_new(
299 opts.meta_addr,
300 WorkerType::Frontend,
301 &frontend_address,
302 AddWorkerNodeProperty {
303 internal_rpc_host_addr: internal_rpc_host_addr.to_string(),
304 ..Default::default()
305 },
306 &config.meta,
307 )
308 .await;
309
310 let worker_id = meta_client.worker_id();
311 info!("Assigned worker node id {}", worker_id);
312
313 let (heartbeat_join_handle, heartbeat_shutdown_sender) = MetaClient::start_heartbeat_loop(
314 meta_client.clone(),
315 Duration::from_millis(config.server.heartbeat_interval_ms as u64),
316 );
317 let mut join_handles = vec![heartbeat_join_handle];
318 let mut shutdown_senders = vec![heartbeat_shutdown_sender];
319
320 let frontend_meta_client = Arc::new(FrontendMetaClientImpl(meta_client.clone()));
321 let hummock_snapshot_manager =
322 Arc::new(HummockSnapshotManager::new(frontend_meta_client.clone()));
323
324 let (catalog_updated_tx, catalog_updated_rx) = watch::channel(0);
325 let catalog = Arc::new(RwLock::new(Catalog::default()));
326 let catalog_writer = Arc::new(CatalogWriterImpl::new(
327 meta_client.clone(),
328 catalog_updated_rx,
329 hummock_snapshot_manager.clone(),
330 ));
331 let catalog_reader = CatalogReader::new(catalog.clone());
332
333 let worker_node_manager = Arc::new(WorkerNodeManager::new());
334
335 let compute_client_pool = Arc::new(ComputeClientPool::new(
336 config.batch_exchange_connection_pool_size(),
337 config.batch.developer.compute_client_config.clone(),
338 ));
339 let frontend_client_pool = Arc::new(FrontendClientPool::new(
340 1,
341 config.batch.developer.frontend_client_config.clone(),
342 ));
343 let query_manager = QueryManager::new(
344 worker_node_manager.clone(),
345 compute_client_pool.clone(),
346 catalog_reader.clone(),
347 Arc::new(GLOBAL_DISTRIBUTED_QUERY_METRICS.clone()),
348 config.batch.distributed_query_limit,
349 config.batch.max_batch_queries_per_frontend_node,
350 );
351
352 let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
353 let (user_info_updated_tx, user_info_updated_rx) = watch::channel(0);
354 let user_info_reader = UserInfoReader::new(user_info_manager.clone());
355 let user_info_writer = Arc::new(UserInfoWriterImpl::new(
356 meta_client.clone(),
357 user_info_updated_rx,
358 ));
359
360 let system_params_manager =
361 Arc::new(LocalSystemParamsManager::new(system_params_reader.clone()));
362
363 LocalSecretManager::init(
364 opts.temp_secret_file_dir,
365 meta_client.cluster_id().to_owned(),
366 worker_id,
367 );
368
369 let session_params = Arc::new(RwLock::new(SessionConfig::default()));
371 let sessions_map: SessionMapRef = Arc::new(RwLock::new(HashMap::new()));
372 let cursor_metrics = Arc::new(CursorMetrics::init(sessions_map.clone()));
373
374 let frontend_observer_node = FrontendObserverNode::new(
375 worker_node_manager.clone(),
376 catalog,
377 catalog_updated_tx,
378 user_info_manager,
379 user_info_updated_tx,
380 hummock_snapshot_manager.clone(),
381 system_params_manager.clone(),
382 session_params.clone(),
383 compute_client_pool.clone(),
384 );
385 let observer_manager =
386 ObserverManager::new_with_meta_client(meta_client.clone(), frontend_observer_node)
387 .await;
388 let observer_join_handle = observer_manager.start().await;
389 join_handles.push(observer_join_handle);
390
391 meta_client.activate(&frontend_address).await?;
392
393 let frontend_metrics = Arc::new(GLOBAL_FRONTEND_METRICS.clone());
394 let source_metrics = Arc::new(GLOBAL_SOURCE_METRICS.clone());
395 let spill_metrics = Arc::new(GLOBAL_BATCH_SPILL_METRICS.clone());
396
397 if config.server.metrics_level > MetricLevel::Disabled {
398 MetricsManager::boot_metrics_service(opts.prometheus_listener_addr.clone());
399 }
400
401 let health_srv = HealthServiceImpl::new();
402 let frontend_srv = FrontendServiceImpl::new(sessions_map.clone());
403 let frontend_rpc_addr = opts.frontend_rpc_listener_addr.parse().unwrap();
404
405 let telemetry_manager = TelemetryManager::new(
406 Arc::new(meta_client.clone()),
407 Arc::new(FrontendTelemetryCreator::new()),
408 );
409
410 if config.server.telemetry_enabled && telemetry_env_enabled() {
413 let (join_handle, shutdown_sender) = telemetry_manager.start().await;
414 join_handles.push(join_handle);
415 shutdown_senders.push(shutdown_sender);
416 } else {
417 tracing::info!("Telemetry didn't start due to config");
418 }
419
420 tokio::spawn(async move {
421 tonic::transport::Server::builder()
422 .add_service(HealthServer::new(health_srv))
423 .add_service(FrontendServiceServer::new(frontend_srv))
424 .serve(frontend_rpc_addr)
425 .await
426 .unwrap();
427 });
428 info!(
429 "Health Check RPC Listener is set up on {}",
430 opts.frontend_rpc_listener_addr.clone()
431 );
432
433 let creating_streaming_job_tracker =
434 Arc::new(StreamingJobTracker::new(frontend_meta_client.clone()));
435
436 let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(
437 Builder::new_multi_thread()
438 .worker_threads(config.batch.frontend_compute_runtime_worker_threads)
439 .thread_name("rw-batch-local")
440 .enable_all()
441 .build()
442 .unwrap(),
443 ));
444
445 let sessions = sessions_map.clone();
446 let join_handle = tokio::spawn(async move {
448 let mut check_idle_txn_interval =
449 tokio::time::interval(core::time::Duration::from_secs(10));
450 check_idle_txn_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
451 check_idle_txn_interval.reset();
452 loop {
453 check_idle_txn_interval.tick().await;
454 sessions.read().values().for_each(|session| {
455 let _ = session.check_idle_in_transaction_timeout();
456 })
457 }
458 });
459 join_handles.push(join_handle);
460
461 #[cfg(not(madsim))]
463 if config.batch.enable_spill {
464 SpillOp::clean_spill_directory()
465 .await
466 .map_err(|err| anyhow!(err))?;
467 }
468
469 let total_memory_bytes = opts.frontend_total_memory_bytes;
470 let heap_profiler =
471 HeapProfiler::new(total_memory_bytes, config.server.heap_profiling.clone());
472 heap_profiler.start();
474
475 let batch_memory_limit = total_memory_bytes as f64 * FRONTEND_BATCH_MEMORY_PROPORTION;
476 let mem_context = MemoryContext::root(
477 frontend_metrics.batch_total_mem.clone(),
478 batch_memory_limit as u64,
479 );
480
481 info!(
482 "Frontend total_memory: {} batch_memory: {}",
483 convert(total_memory_bytes as _),
484 convert(batch_memory_limit as _),
485 );
486
487 Ok((
488 Self {
489 catalog_reader,
490 catalog_writer,
491 user_info_reader,
492 user_info_writer,
493 worker_node_manager,
494 meta_client: frontend_meta_client,
495 query_manager,
496 hummock_snapshot_manager,
497 system_params_manager,
498 session_params,
499 server_addr: frontend_address,
500 client_pool: compute_client_pool,
501 frontend_client_pool,
502 frontend_metrics,
503 cursor_metrics,
504 spill_metrics,
505 sessions_map,
506 batch_config: config.batch,
507 frontend_config: config.frontend,
508 meta_config: config.meta,
509 streaming_config: config.streaming,
510 serverless_backfill_controller_addr: opts.serverless_backfill_controller_addr,
511 udf_config: config.udf,
512 source_metrics,
513 creating_streaming_job_tracker,
514 compute_runtime,
515 mem_context,
516 },
517 join_handles,
518 shutdown_senders,
519 ))
520 }
521
522 fn catalog_writer(&self, _guard: transaction::WriteGuard) -> &dyn CatalogWriter {
527 &*self.catalog_writer
528 }
529
530 pub fn catalog_reader(&self) -> &CatalogReader {
532 &self.catalog_reader
533 }
534
535 fn user_info_writer(&self, _guard: transaction::WriteGuard) -> &dyn UserInfoWriter {
540 &*self.user_info_writer
541 }
542
543 pub fn user_info_reader(&self) -> &UserInfoReader {
545 &self.user_info_reader
546 }
547
548 pub fn worker_node_manager_ref(&self) -> WorkerNodeManagerRef {
549 self.worker_node_manager.clone()
550 }
551
552 pub fn meta_client(&self) -> &dyn FrontendMetaClient {
553 &*self.meta_client
554 }
555
556 pub fn meta_client_ref(&self) -> Arc<dyn FrontendMetaClient> {
557 self.meta_client.clone()
558 }
559
560 pub fn query_manager(&self) -> &QueryManager {
561 &self.query_manager
562 }
563
564 pub fn hummock_snapshot_manager(&self) -> &HummockSnapshotManagerRef {
565 &self.hummock_snapshot_manager
566 }
567
568 pub fn system_params_manager(&self) -> &LocalSystemParamsManagerRef {
569 &self.system_params_manager
570 }
571
572 pub fn session_params_snapshot(&self) -> SessionConfig {
573 self.session_params.read_recursive().clone()
574 }
575
576 pub fn sbc_address(&self) -> &String {
577 &self.serverless_backfill_controller_addr
578 }
579
580 pub fn server_address(&self) -> &HostAddr {
581 &self.server_addr
582 }
583
584 pub fn client_pool(&self) -> ComputeClientPoolRef {
585 self.client_pool.clone()
586 }
587
588 pub fn frontend_client_pool(&self) -> FrontendClientPoolRef {
589 self.frontend_client_pool.clone()
590 }
591
592 pub fn batch_config(&self) -> &BatchConfig {
593 &self.batch_config
594 }
595
596 pub fn frontend_config(&self) -> &FrontendConfig {
597 &self.frontend_config
598 }
599
600 pub fn streaming_config(&self) -> &StreamingConfig {
601 &self.streaming_config
602 }
603
604 pub fn udf_config(&self) -> &UdfConfig {
605 &self.udf_config
606 }
607
608 pub fn source_metrics(&self) -> Arc<SourceMetrics> {
609 self.source_metrics.clone()
610 }
611
612 pub fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
613 self.spill_metrics.clone()
614 }
615
616 pub fn creating_streaming_job_tracker(&self) -> &StreamingJobTrackerRef {
617 &self.creating_streaming_job_tracker
618 }
619
620 pub fn sessions_map(&self) -> &SessionMapRef {
621 &self.sessions_map
622 }
623
624 pub fn compute_runtime(&self) -> Arc<BackgroundShutdownRuntime> {
625 self.compute_runtime.clone()
626 }
627
628 pub fn cancel_queries_in_session(&self, session_id: SessionId) -> bool {
631 cancel_queries_in_session(session_id, self.sessions_map.clone())
632 }
633
634 pub fn cancel_creating_jobs_in_session(&self, session_id: SessionId) -> bool {
637 cancel_creating_jobs_in_session(session_id, self.sessions_map.clone())
638 }
639
640 pub fn mem_context(&self) -> MemoryContext {
641 self.mem_context.clone()
642 }
643}
644
645#[derive(Clone)]
646pub struct AuthContext {
647 pub database: String,
648 pub user_name: String,
649 pub user_id: UserId,
650}
651
652impl AuthContext {
653 pub fn new(database: String, user_name: String, user_id: UserId) -> Self {
654 Self {
655 database,
656 user_name,
657 user_id,
658 }
659 }
660}
661pub struct SessionImpl {
662 env: FrontendEnv,
663 auth_context: Arc<RwLock<AuthContext>>,
664 user_authenticator: UserAuthenticator,
666 config_map: Arc<RwLock<SessionConfig>>,
668
669 notice_tx: UnboundedSender<String>,
671 notice_rx: Mutex<UnboundedReceiver<String>>,
673
674 id: (i32, i32),
676
677 peer_addr: AddressRef,
679
680 txn: Arc<Mutex<transaction::State>>,
684
685 current_query_cancel_flag: Mutex<Option<ShutdownSender>>,
689
690 exec_context: Mutex<Option<Weak<ExecContext>>>,
692
693 last_idle_instant: Arc<Mutex<Option<Instant>>>,
695
696 cursor_manager: Arc<CursorManager>,
697
698 temporary_source_manager: Arc<Mutex<TemporarySourceManager>>,
700}
701
702#[derive(Default, Clone)]
711pub struct TemporarySourceManager {
712 sources: HashMap<String, SourceCatalog>,
713}
714
715impl TemporarySourceManager {
716 pub fn new() -> Self {
717 Self {
718 sources: HashMap::new(),
719 }
720 }
721
722 pub fn create_source(&mut self, name: String, source: SourceCatalog) {
723 self.sources.insert(name, source);
724 }
725
726 pub fn drop_source(&mut self, name: &str) {
727 self.sources.remove(name);
728 }
729
730 pub fn get_source(&self, name: &str) -> Option<&SourceCatalog> {
731 self.sources.get(name)
732 }
733
734 pub fn keys(&self) -> Vec<String> {
735 self.sources.keys().cloned().collect()
736 }
737}
738
739#[derive(Error, Debug)]
740pub enum CheckRelationError {
741 #[error("{0}")]
742 Resolve(#[from] ResolveQualifiedNameError),
743 #[error("{0}")]
744 Catalog(#[from] CatalogError),
745}
746
747impl From<CheckRelationError> for RwError {
748 fn from(e: CheckRelationError) -> Self {
749 match e {
750 CheckRelationError::Resolve(e) => e.into(),
751 CheckRelationError::Catalog(e) => e.into(),
752 }
753 }
754}
755
756impl SessionImpl {
757 pub(crate) fn new(
758 env: FrontendEnv,
759 auth_context: AuthContext,
760 user_authenticator: UserAuthenticator,
761 id: SessionId,
762 peer_addr: AddressRef,
763 session_config: SessionConfig,
764 ) -> Self {
765 let cursor_metrics = env.cursor_metrics.clone();
766 let (notice_tx, notice_rx) = mpsc::unbounded_channel();
767
768 Self {
769 env,
770 auth_context: Arc::new(RwLock::new(auth_context)),
771 user_authenticator,
772 config_map: Arc::new(RwLock::new(session_config)),
773 id,
774 peer_addr,
775 txn: Default::default(),
776 current_query_cancel_flag: Mutex::new(None),
777 notice_tx,
778 notice_rx: Mutex::new(notice_rx),
779 exec_context: Mutex::new(None),
780 last_idle_instant: Default::default(),
781 cursor_manager: Arc::new(CursorManager::new(cursor_metrics)),
782 temporary_source_manager: Default::default(),
783 }
784 }
785
786 #[cfg(test)]
787 pub fn mock() -> Self {
788 let env = FrontendEnv::mock();
789 let (notice_tx, notice_rx) = mpsc::unbounded_channel();
790
791 Self {
792 env: FrontendEnv::mock(),
793 auth_context: Arc::new(RwLock::new(AuthContext::new(
794 DEFAULT_DATABASE_NAME.to_owned(),
795 DEFAULT_SUPER_USER.to_owned(),
796 DEFAULT_SUPER_USER_ID,
797 ))),
798 user_authenticator: UserAuthenticator::None,
799 config_map: Default::default(),
800 id: (0, 0),
802 txn: Default::default(),
803 current_query_cancel_flag: Mutex::new(None),
804 notice_tx,
805 notice_rx: Mutex::new(notice_rx),
806 exec_context: Mutex::new(None),
807 peer_addr: Address::Tcp(SocketAddr::new(
808 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
809 8080,
810 ))
811 .into(),
812 last_idle_instant: Default::default(),
813 cursor_manager: Arc::new(CursorManager::new(env.cursor_metrics.clone())),
814 temporary_source_manager: Default::default(),
815 }
816 }
817
818 pub(crate) fn env(&self) -> &FrontendEnv {
819 &self.env
820 }
821
822 pub fn auth_context(&self) -> Arc<AuthContext> {
823 let ctx = self.auth_context.read();
824 Arc::new(ctx.clone())
825 }
826
827 pub fn database(&self) -> String {
828 self.auth_context.read().database.clone()
829 }
830
831 pub fn database_id(&self) -> DatabaseId {
832 let db_name = self.database();
833 self.env
834 .catalog_reader()
835 .read_guard()
836 .get_database_by_name(&db_name)
837 .map(|db| db.id())
838 .expect("session database not found")
839 }
840
841 pub fn user_name(&self) -> String {
842 self.auth_context.read().user_name.clone()
843 }
844
845 pub fn user_id(&self) -> UserId {
846 self.auth_context.read().user_id
847 }
848
849 pub fn update_database(&self, database: String) {
850 self.auth_context.write().database = database;
851 }
852
853 pub fn shared_config(&self) -> Arc<RwLock<SessionConfig>> {
854 Arc::clone(&self.config_map)
855 }
856
857 pub fn config(&self) -> RwLockReadGuard<'_, SessionConfig> {
858 self.config_map.read()
859 }
860
861 pub fn set_config(&self, key: &str, value: String) -> Result<String> {
862 self.config_map
863 .write()
864 .set(key, value, &mut ())
865 .map_err(Into::into)
866 }
867
868 pub fn reset_config(&self, key: &str) -> Result<String> {
869 self.config_map
870 .write()
871 .reset(key, &mut ())
872 .map_err(Into::into)
873 }
874
875 pub fn set_config_report(
876 &self,
877 key: &str,
878 value: Option<String>,
879 mut reporter: impl ConfigReporter,
880 ) -> Result<String> {
881 if let Some(value) = value {
882 self.config_map
883 .write()
884 .set(key, value, &mut reporter)
885 .map_err(Into::into)
886 } else {
887 self.config_map
888 .write()
889 .reset(key, &mut reporter)
890 .map_err(Into::into)
891 }
892 }
893
894 pub fn session_id(&self) -> SessionId {
895 self.id
896 }
897
898 pub fn running_sql(&self) -> Option<Arc<str>> {
899 self.exec_context
900 .lock()
901 .as_ref()
902 .and_then(|weak| weak.upgrade())
903 .map(|context| context.running_sql.clone())
904 }
905
906 pub fn get_cursor_manager(&self) -> Arc<CursorManager> {
907 self.cursor_manager.clone()
908 }
909
910 pub fn peer_addr(&self) -> &Address {
911 &self.peer_addr
912 }
913
914 pub fn elapse_since_running_sql(&self) -> Option<u128> {
915 self.exec_context
916 .lock()
917 .as_ref()
918 .and_then(|weak| weak.upgrade())
919 .map(|context| context.last_instant.elapsed().as_millis())
920 }
921
922 pub fn elapse_since_last_idle_instant(&self) -> Option<u128> {
923 self.last_idle_instant
924 .lock()
925 .as_ref()
926 .map(|x| x.elapsed().as_millis())
927 }
928
929 pub fn check_relation_name_duplicated(
930 &self,
931 name: ObjectName,
932 stmt_type: StatementType,
933 if_not_exists: bool,
934 ) -> std::result::Result<Either<(), RwPgResponse>, CheckRelationError> {
935 let db_name = &self.database();
936 let catalog_reader = self.env().catalog_reader().read_guard();
937 let (schema_name, relation_name) = {
938 let (schema_name, relation_name) =
939 Binder::resolve_schema_qualified_name(db_name, name)?;
940 let search_path = self.config().search_path();
941 let user_name = &self.user_name();
942 let schema_name = match schema_name {
943 Some(schema_name) => schema_name,
944 None => catalog_reader
945 .first_valid_schema(db_name, &search_path, user_name)?
946 .name(),
947 };
948 (schema_name, relation_name)
949 };
950 match catalog_reader.check_relation_name_duplicated(db_name, &schema_name, &relation_name) {
951 Err(CatalogError::Duplicated(_, name, is_creating)) if if_not_exists => {
952 if !is_creating {
957 Ok(Either::Right(
958 PgResponse::builder(stmt_type)
959 .notice(format!("relation \"{}\" already exists, skipping", name))
960 .into(),
961 ))
962 } else if stmt_type == StatementType::CREATE_SUBSCRIPTION {
963 Ok(Either::Right(
966 PgResponse::builder(stmt_type)
967 .notice(format!(
968 "relation \"{}\" already exists but still creating, skipping",
969 name
970 ))
971 .into(),
972 ))
973 } else {
974 Ok(Either::Left(()))
975 }
976 }
977 Err(e) => Err(e.into()),
978 Ok(_) => Ok(Either::Left(())),
979 }
980 }
981
982 pub fn check_secret_name_duplicated(&self, name: ObjectName) -> Result<()> {
983 let db_name = &self.database();
984 let catalog_reader = self.env().catalog_reader().read_guard();
985 let (schema_name, secret_name) = {
986 let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
987 let search_path = self.config().search_path();
988 let user_name = &self.user_name();
989 let schema_name = match schema_name {
990 Some(schema_name) => schema_name,
991 None => catalog_reader
992 .first_valid_schema(db_name, &search_path, user_name)?
993 .name(),
994 };
995 (schema_name, secret_name)
996 };
997 catalog_reader
998 .check_secret_name_duplicated(db_name, &schema_name, &secret_name)
999 .map_err(RwError::from)
1000 }
1001
1002 pub fn check_connection_name_duplicated(&self, name: ObjectName) -> Result<()> {
1003 let db_name = &self.database();
1004 let catalog_reader = self.env().catalog_reader().read_guard();
1005 let (schema_name, connection_name) = {
1006 let (schema_name, connection_name) =
1007 Binder::resolve_schema_qualified_name(db_name, name)?;
1008 let search_path = self.config().search_path();
1009 let user_name = &self.user_name();
1010 let schema_name = match schema_name {
1011 Some(schema_name) => schema_name,
1012 None => catalog_reader
1013 .first_valid_schema(db_name, &search_path, user_name)?
1014 .name(),
1015 };
1016 (schema_name, connection_name)
1017 };
1018 catalog_reader
1019 .check_connection_name_duplicated(db_name, &schema_name, &connection_name)
1020 .map_err(RwError::from)
1021 }
1022
1023 pub fn check_function_name_duplicated(
1024 &self,
1025 stmt_type: StatementType,
1026 name: ObjectName,
1027 arg_types: &[DataType],
1028 if_not_exists: bool,
1029 ) -> Result<Either<(), RwPgResponse>> {
1030 let db_name = &self.database();
1031 let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
1032 let (database_id, schema_id) = self.get_database_and_schema_id_for_create(schema_name)?;
1033
1034 let catalog_reader = self.env().catalog_reader().read_guard();
1035 if catalog_reader
1036 .get_schema_by_id(&database_id, &schema_id)?
1037 .get_function_by_name_args(&function_name, arg_types)
1038 .is_some()
1039 {
1040 let full_name = format!(
1041 "{function_name}({})",
1042 arg_types.iter().map(|t| t.to_string()).join(",")
1043 );
1044 if if_not_exists {
1045 Ok(Either::Right(
1046 PgResponse::builder(stmt_type)
1047 .notice(format!(
1048 "function \"{}\" already exists, skipping",
1049 full_name
1050 ))
1051 .into(),
1052 ))
1053 } else {
1054 Err(CatalogError::duplicated("function", full_name).into())
1055 }
1056 } else {
1057 Ok(Either::Left(()))
1058 }
1059 }
1060
1061 pub fn get_database_and_schema_id_for_create(
1063 &self,
1064 schema_name: Option<String>,
1065 ) -> Result<(DatabaseId, SchemaId)> {
1066 let db_name = &self.database();
1067
1068 let search_path = self.config().search_path();
1069 let user_name = &self.user_name();
1070
1071 let catalog_reader = self.env().catalog_reader().read_guard();
1072 let schema = match schema_name {
1073 Some(schema_name) => catalog_reader.get_schema_by_name(db_name, &schema_name)?,
1074 None => catalog_reader.first_valid_schema(db_name, &search_path, user_name)?,
1075 };
1076
1077 check_schema_writable(&schema.name())?;
1078 self.check_privileges(&[ObjectCheckItem::new(
1079 schema.owner(),
1080 AclMode::Create,
1081 Object::SchemaId(schema.id()),
1082 )])?;
1083
1084 let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1085 Ok((db_id, schema.id()))
1086 }
1087
1088 pub fn get_connection_by_name(
1089 &self,
1090 schema_name: Option<String>,
1091 connection_name: &str,
1092 ) -> Result<Arc<ConnectionCatalog>> {
1093 let db_name = &self.database();
1094 let search_path = self.config().search_path();
1095 let user_name = &self.user_name();
1096
1097 let catalog_reader = self.env().catalog_reader().read_guard();
1098 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1099 let (connection, _) =
1100 catalog_reader.get_connection_by_name(db_name, schema_path, connection_name)?;
1101
1102 self.check_privileges(&[ObjectCheckItem::new(
1103 connection.owner(),
1104 AclMode::Usage,
1105 Object::ConnectionId(connection.id),
1106 )])?;
1107
1108 Ok(connection.clone())
1109 }
1110
1111 pub fn get_subscription_by_schema_id_name(
1112 &self,
1113 schema_id: SchemaId,
1114 subscription_name: &str,
1115 ) -> Result<Arc<SubscriptionCatalog>> {
1116 let db_name = &self.database();
1117
1118 let catalog_reader = self.env().catalog_reader().read_guard();
1119 let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1120 let schema = catalog_reader.get_schema_by_id(&db_id, &schema_id)?;
1121 let subscription = schema
1122 .get_subscription_by_name(subscription_name)
1123 .ok_or_else(|| {
1124 RwError::from(ErrorCode::ItemNotFound(format!(
1125 "subscription {} not found",
1126 subscription_name
1127 )))
1128 })?;
1129 Ok(subscription.clone())
1130 }
1131
1132 pub fn get_subscription_by_name(
1133 &self,
1134 schema_name: Option<String>,
1135 subscription_name: &str,
1136 ) -> Result<Arc<SubscriptionCatalog>> {
1137 let db_name = &self.database();
1138 let search_path = self.config().search_path();
1139 let user_name = &self.user_name();
1140
1141 let catalog_reader = self.env().catalog_reader().read_guard();
1142 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1143 let (subscription, _) =
1144 catalog_reader.get_subscription_by_name(db_name, schema_path, subscription_name)?;
1145 Ok(subscription.clone())
1146 }
1147
1148 pub fn get_table_by_id(&self, table_id: &TableId) -> Result<Arc<TableCatalog>> {
1149 let catalog_reader = self.env().catalog_reader().read_guard();
1150 Ok(catalog_reader.get_any_table_by_id(table_id)?.clone())
1151 }
1152
1153 pub fn get_table_by_name(
1154 &self,
1155 table_name: &str,
1156 db_id: u32,
1157 schema_id: u32,
1158 ) -> Result<Arc<TableCatalog>> {
1159 let catalog_reader = self.env().catalog_reader().read_guard();
1160 let table = catalog_reader
1161 .get_schema_by_id(&DatabaseId::from(db_id), &SchemaId::from(schema_id))?
1162 .get_created_table_by_name(table_name)
1163 .ok_or_else(|| {
1164 Error::new(
1165 ErrorKind::InvalidInput,
1166 format!("table \"{}\" does not exist", table_name),
1167 )
1168 })?;
1169
1170 self.check_privileges(&[ObjectCheckItem::new(
1171 table.owner(),
1172 AclMode::Select,
1173 Object::TableId(table.id.table_id()),
1174 )])?;
1175
1176 Ok(table.clone())
1177 }
1178
1179 pub fn get_secret_by_name(
1180 &self,
1181 schema_name: Option<String>,
1182 secret_name: &str,
1183 ) -> Result<Arc<SecretCatalog>> {
1184 let db_name = &self.database();
1185 let search_path = self.config().search_path();
1186 let user_name = &self.user_name();
1187
1188 let catalog_reader = self.env().catalog_reader().read_guard();
1189 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1190 let (secret, _) = catalog_reader.get_secret_by_name(db_name, schema_path, secret_name)?;
1191
1192 self.check_privileges(&[ObjectCheckItem::new(
1193 secret.owner(),
1194 AclMode::Create,
1195 Object::SecretId(secret.id.secret_id()),
1196 )])?;
1197
1198 Ok(secret.clone())
1199 }
1200
1201 pub fn list_change_log_epochs(
1202 &self,
1203 table_id: u32,
1204 min_epoch: u64,
1205 max_count: u32,
1206 ) -> Result<Vec<u64>> {
1207 Ok(self
1208 .env
1209 .hummock_snapshot_manager()
1210 .acquire()
1211 .list_change_log_epochs(table_id, min_epoch, max_count))
1212 }
1213
1214 pub fn clear_cancel_query_flag(&self) {
1215 let mut flag = self.current_query_cancel_flag.lock();
1216 *flag = None;
1217 }
1218
1219 pub fn reset_cancel_query_flag(&self) -> ShutdownToken {
1220 let mut flag = self.current_query_cancel_flag.lock();
1221 let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
1222 *flag = Some(shutdown_tx);
1223 shutdown_rx
1224 }
1225
1226 pub fn cancel_current_query(&self) {
1227 let mut flag_guard = self.current_query_cancel_flag.lock();
1228 if let Some(sender) = flag_guard.take() {
1229 info!("Trying to cancel query in local mode.");
1230 sender.cancel();
1232 info!("Cancel query request sent.");
1233 } else {
1234 info!("Trying to cancel query in distributed mode.");
1235 self.env.query_manager().cancel_queries_in_session(self.id)
1236 }
1237 }
1238
1239 pub fn cancel_current_creating_job(&self) {
1240 self.env.creating_streaming_job_tracker.abort_jobs(self.id);
1241 }
1242
1243 pub async fn run_statement(
1246 self: Arc<Self>,
1247 sql: Arc<str>,
1248 formats: Vec<Format>,
1249 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1250 let mut stmts = Parser::parse_sql(&sql)?;
1252 if stmts.is_empty() {
1253 return Ok(PgResponse::empty_result(
1254 pgwire::pg_response::StatementType::EMPTY,
1255 ));
1256 }
1257 if stmts.len() > 1 {
1258 return Ok(
1259 PgResponse::builder(pgwire::pg_response::StatementType::EMPTY)
1260 .notice("cannot insert multiple commands into statement")
1261 .into(),
1262 );
1263 }
1264 let stmt = stmts.swap_remove(0);
1265 let rsp = handle(self, stmt, sql.clone(), formats).await?;
1266 Ok(rsp)
1267 }
1268
1269 pub fn notice_to_user(&self, str: impl Into<String>) {
1270 let notice = str.into();
1271 tracing::trace!(notice, "notice to user");
1272 self.notice_tx
1273 .send(notice)
1274 .expect("notice channel should not be closed");
1275 }
1276
1277 pub fn is_barrier_read(&self) -> bool {
1278 match self.config().visibility_mode() {
1279 VisibilityMode::Default => self.env.batch_config.enable_barrier_read,
1280 VisibilityMode::All => true,
1281 VisibilityMode::Checkpoint => false,
1282 }
1283 }
1284
1285 pub fn statement_timeout(&self) -> Duration {
1286 if self.config().statement_timeout() == 0 {
1287 Duration::from_secs(self.env.batch_config.statement_timeout_in_sec as u64)
1288 } else {
1289 Duration::from_secs(self.config().statement_timeout() as u64)
1290 }
1291 }
1292
1293 pub fn create_temporary_source(&self, source: SourceCatalog) {
1294 self.temporary_source_manager
1295 .lock()
1296 .create_source(source.name.clone(), source);
1297 }
1298
1299 pub fn get_temporary_source(&self, name: &str) -> Option<SourceCatalog> {
1300 self.temporary_source_manager
1301 .lock()
1302 .get_source(name)
1303 .cloned()
1304 }
1305
1306 pub fn drop_temporary_source(&self, name: &str) {
1307 self.temporary_source_manager.lock().drop_source(name);
1308 }
1309
1310 pub fn temporary_source_manager(&self) -> TemporarySourceManager {
1311 self.temporary_source_manager.lock().clone()
1312 }
1313
1314 pub async fn check_cluster_limits(&self) -> Result<()> {
1315 if self.config().bypass_cluster_limits() {
1316 return Ok(());
1317 }
1318
1319 let gen_message = |ActorCountPerParallelism {
1320 worker_id_to_actor_count,
1321 hard_limit,
1322 soft_limit,
1323 }: ActorCountPerParallelism,
1324 exceed_hard_limit: bool|
1325 -> String {
1326 let (limit_type, action) = if exceed_hard_limit {
1327 ("critical", "Scale the cluster immediately to proceed.")
1328 } else {
1329 (
1330 "recommended",
1331 "Consider scaling the cluster for optimal performance.",
1332 )
1333 };
1334 format!(
1335 r#"Actor count per parallelism exceeds the {limit_type} limit.
1336
1337Depending on your workload, this may overload the cluster and cause performance/stability issues. {action}
1338
1339HINT:
1340- For best practices on managing streaming jobs: https://docs.risingwave.com/operate/manage-a-large-number-of-streaming-jobs
1341- To bypass the check (if the cluster load is acceptable): `[ALTER SYSTEM] SET bypass_cluster_limits TO true`.
1342 See https://docs.risingwave.com/operate/view-configure-runtime-parameters#how-to-configure-runtime-parameters
1343- Contact us via slack or https://risingwave.com/contact-us/ for further enquiry.
1344
1345DETAILS:
1346- hard limit: {hard_limit}
1347- soft limit: {soft_limit}
1348- worker_id_to_actor_count: {worker_id_to_actor_count:?}"#,
1349 )
1350 };
1351
1352 let limits = self.env().meta_client().get_cluster_limits().await?;
1353 for limit in limits {
1354 match limit {
1355 cluster_limit::ClusterLimit::ActorCount(l) => {
1356 if l.exceed_hard_limit() {
1357 return Err(RwError::from(ErrorCode::ProtocolError(gen_message(
1358 l, true,
1359 ))));
1360 } else if l.exceed_soft_limit() {
1361 self.notice_to_user(gen_message(l, false));
1362 }
1363 }
1364 }
1365 }
1366 Ok(())
1367 }
1368}
1369
1370pub static SESSION_MANAGER: std::sync::OnceLock<Arc<SessionManagerImpl>> =
1371 std::sync::OnceLock::new();
1372
1373pub struct SessionManagerImpl {
1374 env: FrontendEnv,
1375 _join_handles: Vec<JoinHandle<()>>,
1376 _shutdown_senders: Vec<Sender<()>>,
1377 number: AtomicI32,
1378}
1379
1380impl SessionManager for SessionManagerImpl {
1381 type Session = SessionImpl;
1382
1383 fn create_dummy_session(
1384 &self,
1385 database_id: u32,
1386 user_id: u32,
1387 ) -> std::result::Result<Arc<Self::Session>, BoxedError> {
1388 let dummy_addr = Address::Tcp(SocketAddr::new(
1389 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
1390 5691, ));
1392 let user_reader = self.env.user_info_reader();
1393 let reader = user_reader.read_guard();
1394 if let Some(user_name) = reader.get_user_name_by_id(user_id) {
1395 self.connect_inner(database_id, user_name.as_str(), Arc::new(dummy_addr))
1396 } else {
1397 Err(Box::new(Error::new(
1398 ErrorKind::InvalidInput,
1399 format!("Role id {} does not exist", user_id),
1400 )))
1401 }
1402 }
1403
1404 fn connect(
1405 &self,
1406 database: &str,
1407 user_name: &str,
1408 peer_addr: AddressRef,
1409 ) -> std::result::Result<Arc<Self::Session>, BoxedError> {
1410 let catalog_reader = self.env.catalog_reader();
1411 let reader = catalog_reader.read_guard();
1412 let database_id = reader
1413 .get_database_by_name(database)
1414 .map_err(|_| {
1415 Box::new(Error::new(
1416 ErrorKind::InvalidInput,
1417 format!("database \"{}\" does not exist", database),
1418 ))
1419 })?
1420 .id();
1421
1422 self.connect_inner(database_id, user_name, peer_addr)
1423 }
1424
1425 fn cancel_queries_in_session(&self, session_id: SessionId) {
1427 self.env.cancel_queries_in_session(session_id);
1428 }
1429
1430 fn cancel_creating_jobs_in_session(&self, session_id: SessionId) {
1431 self.env.cancel_creating_jobs_in_session(session_id);
1432 }
1433
1434 fn end_session(&self, session: &Self::Session) {
1435 self.delete_session(&session.session_id());
1436 }
1437
1438 async fn shutdown(&self) {
1439 self.env.sessions_map().write().clear();
1441 self.env.meta_client().try_unregister().await;
1443 }
1444}
1445
1446impl SessionManagerImpl {
1447 pub async fn new(opts: FrontendOpts) -> Result<Self> {
1448 let (env, join_handles, shutdown_senders) = FrontendEnv::init(opts).await?;
1450 Ok(Self {
1451 env,
1452 _join_handles: join_handles,
1453 _shutdown_senders: shutdown_senders,
1454 number: AtomicI32::new(0),
1455 })
1456 }
1457
1458 pub(crate) fn env(&self) -> &FrontendEnv {
1459 &self.env
1460 }
1461
1462 fn insert_session(&self, session: Arc<SessionImpl>) {
1463 let active_sessions = {
1464 let mut write_guard = self.env.sessions_map.write();
1465 write_guard.insert(session.id(), session);
1466 write_guard.len()
1467 };
1468 self.env
1469 .frontend_metrics
1470 .active_sessions
1471 .set(active_sessions as i64);
1472 }
1473
1474 fn delete_session(&self, session_id: &SessionId) {
1475 let active_sessions = {
1476 let mut write_guard = self.env.sessions_map.write();
1477 write_guard.remove(session_id);
1478 write_guard.len()
1479 };
1480 self.env
1481 .frontend_metrics
1482 .active_sessions
1483 .set(active_sessions as i64);
1484 }
1485
1486 fn connect_inner(
1487 &self,
1488 database_id: u32,
1489 user_name: &str,
1490 peer_addr: AddressRef,
1491 ) -> std::result::Result<Arc<SessionImpl>, BoxedError> {
1492 let catalog_reader = self.env.catalog_reader();
1493 let reader = catalog_reader.read_guard();
1494 let database_name = reader
1495 .get_database_by_id(&database_id)
1496 .map_err(|_| {
1497 Box::new(Error::new(
1498 ErrorKind::InvalidInput,
1499 format!("database \"{}\" does not exist", database_id),
1500 ))
1501 })?
1502 .name();
1503
1504 let user_reader = self.env.user_info_reader();
1505 let reader = user_reader.read_guard();
1506 if let Some(user) = reader.get_user_by_name(user_name) {
1507 if !user.can_login {
1508 return Err(Box::new(Error::new(
1509 ErrorKind::InvalidInput,
1510 format!("User {} is not allowed to login", user_name),
1511 )));
1512 }
1513 let has_privilege =
1514 user.has_privilege(&Object::DatabaseId(database_id), AclMode::Connect);
1515 if !user.is_super && !has_privilege {
1516 return Err(Box::new(Error::new(
1517 ErrorKind::PermissionDenied,
1518 "User does not have CONNECT privilege.",
1519 )));
1520 }
1521 let user_authenticator = match &user.auth_info {
1522 None => UserAuthenticator::None,
1523 Some(auth_info) => {
1524 if auth_info.encryption_type == EncryptionType::Plaintext as i32 {
1525 UserAuthenticator::ClearText(auth_info.encrypted_value.clone())
1526 } else if auth_info.encryption_type == EncryptionType::Md5 as i32 {
1527 let mut salt = [0; 4];
1528 let mut rng = rand::rng();
1529 rng.fill_bytes(&mut salt);
1530 UserAuthenticator::Md5WithSalt {
1531 encrypted_password: md5_hash_with_salt(
1532 &auth_info.encrypted_value,
1533 &salt,
1534 ),
1535 salt,
1536 }
1537 } else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
1538 UserAuthenticator::OAuth(auth_info.metadata.clone())
1539 } else {
1540 return Err(Box::new(Error::new(
1541 ErrorKind::Unsupported,
1542 format!("Unsupported auth type: {}", auth_info.encryption_type),
1543 )));
1544 }
1545 }
1546 };
1547
1548 let secret_key = self.number.fetch_add(1, Ordering::Relaxed);
1550 let id = (secret_key, secret_key);
1552 let session_config = self.env.session_params_snapshot();
1554
1555 let session_impl: Arc<SessionImpl> = SessionImpl::new(
1556 self.env.clone(),
1557 AuthContext::new(database_name.to_owned(), user_name.to_owned(), user.id),
1558 user_authenticator,
1559 id,
1560 peer_addr,
1561 session_config,
1562 )
1563 .into();
1564 self.insert_session(session_impl.clone());
1565
1566 Ok(session_impl)
1567 } else {
1568 Err(Box::new(Error::new(
1569 ErrorKind::InvalidInput,
1570 format!("Role {} does not exist", user_name),
1571 )))
1572 }
1573 }
1574}
1575
1576impl Session for SessionImpl {
1577 type Portal = Portal;
1578 type PreparedStatement = PrepareStatement;
1579 type ValuesStream = PgResponseStream;
1580
1581 async fn run_one_query(
1584 self: Arc<Self>,
1585 stmt: Statement,
1586 format: Format,
1587 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1588 let string = stmt.to_string();
1589 let sql_str = string.as_str();
1590 let sql: Arc<str> = Arc::from(sql_str);
1591 drop(string);
1593 let rsp = handle(self, stmt, sql, vec![format]).await?;
1594 Ok(rsp)
1595 }
1596
1597 fn user_authenticator(&self) -> &UserAuthenticator {
1598 &self.user_authenticator
1599 }
1600
1601 fn id(&self) -> SessionId {
1602 self.id
1603 }
1604
1605 async fn parse(
1606 self: Arc<Self>,
1607 statement: Option<Statement>,
1608 params_types: Vec<Option<DataType>>,
1609 ) -> std::result::Result<PrepareStatement, BoxedError> {
1610 Ok(if let Some(statement) = statement {
1611 handle_parse(self, statement, params_types).await?
1612 } else {
1613 PrepareStatement::Empty
1614 })
1615 }
1616
1617 fn bind(
1618 self: Arc<Self>,
1619 prepare_statement: PrepareStatement,
1620 params: Vec<Option<Bytes>>,
1621 param_formats: Vec<Format>,
1622 result_formats: Vec<Format>,
1623 ) -> std::result::Result<Portal, BoxedError> {
1624 Ok(handle_bind(
1625 prepare_statement,
1626 params,
1627 param_formats,
1628 result_formats,
1629 )?)
1630 }
1631
1632 async fn execute(
1633 self: Arc<Self>,
1634 portal: Portal,
1635 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1636 let rsp = handle_execute(self, portal).await?;
1637 Ok(rsp)
1638 }
1639
1640 fn describe_statement(
1641 self: Arc<Self>,
1642 prepare_statement: PrepareStatement,
1643 ) -> std::result::Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
1644 Ok(match prepare_statement {
1645 PrepareStatement::Empty => (vec![], vec![]),
1646 PrepareStatement::Prepared(prepare_statement) => (
1647 prepare_statement.bound_result.param_types,
1648 infer(
1649 Some(prepare_statement.bound_result.bound),
1650 prepare_statement.statement,
1651 )?,
1652 ),
1653 PrepareStatement::PureStatement(statement) => (vec![], infer(None, statement)?),
1654 })
1655 }
1656
1657 fn describe_portal(
1658 self: Arc<Self>,
1659 portal: Portal,
1660 ) -> std::result::Result<Vec<PgFieldDescriptor>, BoxedError> {
1661 match portal {
1662 Portal::Empty => Ok(vec![]),
1663 Portal::Portal(portal) => {
1664 let mut columns = infer(Some(portal.bound_result.bound), portal.statement)?;
1665 let formats = FormatIterator::new(&portal.result_formats, columns.len())?;
1666 columns.iter_mut().zip_eq_fast(formats).for_each(|(c, f)| {
1667 if f == Format::Binary {
1668 c.set_to_binary()
1669 }
1670 });
1671 Ok(columns)
1672 }
1673 Portal::PureStatement(statement) => Ok(infer(None, statement)?),
1674 }
1675 }
1676
1677 fn set_config(&self, key: &str, value: String) -> std::result::Result<String, BoxedError> {
1678 Self::set_config(self, key, value).map_err(Into::into)
1679 }
1680
1681 async fn next_notice(self: &Arc<Self>) -> String {
1682 std::future::poll_fn(|cx| self.clone().notice_rx.lock().poll_recv(cx))
1683 .await
1684 .expect("notice channel should not be closed")
1685 }
1686
1687 fn transaction_status(&self) -> TransactionStatus {
1688 match &*self.txn.lock() {
1689 transaction::State::Initial | transaction::State::Implicit(_) => {
1690 TransactionStatus::Idle
1691 }
1692 transaction::State::Explicit(_) => TransactionStatus::InTransaction,
1693 }
1695 }
1696
1697 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
1699 let exec_context = Arc::new(ExecContext {
1700 running_sql: sql,
1701 last_instant: Instant::now(),
1702 last_idle_instant: self.last_idle_instant.clone(),
1703 });
1704 *self.exec_context.lock() = Some(Arc::downgrade(&exec_context));
1705 *self.last_idle_instant.lock() = None;
1707 ExecContextGuard::new(exec_context)
1708 }
1709
1710 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
1713 if matches!(self.transaction_status(), TransactionStatus::InTransaction) {
1715 let idle_in_transaction_session_timeout =
1716 self.config().idle_in_transaction_session_timeout() as u128;
1717 if idle_in_transaction_session_timeout != 0 {
1719 let guard = self.exec_context.lock();
1721 if guard.as_ref().and_then(|weak| weak.upgrade()).is_none() {
1723 if let Some(elapse_since_last_idle_instant) =
1725 self.elapse_since_last_idle_instant()
1726 {
1727 if elapse_since_last_idle_instant > idle_in_transaction_session_timeout {
1728 return Err(PsqlError::IdleInTxnTimeout);
1729 }
1730 }
1731 }
1732 }
1733 }
1734 Ok(())
1735 }
1736}
1737
1738fn infer(bound: Option<BoundStatement>, stmt: Statement) -> Result<Vec<PgFieldDescriptor>> {
1740 match stmt {
1741 Statement::Query(_)
1742 | Statement::Insert { .. }
1743 | Statement::Delete { .. }
1744 | Statement::Update { .. }
1745 | Statement::FetchCursor { .. } => Ok(bound
1746 .unwrap()
1747 .output_fields()
1748 .iter()
1749 .map(to_pg_field)
1750 .collect()),
1751 Statement::ShowObjects {
1752 object: show_object,
1753 ..
1754 } => Ok(infer_show_object(&show_object)),
1755 Statement::ShowCreateObject { .. } => Ok(infer_show_create_object()),
1756 Statement::ShowTransactionIsolationLevel => {
1757 let name = "transaction_isolation";
1758 Ok(infer_show_variable(name))
1759 }
1760 Statement::ShowVariable { variable } => {
1761 let name = &variable[0].real_value().to_lowercase();
1762 Ok(infer_show_variable(name))
1763 }
1764 Statement::Describe { name: _, kind } => Ok(infer_describe(&kind)),
1765 Statement::Explain { .. } => Ok(vec![PgFieldDescriptor::new(
1766 "QUERY PLAN".to_owned(),
1767 DataType::Varchar.to_oid(),
1768 DataType::Varchar.type_len(),
1769 )]),
1770 _ => Ok(vec![]),
1771 }
1772}
1773
1774pub struct WorkerProcessId {
1775 pub worker_id: u32,
1776 pub process_id: i32,
1777}
1778
1779impl WorkerProcessId {
1780 pub fn new(worker_id: u32, process_id: i32) -> Self {
1781 Self {
1782 worker_id,
1783 process_id,
1784 }
1785 }
1786}
1787
1788impl Display for WorkerProcessId {
1789 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1790 write!(f, "{}:{}", self.worker_id, self.process_id)
1791 }
1792}
1793
1794impl TryFrom<String> for WorkerProcessId {
1795 type Error = String;
1796
1797 fn try_from(worker_process_id: String) -> std::result::Result<Self, Self::Error> {
1798 const INVALID: &str = "invalid WorkerProcessId";
1799 let splits: Vec<&str> = worker_process_id.split(":").collect();
1800 if splits.len() != 2 {
1801 return Err(INVALID.to_owned());
1802 }
1803 let Ok(worker_id) = splits[0].parse::<u32>() else {
1804 return Err(INVALID.to_owned());
1805 };
1806 let Ok(process_id) = splits[1].parse::<i32>() else {
1807 return Err(INVALID.to_owned());
1808 };
1809 Ok(WorkerProcessId::new(worker_id, process_id))
1810 }
1811}
1812
1813pub fn cancel_queries_in_session(session_id: SessionId, sessions_map: SessionMapRef) -> bool {
1814 let guard = sessions_map.read();
1815 if let Some(session) = guard.get(&session_id) {
1816 session.cancel_current_query();
1817 true
1818 } else {
1819 info!("Current session finished, ignoring cancel query request");
1820 false
1821 }
1822}
1823
1824pub fn cancel_creating_jobs_in_session(session_id: SessionId, sessions_map: SessionMapRef) -> bool {
1825 let guard = sessions_map.read();
1826 if let Some(session) = guard.get(&session_id) {
1827 session.cancel_current_creating_job();
1828 true
1829 } else {
1830 info!("Current session finished, ignoring cancel creating request");
1831 false
1832 }
1833}