1use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::io::{Error, ErrorKind};
18use std::net::{IpAddr, Ipv4Addr, SocketAddr};
19use std::sync::atomic::{AtomicI32, Ordering};
20use std::sync::{Arc, Weak};
21use std::time::{Duration, Instant};
22
23use anyhow::anyhow;
24use bytes::Bytes;
25use either::Either;
26use itertools::Itertools;
27use parking_lot::{Mutex, RwLock, RwLockReadGuard};
28use pgwire::error::{PsqlError, PsqlResult};
29use pgwire::net::{Address, AddressRef};
30use pgwire::pg_field_descriptor::PgFieldDescriptor;
31use pgwire::pg_message::TransactionStatus;
32use pgwire::pg_response::{PgResponse, StatementType};
33use pgwire::pg_server::{
34 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
35 UserAuthenticator,
36};
37use pgwire::types::{Format, FormatIterator};
38use rand::RngCore;
39use risingwave_batch::monitor::{BatchSpillMetrics, GLOBAL_BATCH_SPILL_METRICS};
40use risingwave_batch::spill::spill_op::SpillOp;
41use risingwave_batch::task::{ShutdownSender, ShutdownToken};
42use risingwave_batch::worker_manager::worker_node_manager::{
43 WorkerNodeManager, WorkerNodeManagerRef,
44};
45use risingwave_common::acl::AclMode;
46#[cfg(test)]
47use risingwave_common::catalog::{
48 DEFAULT_DATABASE_NAME, DEFAULT_SUPER_USER, DEFAULT_SUPER_USER_ID,
49};
50use risingwave_common::config::{
51 BatchConfig, FrontendConfig, MetaConfig, MetricLevel, StreamingConfig, UdfConfig, load_config,
52};
53use risingwave_common::memory::MemoryContext;
54use risingwave_common::secret::LocalSecretManager;
55use risingwave_common::session_config::{ConfigReporter, SessionConfig, VisibilityMode};
56use risingwave_common::system_param::local_manager::{
57 LocalSystemParamsManager, LocalSystemParamsManagerRef,
58};
59use risingwave_common::telemetry::manager::TelemetryManager;
60use risingwave_common::telemetry::telemetry_env_enabled;
61use risingwave_common::types::DataType;
62use risingwave_common::util::addr::HostAddr;
63use risingwave_common::util::cluster_limit;
64use risingwave_common::util::cluster_limit::ActorCountPerParallelism;
65use risingwave_common::util::iter_util::ZipEqFast;
66use risingwave_common::util::pretty_bytes::convert;
67use risingwave_common::util::runtime::BackgroundShutdownRuntime;
68use risingwave_common::{GIT_SHA, RW_VERSION};
69use risingwave_common_heap_profiling::HeapProfiler;
70use risingwave_common_service::{MetricsManager, ObserverManager};
71use risingwave_connector::source::monitor::{GLOBAL_SOURCE_METRICS, SourceMetrics};
72use risingwave_pb::common::WorkerType;
73use risingwave_pb::common::worker_node::Property as AddWorkerNodeProperty;
74use risingwave_pb::frontend_service::frontend_service_server::FrontendServiceServer;
75use risingwave_pb::health::health_server::HealthServer;
76use risingwave_pb::user::auth_info::EncryptionType;
77use risingwave_pb::user::grant_privilege::Object;
78use risingwave_rpc_client::{
79 ComputeClientPool, ComputeClientPoolRef, FrontendClientPool, FrontendClientPoolRef, MetaClient,
80};
81use risingwave_sqlparser::ast::{ObjectName, Statement};
82use risingwave_sqlparser::parser::Parser;
83use thiserror::Error;
84use tokio::runtime::Builder;
85use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
86use tokio::sync::oneshot::Sender;
87use tokio::sync::watch;
88use tokio::task::JoinHandle;
89use tracing::info;
90use tracing::log::error;
91
92use self::cursor_manager::CursorManager;
93use crate::binder::{Binder, BoundStatement, ResolveQualifiedNameError};
94use crate::catalog::catalog_service::{CatalogReader, CatalogWriter, CatalogWriterImpl};
95use crate::catalog::connection_catalog::ConnectionCatalog;
96use crate::catalog::root_catalog::{Catalog, SchemaPath};
97use crate::catalog::secret_catalog::SecretCatalog;
98use crate::catalog::source_catalog::SourceCatalog;
99use crate::catalog::subscription_catalog::SubscriptionCatalog;
100use crate::catalog::{
101 CatalogError, DatabaseId, OwnedByUserCatalog, SchemaId, TableId, check_schema_writable,
102};
103use crate::error::{ErrorCode, Result, RwError};
104use crate::handler::describe::infer_describe;
105use crate::handler::extended_handle::{
106 Portal, PrepareStatement, handle_bind, handle_execute, handle_parse,
107};
108use crate::handler::privilege::ObjectCheckItem;
109use crate::handler::show::{infer_show_create_object, infer_show_object};
110use crate::handler::util::to_pg_field;
111use crate::handler::variable::infer_show_variable;
112use crate::handler::{RwPgResponse, handle};
113use crate::health_service::HealthServiceImpl;
114use crate::meta_client::{FrontendMetaClient, FrontendMetaClientImpl};
115use crate::monitor::{CursorMetrics, FrontendMetrics, GLOBAL_FRONTEND_METRICS};
116use crate::observer::FrontendObserverNode;
117use crate::rpc::FrontendServiceImpl;
118use crate::scheduler::streaming_manager::{StreamingJobTracker, StreamingJobTrackerRef};
119use crate::scheduler::{
120 DistributedQueryMetrics, GLOBAL_DISTRIBUTED_QUERY_METRICS, HummockSnapshotManager,
121 HummockSnapshotManagerRef, QueryManager,
122};
123use crate::telemetry::FrontendTelemetryCreator;
124use crate::user::UserId;
125use crate::user::user_authentication::md5_hash_with_salt;
126use crate::user::user_manager::UserInfoManager;
127use crate::user::user_service::{UserInfoReader, UserInfoWriter, UserInfoWriterImpl};
128use crate::{FrontendOpts, PgResponseStream, TableCatalog};
129
130pub(crate) mod current;
131pub(crate) mod cursor_manager;
132pub(crate) mod transaction;
133
134#[derive(Clone)]
136pub(crate) struct FrontendEnv {
137 meta_client: Arc<dyn FrontendMetaClient>,
140 catalog_writer: Arc<dyn CatalogWriter>,
141 catalog_reader: CatalogReader,
142 user_info_writer: Arc<dyn UserInfoWriter>,
143 user_info_reader: UserInfoReader,
144 worker_node_manager: WorkerNodeManagerRef,
145 query_manager: QueryManager,
146 hummock_snapshot_manager: HummockSnapshotManagerRef,
147 system_params_manager: LocalSystemParamsManagerRef,
148 session_params: Arc<RwLock<SessionConfig>>,
149
150 server_addr: HostAddr,
151 client_pool: ComputeClientPoolRef,
152 frontend_client_pool: FrontendClientPoolRef,
153
154 sessions_map: SessionMapRef,
158
159 pub frontend_metrics: Arc<FrontendMetrics>,
160
161 pub cursor_metrics: Arc<CursorMetrics>,
162
163 source_metrics: Arc<SourceMetrics>,
164
165 spill_metrics: Arc<BatchSpillMetrics>,
167
168 batch_config: BatchConfig,
169 frontend_config: FrontendConfig,
170 #[expect(dead_code)]
171 meta_config: MetaConfig,
172 streaming_config: StreamingConfig,
173 udf_config: UdfConfig,
174
175 creating_streaming_job_tracker: StreamingJobTrackerRef,
178
179 compute_runtime: Arc<BackgroundShutdownRuntime>,
182
183 mem_context: MemoryContext,
185
186 serverless_backfill_controller_addr: String,
188}
189
190pub type SessionMapRef = Arc<RwLock<HashMap<(i32, i32), Arc<SessionImpl>>>>;
192
193const FRONTEND_BATCH_MEMORY_PROPORTION: f64 = 0.5;
195
196impl FrontendEnv {
197 pub fn mock() -> Self {
198 use crate::test_utils::{MockCatalogWriter, MockFrontendMetaClient, MockUserInfoWriter};
199
200 let catalog = Arc::new(RwLock::new(Catalog::default()));
201 let meta_client = Arc::new(MockFrontendMetaClient {});
202 let hummock_snapshot_manager = Arc::new(HummockSnapshotManager::new(meta_client.clone()));
203 let catalog_writer = Arc::new(MockCatalogWriter::new(
204 catalog.clone(),
205 hummock_snapshot_manager.clone(),
206 ));
207 let catalog_reader = CatalogReader::new(catalog);
208 let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
209 let user_info_writer = Arc::new(MockUserInfoWriter::new(user_info_manager.clone()));
210 let user_info_reader = UserInfoReader::new(user_info_manager);
211 let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
212 let system_params_manager = Arc::new(LocalSystemParamsManager::for_test());
213 let compute_client_pool = Arc::new(ComputeClientPool::for_test());
214 let frontend_client_pool = Arc::new(FrontendClientPool::for_test());
215 let query_manager = QueryManager::new(
216 worker_node_manager.clone(),
217 compute_client_pool,
218 catalog_reader.clone(),
219 Arc::new(DistributedQueryMetrics::for_test()),
220 None,
221 None,
222 );
223 let server_addr = HostAddr::try_from("127.0.0.1:4565").unwrap();
224 let client_pool = Arc::new(ComputeClientPool::for_test());
225 let creating_streaming_tracker = StreamingJobTracker::new(meta_client.clone());
226 let runtime = {
227 let mut builder = Builder::new_multi_thread();
228 if let Some(frontend_compute_runtime_worker_threads) =
229 load_config("", FrontendOpts::default())
230 .batch
231 .frontend_compute_runtime_worker_threads
232 {
233 builder.worker_threads(frontend_compute_runtime_worker_threads);
234 }
235 builder
236 .thread_name("rw-batch-local")
237 .enable_all()
238 .build()
239 .unwrap()
240 };
241 let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(runtime));
242 let sessions_map = Arc::new(RwLock::new(HashMap::new()));
243 Self {
244 meta_client,
245 catalog_writer,
246 catalog_reader,
247 user_info_writer,
248 user_info_reader,
249 worker_node_manager,
250 query_manager,
251 hummock_snapshot_manager,
252 system_params_manager,
253 session_params: Default::default(),
254 server_addr,
255 client_pool,
256 frontend_client_pool,
257 sessions_map: sessions_map.clone(),
258 frontend_metrics: Arc::new(FrontendMetrics::for_test()),
259 cursor_metrics: Arc::new(CursorMetrics::for_test()),
260 batch_config: BatchConfig::default(),
261 frontend_config: FrontendConfig::default(),
262 meta_config: MetaConfig::default(),
263 streaming_config: StreamingConfig::default(),
264 udf_config: UdfConfig::default(),
265 source_metrics: Arc::new(SourceMetrics::default()),
266 spill_metrics: BatchSpillMetrics::for_test(),
267 creating_streaming_job_tracker: Arc::new(creating_streaming_tracker),
268 compute_runtime,
269 mem_context: MemoryContext::none(),
270 serverless_backfill_controller_addr: Default::default(),
271 }
272 }
273
274 pub async fn init(opts: FrontendOpts) -> Result<(Self, Vec<JoinHandle<()>>, Vec<Sender<()>>)> {
275 let config = load_config(&opts.config_path, &opts);
276 info!("Starting frontend node");
277 info!("> config: {:?}", config);
278 info!(
279 "> debug assertions: {}",
280 if cfg!(debug_assertions) { "on" } else { "off" }
281 );
282 info!("> version: {} ({})", RW_VERSION, GIT_SHA);
283
284 let frontend_address: HostAddr = opts
285 .advertise_addr
286 .as_ref()
287 .unwrap_or_else(|| {
288 tracing::warn!("advertise addr is not specified, defaulting to listen_addr");
289 &opts.listen_addr
290 })
291 .parse()
292 .unwrap();
293 info!("advertise addr is {}", frontend_address);
294
295 let rpc_addr: HostAddr = opts.frontend_rpc_listener_addr.parse().unwrap();
296 let internal_rpc_host_addr = HostAddr {
297 host: frontend_address.host.clone(),
299 port: rpc_addr.port,
300 };
301 let (meta_client, system_params_reader) = MetaClient::register_new(
303 opts.meta_addr,
304 WorkerType::Frontend,
305 &frontend_address,
306 AddWorkerNodeProperty {
307 internal_rpc_host_addr: internal_rpc_host_addr.to_string(),
308 ..Default::default()
309 },
310 &config.meta,
311 )
312 .await;
313
314 let worker_id = meta_client.worker_id();
315 info!("Assigned worker node id {}", worker_id);
316
317 let (heartbeat_join_handle, heartbeat_shutdown_sender) = MetaClient::start_heartbeat_loop(
318 meta_client.clone(),
319 Duration::from_millis(config.server.heartbeat_interval_ms as u64),
320 );
321 let mut join_handles = vec![heartbeat_join_handle];
322 let mut shutdown_senders = vec![heartbeat_shutdown_sender];
323
324 let frontend_meta_client = Arc::new(FrontendMetaClientImpl(meta_client.clone()));
325 let hummock_snapshot_manager =
326 Arc::new(HummockSnapshotManager::new(frontend_meta_client.clone()));
327
328 let (catalog_updated_tx, catalog_updated_rx) = watch::channel(0);
329 let catalog = Arc::new(RwLock::new(Catalog::default()));
330 let catalog_writer = Arc::new(CatalogWriterImpl::new(
331 meta_client.clone(),
332 catalog_updated_rx.clone(),
333 hummock_snapshot_manager.clone(),
334 ));
335 let catalog_reader = CatalogReader::new(catalog.clone());
336
337 let worker_node_manager = Arc::new(WorkerNodeManager::new());
338
339 let compute_client_pool = Arc::new(ComputeClientPool::new(
340 config.batch_exchange_connection_pool_size(),
341 config.batch.developer.compute_client_config.clone(),
342 ));
343 let frontend_client_pool = Arc::new(FrontendClientPool::new(
344 1,
345 config.batch.developer.frontend_client_config.clone(),
346 ));
347 let query_manager = QueryManager::new(
348 worker_node_manager.clone(),
349 compute_client_pool.clone(),
350 catalog_reader.clone(),
351 Arc::new(GLOBAL_DISTRIBUTED_QUERY_METRICS.clone()),
352 config.batch.distributed_query_limit,
353 config.batch.max_batch_queries_per_frontend_node,
354 );
355
356 let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
357 let user_info_reader = UserInfoReader::new(user_info_manager.clone());
358 let user_info_writer = Arc::new(UserInfoWriterImpl::new(
359 meta_client.clone(),
360 catalog_updated_rx,
361 ));
362
363 let system_params_manager =
364 Arc::new(LocalSystemParamsManager::new(system_params_reader.clone()));
365
366 LocalSecretManager::init(
367 opts.temp_secret_file_dir,
368 meta_client.cluster_id().to_owned(),
369 worker_id,
370 );
371
372 let session_params = Arc::new(RwLock::new(SessionConfig::default()));
374 let sessions_map: SessionMapRef = Arc::new(RwLock::new(HashMap::new()));
375 let cursor_metrics = Arc::new(CursorMetrics::init(sessions_map.clone()));
376
377 let frontend_observer_node = FrontendObserverNode::new(
378 worker_node_manager.clone(),
379 catalog,
380 catalog_updated_tx,
381 user_info_manager,
382 hummock_snapshot_manager.clone(),
383 system_params_manager.clone(),
384 session_params.clone(),
385 compute_client_pool.clone(),
386 );
387 let observer_manager =
388 ObserverManager::new_with_meta_client(meta_client.clone(), frontend_observer_node)
389 .await;
390 let observer_join_handle = observer_manager.start().await;
391 join_handles.push(observer_join_handle);
392
393 meta_client.activate(&frontend_address).await?;
394
395 let frontend_metrics = Arc::new(GLOBAL_FRONTEND_METRICS.clone());
396 let source_metrics = Arc::new(GLOBAL_SOURCE_METRICS.clone());
397 let spill_metrics = Arc::new(GLOBAL_BATCH_SPILL_METRICS.clone());
398
399 if config.server.metrics_level > MetricLevel::Disabled {
400 MetricsManager::boot_metrics_service(opts.prometheus_listener_addr.clone());
401 }
402
403 let health_srv = HealthServiceImpl::new();
404 let frontend_srv = FrontendServiceImpl::new(sessions_map.clone());
405 let frontend_rpc_addr = opts.frontend_rpc_listener_addr.parse().unwrap();
406
407 let telemetry_manager = TelemetryManager::new(
408 Arc::new(meta_client.clone()),
409 Arc::new(FrontendTelemetryCreator::new()),
410 );
411
412 if config.server.telemetry_enabled && telemetry_env_enabled() {
415 let (join_handle, shutdown_sender) = telemetry_manager.start().await;
416 join_handles.push(join_handle);
417 shutdown_senders.push(shutdown_sender);
418 } else {
419 tracing::info!("Telemetry didn't start due to config");
420 }
421
422 tokio::spawn(async move {
423 tonic::transport::Server::builder()
424 .add_service(HealthServer::new(health_srv))
425 .add_service(FrontendServiceServer::new(frontend_srv))
426 .serve(frontend_rpc_addr)
427 .await
428 .unwrap();
429 });
430 info!(
431 "Health Check RPC Listener is set up on {}",
432 opts.frontend_rpc_listener_addr.clone()
433 );
434
435 let creating_streaming_job_tracker =
436 Arc::new(StreamingJobTracker::new(frontend_meta_client.clone()));
437
438 let runtime = {
439 let mut builder = Builder::new_multi_thread();
440 if let Some(frontend_compute_runtime_worker_threads) =
441 config.batch.frontend_compute_runtime_worker_threads
442 {
443 builder.worker_threads(frontend_compute_runtime_worker_threads);
444 }
445 builder
446 .thread_name("rw-batch-local")
447 .enable_all()
448 .build()
449 .unwrap()
450 };
451 let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(runtime));
452
453 let sessions = sessions_map.clone();
454 let join_handle = tokio::spawn(async move {
456 let mut check_idle_txn_interval =
457 tokio::time::interval(core::time::Duration::from_secs(10));
458 check_idle_txn_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
459 check_idle_txn_interval.reset();
460 loop {
461 check_idle_txn_interval.tick().await;
462 sessions.read().values().for_each(|session| {
463 let _ = session.check_idle_in_transaction_timeout();
464 })
465 }
466 });
467 join_handles.push(join_handle);
468
469 #[cfg(not(madsim))]
471 if config.batch.enable_spill {
472 SpillOp::clean_spill_directory()
473 .await
474 .map_err(|err| anyhow!(err))?;
475 }
476
477 let total_memory_bytes = opts.frontend_total_memory_bytes;
478 let heap_profiler =
479 HeapProfiler::new(total_memory_bytes, config.server.heap_profiling.clone());
480 heap_profiler.start();
482
483 let batch_memory_limit = total_memory_bytes as f64 * FRONTEND_BATCH_MEMORY_PROPORTION;
484 let mem_context = MemoryContext::root(
485 frontend_metrics.batch_total_mem.clone(),
486 batch_memory_limit as u64,
487 );
488
489 info!(
490 "Frontend total_memory: {} batch_memory: {}",
491 convert(total_memory_bytes as _),
492 convert(batch_memory_limit as _),
493 );
494
495 Ok((
496 Self {
497 catalog_reader,
498 catalog_writer,
499 user_info_reader,
500 user_info_writer,
501 worker_node_manager,
502 meta_client: frontend_meta_client,
503 query_manager,
504 hummock_snapshot_manager,
505 system_params_manager,
506 session_params,
507 server_addr: frontend_address,
508 client_pool: compute_client_pool,
509 frontend_client_pool,
510 frontend_metrics,
511 cursor_metrics,
512 spill_metrics,
513 sessions_map,
514 batch_config: config.batch,
515 frontend_config: config.frontend,
516 meta_config: config.meta,
517 streaming_config: config.streaming,
518 serverless_backfill_controller_addr: opts.serverless_backfill_controller_addr,
519 udf_config: config.udf,
520 source_metrics,
521 creating_streaming_job_tracker,
522 compute_runtime,
523 mem_context,
524 },
525 join_handles,
526 shutdown_senders,
527 ))
528 }
529
530 fn catalog_writer(&self, _guard: transaction::WriteGuard) -> &dyn CatalogWriter {
535 &*self.catalog_writer
536 }
537
538 pub fn catalog_reader(&self) -> &CatalogReader {
540 &self.catalog_reader
541 }
542
543 fn user_info_writer(&self, _guard: transaction::WriteGuard) -> &dyn UserInfoWriter {
548 &*self.user_info_writer
549 }
550
551 pub fn user_info_reader(&self) -> &UserInfoReader {
553 &self.user_info_reader
554 }
555
556 pub fn worker_node_manager_ref(&self) -> WorkerNodeManagerRef {
557 self.worker_node_manager.clone()
558 }
559
560 pub fn meta_client(&self) -> &dyn FrontendMetaClient {
561 &*self.meta_client
562 }
563
564 pub fn meta_client_ref(&self) -> Arc<dyn FrontendMetaClient> {
565 self.meta_client.clone()
566 }
567
568 pub fn query_manager(&self) -> &QueryManager {
569 &self.query_manager
570 }
571
572 pub fn hummock_snapshot_manager(&self) -> &HummockSnapshotManagerRef {
573 &self.hummock_snapshot_manager
574 }
575
576 pub fn system_params_manager(&self) -> &LocalSystemParamsManagerRef {
577 &self.system_params_manager
578 }
579
580 pub fn session_params_snapshot(&self) -> SessionConfig {
581 self.session_params.read_recursive().clone()
582 }
583
584 pub fn sbc_address(&self) -> &String {
585 &self.serverless_backfill_controller_addr
586 }
587
588 pub fn server_address(&self) -> &HostAddr {
589 &self.server_addr
590 }
591
592 pub fn client_pool(&self) -> ComputeClientPoolRef {
593 self.client_pool.clone()
594 }
595
596 pub fn frontend_client_pool(&self) -> FrontendClientPoolRef {
597 self.frontend_client_pool.clone()
598 }
599
600 pub fn batch_config(&self) -> &BatchConfig {
601 &self.batch_config
602 }
603
604 pub fn frontend_config(&self) -> &FrontendConfig {
605 &self.frontend_config
606 }
607
608 pub fn streaming_config(&self) -> &StreamingConfig {
609 &self.streaming_config
610 }
611
612 pub fn udf_config(&self) -> &UdfConfig {
613 &self.udf_config
614 }
615
616 pub fn source_metrics(&self) -> Arc<SourceMetrics> {
617 self.source_metrics.clone()
618 }
619
620 pub fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
621 self.spill_metrics.clone()
622 }
623
624 pub fn creating_streaming_job_tracker(&self) -> &StreamingJobTrackerRef {
625 &self.creating_streaming_job_tracker
626 }
627
628 pub fn sessions_map(&self) -> &SessionMapRef {
629 &self.sessions_map
630 }
631
632 pub fn compute_runtime(&self) -> Arc<BackgroundShutdownRuntime> {
633 self.compute_runtime.clone()
634 }
635
636 pub fn cancel_queries_in_session(&self, session_id: SessionId) -> bool {
639 cancel_queries_in_session(session_id, self.sessions_map.clone())
640 }
641
642 pub fn cancel_creating_jobs_in_session(&self, session_id: SessionId) -> bool {
645 cancel_creating_jobs_in_session(session_id, self.sessions_map.clone())
646 }
647
648 pub fn mem_context(&self) -> MemoryContext {
649 self.mem_context.clone()
650 }
651}
652
653#[derive(Clone)]
654pub struct AuthContext {
655 pub database: String,
656 pub user_name: String,
657 pub user_id: UserId,
658}
659
660impl AuthContext {
661 pub fn new(database: String, user_name: String, user_id: UserId) -> Self {
662 Self {
663 database,
664 user_name,
665 user_id,
666 }
667 }
668}
669pub struct SessionImpl {
670 env: FrontendEnv,
671 auth_context: Arc<RwLock<AuthContext>>,
672 user_authenticator: UserAuthenticator,
674 config_map: Arc<RwLock<SessionConfig>>,
676
677 notice_tx: UnboundedSender<String>,
679 notice_rx: Mutex<UnboundedReceiver<String>>,
681
682 id: (i32, i32),
684
685 peer_addr: AddressRef,
687
688 txn: Arc<Mutex<transaction::State>>,
692
693 current_query_cancel_flag: Mutex<Option<ShutdownSender>>,
697
698 exec_context: Mutex<Option<Weak<ExecContext>>>,
700
701 last_idle_instant: Arc<Mutex<Option<Instant>>>,
703
704 cursor_manager: Arc<CursorManager>,
705
706 temporary_source_manager: Arc<Mutex<TemporarySourceManager>>,
708}
709
710#[derive(Default, Clone)]
719pub struct TemporarySourceManager {
720 sources: HashMap<String, SourceCatalog>,
721}
722
723impl TemporarySourceManager {
724 pub fn new() -> Self {
725 Self {
726 sources: HashMap::new(),
727 }
728 }
729
730 pub fn create_source(&mut self, name: String, source: SourceCatalog) {
731 self.sources.insert(name, source);
732 }
733
734 pub fn drop_source(&mut self, name: &str) {
735 self.sources.remove(name);
736 }
737
738 pub fn get_source(&self, name: &str) -> Option<&SourceCatalog> {
739 self.sources.get(name)
740 }
741
742 pub fn keys(&self) -> Vec<String> {
743 self.sources.keys().cloned().collect()
744 }
745}
746
747#[derive(Error, Debug)]
748pub enum CheckRelationError {
749 #[error("{0}")]
750 Resolve(#[from] ResolveQualifiedNameError),
751 #[error("{0}")]
752 Catalog(#[from] CatalogError),
753}
754
755impl From<CheckRelationError> for RwError {
756 fn from(e: CheckRelationError) -> Self {
757 match e {
758 CheckRelationError::Resolve(e) => e.into(),
759 CheckRelationError::Catalog(e) => e.into(),
760 }
761 }
762}
763
764impl SessionImpl {
765 pub(crate) fn new(
766 env: FrontendEnv,
767 auth_context: AuthContext,
768 user_authenticator: UserAuthenticator,
769 id: SessionId,
770 peer_addr: AddressRef,
771 session_config: SessionConfig,
772 ) -> Self {
773 let cursor_metrics = env.cursor_metrics.clone();
774 let (notice_tx, notice_rx) = mpsc::unbounded_channel();
775
776 Self {
777 env,
778 auth_context: Arc::new(RwLock::new(auth_context)),
779 user_authenticator,
780 config_map: Arc::new(RwLock::new(session_config)),
781 id,
782 peer_addr,
783 txn: Default::default(),
784 current_query_cancel_flag: Mutex::new(None),
785 notice_tx,
786 notice_rx: Mutex::new(notice_rx),
787 exec_context: Mutex::new(None),
788 last_idle_instant: Default::default(),
789 cursor_manager: Arc::new(CursorManager::new(cursor_metrics)),
790 temporary_source_manager: Default::default(),
791 }
792 }
793
794 #[cfg(test)]
795 pub fn mock() -> Self {
796 let env = FrontendEnv::mock();
797 let (notice_tx, notice_rx) = mpsc::unbounded_channel();
798
799 Self {
800 env: FrontendEnv::mock(),
801 auth_context: Arc::new(RwLock::new(AuthContext::new(
802 DEFAULT_DATABASE_NAME.to_owned(),
803 DEFAULT_SUPER_USER.to_owned(),
804 DEFAULT_SUPER_USER_ID,
805 ))),
806 user_authenticator: UserAuthenticator::None,
807 config_map: Default::default(),
808 id: (0, 0),
810 txn: Default::default(),
811 current_query_cancel_flag: Mutex::new(None),
812 notice_tx,
813 notice_rx: Mutex::new(notice_rx),
814 exec_context: Mutex::new(None),
815 peer_addr: Address::Tcp(SocketAddr::new(
816 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
817 8080,
818 ))
819 .into(),
820 last_idle_instant: Default::default(),
821 cursor_manager: Arc::new(CursorManager::new(env.cursor_metrics.clone())),
822 temporary_source_manager: Default::default(),
823 }
824 }
825
826 pub(crate) fn env(&self) -> &FrontendEnv {
827 &self.env
828 }
829
830 pub fn auth_context(&self) -> Arc<AuthContext> {
831 let ctx = self.auth_context.read();
832 Arc::new(ctx.clone())
833 }
834
835 pub fn database(&self) -> String {
836 self.auth_context.read().database.clone()
837 }
838
839 pub fn database_id(&self) -> DatabaseId {
840 let db_name = self.database();
841 self.env
842 .catalog_reader()
843 .read_guard()
844 .get_database_by_name(&db_name)
845 .map(|db| db.id())
846 .expect("session database not found")
847 }
848
849 pub fn user_name(&self) -> String {
850 self.auth_context.read().user_name.clone()
851 }
852
853 pub fn user_id(&self) -> UserId {
854 self.auth_context.read().user_id
855 }
856
857 pub fn update_database(&self, database: String) {
858 self.auth_context.write().database = database;
859 }
860
861 pub fn shared_config(&self) -> Arc<RwLock<SessionConfig>> {
862 Arc::clone(&self.config_map)
863 }
864
865 pub fn config(&self) -> RwLockReadGuard<'_, SessionConfig> {
866 self.config_map.read()
867 }
868
869 pub fn set_config(&self, key: &str, value: String) -> Result<String> {
870 self.config_map
871 .write()
872 .set(key, value, &mut ())
873 .map_err(Into::into)
874 }
875
876 pub fn reset_config(&self, key: &str) -> Result<String> {
877 self.config_map
878 .write()
879 .reset(key, &mut ())
880 .map_err(Into::into)
881 }
882
883 pub fn set_config_report(
884 &self,
885 key: &str,
886 value: Option<String>,
887 mut reporter: impl ConfigReporter,
888 ) -> Result<String> {
889 if let Some(value) = value {
890 self.config_map
891 .write()
892 .set(key, value, &mut reporter)
893 .map_err(Into::into)
894 } else {
895 self.config_map
896 .write()
897 .reset(key, &mut reporter)
898 .map_err(Into::into)
899 }
900 }
901
902 pub fn session_id(&self) -> SessionId {
903 self.id
904 }
905
906 pub fn running_sql(&self) -> Option<Arc<str>> {
907 self.exec_context
908 .lock()
909 .as_ref()
910 .and_then(|weak| weak.upgrade())
911 .map(|context| context.running_sql.clone())
912 }
913
914 pub fn get_cursor_manager(&self) -> Arc<CursorManager> {
915 self.cursor_manager.clone()
916 }
917
918 pub fn peer_addr(&self) -> &Address {
919 &self.peer_addr
920 }
921
922 pub fn elapse_since_running_sql(&self) -> Option<u128> {
923 self.exec_context
924 .lock()
925 .as_ref()
926 .and_then(|weak| weak.upgrade())
927 .map(|context| context.last_instant.elapsed().as_millis())
928 }
929
930 pub fn elapse_since_last_idle_instant(&self) -> Option<u128> {
931 self.last_idle_instant
932 .lock()
933 .as_ref()
934 .map(|x| x.elapsed().as_millis())
935 }
936
937 pub fn check_relation_name_duplicated(
938 &self,
939 name: ObjectName,
940 stmt_type: StatementType,
941 if_not_exists: bool,
942 ) -> std::result::Result<Either<(), RwPgResponse>, CheckRelationError> {
943 let db_name = &self.database();
944 let catalog_reader = self.env().catalog_reader().read_guard();
945 let (schema_name, relation_name) = {
946 let (schema_name, relation_name) =
947 Binder::resolve_schema_qualified_name(db_name, name)?;
948 let search_path = self.config().search_path();
949 let user_name = &self.user_name();
950 let schema_name = match schema_name {
951 Some(schema_name) => schema_name,
952 None => catalog_reader
953 .first_valid_schema(db_name, &search_path, user_name)?
954 .name(),
955 };
956 (schema_name, relation_name)
957 };
958 match catalog_reader.check_relation_name_duplicated(db_name, &schema_name, &relation_name) {
959 Err(CatalogError::Duplicated(_, name, is_creating)) if if_not_exists => {
960 if !is_creating {
965 Ok(Either::Right(
966 PgResponse::builder(stmt_type)
967 .notice(format!("relation \"{}\" already exists, skipping", name))
968 .into(),
969 ))
970 } else if stmt_type == StatementType::CREATE_SUBSCRIPTION {
971 Ok(Either::Right(
974 PgResponse::builder(stmt_type)
975 .notice(format!(
976 "relation \"{}\" already exists but still creating, skipping",
977 name
978 ))
979 .into(),
980 ))
981 } else {
982 Ok(Either::Left(()))
983 }
984 }
985 Err(e) => Err(e.into()),
986 Ok(_) => Ok(Either::Left(())),
987 }
988 }
989
990 pub fn check_secret_name_duplicated(&self, name: ObjectName) -> Result<()> {
991 let db_name = &self.database();
992 let catalog_reader = self.env().catalog_reader().read_guard();
993 let (schema_name, secret_name) = {
994 let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
995 let search_path = self.config().search_path();
996 let user_name = &self.user_name();
997 let schema_name = match schema_name {
998 Some(schema_name) => schema_name,
999 None => catalog_reader
1000 .first_valid_schema(db_name, &search_path, user_name)?
1001 .name(),
1002 };
1003 (schema_name, secret_name)
1004 };
1005 catalog_reader
1006 .check_secret_name_duplicated(db_name, &schema_name, &secret_name)
1007 .map_err(RwError::from)
1008 }
1009
1010 pub fn check_connection_name_duplicated(&self, name: ObjectName) -> Result<()> {
1011 let db_name = &self.database();
1012 let catalog_reader = self.env().catalog_reader().read_guard();
1013 let (schema_name, connection_name) = {
1014 let (schema_name, connection_name) =
1015 Binder::resolve_schema_qualified_name(db_name, name)?;
1016 let search_path = self.config().search_path();
1017 let user_name = &self.user_name();
1018 let schema_name = match schema_name {
1019 Some(schema_name) => schema_name,
1020 None => catalog_reader
1021 .first_valid_schema(db_name, &search_path, user_name)?
1022 .name(),
1023 };
1024 (schema_name, connection_name)
1025 };
1026 catalog_reader
1027 .check_connection_name_duplicated(db_name, &schema_name, &connection_name)
1028 .map_err(RwError::from)
1029 }
1030
1031 pub fn check_function_name_duplicated(
1032 &self,
1033 stmt_type: StatementType,
1034 name: ObjectName,
1035 arg_types: &[DataType],
1036 if_not_exists: bool,
1037 ) -> Result<Either<(), RwPgResponse>> {
1038 let db_name = &self.database();
1039 let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
1040 let (database_id, schema_id) = self.get_database_and_schema_id_for_create(schema_name)?;
1041
1042 let catalog_reader = self.env().catalog_reader().read_guard();
1043 if catalog_reader
1044 .get_schema_by_id(&database_id, &schema_id)?
1045 .get_function_by_name_args(&function_name, arg_types)
1046 .is_some()
1047 {
1048 let full_name = format!(
1049 "{function_name}({})",
1050 arg_types.iter().map(|t| t.to_string()).join(",")
1051 );
1052 if if_not_exists {
1053 Ok(Either::Right(
1054 PgResponse::builder(stmt_type)
1055 .notice(format!(
1056 "function \"{}\" already exists, skipping",
1057 full_name
1058 ))
1059 .into(),
1060 ))
1061 } else {
1062 Err(CatalogError::duplicated("function", full_name).into())
1063 }
1064 } else {
1065 Ok(Either::Left(()))
1066 }
1067 }
1068
1069 pub fn get_database_and_schema_id_for_create(
1071 &self,
1072 schema_name: Option<String>,
1073 ) -> Result<(DatabaseId, SchemaId)> {
1074 let db_name = &self.database();
1075
1076 let search_path = self.config().search_path();
1077 let user_name = &self.user_name();
1078
1079 let catalog_reader = self.env().catalog_reader().read_guard();
1080 let schema = match schema_name {
1081 Some(schema_name) => catalog_reader.get_schema_by_name(db_name, &schema_name)?,
1082 None => catalog_reader.first_valid_schema(db_name, &search_path, user_name)?,
1083 };
1084 let schema_name = schema.name();
1085
1086 check_schema_writable(&schema_name)?;
1087 self.check_privileges(&[ObjectCheckItem::new(
1088 schema.owner(),
1089 AclMode::Create,
1090 schema_name,
1091 Object::SchemaId(schema.id()),
1092 )])?;
1093
1094 let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1095 Ok((db_id, schema.id()))
1096 }
1097
1098 pub fn get_connection_by_name(
1099 &self,
1100 schema_name: Option<String>,
1101 connection_name: &str,
1102 ) -> Result<Arc<ConnectionCatalog>> {
1103 let db_name = &self.database();
1104 let search_path = self.config().search_path();
1105 let user_name = &self.user_name();
1106
1107 let catalog_reader = self.env().catalog_reader().read_guard();
1108 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1109 let (connection, _) =
1110 catalog_reader.get_connection_by_name(db_name, schema_path, connection_name)?;
1111
1112 self.check_privileges(&[ObjectCheckItem::new(
1113 connection.owner(),
1114 AclMode::Usage,
1115 connection.name.clone(),
1116 Object::ConnectionId(connection.id),
1117 )])?;
1118
1119 Ok(connection.clone())
1120 }
1121
1122 pub fn get_subscription_by_schema_id_name(
1123 &self,
1124 schema_id: SchemaId,
1125 subscription_name: &str,
1126 ) -> Result<Arc<SubscriptionCatalog>> {
1127 let db_name = &self.database();
1128
1129 let catalog_reader = self.env().catalog_reader().read_guard();
1130 let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1131 let schema = catalog_reader.get_schema_by_id(&db_id, &schema_id)?;
1132 let subscription = schema
1133 .get_subscription_by_name(subscription_name)
1134 .ok_or_else(|| {
1135 RwError::from(ErrorCode::ItemNotFound(format!(
1136 "subscription {} not found",
1137 subscription_name
1138 )))
1139 })?;
1140 Ok(subscription.clone())
1141 }
1142
1143 pub fn get_subscription_by_name(
1144 &self,
1145 schema_name: Option<String>,
1146 subscription_name: &str,
1147 ) -> Result<Arc<SubscriptionCatalog>> {
1148 let db_name = &self.database();
1149 let search_path = self.config().search_path();
1150 let user_name = &self.user_name();
1151
1152 let catalog_reader = self.env().catalog_reader().read_guard();
1153 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1154 let (subscription, _) =
1155 catalog_reader.get_subscription_by_name(db_name, schema_path, subscription_name)?;
1156 Ok(subscription.clone())
1157 }
1158
1159 pub fn get_table_by_id(&self, table_id: &TableId) -> Result<Arc<TableCatalog>> {
1160 let catalog_reader = self.env().catalog_reader().read_guard();
1161 Ok(catalog_reader.get_any_table_by_id(table_id)?.clone())
1162 }
1163
1164 pub fn get_table_by_name(
1165 &self,
1166 table_name: &str,
1167 db_id: u32,
1168 schema_id: u32,
1169 ) -> Result<Arc<TableCatalog>> {
1170 let catalog_reader = self.env().catalog_reader().read_guard();
1171 let table = catalog_reader
1172 .get_schema_by_id(&DatabaseId::from(db_id), &SchemaId::from(schema_id))?
1173 .get_created_table_by_name(table_name)
1174 .ok_or_else(|| {
1175 Error::new(
1176 ErrorKind::InvalidInput,
1177 format!("table \"{}\" does not exist", table_name),
1178 )
1179 })?;
1180
1181 self.check_privileges(&[ObjectCheckItem::new(
1182 table.owner(),
1183 AclMode::Select,
1184 table_name.to_owned(),
1185 Object::TableId(table.id.table_id()),
1186 )])?;
1187
1188 Ok(table.clone())
1189 }
1190
1191 pub fn get_secret_by_name(
1192 &self,
1193 schema_name: Option<String>,
1194 secret_name: &str,
1195 ) -> Result<Arc<SecretCatalog>> {
1196 let db_name = &self.database();
1197 let search_path = self.config().search_path();
1198 let user_name = &self.user_name();
1199
1200 let catalog_reader = self.env().catalog_reader().read_guard();
1201 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1202 let (secret, _) = catalog_reader.get_secret_by_name(db_name, schema_path, secret_name)?;
1203
1204 self.check_privileges(&[ObjectCheckItem::new(
1205 secret.owner(),
1206 AclMode::Create,
1207 secret.name.clone(),
1208 Object::SecretId(secret.id.secret_id()),
1209 )])?;
1210
1211 Ok(secret.clone())
1212 }
1213
1214 pub fn list_change_log_epochs(
1215 &self,
1216 table_id: u32,
1217 min_epoch: u64,
1218 max_count: u32,
1219 ) -> Result<Vec<u64>> {
1220 Ok(self
1221 .env
1222 .hummock_snapshot_manager()
1223 .acquire()
1224 .list_change_log_epochs(table_id, min_epoch, max_count))
1225 }
1226
1227 pub fn clear_cancel_query_flag(&self) {
1228 let mut flag = self.current_query_cancel_flag.lock();
1229 *flag = None;
1230 }
1231
1232 pub fn reset_cancel_query_flag(&self) -> ShutdownToken {
1233 let mut flag = self.current_query_cancel_flag.lock();
1234 let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
1235 *flag = Some(shutdown_tx);
1236 shutdown_rx
1237 }
1238
1239 pub fn cancel_current_query(&self) {
1240 let mut flag_guard = self.current_query_cancel_flag.lock();
1241 if let Some(sender) = flag_guard.take() {
1242 info!("Trying to cancel query in local mode.");
1243 sender.cancel();
1245 info!("Cancel query request sent.");
1246 } else {
1247 info!("Trying to cancel query in distributed mode.");
1248 self.env.query_manager().cancel_queries_in_session(self.id)
1249 }
1250 }
1251
1252 pub fn cancel_current_creating_job(&self) {
1253 self.env.creating_streaming_job_tracker.abort_jobs(self.id);
1254 }
1255
1256 pub async fn run_statement(
1259 self: Arc<Self>,
1260 sql: Arc<str>,
1261 formats: Vec<Format>,
1262 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1263 let mut stmts = Parser::parse_sql(&sql)?;
1265 if stmts.is_empty() {
1266 return Ok(PgResponse::empty_result(
1267 pgwire::pg_response::StatementType::EMPTY,
1268 ));
1269 }
1270 if stmts.len() > 1 {
1271 return Ok(
1272 PgResponse::builder(pgwire::pg_response::StatementType::EMPTY)
1273 .notice("cannot insert multiple commands into statement")
1274 .into(),
1275 );
1276 }
1277 let stmt = stmts.swap_remove(0);
1278 let rsp = handle(self, stmt, sql.clone(), formats).await?;
1279 Ok(rsp)
1280 }
1281
1282 pub fn notice_to_user(&self, str: impl Into<String>) {
1283 let notice = str.into();
1284 tracing::trace!(notice, "notice to user");
1285 self.notice_tx
1286 .send(notice)
1287 .expect("notice channel should not be closed");
1288 }
1289
1290 pub fn is_barrier_read(&self) -> bool {
1291 match self.config().visibility_mode() {
1292 VisibilityMode::Default => self.env.batch_config.enable_barrier_read,
1293 VisibilityMode::All => true,
1294 VisibilityMode::Checkpoint => false,
1295 }
1296 }
1297
1298 pub fn statement_timeout(&self) -> Duration {
1299 if self.config().statement_timeout() == 0 {
1300 Duration::from_secs(self.env.batch_config.statement_timeout_in_sec as u64)
1301 } else {
1302 Duration::from_secs(self.config().statement_timeout() as u64)
1303 }
1304 }
1305
1306 pub fn create_temporary_source(&self, source: SourceCatalog) {
1307 self.temporary_source_manager
1308 .lock()
1309 .create_source(source.name.clone(), source);
1310 }
1311
1312 pub fn get_temporary_source(&self, name: &str) -> Option<SourceCatalog> {
1313 self.temporary_source_manager
1314 .lock()
1315 .get_source(name)
1316 .cloned()
1317 }
1318
1319 pub fn drop_temporary_source(&self, name: &str) {
1320 self.temporary_source_manager.lock().drop_source(name);
1321 }
1322
1323 pub fn temporary_source_manager(&self) -> TemporarySourceManager {
1324 self.temporary_source_manager.lock().clone()
1325 }
1326
1327 pub async fn check_cluster_limits(&self) -> Result<()> {
1328 if self.config().bypass_cluster_limits() {
1329 return Ok(());
1330 }
1331
1332 let gen_message = |ActorCountPerParallelism {
1333 worker_id_to_actor_count,
1334 hard_limit,
1335 soft_limit,
1336 }: ActorCountPerParallelism,
1337 exceed_hard_limit: bool|
1338 -> String {
1339 let (limit_type, action) = if exceed_hard_limit {
1340 ("critical", "Scale the cluster immediately to proceed.")
1341 } else {
1342 (
1343 "recommended",
1344 "Consider scaling the cluster for optimal performance.",
1345 )
1346 };
1347 format!(
1348 r#"Actor count per parallelism exceeds the {limit_type} limit.
1349
1350Depending on your workload, this may overload the cluster and cause performance/stability issues. {action}
1351
1352HINT:
1353- For best practices on managing streaming jobs: https://docs.risingwave.com/operate/manage-a-large-number-of-streaming-jobs
1354- To bypass the check (if the cluster load is acceptable): `[ALTER SYSTEM] SET bypass_cluster_limits TO true`.
1355 See https://docs.risingwave.com/operate/view-configure-runtime-parameters#how-to-configure-runtime-parameters
1356- Contact us via slack or https://risingwave.com/contact-us/ for further enquiry.
1357
1358DETAILS:
1359- hard limit: {hard_limit}
1360- soft limit: {soft_limit}
1361- worker_id_to_actor_count: {worker_id_to_actor_count:?}"#,
1362 )
1363 };
1364
1365 let limits = self.env().meta_client().get_cluster_limits().await?;
1366 for limit in limits {
1367 match limit {
1368 cluster_limit::ClusterLimit::ActorCount(l) => {
1369 if l.exceed_hard_limit() {
1370 return Err(RwError::from(ErrorCode::ProtocolError(gen_message(
1371 l, true,
1372 ))));
1373 } else if l.exceed_soft_limit() {
1374 self.notice_to_user(gen_message(l, false));
1375 }
1376 }
1377 }
1378 }
1379 Ok(())
1380 }
1381}
1382
1383pub static SESSION_MANAGER: std::sync::OnceLock<Arc<SessionManagerImpl>> =
1384 std::sync::OnceLock::new();
1385
1386pub struct SessionManagerImpl {
1387 env: FrontendEnv,
1388 _join_handles: Vec<JoinHandle<()>>,
1389 _shutdown_senders: Vec<Sender<()>>,
1390 number: AtomicI32,
1391}
1392
1393impl SessionManager for SessionManagerImpl {
1394 type Session = SessionImpl;
1395
1396 fn create_dummy_session(
1397 &self,
1398 database_id: u32,
1399 user_id: u32,
1400 ) -> std::result::Result<Arc<Self::Session>, BoxedError> {
1401 let dummy_addr = Address::Tcp(SocketAddr::new(
1402 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
1403 5691, ));
1405 let user_reader = self.env.user_info_reader();
1406 let reader = user_reader.read_guard();
1407 if let Some(user_name) = reader.get_user_name_by_id(user_id) {
1408 self.connect_inner(database_id, user_name.as_str(), Arc::new(dummy_addr))
1409 } else {
1410 Err(Box::new(Error::new(
1411 ErrorKind::InvalidInput,
1412 format!("Role id {} does not exist", user_id),
1413 )))
1414 }
1415 }
1416
1417 fn connect(
1418 &self,
1419 database: &str,
1420 user_name: &str,
1421 peer_addr: AddressRef,
1422 ) -> std::result::Result<Arc<Self::Session>, BoxedError> {
1423 let catalog_reader = self.env.catalog_reader();
1424 let reader = catalog_reader.read_guard();
1425 let database_id = reader
1426 .get_database_by_name(database)
1427 .map_err(|_| {
1428 Box::new(Error::new(
1429 ErrorKind::InvalidInput,
1430 format!("database \"{}\" does not exist", database),
1431 ))
1432 })?
1433 .id();
1434
1435 self.connect_inner(database_id, user_name, peer_addr)
1436 }
1437
1438 fn cancel_queries_in_session(&self, session_id: SessionId) {
1440 self.env.cancel_queries_in_session(session_id);
1441 }
1442
1443 fn cancel_creating_jobs_in_session(&self, session_id: SessionId) {
1444 self.env.cancel_creating_jobs_in_session(session_id);
1445 }
1446
1447 fn end_session(&self, session: &Self::Session) {
1448 self.delete_session(&session.session_id());
1449 }
1450
1451 async fn shutdown(&self) {
1452 self.env.sessions_map().write().clear();
1454 self.env.meta_client().try_unregister().await;
1456 }
1457}
1458
1459impl SessionManagerImpl {
1460 pub async fn new(opts: FrontendOpts) -> Result<Self> {
1461 let (env, join_handles, shutdown_senders) = FrontendEnv::init(opts).await?;
1463 Ok(Self {
1464 env,
1465 _join_handles: join_handles,
1466 _shutdown_senders: shutdown_senders,
1467 number: AtomicI32::new(0),
1468 })
1469 }
1470
1471 pub(crate) fn env(&self) -> &FrontendEnv {
1472 &self.env
1473 }
1474
1475 fn insert_session(&self, session: Arc<SessionImpl>) {
1476 let active_sessions = {
1477 let mut write_guard = self.env.sessions_map.write();
1478 write_guard.insert(session.id(), session);
1479 write_guard.len()
1480 };
1481 self.env
1482 .frontend_metrics
1483 .active_sessions
1484 .set(active_sessions as i64);
1485 }
1486
1487 fn delete_session(&self, session_id: &SessionId) {
1488 let active_sessions = {
1489 let mut write_guard = self.env.sessions_map.write();
1490 write_guard.remove(session_id);
1491 write_guard.len()
1492 };
1493 self.env
1494 .frontend_metrics
1495 .active_sessions
1496 .set(active_sessions as i64);
1497 }
1498
1499 fn connect_inner(
1500 &self,
1501 database_id: u32,
1502 user_name: &str,
1503 peer_addr: AddressRef,
1504 ) -> std::result::Result<Arc<SessionImpl>, BoxedError> {
1505 let catalog_reader = self.env.catalog_reader();
1506 let reader = catalog_reader.read_guard();
1507 let (database_name, database_owner) = {
1508 let db = reader.get_database_by_id(&database_id).map_err(|_| {
1509 Box::new(Error::new(
1510 ErrorKind::InvalidInput,
1511 format!("database \"{}\" does not exist", database_id),
1512 ))
1513 })?;
1514 (db.name(), db.owner())
1515 };
1516
1517 let user_reader = self.env.user_info_reader();
1518 let reader = user_reader.read_guard();
1519 if let Some(user) = reader.get_user_by_name(user_name) {
1520 if !user.can_login {
1521 return Err(Box::new(Error::new(
1522 ErrorKind::InvalidInput,
1523 format!("User {} is not allowed to login", user_name),
1524 )));
1525 }
1526 let has_privilege =
1527 user.has_privilege(&Object::DatabaseId(database_id), AclMode::Connect);
1528 if !user.is_super && database_owner != user.id && !has_privilege {
1529 return Err(Box::new(Error::new(
1530 ErrorKind::PermissionDenied,
1531 "User does not have CONNECT privilege.",
1532 )));
1533 }
1534 let user_authenticator = match &user.auth_info {
1535 None => UserAuthenticator::None,
1536 Some(auth_info) => {
1537 if auth_info.encryption_type == EncryptionType::Plaintext as i32 {
1538 UserAuthenticator::ClearText(auth_info.encrypted_value.clone())
1539 } else if auth_info.encryption_type == EncryptionType::Md5 as i32 {
1540 let mut salt = [0; 4];
1541 let mut rng = rand::rng();
1542 rng.fill_bytes(&mut salt);
1543 UserAuthenticator::Md5WithSalt {
1544 encrypted_password: md5_hash_with_salt(
1545 &auth_info.encrypted_value,
1546 &salt,
1547 ),
1548 salt,
1549 }
1550 } else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
1551 UserAuthenticator::OAuth(auth_info.metadata.clone())
1552 } else {
1553 return Err(Box::new(Error::new(
1554 ErrorKind::Unsupported,
1555 format!("Unsupported auth type: {}", auth_info.encryption_type),
1556 )));
1557 }
1558 }
1559 };
1560
1561 let secret_key = self.number.fetch_add(1, Ordering::Relaxed);
1563 let id = (secret_key, secret_key);
1565 let session_config = self.env.session_params_snapshot();
1567
1568 let session_impl: Arc<SessionImpl> = SessionImpl::new(
1569 self.env.clone(),
1570 AuthContext::new(database_name.to_owned(), user_name.to_owned(), user.id),
1571 user_authenticator,
1572 id,
1573 peer_addr,
1574 session_config,
1575 )
1576 .into();
1577 self.insert_session(session_impl.clone());
1578
1579 Ok(session_impl)
1580 } else {
1581 Err(Box::new(Error::new(
1582 ErrorKind::InvalidInput,
1583 format!("Role {} does not exist", user_name),
1584 )))
1585 }
1586 }
1587}
1588
1589impl Session for SessionImpl {
1590 type Portal = Portal;
1591 type PreparedStatement = PrepareStatement;
1592 type ValuesStream = PgResponseStream;
1593
1594 async fn run_one_query(
1597 self: Arc<Self>,
1598 stmt: Statement,
1599 format: Format,
1600 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1601 let string = stmt.to_string();
1602 let sql_str = string.as_str();
1603 let sql: Arc<str> = Arc::from(sql_str);
1604 drop(string);
1606 let rsp = handle(self, stmt, sql, vec![format]).await?;
1607 Ok(rsp)
1608 }
1609
1610 fn user_authenticator(&self) -> &UserAuthenticator {
1611 &self.user_authenticator
1612 }
1613
1614 fn id(&self) -> SessionId {
1615 self.id
1616 }
1617
1618 async fn parse(
1619 self: Arc<Self>,
1620 statement: Option<Statement>,
1621 params_types: Vec<Option<DataType>>,
1622 ) -> std::result::Result<PrepareStatement, BoxedError> {
1623 Ok(if let Some(statement) = statement {
1624 handle_parse(self, statement, params_types).await?
1625 } else {
1626 PrepareStatement::Empty
1627 })
1628 }
1629
1630 fn bind(
1631 self: Arc<Self>,
1632 prepare_statement: PrepareStatement,
1633 params: Vec<Option<Bytes>>,
1634 param_formats: Vec<Format>,
1635 result_formats: Vec<Format>,
1636 ) -> std::result::Result<Portal, BoxedError> {
1637 Ok(handle_bind(
1638 prepare_statement,
1639 params,
1640 param_formats,
1641 result_formats,
1642 )?)
1643 }
1644
1645 async fn execute(
1646 self: Arc<Self>,
1647 portal: Portal,
1648 ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1649 let rsp = handle_execute(self, portal).await?;
1650 Ok(rsp)
1651 }
1652
1653 fn describe_statement(
1654 self: Arc<Self>,
1655 prepare_statement: PrepareStatement,
1656 ) -> std::result::Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
1657 Ok(match prepare_statement {
1658 PrepareStatement::Empty => (vec![], vec![]),
1659 PrepareStatement::Prepared(prepare_statement) => (
1660 prepare_statement.bound_result.param_types,
1661 infer(
1662 Some(prepare_statement.bound_result.bound),
1663 prepare_statement.statement,
1664 )?,
1665 ),
1666 PrepareStatement::PureStatement(statement) => (vec![], infer(None, statement)?),
1667 })
1668 }
1669
1670 fn describe_portal(
1671 self: Arc<Self>,
1672 portal: Portal,
1673 ) -> std::result::Result<Vec<PgFieldDescriptor>, BoxedError> {
1674 match portal {
1675 Portal::Empty => Ok(vec![]),
1676 Portal::Portal(portal) => {
1677 let mut columns = infer(Some(portal.bound_result.bound), portal.statement)?;
1678 let formats = FormatIterator::new(&portal.result_formats, columns.len())?;
1679 columns.iter_mut().zip_eq_fast(formats).for_each(|(c, f)| {
1680 if f == Format::Binary {
1681 c.set_to_binary()
1682 }
1683 });
1684 Ok(columns)
1685 }
1686 Portal::PureStatement(statement) => Ok(infer(None, statement)?),
1687 }
1688 }
1689
1690 fn get_config(&self, key: &str) -> std::result::Result<String, BoxedError> {
1691 self.config().get(key).map_err(Into::into)
1692 }
1693
1694 fn set_config(&self, key: &str, value: String) -> std::result::Result<String, BoxedError> {
1695 Self::set_config(self, key, value).map_err(Into::into)
1696 }
1697
1698 async fn next_notice(self: &Arc<Self>) -> String {
1699 std::future::poll_fn(|cx| self.clone().notice_rx.lock().poll_recv(cx))
1700 .await
1701 .expect("notice channel should not be closed")
1702 }
1703
1704 fn transaction_status(&self) -> TransactionStatus {
1705 match &*self.txn.lock() {
1706 transaction::State::Initial | transaction::State::Implicit(_) => {
1707 TransactionStatus::Idle
1708 }
1709 transaction::State::Explicit(_) => TransactionStatus::InTransaction,
1710 }
1712 }
1713
1714 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
1716 let exec_context = Arc::new(ExecContext {
1717 running_sql: sql,
1718 last_instant: Instant::now(),
1719 last_idle_instant: self.last_idle_instant.clone(),
1720 });
1721 *self.exec_context.lock() = Some(Arc::downgrade(&exec_context));
1722 *self.last_idle_instant.lock() = None;
1724 ExecContextGuard::new(exec_context)
1725 }
1726
1727 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
1730 if matches!(self.transaction_status(), TransactionStatus::InTransaction) {
1732 let idle_in_transaction_session_timeout =
1733 self.config().idle_in_transaction_session_timeout() as u128;
1734 if idle_in_transaction_session_timeout != 0 {
1736 let guard = self.exec_context.lock();
1738 if guard.as_ref().and_then(|weak| weak.upgrade()).is_none() {
1740 if let Some(elapse_since_last_idle_instant) =
1742 self.elapse_since_last_idle_instant()
1743 && elapse_since_last_idle_instant > idle_in_transaction_session_timeout
1744 {
1745 return Err(PsqlError::IdleInTxnTimeout);
1746 }
1747 }
1748 }
1749 }
1750 Ok(())
1751 }
1752}
1753
1754fn infer(bound: Option<BoundStatement>, stmt: Statement) -> Result<Vec<PgFieldDescriptor>> {
1756 match stmt {
1757 Statement::Query(_)
1758 | Statement::Insert { .. }
1759 | Statement::Delete { .. }
1760 | Statement::Update { .. }
1761 | Statement::FetchCursor { .. } => Ok(bound
1762 .unwrap()
1763 .output_fields()
1764 .iter()
1765 .map(to_pg_field)
1766 .collect()),
1767 Statement::ShowObjects {
1768 object: show_object,
1769 ..
1770 } => Ok(infer_show_object(&show_object)),
1771 Statement::ShowCreateObject { .. } => Ok(infer_show_create_object()),
1772 Statement::ShowTransactionIsolationLevel => {
1773 let name = "transaction_isolation";
1774 Ok(infer_show_variable(name))
1775 }
1776 Statement::ShowVariable { variable } => {
1777 let name = &variable[0].real_value().to_lowercase();
1778 Ok(infer_show_variable(name))
1779 }
1780 Statement::Describe { name: _, kind } => Ok(infer_describe(&kind)),
1781 Statement::Explain { .. } => Ok(vec![PgFieldDescriptor::new(
1782 "QUERY PLAN".to_owned(),
1783 DataType::Varchar.to_oid(),
1784 DataType::Varchar.type_len(),
1785 )]),
1786 _ => Ok(vec![]),
1787 }
1788}
1789
1790pub struct WorkerProcessId {
1791 pub worker_id: u32,
1792 pub process_id: i32,
1793}
1794
1795impl WorkerProcessId {
1796 pub fn new(worker_id: u32, process_id: i32) -> Self {
1797 Self {
1798 worker_id,
1799 process_id,
1800 }
1801 }
1802}
1803
1804impl Display for WorkerProcessId {
1805 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1806 write!(f, "{}:{}", self.worker_id, self.process_id)
1807 }
1808}
1809
1810impl TryFrom<String> for WorkerProcessId {
1811 type Error = String;
1812
1813 fn try_from(worker_process_id: String) -> std::result::Result<Self, Self::Error> {
1814 const INVALID: &str = "invalid WorkerProcessId";
1815 let splits: Vec<&str> = worker_process_id.split(":").collect();
1816 if splits.len() != 2 {
1817 return Err(INVALID.to_owned());
1818 }
1819 let Ok(worker_id) = splits[0].parse::<u32>() else {
1820 return Err(INVALID.to_owned());
1821 };
1822 let Ok(process_id) = splits[1].parse::<i32>() else {
1823 return Err(INVALID.to_owned());
1824 };
1825 Ok(WorkerProcessId::new(worker_id, process_id))
1826 }
1827}
1828
1829pub fn cancel_queries_in_session(session_id: SessionId, sessions_map: SessionMapRef) -> bool {
1830 let guard = sessions_map.read();
1831 if let Some(session) = guard.get(&session_id) {
1832 session.cancel_current_query();
1833 true
1834 } else {
1835 info!("Current session finished, ignoring cancel query request");
1836 false
1837 }
1838}
1839
1840pub fn cancel_creating_jobs_in_session(session_id: SessionId, sessions_map: SessionMapRef) -> bool {
1841 let guard = sessions_map.read();
1842 if let Some(session) = guard.get(&session_id) {
1843 session.cancel_current_creating_job();
1844 true
1845 } else {
1846 info!("Current session finished, ignoring cancel creating request");
1847 false
1848 }
1849}