1use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::io::{Error, ErrorKind};
18use std::net::{IpAddr, Ipv4Addr, SocketAddr};
19use std::sync::atomic::{AtomicI32, Ordering};
20use std::sync::{Arc, Weak};
21use std::time::{Duration, Instant};
22
23use anyhow::anyhow;
24use bytes::Bytes;
25use either::Either;
26use itertools::Itertools;
27use parking_lot::{Mutex, RwLock, RwLockReadGuard};
28use pgwire::error::{PsqlError, PsqlResult};
29use pgwire::net::{Address, AddressRef};
30use pgwire::pg_field_descriptor::PgFieldDescriptor;
31use pgwire::pg_message::TransactionStatus;
32use pgwire::pg_response::{PgResponse, StatementType};
33use pgwire::pg_server::{
34 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
35 UserAuthenticator,
36};
37use pgwire::types::{Format, FormatIterator};
38use prometheus_http_query::Client as PrometheusClient;
39use rand::RngCore;
40use risingwave_batch::monitor::{BatchSpillMetrics, GLOBAL_BATCH_SPILL_METRICS};
41use risingwave_batch::spill::spill_op::SpillOp;
42use risingwave_batch::task::{ShutdownSender, ShutdownToken};
43use risingwave_batch::worker_manager::worker_node_manager::{
44 WorkerNodeManager, WorkerNodeManagerRef,
45};
46use risingwave_common::acl::AclMode;
47#[cfg(test)]
48use risingwave_common::catalog::{
49 DEFAULT_DATABASE_NAME, DEFAULT_SUPER_USER, DEFAULT_SUPER_USER_ID,
50};
51use risingwave_common::config::{
52 BatchConfig, FrontendConfig, MetaConfig, MetricLevel, StreamingConfig, UdfConfig, load_config,
53};
54use risingwave_common::memory::MemoryContext;
55use risingwave_common::secret::LocalSecretManager;
56use risingwave_common::session_config::{ConfigReporter, SessionConfig, VisibilityMode};
57use risingwave_common::system_param::local_manager::{
58 LocalSystemParamsManager, LocalSystemParamsManagerRef,
59};
60use risingwave_common::telemetry::manager::TelemetryManager;
61use risingwave_common::telemetry::telemetry_env_enabled;
62use risingwave_common::types::DataType;
63use risingwave_common::util::addr::HostAddr;
64use risingwave_common::util::cluster_limit;
65use risingwave_common::util::cluster_limit::ActorCountPerParallelism;
66use risingwave_common::util::iter_util::ZipEqFast;
67use risingwave_common::util::pretty_bytes::convert;
68use risingwave_common::util::runtime::BackgroundShutdownRuntime;
69use risingwave_common::{GIT_SHA, RW_VERSION};
70use risingwave_common_heap_profiling::HeapProfiler;
71use risingwave_common_service::{MetricsManager, ObserverManager};
72use risingwave_connector::source::monitor::{GLOBAL_SOURCE_METRICS, SourceMetrics};
73use risingwave_pb::common::WorkerType;
74use risingwave_pb::common::worker_node::Property as AddWorkerNodeProperty;
75use risingwave_pb::frontend_service::frontend_service_server::FrontendServiceServer;
76use risingwave_pb::health::health_server::HealthServer;
77use risingwave_pb::user::auth_info::EncryptionType;
78use risingwave_pb::user::grant_privilege::Object;
79use risingwave_rpc_client::{
80 ComputeClientPool, ComputeClientPoolRef, FrontendClientPool, FrontendClientPoolRef, MetaClient,
81};
82use risingwave_sqlparser::ast::{ObjectName, Statement};
83use risingwave_sqlparser::parser::Parser;
84use thiserror::Error;
85use thiserror_ext::AsReport;
86use tokio::runtime::Builder;
87use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
88use tokio::sync::oneshot::Sender;
89use tokio::sync::watch;
90use tokio::task::JoinHandle;
91use tracing::{error, info};
92
93use self::cursor_manager::CursorManager;
94use crate::binder::{Binder, BoundStatement, ResolveQualifiedNameError};
95use crate::catalog::catalog_service::{CatalogReader, CatalogWriter, CatalogWriterImpl};
96use crate::catalog::connection_catalog::ConnectionCatalog;
97use crate::catalog::root_catalog::{Catalog, SchemaPath};
98use crate::catalog::secret_catalog::SecretCatalog;
99use crate::catalog::source_catalog::SourceCatalog;
100use crate::catalog::subscription_catalog::SubscriptionCatalog;
101use crate::catalog::{
102 CatalogError, DatabaseId, OwnedByUserCatalog, SchemaId, TableId, check_schema_writable,
103};
104use crate::error::{ErrorCode, Result, RwError};
105use crate::handler::describe::infer_describe;
106use crate::handler::extended_handle::{
107 Portal, PrepareStatement, handle_bind, handle_execute, handle_parse,
108};
109use crate::handler::privilege::ObjectCheckItem;
110use crate::handler::show::{infer_show_create_object, infer_show_object};
111use crate::handler::util::to_pg_field;
112use crate::handler::variable::infer_show_variable;
113use crate::handler::{RwPgResponse, handle};
114use crate::health_service::HealthServiceImpl;
115use crate::meta_client::{FrontendMetaClient, FrontendMetaClientImpl};
116use crate::monitor::{CursorMetrics, FrontendMetrics, GLOBAL_FRONTEND_METRICS};
117use crate::observer::FrontendObserverNode;
118use crate::rpc::FrontendServiceImpl;
119use crate::scheduler::streaming_manager::{StreamingJobTracker, StreamingJobTrackerRef};
120use crate::scheduler::{
121 DistributedQueryMetrics, GLOBAL_DISTRIBUTED_QUERY_METRICS, HummockSnapshotManager,
122 HummockSnapshotManagerRef, QueryManager,
123};
124use crate::telemetry::FrontendTelemetryCreator;
125use crate::user::UserId;
126use crate::user::user_authentication::md5_hash_with_salt;
127use crate::user::user_manager::UserInfoManager;
128use crate::user::user_service::{UserInfoReader, UserInfoWriter, UserInfoWriterImpl};
129use crate::{FrontendOpts, PgResponseStream, TableCatalog};
130
131pub(crate) mod current;
132pub(crate) mod cursor_manager;
133pub(crate) mod transaction;
134
135#[derive(Clone)]
137pub(crate) struct FrontendEnv {
138 meta_client: Arc<dyn FrontendMetaClient>,
141 catalog_writer: Arc<dyn CatalogWriter>,
142 catalog_reader: CatalogReader,
143 user_info_writer: Arc<dyn UserInfoWriter>,
144 user_info_reader: UserInfoReader,
145 worker_node_manager: WorkerNodeManagerRef,
146 query_manager: QueryManager,
147 hummock_snapshot_manager: HummockSnapshotManagerRef,
148 system_params_manager: LocalSystemParamsManagerRef,
149 session_params: Arc<RwLock<SessionConfig>>,
150
151 server_addr: HostAddr,
152 client_pool: ComputeClientPoolRef,
153 frontend_client_pool: FrontendClientPoolRef,
154
155 sessions_map: SessionMapRef,
159
160 pub frontend_metrics: Arc<FrontendMetrics>,
161
162 pub cursor_metrics: Arc<CursorMetrics>,
163
164 source_metrics: Arc<SourceMetrics>,
165
166 spill_metrics: Arc<BatchSpillMetrics>,
168
169 batch_config: BatchConfig,
170 frontend_config: FrontendConfig,
171 #[expect(dead_code)]
172 meta_config: MetaConfig,
173 streaming_config: StreamingConfig,
174 udf_config: UdfConfig,
175
176 creating_streaming_job_tracker: StreamingJobTrackerRef,
179
180 compute_runtime: Arc<BackgroundShutdownRuntime>,
183
184 mem_context: MemoryContext,
186
187 serverless_backfill_controller_addr: String,
189
190 prometheus_client: Option<PrometheusClient>,
192
193 prometheus_selector: String,
195}
196
197pub type SessionMapRef = Arc<RwLock<HashMap<(i32, i32), Arc<SessionImpl>>>>;
199
200const FRONTEND_BATCH_MEMORY_PROPORTION: f64 = 0.5;
202
203impl FrontendEnv {
204 pub fn mock() -> Self {
205 use crate::test_utils::{MockCatalogWriter, MockFrontendMetaClient, MockUserInfoWriter};
206
207 let catalog = Arc::new(RwLock::new(Catalog::default()));
208 let meta_client = Arc::new(MockFrontendMetaClient {});
209 let hummock_snapshot_manager = Arc::new(HummockSnapshotManager::new(meta_client.clone()));
210 let catalog_writer = Arc::new(MockCatalogWriter::new(
211 catalog.clone(),
212 hummock_snapshot_manager.clone(),
213 ));
214 let catalog_reader = CatalogReader::new(catalog);
215 let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
216 let user_info_writer = Arc::new(MockUserInfoWriter::new(user_info_manager.clone()));
217 let user_info_reader = UserInfoReader::new(user_info_manager);
218 let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
219 let system_params_manager = Arc::new(LocalSystemParamsManager::for_test());
220 let compute_client_pool = Arc::new(ComputeClientPool::for_test());
221 let frontend_client_pool = Arc::new(FrontendClientPool::for_test());
222 let query_manager = QueryManager::new(
223 worker_node_manager.clone(),
224 compute_client_pool,
225 catalog_reader.clone(),
226 Arc::new(DistributedQueryMetrics::for_test()),
227 None,
228 None,
229 );
230 let server_addr = HostAddr::try_from("127.0.0.1:4565").unwrap();
231 let client_pool = Arc::new(ComputeClientPool::for_test());
232 let creating_streaming_tracker = StreamingJobTracker::new(meta_client.clone());
233 let runtime = {
234 let mut builder = Builder::new_multi_thread();
235 if let Some(frontend_compute_runtime_worker_threads) =
236 load_config("", FrontendOpts::default())
237 .batch
238 .frontend_compute_runtime_worker_threads
239 {
240 builder.worker_threads(frontend_compute_runtime_worker_threads);
241 }
242 builder
243 .thread_name("rw-batch-local")
244 .enable_all()
245 .build()
246 .unwrap()
247 };
248 let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(runtime));
249 let sessions_map = Arc::new(RwLock::new(HashMap::new()));
250 Self {
251 meta_client,
252 catalog_writer,
253 catalog_reader,
254 user_info_writer,
255 user_info_reader,
256 worker_node_manager,
257 query_manager,
258 hummock_snapshot_manager,
259 system_params_manager,
260 session_params: Default::default(),
261 server_addr,
262 client_pool,
263 frontend_client_pool,
264 sessions_map: sessions_map.clone(),
265 frontend_metrics: Arc::new(FrontendMetrics::for_test()),
266 cursor_metrics: Arc::new(CursorMetrics::for_test()),
267 batch_config: BatchConfig::default(),
268 frontend_config: FrontendConfig::default(),
269 meta_config: MetaConfig::default(),
270 streaming_config: StreamingConfig::default(),
271 udf_config: UdfConfig::default(),
272 source_metrics: Arc::new(SourceMetrics::default()),
273 spill_metrics: BatchSpillMetrics::for_test(),
274 creating_streaming_job_tracker: Arc::new(creating_streaming_tracker),
275 compute_runtime,
276 mem_context: MemoryContext::none(),
277 serverless_backfill_controller_addr: Default::default(),
278 prometheus_client: None,
279 prometheus_selector: String::new(),
280 }
281 }
282
283 pub async fn init(opts: FrontendOpts) -> Result<(Self, Vec<JoinHandle<()>>, Vec<Sender<()>>)> {
284 let config = load_config(&opts.config_path, &opts);
285 info!("Starting frontend node");
286 info!("> config: {:?}", config);
287 info!(
288 "> debug assertions: {}",
289 if cfg!(debug_assertions) { "on" } else { "off" }
290 );
291 info!("> version: {} ({})", RW_VERSION, GIT_SHA);
292
293 let frontend_address: HostAddr = opts
294 .advertise_addr
295 .as_ref()
296 .unwrap_or_else(|| {
297 tracing::warn!("advertise addr is not specified, defaulting to listen_addr");
298 &opts.listen_addr
299 })
300 .parse()
301 .unwrap();
302 info!("advertise addr is {}", frontend_address);
303
304 let rpc_addr: HostAddr = opts.frontend_rpc_listener_addr.parse().unwrap();
305 let internal_rpc_host_addr = HostAddr {
306 host: frontend_address.host.clone(),
308 port: rpc_addr.port,
309 };
310 let (meta_client, system_params_reader) = MetaClient::register_new(
312 opts.meta_addr,
313 WorkerType::Frontend,
314 &frontend_address,
315 AddWorkerNodeProperty {
316 internal_rpc_host_addr: internal_rpc_host_addr.to_string(),
317 ..Default::default()
318 },
319 Arc::new(config.meta.clone()),
320 )
321 .await;
322
323 let worker_id = meta_client.worker_id();
324 info!("Assigned worker node id {}", worker_id);
325
326 let (heartbeat_join_handle, heartbeat_shutdown_sender) = MetaClient::start_heartbeat_loop(
327 meta_client.clone(),
328 Duration::from_millis(config.server.heartbeat_interval_ms as u64),
329 );
330 let mut join_handles = vec![heartbeat_join_handle];
331 let mut shutdown_senders = vec![heartbeat_shutdown_sender];
332
333 let frontend_meta_client = Arc::new(FrontendMetaClientImpl(meta_client.clone()));
334 let hummock_snapshot_manager =
335 Arc::new(HummockSnapshotManager::new(frontend_meta_client.clone()));
336
337 let (catalog_updated_tx, catalog_updated_rx) = watch::channel(0);
338 let catalog = Arc::new(RwLock::new(Catalog::default()));
339 let catalog_writer = Arc::new(CatalogWriterImpl::new(
340 meta_client.clone(),
341 catalog_updated_rx.clone(),
342 hummock_snapshot_manager.clone(),
343 ));
344 let catalog_reader = CatalogReader::new(catalog.clone());
345
346 let worker_node_manager = Arc::new(WorkerNodeManager::new());
347
348 let compute_client_pool = Arc::new(ComputeClientPool::new(
349 config.batch_exchange_connection_pool_size(),
350 config.batch.developer.compute_client_config.clone(),
351 ));
352 let frontend_client_pool = Arc::new(FrontendClientPool::new(
353 1,
354 config.batch.developer.frontend_client_config.clone(),
355 ));
356 let query_manager = QueryManager::new(
357 worker_node_manager.clone(),
358 compute_client_pool.clone(),
359 catalog_reader.clone(),
360 Arc::new(GLOBAL_DISTRIBUTED_QUERY_METRICS.clone()),
361 config.batch.distributed_query_limit,
362 config.batch.max_batch_queries_per_frontend_node,
363 );
364
365 let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
366 let user_info_reader = UserInfoReader::new(user_info_manager.clone());
367 let user_info_writer = Arc::new(UserInfoWriterImpl::new(
368 meta_client.clone(),
369 catalog_updated_rx,
370 ));
371
372 let system_params_manager =
373 Arc::new(LocalSystemParamsManager::new(system_params_reader.clone()));
374
375 LocalSecretManager::init(
376 opts.temp_secret_file_dir,
377 meta_client.cluster_id().to_owned(),
378 worker_id,
379 );
380
381 let session_params = Arc::new(RwLock::new(SessionConfig::default()));
383 let sessions_map: SessionMapRef = Arc::new(RwLock::new(HashMap::new()));
384 let cursor_metrics = Arc::new(CursorMetrics::init(sessions_map.clone()));
385
386 let frontend_observer_node = FrontendObserverNode::new(
387 worker_node_manager.clone(),
388 catalog,
389 catalog_updated_tx,
390 user_info_manager,
391 hummock_snapshot_manager.clone(),
392 system_params_manager.clone(),
393 session_params.clone(),
394 compute_client_pool.clone(),
395 );
396 let observer_manager =
397 ObserverManager::new_with_meta_client(meta_client.clone(), frontend_observer_node)
398 .await;
399 let observer_join_handle = observer_manager.start().await;
400 join_handles.push(observer_join_handle);
401
402 meta_client.activate(&frontend_address).await?;
403
404 let frontend_metrics = Arc::new(GLOBAL_FRONTEND_METRICS.clone());
405 let source_metrics = Arc::new(GLOBAL_SOURCE_METRICS.clone());
406 let spill_metrics = Arc::new(GLOBAL_BATCH_SPILL_METRICS.clone());
407
408 if config.server.metrics_level > MetricLevel::Disabled {
409 MetricsManager::boot_metrics_service(opts.prometheus_listener_addr.clone());
410 }
411
412 let health_srv = HealthServiceImpl::new();
413 let frontend_srv = FrontendServiceImpl::new(sessions_map.clone());
414 let frontend_rpc_addr = opts.frontend_rpc_listener_addr.parse().unwrap();
415
416 let telemetry_manager = TelemetryManager::new(
417 Arc::new(meta_client.clone()),
418 Arc::new(FrontendTelemetryCreator::new()),
419 );
420
421 if config.server.telemetry_enabled && telemetry_env_enabled() {
424 let (join_handle, shutdown_sender) = telemetry_manager.start().await;
425 join_handles.push(join_handle);
426 shutdown_senders.push(shutdown_sender);
427 } else {
428 tracing::info!("Telemetry didn't start due to config");
429 }
430
431 tokio::spawn(async move {
432 tonic::transport::Server::builder()
433 .add_service(HealthServer::new(health_srv))
434 .add_service(FrontendServiceServer::new(frontend_srv))
435 .serve(frontend_rpc_addr)
436 .await
437 .unwrap();
438 });
439 info!(
440 "Health Check RPC Listener is set up on {}",
441 opts.frontend_rpc_listener_addr.clone()
442 );
443
444 let creating_streaming_job_tracker =
445 Arc::new(StreamingJobTracker::new(frontend_meta_client.clone()));
446
447 let runtime = {
448 let mut builder = Builder::new_multi_thread();
449 if let Some(frontend_compute_runtime_worker_threads) =
450 config.batch.frontend_compute_runtime_worker_threads
451 {
452 builder.worker_threads(frontend_compute_runtime_worker_threads);
453 }
454 builder
455 .thread_name("rw-batch-local")
456 .enable_all()
457 .build()
458 .unwrap()
459 };
460 let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(runtime));
461
462 let sessions = sessions_map.clone();
463 let join_handle = tokio::spawn(async move {
465 let mut check_idle_txn_interval =
466 tokio::time::interval(core::time::Duration::from_secs(10));
467 check_idle_txn_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
468 check_idle_txn_interval.reset();
469 loop {
470 check_idle_txn_interval.tick().await;
471 sessions.read().values().for_each(|session| {
472 let _ = session.check_idle_in_transaction_timeout();
473 })
474 }
475 });
476 join_handles.push(join_handle);
477
478 #[cfg(not(madsim))]
480 if config.batch.enable_spill {
481 SpillOp::clean_spill_directory()
482 .await
483 .map_err(|err| anyhow!(err))?;
484 }
485
486 let total_memory_bytes = opts.frontend_total_memory_bytes;
487 let heap_profiler =
488 HeapProfiler::new(total_memory_bytes, config.server.heap_profiling.clone());
489 heap_profiler.start();
491
492 let batch_memory_limit = total_memory_bytes as f64 * FRONTEND_BATCH_MEMORY_PROPORTION;
493 let mem_context = MemoryContext::root(
494 frontend_metrics.batch_total_mem.clone(),
495 batch_memory_limit as u64,
496 );
497
498 let prometheus_client = if let Some(ref endpoint) = opts.prometheus_endpoint {
500 match PrometheusClient::try_from(endpoint.as_str()) {
501 Ok(client) => {
502 info!("Prometheus client initialized with endpoint: {}", endpoint);
503 Some(client)
504 }
505 Err(e) => {
506 error!(
507 error = %e.as_report(),
508 "failed to initialize Prometheus client",
509 );
510 None
511 }
512 }
513 } else {
514 None
515 };
516
517 let prometheus_selector = opts.prometheus_selector.unwrap_or_default();
519
520 info!(
521 "Frontend total_memory: {} batch_memory: {}",
522 convert(total_memory_bytes as _),
523 convert(batch_memory_limit as _),
524 );
525
526 Ok((
527 Self {
528 catalog_reader,
529 catalog_writer,
530 user_info_reader,
531 user_info_writer,
532 worker_node_manager,
533 meta_client: frontend_meta_client,
534 query_manager,
535 hummock_snapshot_manager,
536 system_params_manager,
537 session_params,
538 server_addr: frontend_address,
539 client_pool: compute_client_pool,
540 frontend_client_pool,
541 frontend_metrics,
542 cursor_metrics,
543 spill_metrics,
544 sessions_map,
545 batch_config: config.batch,
546 frontend_config: config.frontend,
547 meta_config: config.meta,
548 streaming_config: config.streaming,
549 serverless_backfill_controller_addr: opts.serverless_backfill_controller_addr,
550 udf_config: config.udf,
551 source_metrics,
552 creating_streaming_job_tracker,
553 compute_runtime,
554 mem_context,
555 prometheus_client,
556 prometheus_selector,
557 },
558 join_handles,
559 shutdown_senders,
560 ))
561 }
562
563 fn catalog_writer(&self, _guard: transaction::WriteGuard) -> &dyn CatalogWriter {
568 &*self.catalog_writer
569 }
570
571 pub fn catalog_reader(&self) -> &CatalogReader {
573 &self.catalog_reader
574 }
575
576 fn user_info_writer(&self, _guard: transaction::WriteGuard) -> &dyn UserInfoWriter {
581 &*self.user_info_writer
582 }
583
584 pub fn user_info_reader(&self) -> &UserInfoReader {
586 &self.user_info_reader
587 }
588
589 pub fn worker_node_manager_ref(&self) -> WorkerNodeManagerRef {
590 self.worker_node_manager.clone()
591 }
592
593 pub fn meta_client(&self) -> &dyn FrontendMetaClient {
594 &*self.meta_client
595 }
596
597 pub fn meta_client_ref(&self) -> Arc<dyn FrontendMetaClient> {
598 self.meta_client.clone()
599 }
600
601 pub fn query_manager(&self) -> &QueryManager {
602 &self.query_manager
603 }
604
605 pub fn hummock_snapshot_manager(&self) -> &HummockSnapshotManagerRef {
606 &self.hummock_snapshot_manager
607 }
608
609 pub fn system_params_manager(&self) -> &LocalSystemParamsManagerRef {
610 &self.system_params_manager
611 }
612
613 pub fn session_params_snapshot(&self) -> SessionConfig {
614 self.session_params.read_recursive().clone()
615 }
616
617 pub fn sbc_address(&self) -> &String {
618 &self.serverless_backfill_controller_addr
619 }
620
621 pub fn prometheus_client(&self) -> Option<&PrometheusClient> {
623 self.prometheus_client.as_ref()
624 }
625
626 pub fn prometheus_selector(&self) -> &str {
628 &self.prometheus_selector
629 }
630
631 pub fn server_address(&self) -> &HostAddr {
632 &self.server_addr
633 }
634
635 pub fn client_pool(&self) -> ComputeClientPoolRef {
636 self.client_pool.clone()
637 }
638
639 pub fn frontend_client_pool(&self) -> FrontendClientPoolRef {
640 self.frontend_client_pool.clone()
641 }
642
643 pub fn batch_config(&self) -> &BatchConfig {
644 &self.batch_config
645 }
646
647 pub fn frontend_config(&self) -> &FrontendConfig {
648 &self.frontend_config
649 }
650
651 pub fn streaming_config(&self) -> &StreamingConfig {
652 &self.streaming_config
653 }
654
655 pub fn udf_config(&self) -> &UdfConfig {
656 &self.udf_config
657 }
658
659 pub fn source_metrics(&self) -> Arc<SourceMetrics> {
660 self.source_metrics.clone()
661 }
662
663 pub fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
664 self.spill_metrics.clone()
665 }
666
667 pub fn creating_streaming_job_tracker(&self) -> &StreamingJobTrackerRef {
668 &self.creating_streaming_job_tracker
669 }
670
671 pub fn sessions_map(&self) -> &SessionMapRef {
672 &self.sessions_map
673 }
674
675 pub fn compute_runtime(&self) -> Arc<BackgroundShutdownRuntime> {
676 self.compute_runtime.clone()
677 }
678
679 pub fn cancel_queries_in_session(&self, session_id: SessionId) -> bool {
682 cancel_queries_in_session(session_id, self.sessions_map.clone())
683 }
684
685 pub fn cancel_creating_jobs_in_session(&self, session_id: SessionId) -> bool {
688 cancel_creating_jobs_in_session(session_id, self.sessions_map.clone())
689 }
690
691 pub fn mem_context(&self) -> MemoryContext {
692 self.mem_context.clone()
693 }
694}
695
696#[derive(Clone)]
697pub struct AuthContext {
698 pub database: String,
699 pub user_name: String,
700 pub user_id: UserId,
701}
702
703impl AuthContext {
704 pub fn new(database: String, user_name: String, user_id: UserId) -> Self {
705 Self {
706 database,
707 user_name,
708 user_id,
709 }
710 }
711}
712pub struct SessionImpl {
713 env: FrontendEnv,
714 auth_context: Arc<RwLock<AuthContext>>,
715 user_authenticator: UserAuthenticator,
717 config_map: Arc<RwLock<SessionConfig>>,
719
720 notice_tx: UnboundedSender<String>,
722 notice_rx: Mutex<UnboundedReceiver<String>>,
724
725 id: (i32, i32),
727
728 peer_addr: AddressRef,
730
731 txn: Arc<Mutex<transaction::State>>,
735
736 current_query_cancel_flag: Mutex<Option<ShutdownSender>>,
740
741 exec_context: Mutex<Option<Weak<ExecContext>>>,
743
744 last_idle_instant: Arc<Mutex<Option<Instant>>>,
746
747 cursor_manager: Arc<CursorManager>,
748
749 temporary_source_manager: Arc<Mutex<TemporarySourceManager>>,
751}
752
753#[derive(Default, Clone)]
762pub struct TemporarySourceManager {
763 sources: HashMap<String, SourceCatalog>,
764}
765
766impl TemporarySourceManager {
767 pub fn new() -> Self {
768 Self {
769 sources: HashMap::new(),
770 }
771 }
772
773 pub fn create_source(&mut self, name: String, source: SourceCatalog) {
774 self.sources.insert(name, source);
775 }
776
777 pub fn drop_source(&mut self, name: &str) {
778 self.sources.remove(name);
779 }
780
781 pub fn get_source(&self, name: &str) -> Option<&SourceCatalog> {
782 self.sources.get(name)
783 }
784
785 pub fn keys(&self) -> Vec<String> {
786 self.sources.keys().cloned().collect()
787 }
788}
789
790#[derive(Error, Debug)]
791pub enum CheckRelationError {
792 #[error("{0}")]
793 Resolve(#[from] ResolveQualifiedNameError),
794 #[error("{0}")]
795 Catalog(#[from] CatalogError),
796}
797
798impl From<CheckRelationError> for RwError {
799 fn from(e: CheckRelationError) -> Self {
800 match e {
801 CheckRelationError::Resolve(e) => e.into(),
802 CheckRelationError::Catalog(e) => e.into(),
803 }
804 }
805}
806
807impl SessionImpl {
808 pub(crate) fn new(
809 env: FrontendEnv,
810 auth_context: AuthContext,
811 user_authenticator: UserAuthenticator,
812 id: SessionId,
813 peer_addr: AddressRef,
814 session_config: SessionConfig,
815 ) -> Self {
816 let cursor_metrics = env.cursor_metrics.clone();
817 let (notice_tx, notice_rx) = mpsc::unbounded_channel();
818
819 Self {
820 env,
821 auth_context: Arc::new(RwLock::new(auth_context)),
822 user_authenticator,
823 config_map: Arc::new(RwLock::new(session_config)),
824 id,
825 peer_addr,
826 txn: Default::default(),
827 current_query_cancel_flag: Mutex::new(None),
828 notice_tx,
829 notice_rx: Mutex::new(notice_rx),
830 exec_context: Mutex::new(None),
831 last_idle_instant: Default::default(),
832 cursor_manager: Arc::new(CursorManager::new(cursor_metrics)),
833 temporary_source_manager: Default::default(),
834 }
835 }
836
837 #[cfg(test)]
838 pub fn mock() -> Self {
839 let env = FrontendEnv::mock();
840 let (notice_tx, notice_rx) = mpsc::unbounded_channel();
841
842 Self {
843 env: FrontendEnv::mock(),
844 auth_context: Arc::new(RwLock::new(AuthContext::new(
845 DEFAULT_DATABASE_NAME.to_owned(),
846 DEFAULT_SUPER_USER.to_owned(),
847 DEFAULT_SUPER_USER_ID,
848 ))),
849 user_authenticator: UserAuthenticator::None,
850 config_map: Default::default(),
851 id: (0, 0),
853 txn: Default::default(),
854 current_query_cancel_flag: Mutex::new(None),
855 notice_tx,
856 notice_rx: Mutex::new(notice_rx),
857 exec_context: Mutex::new(None),
858 peer_addr: Address::Tcp(SocketAddr::new(
859 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
860 8080,
861 ))
862 .into(),
863 last_idle_instant: Default::default(),
864 cursor_manager: Arc::new(CursorManager::new(env.cursor_metrics.clone())),
865 temporary_source_manager: Default::default(),
866 }
867 }
868
869 pub(crate) fn env(&self) -> &FrontendEnv {
870 &self.env
871 }
872
873 pub fn auth_context(&self) -> Arc<AuthContext> {
874 let ctx = self.auth_context.read();
875 Arc::new(ctx.clone())
876 }
877
878 pub fn database(&self) -> String {
879 self.auth_context.read().database.clone()
880 }
881
882 pub fn database_id(&self) -> DatabaseId {
883 let db_name = self.database();
884 self.env
885 .catalog_reader()
886 .read_guard()
887 .get_database_by_name(&db_name)
888 .map(|db| db.id())
889 .expect("session database not found")
890 }
891
892 pub fn user_name(&self) -> String {
893 self.auth_context.read().user_name.clone()
894 }
895
896 pub fn user_id(&self) -> UserId {
897 self.auth_context.read().user_id
898 }
899
900 pub fn update_database(&self, database: String) {
901 self.auth_context.write().database = database;
902 }
903
904 pub fn shared_config(&self) -> Arc<RwLock<SessionConfig>> {
905 Arc::clone(&self.config_map)
906 }
907
908 pub fn config(&self) -> RwLockReadGuard<'_, SessionConfig> {
909 self.config_map.read()
910 }
911
912 pub fn set_config(&self, key: &str, value: String) -> Result<String> {
913 self.config_map
914 .write()
915 .set(key, value, &mut ())
916 .map_err(Into::into)
917 }
918
919 pub fn reset_config(&self, key: &str) -> Result<String> {
920 self.config_map
921 .write()
922 .reset(key, &mut ())
923 .map_err(Into::into)
924 }
925
926 pub fn set_config_report(
927 &self,
928 key: &str,
929 value: Option<String>,
930 mut reporter: impl ConfigReporter,
931 ) -> Result<String> {
932 if let Some(value) = value {
933 self.config_map
934 .write()
935 .set(key, value, &mut reporter)
936 .map_err(Into::into)
937 } else {
938 self.config_map
939 .write()
940 .reset(key, &mut reporter)
941 .map_err(Into::into)
942 }
943 }
944
945 pub fn session_id(&self) -> SessionId {
946 self.id
947 }
948
949 pub fn running_sql(&self) -> Option<Arc<str>> {
950 self.exec_context
951 .lock()
952 .as_ref()
953 .and_then(|weak| weak.upgrade())
954 .map(|context| context.running_sql.clone())
955 }
956
957 pub fn get_cursor_manager(&self) -> Arc<CursorManager> {
958 self.cursor_manager.clone()
959 }
960
961 pub fn peer_addr(&self) -> &Address {
962 &self.peer_addr
963 }
964
965 pub fn elapse_since_running_sql(&self) -> Option<u128> {
966 self.exec_context
967 .lock()
968 .as_ref()
969 .and_then(|weak| weak.upgrade())
970 .map(|context| context.last_instant.elapsed().as_millis())
971 }
972
973 pub fn elapse_since_last_idle_instant(&self) -> Option<u128> {
974 self.last_idle_instant
975 .lock()
976 .as_ref()
977 .map(|x| x.elapsed().as_millis())
978 }
979
980 pub fn check_relation_name_duplicated(
981 &self,
982 name: ObjectName,
983 stmt_type: StatementType,
984 if_not_exists: bool,
985 ) -> std::result::Result<Either<(), RwPgResponse>, CheckRelationError> {
986 let db_name = &self.database();
987 let catalog_reader = self.env().catalog_reader().read_guard();
988 let (schema_name, relation_name) = {
989 let (schema_name, relation_name) =
990 Binder::resolve_schema_qualified_name(db_name, &name)?;
991 let search_path = self.config().search_path();
992 let user_name = &self.user_name();
993 let schema_name = match schema_name {
994 Some(schema_name) => schema_name,
995 None => catalog_reader
996 .first_valid_schema(db_name, &search_path, user_name)?
997 .name(),
998 };
999 (schema_name, relation_name)
1000 };
1001 match catalog_reader.check_relation_name_duplicated(db_name, &schema_name, &relation_name) {
1002 Err(CatalogError::Duplicated(_, name, is_creating)) if if_not_exists => {
1003 if !is_creating {
1008 Ok(Either::Right(
1009 PgResponse::builder(stmt_type)
1010 .notice(format!("relation \"{}\" already exists, skipping", name))
1011 .into(),
1012 ))
1013 } else if stmt_type == StatementType::CREATE_SUBSCRIPTION {
1014 Ok(Either::Right(
1017 PgResponse::builder(stmt_type)
1018 .notice(format!(
1019 "relation \"{}\" already exists but still creating, skipping",
1020 name
1021 ))
1022 .into(),
1023 ))
1024 } else {
1025 Ok(Either::Left(()))
1026 }
1027 }
1028 Err(e) => Err(e.into()),
1029 Ok(_) => Ok(Either::Left(())),
1030 }
1031 }
1032
1033 pub fn check_secret_name_duplicated(&self, name: ObjectName) -> Result<()> {
1034 let db_name = &self.database();
1035 let catalog_reader = self.env().catalog_reader().read_guard();
1036 let (schema_name, secret_name) = {
1037 let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, &name)?;
1038 let search_path = self.config().search_path();
1039 let user_name = &self.user_name();
1040 let schema_name = match schema_name {
1041 Some(schema_name) => schema_name,
1042 None => catalog_reader
1043 .first_valid_schema(db_name, &search_path, user_name)?
1044 .name(),
1045 };
1046 (schema_name, secret_name)
1047 };
1048 catalog_reader
1049 .check_secret_name_duplicated(db_name, &schema_name, &secret_name)
1050 .map_err(RwError::from)
1051 }
1052
1053 pub fn check_connection_name_duplicated(&self, name: ObjectName) -> Result<()> {
1054 let db_name = &self.database();
1055 let catalog_reader = self.env().catalog_reader().read_guard();
1056 let (schema_name, connection_name) = {
1057 let (schema_name, connection_name) =
1058 Binder::resolve_schema_qualified_name(db_name, &name)?;
1059 let search_path = self.config().search_path();
1060 let user_name = &self.user_name();
1061 let schema_name = match schema_name {
1062 Some(schema_name) => schema_name,
1063 None => catalog_reader
1064 .first_valid_schema(db_name, &search_path, user_name)?
1065 .name(),
1066 };
1067 (schema_name, connection_name)
1068 };
1069 catalog_reader
1070 .check_connection_name_duplicated(db_name, &schema_name, &connection_name)
1071 .map_err(RwError::from)
1072 }
1073
1074 pub fn check_function_name_duplicated(
1075 &self,
1076 stmt_type: StatementType,
1077 name: ObjectName,
1078 arg_types: &[DataType],
1079 if_not_exists: bool,
1080 ) -> Result<Either<(), RwPgResponse>> {
1081 let db_name = &self.database();
1082 let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, &name)?;
1083 let (database_id, schema_id) = self.get_database_and_schema_id_for_create(schema_name)?;
1084
1085 let catalog_reader = self.env().catalog_reader().read_guard();
1086 if catalog_reader
1087 .get_schema_by_id(&database_id, &schema_id)?
1088 .get_function_by_name_args(&function_name, arg_types)
1089 .is_some()
1090 {
1091 let full_name = format!(
1092 "{function_name}({})",
1093 arg_types.iter().map(|t| t.to_string()).join(",")
1094 );
1095 if if_not_exists {
1096 Ok(Either::Right(
1097 PgResponse::builder(stmt_type)
1098 .notice(format!(
1099 "function \"{}\" already exists, skipping",
1100 full_name
1101 ))
1102 .into(),
1103 ))
1104 } else {
1105 Err(CatalogError::duplicated("function", full_name).into())
1106 }
1107 } else {
1108 Ok(Either::Left(()))
1109 }
1110 }
1111
1112 pub fn get_database_and_schema_id_for_create(
1114 &self,
1115 schema_name: Option<String>,
1116 ) -> Result<(DatabaseId, SchemaId)> {
1117 let db_name = &self.database();
1118
1119 let search_path = self.config().search_path();
1120 let user_name = &self.user_name();
1121
1122 let catalog_reader = self.env().catalog_reader().read_guard();
1123 let schema = match schema_name {
1124 Some(schema_name) => catalog_reader.get_schema_by_name(db_name, &schema_name)?,
1125 None => catalog_reader.first_valid_schema(db_name, &search_path, user_name)?,
1126 };
1127 let schema_name = schema.name();
1128
1129 check_schema_writable(&schema_name)?;
1130 self.check_privileges(&[ObjectCheckItem::new(
1131 schema.owner(),
1132 AclMode::Create,
1133 schema_name,
1134 Object::SchemaId(schema.id()),
1135 )])?;
1136
1137 let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1138 Ok((db_id, schema.id()))
1139 }
1140
1141 pub fn get_connection_by_name(
1142 &self,
1143 schema_name: Option<String>,
1144 connection_name: &str,
1145 ) -> Result<Arc<ConnectionCatalog>> {
1146 let db_name = &self.database();
1147 let search_path = self.config().search_path();
1148 let user_name = &self.user_name();
1149
1150 let catalog_reader = self.env().catalog_reader().read_guard();
1151 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1152 let (connection, _) =
1153 catalog_reader.get_connection_by_name(db_name, schema_path, connection_name)?;
1154
1155 self.check_privileges(&[ObjectCheckItem::new(
1156 connection.owner(),
1157 AclMode::Usage,
1158 connection.name.clone(),
1159 Object::ConnectionId(connection.id),
1160 )])?;
1161
1162 Ok(connection.clone())
1163 }
1164
1165 pub fn get_subscription_by_schema_id_name(
1166 &self,
1167 schema_id: SchemaId,
1168 subscription_name: &str,
1169 ) -> Result<Arc<SubscriptionCatalog>> {
1170 let db_name = &self.database();
1171
1172 let catalog_reader = self.env().catalog_reader().read_guard();
1173 let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1174 let schema = catalog_reader.get_schema_by_id(&db_id, &schema_id)?;
1175 let subscription = schema
1176 .get_subscription_by_name(subscription_name)
1177 .ok_or_else(|| {
1178 RwError::from(ErrorCode::ItemNotFound(format!(
1179 "subscription {} not found",
1180 subscription_name
1181 )))
1182 })?;
1183 Ok(subscription.clone())
1184 }
1185
1186 pub fn get_subscription_by_name(
1187 &self,
1188 schema_name: Option<String>,
1189 subscription_name: &str,
1190 ) -> Result<Arc<SubscriptionCatalog>> {
1191 let db_name = &self.database();
1192 let search_path = self.config().search_path();
1193 let user_name = &self.user_name();
1194
1195 let catalog_reader = self.env().catalog_reader().read_guard();
1196 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1197 let (subscription, _) =
1198 catalog_reader.get_subscription_by_name(db_name, schema_path, subscription_name)?;
1199 Ok(subscription.clone())
1200 }
1201
1202 pub fn get_table_by_id(&self, table_id: &TableId) -> Result<Arc<TableCatalog>> {
1203 let catalog_reader = self.env().catalog_reader().read_guard();
1204 Ok(catalog_reader.get_any_table_by_id(table_id)?.clone())
1205 }
1206
1207 pub fn get_table_by_name(
1208 &self,
1209 table_name: &str,
1210 db_id: u32,
1211 schema_id: u32,
1212 ) -> Result<Arc<TableCatalog>> {
1213 let catalog_reader = self.env().catalog_reader().read_guard();
1214 let table = catalog_reader
1215 .get_schema_by_id(&DatabaseId::from(db_id), &SchemaId::from(schema_id))?
1216 .get_created_table_by_name(table_name)
1217 .ok_or_else(|| {
1218 Error::new(
1219 ErrorKind::InvalidInput,
1220 format!("table \"{}\" does not exist", table_name),
1221 )
1222 })?;
1223
1224 self.check_privileges(&[ObjectCheckItem::new(
1225 table.owner(),
1226 AclMode::Select,
1227 table_name.to_owned(),
1228 Object::TableId(table.id.table_id()),
1229 )])?;
1230
1231 Ok(table.clone())
1232 }
1233
1234 pub fn get_secret_by_name(
1235 &self,
1236 schema_name: Option<String>,
1237 secret_name: &str,
1238 ) -> Result<Arc<SecretCatalog>> {
1239 let db_name = &self.database();
1240 let search_path = self.config().search_path();
1241 let user_name = &self.user_name();
1242
1243 let catalog_reader = self.env().catalog_reader().read_guard();
1244 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1245 let (secret, _) = catalog_reader.get_secret_by_name(db_name, schema_path, secret_name)?;
1246
1247 self.check_privileges(&[ObjectCheckItem::new(
1248 secret.owner(),
1249 AclMode::Usage,
1250 secret.name.clone(),
1251 Object::SecretId(secret.id.secret_id()),
1252 )])?;
1253
1254 Ok(secret.clone())
1255 }
1256
1257 pub fn list_change_log_epochs(
1258 &self,
1259 table_id: u32,
1260 min_epoch: u64,
1261 max_count: u32,
1262 ) -> Result<Vec<u64>> {
1263 Ok(self
1264 .env
1265 .hummock_snapshot_manager()
1266 .acquire()
1267 .list_change_log_epochs(table_id, min_epoch, max_count))
1268 }
1269
1270 pub fn clear_cancel_query_flag(&self) {
1271 let mut flag = self.current_query_cancel_flag.lock();
1272 *flag = None;
1273 }
1274
1275 pub fn reset_cancel_query_flag(&self) -> ShutdownToken {
1276 let mut flag = self.current_query_cancel_flag.lock();
1277 let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
1278 *flag = Some(shutdown_tx);
1279 shutdown_rx
1280 }
1281
1282 pub fn cancel_current_query(&self) {
1283 let mut flag_guard = self.current_query_cancel_flag.lock();
1284 if let Some(sender) = flag_guard.take() {
1285 info!("Trying to cancel query in local mode.");
1286 sender.cancel();
1288 info!("Cancel query request sent.");
1289 } else {
1290 info!("Trying to cancel query in distributed mode.");
1291 self.env.query_manager().cancel_queries_in_session(self.id)
1292 }
1293 }
1294
1295 pub fn cancel_current_creating_job(&self) {
1296 self.env.creating_streaming_job_tracker.abort_jobs(self.id);
1297 }
1298
1299 pub async fn run_statement(
1302 self: Arc<Self>,
1303 sql: Arc<str>,
1304 formats: Vec<Format>,
1305 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1306 let mut stmts = Parser::parse_sql(&sql)?;
1308 if stmts.is_empty() {
1309 return Ok(PgResponse::empty_result(
1310 pgwire::pg_response::StatementType::EMPTY,
1311 ));
1312 }
1313 if stmts.len() > 1 {
1314 return Ok(
1315 PgResponse::builder(pgwire::pg_response::StatementType::EMPTY)
1316 .notice("cannot insert multiple commands into statement")
1317 .into(),
1318 );
1319 }
1320 let stmt = stmts.swap_remove(0);
1321 let rsp = handle(self, stmt, sql.clone(), formats).await?;
1322 Ok(rsp)
1323 }
1324
1325 pub fn notice_to_user(&self, str: impl Into<String>) {
1326 let notice = str.into();
1327 tracing::trace!(notice, "notice to user");
1328 self.notice_tx
1329 .send(notice)
1330 .expect("notice channel should not be closed");
1331 }
1332
1333 pub fn is_barrier_read(&self) -> bool {
1334 match self.config().visibility_mode() {
1335 VisibilityMode::Default => self.env.batch_config.enable_barrier_read,
1336 VisibilityMode::All => true,
1337 VisibilityMode::Checkpoint => false,
1338 }
1339 }
1340
1341 pub fn statement_timeout(&self) -> Duration {
1342 if self.config().statement_timeout() == 0 {
1343 Duration::from_secs(self.env.batch_config.statement_timeout_in_sec as u64)
1344 } else {
1345 Duration::from_secs(self.config().statement_timeout() as u64)
1346 }
1347 }
1348
1349 pub fn create_temporary_source(&self, source: SourceCatalog) {
1350 self.temporary_source_manager
1351 .lock()
1352 .create_source(source.name.clone(), source);
1353 }
1354
1355 pub fn get_temporary_source(&self, name: &str) -> Option<SourceCatalog> {
1356 self.temporary_source_manager
1357 .lock()
1358 .get_source(name)
1359 .cloned()
1360 }
1361
1362 pub fn drop_temporary_source(&self, name: &str) {
1363 self.temporary_source_manager.lock().drop_source(name);
1364 }
1365
1366 pub fn temporary_source_manager(&self) -> TemporarySourceManager {
1367 self.temporary_source_manager.lock().clone()
1368 }
1369
1370 pub async fn check_cluster_limits(&self) -> Result<()> {
1371 if self.config().bypass_cluster_limits() {
1372 return Ok(());
1373 }
1374
1375 let gen_message = |ActorCountPerParallelism {
1376 worker_id_to_actor_count,
1377 hard_limit,
1378 soft_limit,
1379 }: ActorCountPerParallelism,
1380 exceed_hard_limit: bool|
1381 -> String {
1382 let (limit_type, action) = if exceed_hard_limit {
1383 ("critical", "Scale the cluster immediately to proceed.")
1384 } else {
1385 (
1386 "recommended",
1387 "Consider scaling the cluster for optimal performance.",
1388 )
1389 };
1390 format!(
1391 r#"Actor count per parallelism exceeds the {limit_type} limit.
1392
1393Depending on your workload, this may overload the cluster and cause performance/stability issues. {action}
1394
1395HINT:
1396- For best practices on managing streaming jobs: https://docs.risingwave.com/operate/manage-a-large-number-of-streaming-jobs
1397- To bypass the check (if the cluster load is acceptable): `[ALTER SYSTEM] SET bypass_cluster_limits TO true`.
1398 See https://docs.risingwave.com/operate/view-configure-runtime-parameters#how-to-configure-runtime-parameters
1399- Contact us via slack or https://risingwave.com/contact-us/ for further enquiry.
1400
1401DETAILS:
1402- hard limit: {hard_limit}
1403- soft limit: {soft_limit}
1404- worker_id_to_actor_count: {worker_id_to_actor_count:?}"#,
1405 )
1406 };
1407
1408 let limits = self.env().meta_client().get_cluster_limits().await?;
1409 for limit in limits {
1410 match limit {
1411 cluster_limit::ClusterLimit::ActorCount(l) => {
1412 if l.exceed_hard_limit() {
1413 return Err(RwError::from(ErrorCode::ProtocolError(gen_message(
1414 l, true,
1415 ))));
1416 } else if l.exceed_soft_limit() {
1417 self.notice_to_user(gen_message(l, false));
1418 }
1419 }
1420 }
1421 }
1422 Ok(())
1423 }
1424}
1425
1426pub static SESSION_MANAGER: std::sync::OnceLock<Arc<SessionManagerImpl>> =
1427 std::sync::OnceLock::new();
1428
1429pub struct SessionManagerImpl {
1430 env: FrontendEnv,
1431 _join_handles: Vec<JoinHandle<()>>,
1432 _shutdown_senders: Vec<Sender<()>>,
1433 number: AtomicI32,
1434}
1435
1436impl SessionManager for SessionManagerImpl {
1437 type Session = SessionImpl;
1438
1439 fn create_dummy_session(
1440 &self,
1441 database_id: u32,
1442 user_id: u32,
1443 ) -> std::result::Result<Arc<Self::Session>, BoxedError> {
1444 let dummy_addr = Address::Tcp(SocketAddr::new(
1445 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
1446 5691, ));
1448 let user_reader = self.env.user_info_reader();
1449 let reader = user_reader.read_guard();
1450 if let Some(user_name) = reader.get_user_name_by_id(user_id) {
1451 self.connect_inner(database_id, user_name.as_str(), Arc::new(dummy_addr))
1452 } else {
1453 Err(Box::new(Error::new(
1454 ErrorKind::InvalidInput,
1455 format!("Role id {} does not exist", user_id),
1456 )))
1457 }
1458 }
1459
1460 fn connect(
1461 &self,
1462 database: &str,
1463 user_name: &str,
1464 peer_addr: AddressRef,
1465 ) -> std::result::Result<Arc<Self::Session>, BoxedError> {
1466 let catalog_reader = self.env.catalog_reader();
1467 let reader = catalog_reader.read_guard();
1468 let database_id = reader
1469 .get_database_by_name(database)
1470 .map_err(|_| {
1471 Box::new(Error::new(
1472 ErrorKind::InvalidInput,
1473 format!("database \"{}\" does not exist", database),
1474 ))
1475 })?
1476 .id();
1477
1478 self.connect_inner(database_id, user_name, peer_addr)
1479 }
1480
1481 fn cancel_queries_in_session(&self, session_id: SessionId) {
1483 self.env.cancel_queries_in_session(session_id);
1484 }
1485
1486 fn cancel_creating_jobs_in_session(&self, session_id: SessionId) {
1487 self.env.cancel_creating_jobs_in_session(session_id);
1488 }
1489
1490 fn end_session(&self, session: &Self::Session) {
1491 self.delete_session(&session.session_id());
1492 }
1493
1494 async fn shutdown(&self) {
1495 self.env.sessions_map().write().clear();
1497 self.env.meta_client().try_unregister().await;
1499 }
1500}
1501
1502impl SessionManagerImpl {
1503 pub async fn new(opts: FrontendOpts) -> Result<Self> {
1504 let (env, join_handles, shutdown_senders) = FrontendEnv::init(opts).await?;
1506 Ok(Self {
1507 env,
1508 _join_handles: join_handles,
1509 _shutdown_senders: shutdown_senders,
1510 number: AtomicI32::new(0),
1511 })
1512 }
1513
1514 pub(crate) fn env(&self) -> &FrontendEnv {
1515 &self.env
1516 }
1517
1518 fn insert_session(&self, session: Arc<SessionImpl>) {
1519 let active_sessions = {
1520 let mut write_guard = self.env.sessions_map.write();
1521 write_guard.insert(session.id(), session);
1522 write_guard.len()
1523 };
1524 self.env
1525 .frontend_metrics
1526 .active_sessions
1527 .set(active_sessions as i64);
1528 }
1529
1530 fn delete_session(&self, session_id: &SessionId) {
1531 let active_sessions = {
1532 let mut write_guard = self.env.sessions_map.write();
1533 write_guard.remove(session_id);
1534 write_guard.len()
1535 };
1536 self.env
1537 .frontend_metrics
1538 .active_sessions
1539 .set(active_sessions as i64);
1540 }
1541
1542 fn connect_inner(
1543 &self,
1544 database_id: u32,
1545 user_name: &str,
1546 peer_addr: AddressRef,
1547 ) -> std::result::Result<Arc<SessionImpl>, BoxedError> {
1548 let catalog_reader = self.env.catalog_reader();
1549 let reader = catalog_reader.read_guard();
1550 let (database_name, database_owner) = {
1551 let db = reader.get_database_by_id(&database_id).map_err(|_| {
1552 Box::new(Error::new(
1553 ErrorKind::InvalidInput,
1554 format!("database \"{}\" does not exist", database_id),
1555 ))
1556 })?;
1557 (db.name(), db.owner())
1558 };
1559
1560 let user_reader = self.env.user_info_reader();
1561 let reader = user_reader.read_guard();
1562 if let Some(user) = reader.get_user_by_name(user_name) {
1563 if !user.can_login {
1564 return Err(Box::new(Error::new(
1565 ErrorKind::InvalidInput,
1566 format!("User {} is not allowed to login", user_name),
1567 )));
1568 }
1569 let has_privilege =
1570 user.has_privilege(&Object::DatabaseId(database_id), AclMode::Connect);
1571 if !user.is_super && database_owner != user.id && !has_privilege {
1572 return Err(Box::new(Error::new(
1573 ErrorKind::PermissionDenied,
1574 "User does not have CONNECT privilege.",
1575 )));
1576 }
1577 let user_authenticator = match &user.auth_info {
1578 None => UserAuthenticator::None,
1579 Some(auth_info) => {
1580 if auth_info.encryption_type == EncryptionType::Plaintext as i32 {
1581 UserAuthenticator::ClearText(auth_info.encrypted_value.clone())
1582 } else if auth_info.encryption_type == EncryptionType::Md5 as i32 {
1583 let mut salt = [0; 4];
1584 let mut rng = rand::rng();
1585 rng.fill_bytes(&mut salt);
1586 UserAuthenticator::Md5WithSalt {
1587 encrypted_password: md5_hash_with_salt(
1588 &auth_info.encrypted_value,
1589 &salt,
1590 ),
1591 salt,
1592 }
1593 } else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
1594 UserAuthenticator::OAuth(auth_info.metadata.clone())
1595 } else {
1596 return Err(Box::new(Error::new(
1597 ErrorKind::Unsupported,
1598 format!("Unsupported auth type: {}", auth_info.encryption_type),
1599 )));
1600 }
1601 }
1602 };
1603
1604 let secret_key = self.number.fetch_add(1, Ordering::Relaxed);
1606 let id = (secret_key, secret_key);
1608 let session_config = self.env.session_params_snapshot();
1610
1611 let session_impl: Arc<SessionImpl> = SessionImpl::new(
1612 self.env.clone(),
1613 AuthContext::new(database_name.to_owned(), user_name.to_owned(), user.id),
1614 user_authenticator,
1615 id,
1616 peer_addr,
1617 session_config,
1618 )
1619 .into();
1620 self.insert_session(session_impl.clone());
1621
1622 Ok(session_impl)
1623 } else {
1624 Err(Box::new(Error::new(
1625 ErrorKind::InvalidInput,
1626 format!("Role {} does not exist", user_name),
1627 )))
1628 }
1629 }
1630}
1631
1632impl Session for SessionImpl {
1633 type Portal = Portal;
1634 type PreparedStatement = PrepareStatement;
1635 type ValuesStream = PgResponseStream;
1636
1637 async fn run_one_query(
1640 self: Arc<Self>,
1641 stmt: Statement,
1642 format: Format,
1643 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1644 let string = stmt.to_string();
1645 let sql_str = string.as_str();
1646 let sql: Arc<str> = Arc::from(sql_str);
1647 drop(string);
1649 let rsp = handle(self, stmt, sql, vec![format]).await?;
1650 Ok(rsp)
1651 }
1652
1653 fn user_authenticator(&self) -> &UserAuthenticator {
1654 &self.user_authenticator
1655 }
1656
1657 fn id(&self) -> SessionId {
1658 self.id
1659 }
1660
1661 async fn parse(
1662 self: Arc<Self>,
1663 statement: Option<Statement>,
1664 params_types: Vec<Option<DataType>>,
1665 ) -> std::result::Result<PrepareStatement, BoxedError> {
1666 Ok(if let Some(statement) = statement {
1667 handle_parse(self, statement, params_types).await?
1668 } else {
1669 PrepareStatement::Empty
1670 })
1671 }
1672
1673 fn bind(
1674 self: Arc<Self>,
1675 prepare_statement: PrepareStatement,
1676 params: Vec<Option<Bytes>>,
1677 param_formats: Vec<Format>,
1678 result_formats: Vec<Format>,
1679 ) -> std::result::Result<Portal, BoxedError> {
1680 Ok(handle_bind(
1681 prepare_statement,
1682 params,
1683 param_formats,
1684 result_formats,
1685 )?)
1686 }
1687
1688 async fn execute(
1689 self: Arc<Self>,
1690 portal: Portal,
1691 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1692 let rsp = handle_execute(self, portal).await?;
1693 Ok(rsp)
1694 }
1695
1696 fn describe_statement(
1697 self: Arc<Self>,
1698 prepare_statement: PrepareStatement,
1699 ) -> std::result::Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
1700 Ok(match prepare_statement {
1701 PrepareStatement::Empty => (vec![], vec![]),
1702 PrepareStatement::Prepared(prepare_statement) => (
1703 prepare_statement.bound_result.param_types,
1704 infer(
1705 Some(prepare_statement.bound_result.bound),
1706 prepare_statement.statement,
1707 )?,
1708 ),
1709 PrepareStatement::PureStatement(statement) => (vec![], infer(None, statement)?),
1710 })
1711 }
1712
1713 fn describe_portal(
1714 self: Arc<Self>,
1715 portal: Portal,
1716 ) -> std::result::Result<Vec<PgFieldDescriptor>, BoxedError> {
1717 match portal {
1718 Portal::Empty => Ok(vec![]),
1719 Portal::Portal(portal) => {
1720 let mut columns = infer(Some(portal.bound_result.bound), portal.statement)?;
1721 let formats = FormatIterator::new(&portal.result_formats, columns.len())?;
1722 columns.iter_mut().zip_eq_fast(formats).for_each(|(c, f)| {
1723 if f == Format::Binary {
1724 c.set_to_binary()
1725 }
1726 });
1727 Ok(columns)
1728 }
1729 Portal::PureStatement(statement) => Ok(infer(None, statement)?),
1730 }
1731 }
1732
1733 fn get_config(&self, key: &str) -> std::result::Result<String, BoxedError> {
1734 self.config().get(key).map_err(Into::into)
1735 }
1736
1737 fn set_config(&self, key: &str, value: String) -> std::result::Result<String, BoxedError> {
1738 Self::set_config(self, key, value).map_err(Into::into)
1739 }
1740
1741 async fn next_notice(self: &Arc<Self>) -> String {
1742 std::future::poll_fn(|cx| self.clone().notice_rx.lock().poll_recv(cx))
1743 .await
1744 .expect("notice channel should not be closed")
1745 }
1746
1747 fn transaction_status(&self) -> TransactionStatus {
1748 match &*self.txn.lock() {
1749 transaction::State::Initial | transaction::State::Implicit(_) => {
1750 TransactionStatus::Idle
1751 }
1752 transaction::State::Explicit(_) => TransactionStatus::InTransaction,
1753 }
1755 }
1756
1757 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
1759 let exec_context = Arc::new(ExecContext {
1760 running_sql: sql,
1761 last_instant: Instant::now(),
1762 last_idle_instant: self.last_idle_instant.clone(),
1763 });
1764 *self.exec_context.lock() = Some(Arc::downgrade(&exec_context));
1765 *self.last_idle_instant.lock() = None;
1767 ExecContextGuard::new(exec_context)
1768 }
1769
1770 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
1773 if matches!(self.transaction_status(), TransactionStatus::InTransaction) {
1775 let idle_in_transaction_session_timeout =
1776 self.config().idle_in_transaction_session_timeout() as u128;
1777 if idle_in_transaction_session_timeout != 0 {
1779 let guard = self.exec_context.lock();
1781 if guard.as_ref().and_then(|weak| weak.upgrade()).is_none() {
1783 if let Some(elapse_since_last_idle_instant) =
1785 self.elapse_since_last_idle_instant()
1786 && elapse_since_last_idle_instant > idle_in_transaction_session_timeout
1787 {
1788 return Err(PsqlError::IdleInTxnTimeout);
1789 }
1790 }
1791 }
1792 }
1793 Ok(())
1794 }
1795}
1796
1797fn infer(bound: Option<BoundStatement>, stmt: Statement) -> Result<Vec<PgFieldDescriptor>> {
1799 match stmt {
1800 Statement::Query(_)
1801 | Statement::Insert { .. }
1802 | Statement::Delete { .. }
1803 | Statement::Update { .. }
1804 | Statement::FetchCursor { .. } => Ok(bound
1805 .unwrap()
1806 .output_fields()
1807 .iter()
1808 .map(to_pg_field)
1809 .collect()),
1810 Statement::ShowObjects {
1811 object: show_object,
1812 ..
1813 } => Ok(infer_show_object(&show_object)),
1814 Statement::ShowCreateObject { .. } => Ok(infer_show_create_object()),
1815 Statement::ShowTransactionIsolationLevel => {
1816 let name = "transaction_isolation";
1817 Ok(infer_show_variable(name))
1818 }
1819 Statement::ShowVariable { variable } => {
1820 let name = &variable[0].real_value().to_lowercase();
1821 Ok(infer_show_variable(name))
1822 }
1823 Statement::Describe { name: _, kind } => Ok(infer_describe(&kind)),
1824 Statement::Explain { .. } => Ok(vec![PgFieldDescriptor::new(
1825 "QUERY PLAN".to_owned(),
1826 DataType::Varchar.to_oid(),
1827 DataType::Varchar.type_len(),
1828 )]),
1829 _ => Ok(vec![]),
1830 }
1831}
1832
1833pub struct WorkerProcessId {
1834 pub worker_id: u32,
1835 pub process_id: i32,
1836}
1837
1838impl WorkerProcessId {
1839 pub fn new(worker_id: u32, process_id: i32) -> Self {
1840 Self {
1841 worker_id,
1842 process_id,
1843 }
1844 }
1845}
1846
1847impl Display for WorkerProcessId {
1848 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1849 write!(f, "{}:{}", self.worker_id, self.process_id)
1850 }
1851}
1852
1853impl TryFrom<String> for WorkerProcessId {
1854 type Error = String;
1855
1856 fn try_from(worker_process_id: String) -> std::result::Result<Self, Self::Error> {
1857 const INVALID: &str = "invalid WorkerProcessId";
1858 let splits: Vec<&str> = worker_process_id.split(":").collect();
1859 if splits.len() != 2 {
1860 return Err(INVALID.to_owned());
1861 }
1862 let Ok(worker_id) = splits[0].parse::<u32>() else {
1863 return Err(INVALID.to_owned());
1864 };
1865 let Ok(process_id) = splits[1].parse::<i32>() else {
1866 return Err(INVALID.to_owned());
1867 };
1868 Ok(WorkerProcessId::new(worker_id, process_id))
1869 }
1870}
1871
1872pub fn cancel_queries_in_session(session_id: SessionId, sessions_map: SessionMapRef) -> bool {
1873 let guard = sessions_map.read();
1874 if let Some(session) = guard.get(&session_id) {
1875 session.cancel_current_query();
1876 true
1877 } else {
1878 info!("Current session finished, ignoring cancel query request");
1879 false
1880 }
1881}
1882
1883pub fn cancel_creating_jobs_in_session(session_id: SessionId, sessions_map: SessionMapRef) -> bool {
1884 let guard = sessions_map.read();
1885 if let Some(session) = guard.get(&session_id) {
1886 session.cancel_current_creating_job();
1887 true
1888 } else {
1889 info!("Current session finished, ignoring cancel creating request");
1890 false
1891 }
1892}