risingwave_frontend/
session.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::io::{Error, ErrorKind};
18use std::net::{IpAddr, Ipv4Addr, SocketAddr};
19use std::sync::atomic::{AtomicI32, Ordering};
20use std::sync::{Arc, Weak};
21use std::time::{Duration, Instant};
22
23use anyhow::anyhow;
24use bytes::Bytes;
25use either::Either;
26use itertools::Itertools;
27use parking_lot::{Mutex, RwLock, RwLockReadGuard};
28use pgwire::error::{PsqlError, PsqlResult};
29use pgwire::net::{Address, AddressRef};
30use pgwire::pg_field_descriptor::PgFieldDescriptor;
31use pgwire::pg_message::TransactionStatus;
32use pgwire::pg_response::{PgResponse, StatementType};
33use pgwire::pg_server::{
34    BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
35    UserAuthenticator,
36};
37use pgwire::types::{Format, FormatIterator};
38use prometheus_http_query::Client as PrometheusClient;
39use rand::RngCore;
40use risingwave_batch::monitor::{BatchSpillMetrics, GLOBAL_BATCH_SPILL_METRICS};
41use risingwave_batch::spill::spill_op::SpillOp;
42use risingwave_batch::task::{ShutdownSender, ShutdownToken};
43use risingwave_batch::worker_manager::worker_node_manager::{
44    WorkerNodeManager, WorkerNodeManagerRef,
45};
46use risingwave_common::acl::AclMode;
47#[cfg(test)]
48use risingwave_common::catalog::{
49    DEFAULT_DATABASE_NAME, DEFAULT_SUPER_USER, DEFAULT_SUPER_USER_ID,
50};
51use risingwave_common::config::{
52    AuthMethod, BatchConfig, ConnectionType, FrontendConfig, MetaConfig, MetricLevel,
53    StreamingConfig, UdfConfig, load_config,
54};
55use risingwave_common::id::WorkerId;
56use risingwave_common::memory::MemoryContext;
57use risingwave_common::secret::LocalSecretManager;
58use risingwave_common::session_config::{ConfigReporter, SessionConfig, VisibilityMode};
59use risingwave_common::system_param::local_manager::{
60    LocalSystemParamsManager, LocalSystemParamsManagerRef,
61};
62use risingwave_common::telemetry::manager::TelemetryManager;
63use risingwave_common::telemetry::telemetry_env_enabled;
64use risingwave_common::types::DataType;
65use risingwave_common::util::addr::HostAddr;
66use risingwave_common::util::cluster_limit;
67use risingwave_common::util::cluster_limit::ActorCountPerParallelism;
68use risingwave_common::util::iter_util::ZipEqFast;
69use risingwave_common::util::pretty_bytes::convert;
70use risingwave_common::util::runtime::BackgroundShutdownRuntime;
71use risingwave_common::{GIT_SHA, RW_VERSION};
72use risingwave_common_heap_profiling::HeapProfiler;
73use risingwave_common_service::{MetricsManager, ObserverManager};
74use risingwave_connector::source::monitor::{GLOBAL_SOURCE_METRICS, SourceMetrics};
75use risingwave_pb::common::WorkerType;
76use risingwave_pb::common::worker_node::Property as AddWorkerNodeProperty;
77use risingwave_pb::frontend_service::frontend_service_server::FrontendServiceServer;
78use risingwave_pb::health::health_server::HealthServer;
79use risingwave_pb::user::auth_info::EncryptionType;
80use risingwave_rpc_client::{
81    ComputeClientPool, ComputeClientPoolRef, FrontendClientPool, FrontendClientPoolRef, MetaClient,
82};
83use risingwave_sqlparser::ast::{ObjectName, Statement};
84use risingwave_sqlparser::parser::Parser;
85use thiserror::Error;
86use thiserror_ext::AsReport;
87use tokio::runtime::Builder;
88use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
89use tokio::sync::oneshot::Sender;
90use tokio::sync::watch;
91use tokio::task::JoinHandle;
92use tracing::{error, info};
93
94use self::cursor_manager::CursorManager;
95use crate::binder::{Binder, BoundStatement, ResolveQualifiedNameError};
96use crate::catalog::catalog_service::{CatalogReader, CatalogWriter, CatalogWriterImpl};
97use crate::catalog::connection_catalog::ConnectionCatalog;
98use crate::catalog::root_catalog::{Catalog, SchemaPath};
99use crate::catalog::secret_catalog::SecretCatalog;
100use crate::catalog::source_catalog::SourceCatalog;
101use crate::catalog::subscription_catalog::SubscriptionCatalog;
102use crate::catalog::{
103    CatalogError, CatalogErrorInner, DatabaseId, OwnedByUserCatalog, SchemaId, TableId,
104    check_schema_writable,
105};
106use crate::error::{
107    ErrorCode, Result, RwError, bail_catalog_error, bail_permission_denied, bail_protocol_error,
108};
109use crate::handler::describe::infer_describe;
110use crate::handler::extended_handle::{
111    Portal, PrepareStatement, handle_bind, handle_execute, handle_parse,
112};
113use crate::handler::privilege::ObjectCheckItem;
114use crate::handler::show::{infer_show_create_object, infer_show_object};
115use crate::handler::util::to_pg_field;
116use crate::handler::variable::infer_show_variable;
117use crate::handler::{RwPgResponse, handle};
118use crate::health_service::HealthServiceImpl;
119use crate::meta_client::{FrontendMetaClient, FrontendMetaClientImpl};
120use crate::monitor::{CursorMetrics, FrontendMetrics, GLOBAL_FRONTEND_METRICS};
121use crate::observer::FrontendObserverNode;
122use crate::rpc::FrontendServiceImpl;
123use crate::scheduler::streaming_manager::{StreamingJobTracker, StreamingJobTrackerRef};
124use crate::scheduler::{
125    DistributedQueryMetrics, GLOBAL_DISTRIBUTED_QUERY_METRICS, HummockSnapshotManager,
126    HummockSnapshotManagerRef, QueryManager,
127};
128use crate::telemetry::FrontendTelemetryCreator;
129use crate::user::UserId;
130use crate::user::user_authentication::md5_hash_with_salt;
131use crate::user::user_manager::UserInfoManager;
132use crate::user::user_service::{UserInfoReader, UserInfoWriter, UserInfoWriterImpl};
133use crate::{FrontendOpts, PgResponseStream, TableCatalog};
134
135pub(crate) mod current;
136pub(crate) mod cursor_manager;
137pub(crate) mod transaction;
138
139/// The global environment for the frontend server.
140#[derive(Clone)]
141pub(crate) struct FrontendEnv {
142    // Different session may access catalog at the same time and catalog is protected by a
143    // RwLock.
144    meta_client: Arc<dyn FrontendMetaClient>,
145    catalog_writer: Arc<dyn CatalogWriter>,
146    catalog_reader: CatalogReader,
147    user_info_writer: Arc<dyn UserInfoWriter>,
148    user_info_reader: UserInfoReader,
149    worker_node_manager: WorkerNodeManagerRef,
150    query_manager: QueryManager,
151    hummock_snapshot_manager: HummockSnapshotManagerRef,
152    system_params_manager: LocalSystemParamsManagerRef,
153    session_params: Arc<RwLock<SessionConfig>>,
154
155    server_addr: HostAddr,
156    client_pool: ComputeClientPoolRef,
157    frontend_client_pool: FrontendClientPoolRef,
158
159    /// Each session is identified by (`process_id`,
160    /// `secret_key`). When Cancel Request received, find corresponding session and cancel all
161    /// running queries.
162    sessions_map: SessionMapRef,
163
164    pub frontend_metrics: Arc<FrontendMetrics>,
165
166    pub cursor_metrics: Arc<CursorMetrics>,
167
168    source_metrics: Arc<SourceMetrics>,
169
170    /// Batch spill metrics
171    spill_metrics: Arc<BatchSpillMetrics>,
172
173    batch_config: BatchConfig,
174    frontend_config: FrontendConfig,
175    #[expect(dead_code)]
176    meta_config: MetaConfig,
177    streaming_config: StreamingConfig,
178    udf_config: UdfConfig,
179
180    /// Track creating streaming jobs, used to cancel creating streaming job when cancel request
181    /// received.
182    creating_streaming_job_tracker: StreamingJobTrackerRef,
183
184    /// Runtime for compute intensive tasks in frontend, e.g. executors in local mode,
185    /// root stage in mpp mode.
186    compute_runtime: Arc<BackgroundShutdownRuntime>,
187
188    /// Memory context used for batch executors in frontend.
189    mem_context: MemoryContext,
190
191    /// address of the serverless backfill controller.
192    serverless_backfill_controller_addr: String,
193
194    /// Prometheus client for querying metrics.
195    prometheus_client: Option<PrometheusClient>,
196
197    /// The additional selector used when querying Prometheus.
198    prometheus_selector: String,
199}
200
201/// Session map identified by `(process_id, secret_key)`
202pub type SessionMapRef = Arc<RwLock<HashMap<(i32, i32), Arc<SessionImpl>>>>;
203
204/// The proportion of frontend memory used for batch processing.
205const FRONTEND_BATCH_MEMORY_PROPORTION: f64 = 0.5;
206
207impl FrontendEnv {
208    pub fn mock() -> Self {
209        use crate::test_utils::{MockCatalogWriter, MockFrontendMetaClient, MockUserInfoWriter};
210
211        let catalog = Arc::new(RwLock::new(Catalog::default()));
212        let meta_client = Arc::new(MockFrontendMetaClient {});
213        let hummock_snapshot_manager = Arc::new(HummockSnapshotManager::new(meta_client.clone()));
214        let catalog_writer = Arc::new(MockCatalogWriter::new(
215            catalog.clone(),
216            hummock_snapshot_manager.clone(),
217        ));
218        let catalog_reader = CatalogReader::new(catalog);
219        let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
220        let user_info_writer = Arc::new(MockUserInfoWriter::new(user_info_manager.clone()));
221        let user_info_reader = UserInfoReader::new(user_info_manager);
222        let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
223        let system_params_manager = Arc::new(LocalSystemParamsManager::for_test());
224        let compute_client_pool = Arc::new(ComputeClientPool::for_test());
225        let frontend_client_pool = Arc::new(FrontendClientPool::for_test());
226        let query_manager = QueryManager::new(
227            worker_node_manager.clone(),
228            compute_client_pool,
229            catalog_reader.clone(),
230            Arc::new(DistributedQueryMetrics::for_test()),
231            None,
232            None,
233        );
234        let server_addr = HostAddr::try_from("127.0.0.1:4565").unwrap();
235        let client_pool = Arc::new(ComputeClientPool::for_test());
236        let creating_streaming_tracker = StreamingJobTracker::new(meta_client.clone());
237        let runtime = {
238            let mut builder = Builder::new_multi_thread();
239            if let Some(frontend_compute_runtime_worker_threads) =
240                load_config("", FrontendOpts::default())
241                    .batch
242                    .frontend_compute_runtime_worker_threads
243            {
244                builder.worker_threads(frontend_compute_runtime_worker_threads);
245            }
246            builder
247                .thread_name("rw-batch-local")
248                .enable_all()
249                .build()
250                .unwrap()
251        };
252        let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(runtime));
253        let sessions_map = Arc::new(RwLock::new(HashMap::new()));
254        Self {
255            meta_client,
256            catalog_writer,
257            catalog_reader,
258            user_info_writer,
259            user_info_reader,
260            worker_node_manager,
261            query_manager,
262            hummock_snapshot_manager,
263            system_params_manager,
264            session_params: Default::default(),
265            server_addr,
266            client_pool,
267            frontend_client_pool,
268            sessions_map,
269            frontend_metrics: Arc::new(FrontendMetrics::for_test()),
270            cursor_metrics: Arc::new(CursorMetrics::for_test()),
271            batch_config: BatchConfig::default(),
272            frontend_config: FrontendConfig::default(),
273            meta_config: MetaConfig::default(),
274            streaming_config: StreamingConfig::default(),
275            udf_config: UdfConfig::default(),
276            source_metrics: Arc::new(SourceMetrics::default()),
277            spill_metrics: BatchSpillMetrics::for_test(),
278            creating_streaming_job_tracker: Arc::new(creating_streaming_tracker),
279            compute_runtime,
280            mem_context: MemoryContext::none(),
281            serverless_backfill_controller_addr: Default::default(),
282            prometheus_client: None,
283            prometheus_selector: String::new(),
284        }
285    }
286
287    pub async fn init(opts: FrontendOpts) -> Result<(Self, Vec<JoinHandle<()>>, Vec<Sender<()>>)> {
288        let config = load_config(&opts.config_path, &opts);
289        info!("Starting frontend node");
290        info!("> config: {:?}", config);
291        info!(
292            "> debug assertions: {}",
293            if cfg!(debug_assertions) { "on" } else { "off" }
294        );
295        info!("> version: {} ({})", RW_VERSION, GIT_SHA);
296
297        let frontend_address: HostAddr = opts
298            .advertise_addr
299            .as_ref()
300            .unwrap_or_else(|| {
301                tracing::warn!("advertise addr is not specified, defaulting to listen_addr");
302                &opts.listen_addr
303            })
304            .parse()
305            .unwrap();
306        info!("advertise addr is {}", frontend_address);
307
308        let rpc_addr: HostAddr = opts.frontend_rpc_listener_addr.parse().unwrap();
309        let internal_rpc_host_addr = HostAddr {
310            // Use the host of advertise address for the frontend rpc address.
311            host: frontend_address.host.clone(),
312            port: rpc_addr.port,
313        };
314        // Register in meta by calling `AddWorkerNode` RPC.
315        let (meta_client, system_params_reader) = MetaClient::register_new(
316            opts.meta_addr,
317            WorkerType::Frontend,
318            &frontend_address,
319            AddWorkerNodeProperty {
320                internal_rpc_host_addr: internal_rpc_host_addr.to_string(),
321                ..Default::default()
322            },
323            Arc::new(config.meta.clone()),
324        )
325        .await;
326
327        let worker_id = meta_client.worker_id();
328        info!("Assigned worker node id {}", worker_id);
329
330        let (heartbeat_join_handle, heartbeat_shutdown_sender) = MetaClient::start_heartbeat_loop(
331            meta_client.clone(),
332            Duration::from_millis(config.server.heartbeat_interval_ms as u64),
333        );
334        let mut join_handles = vec![heartbeat_join_handle];
335        let mut shutdown_senders = vec![heartbeat_shutdown_sender];
336
337        let frontend_meta_client = Arc::new(FrontendMetaClientImpl(meta_client.clone()));
338        let hummock_snapshot_manager =
339            Arc::new(HummockSnapshotManager::new(frontend_meta_client.clone()));
340
341        let (catalog_updated_tx, catalog_updated_rx) = watch::channel(0);
342        let catalog = Arc::new(RwLock::new(Catalog::default()));
343        let catalog_writer = Arc::new(CatalogWriterImpl::new(
344            meta_client.clone(),
345            catalog_updated_rx.clone(),
346            hummock_snapshot_manager.clone(),
347        ));
348        let catalog_reader = CatalogReader::new(catalog.clone());
349
350        let worker_node_manager = Arc::new(WorkerNodeManager::new());
351
352        let compute_client_pool = Arc::new(ComputeClientPool::new(
353            config.batch_exchange_connection_pool_size(),
354            config.batch.developer.compute_client_config.clone(),
355        ));
356        let frontend_client_pool = Arc::new(FrontendClientPool::new(
357            1,
358            config.batch.developer.frontend_client_config.clone(),
359        ));
360        let query_manager = QueryManager::new(
361            worker_node_manager.clone(),
362            compute_client_pool.clone(),
363            catalog_reader.clone(),
364            Arc::new(GLOBAL_DISTRIBUTED_QUERY_METRICS.clone()),
365            config.batch.distributed_query_limit,
366            config.batch.max_batch_queries_per_frontend_node,
367        );
368
369        let user_info_manager = Arc::new(RwLock::new(UserInfoManager::default()));
370        let user_info_reader = UserInfoReader::new(user_info_manager.clone());
371        let user_info_writer = Arc::new(UserInfoWriterImpl::new(
372            meta_client.clone(),
373            catalog_updated_rx,
374        ));
375
376        let system_params_manager =
377            Arc::new(LocalSystemParamsManager::new(system_params_reader.clone()));
378
379        LocalSecretManager::init(
380            opts.temp_secret_file_dir,
381            meta_client.cluster_id().to_owned(),
382            worker_id,
383        );
384
385        // This `session_params` should be initialized during the initial notification in `observer_manager`
386        let session_params = Arc::new(RwLock::new(SessionConfig::default()));
387        let sessions_map: SessionMapRef = Arc::new(RwLock::new(HashMap::new()));
388        let cursor_metrics = Arc::new(CursorMetrics::init(sessions_map.clone()));
389
390        let frontend_observer_node = FrontendObserverNode::new(
391            worker_node_manager.clone(),
392            catalog,
393            catalog_updated_tx,
394            user_info_manager,
395            hummock_snapshot_manager.clone(),
396            system_params_manager.clone(),
397            session_params.clone(),
398            compute_client_pool.clone(),
399        );
400        let observer_manager =
401            ObserverManager::new_with_meta_client(meta_client.clone(), frontend_observer_node)
402                .await;
403        let observer_join_handle = observer_manager.start().await;
404        join_handles.push(observer_join_handle);
405
406        meta_client.activate(&frontend_address).await?;
407
408        let frontend_metrics = Arc::new(GLOBAL_FRONTEND_METRICS.clone());
409        let source_metrics = Arc::new(GLOBAL_SOURCE_METRICS.clone());
410        let spill_metrics = Arc::new(GLOBAL_BATCH_SPILL_METRICS.clone());
411
412        if config.server.metrics_level > MetricLevel::Disabled {
413            MetricsManager::boot_metrics_service(opts.prometheus_listener_addr.clone());
414        }
415
416        let health_srv = HealthServiceImpl::new();
417        let frontend_srv = FrontendServiceImpl::new(sessions_map.clone());
418        let frontend_rpc_addr = opts.frontend_rpc_listener_addr.parse().unwrap();
419
420        let telemetry_manager = TelemetryManager::new(
421            Arc::new(meta_client.clone()),
422            Arc::new(FrontendTelemetryCreator::new()),
423        );
424
425        // if the toml config file or env variable disables telemetry, do not watch system params
426        // change because if any of configs disable telemetry, we should never start it
427        if config.server.telemetry_enabled && telemetry_env_enabled() {
428            let (join_handle, shutdown_sender) = telemetry_manager.start().await;
429            join_handles.push(join_handle);
430            shutdown_senders.push(shutdown_sender);
431        } else {
432            tracing::info!("Telemetry didn't start due to config");
433        }
434
435        tokio::spawn(async move {
436            tonic::transport::Server::builder()
437                .add_service(HealthServer::new(health_srv))
438                .add_service(FrontendServiceServer::new(frontend_srv))
439                .serve(frontend_rpc_addr)
440                .await
441                .unwrap();
442        });
443        info!(
444            "Health Check RPC Listener is set up on {}",
445            opts.frontend_rpc_listener_addr.clone()
446        );
447
448        let creating_streaming_job_tracker =
449            Arc::new(StreamingJobTracker::new(frontend_meta_client.clone()));
450
451        let runtime = {
452            let mut builder = Builder::new_multi_thread();
453            if let Some(frontend_compute_runtime_worker_threads) =
454                config.batch.frontend_compute_runtime_worker_threads
455            {
456                builder.worker_threads(frontend_compute_runtime_worker_threads);
457            }
458            builder
459                .thread_name("rw-batch-local")
460                .enable_all()
461                .build()
462                .unwrap()
463        };
464        let compute_runtime = Arc::new(BackgroundShutdownRuntime::from(runtime));
465
466        let sessions = sessions_map.clone();
467        // Idle transaction background monitor
468        let join_handle = tokio::spawn(async move {
469            let mut check_idle_txn_interval =
470                tokio::time::interval(core::time::Duration::from_secs(10));
471            check_idle_txn_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
472            check_idle_txn_interval.reset();
473            loop {
474                check_idle_txn_interval.tick().await;
475                sessions.read().values().for_each(|session| {
476                    let _ = session.check_idle_in_transaction_timeout();
477                })
478            }
479        });
480        join_handles.push(join_handle);
481
482        // Clean up the spill directory.
483        #[cfg(not(madsim))]
484        if config.batch.enable_spill {
485            SpillOp::clean_spill_directory()
486                .await
487                .map_err(|err| anyhow!(err))?;
488        }
489
490        let total_memory_bytes = opts.frontend_total_memory_bytes;
491        let heap_profiler =
492            HeapProfiler::new(total_memory_bytes, config.server.heap_profiling.clone());
493        // Run a background heap profiler
494        heap_profiler.start();
495
496        let batch_memory_limit = total_memory_bytes as f64 * FRONTEND_BATCH_MEMORY_PROPORTION;
497        let mem_context = MemoryContext::root(
498            frontend_metrics.batch_total_mem.clone(),
499            batch_memory_limit as u64,
500        );
501
502        // Initialize Prometheus client if endpoint is provided
503        let prometheus_client = if let Some(ref endpoint) = opts.prometheus_endpoint {
504            match PrometheusClient::try_from(endpoint.as_str()) {
505                Ok(client) => {
506                    info!("Prometheus client initialized with endpoint: {}", endpoint);
507                    Some(client)
508                }
509                Err(e) => {
510                    error!(
511                        error = %e.as_report(),
512                        "failed to initialize Prometheus client",
513                    );
514                    None
515                }
516            }
517        } else {
518            None
519        };
520
521        // Initialize Prometheus selector
522        let prometheus_selector = opts.prometheus_selector.unwrap_or_default();
523
524        info!(
525            "Frontend  total_memory: {} batch_memory: {}",
526            convert(total_memory_bytes as _),
527            convert(batch_memory_limit as _),
528        );
529
530        Ok((
531            Self {
532                catalog_reader,
533                catalog_writer,
534                user_info_reader,
535                user_info_writer,
536                worker_node_manager,
537                meta_client: frontend_meta_client,
538                query_manager,
539                hummock_snapshot_manager,
540                system_params_manager,
541                session_params,
542                server_addr: frontend_address,
543                client_pool: compute_client_pool,
544                frontend_client_pool,
545                frontend_metrics,
546                cursor_metrics,
547                spill_metrics,
548                sessions_map,
549                batch_config: config.batch,
550                frontend_config: config.frontend,
551                meta_config: config.meta,
552                streaming_config: config.streaming,
553                serverless_backfill_controller_addr: opts.serverless_backfill_controller_addr,
554                udf_config: config.udf,
555                source_metrics,
556                creating_streaming_job_tracker,
557                compute_runtime,
558                mem_context,
559                prometheus_client,
560                prometheus_selector,
561            },
562            join_handles,
563            shutdown_senders,
564        ))
565    }
566
567    /// Get a reference to the frontend env's catalog writer.
568    ///
569    /// This method is intentionally private, and a write guard is required for the caller to
570    /// prove that the write operations are permitted in the current transaction.
571    fn catalog_writer(&self, _guard: transaction::WriteGuard) -> &dyn CatalogWriter {
572        &*self.catalog_writer
573    }
574
575    /// Get a reference to the frontend env's catalog reader.
576    pub fn catalog_reader(&self) -> &CatalogReader {
577        &self.catalog_reader
578    }
579
580    /// Get a reference to the frontend env's user info writer.
581    ///
582    /// This method is intentionally private, and a write guard is required for the caller to
583    /// prove that the write operations are permitted in the current transaction.
584    fn user_info_writer(&self, _guard: transaction::WriteGuard) -> &dyn UserInfoWriter {
585        &*self.user_info_writer
586    }
587
588    /// Get a reference to the frontend env's user info reader.
589    pub fn user_info_reader(&self) -> &UserInfoReader {
590        &self.user_info_reader
591    }
592
593    pub fn worker_node_manager_ref(&self) -> WorkerNodeManagerRef {
594        self.worker_node_manager.clone()
595    }
596
597    pub fn meta_client(&self) -> &dyn FrontendMetaClient {
598        &*self.meta_client
599    }
600
601    pub fn meta_client_ref(&self) -> Arc<dyn FrontendMetaClient> {
602        self.meta_client.clone()
603    }
604
605    pub fn query_manager(&self) -> &QueryManager {
606        &self.query_manager
607    }
608
609    pub fn hummock_snapshot_manager(&self) -> &HummockSnapshotManagerRef {
610        &self.hummock_snapshot_manager
611    }
612
613    pub fn system_params_manager(&self) -> &LocalSystemParamsManagerRef {
614        &self.system_params_manager
615    }
616
617    pub fn session_params_snapshot(&self) -> SessionConfig {
618        self.session_params.read_recursive().clone()
619    }
620
621    pub fn sbc_address(&self) -> &String {
622        &self.serverless_backfill_controller_addr
623    }
624
625    /// Get a reference to the Prometheus client if available.
626    pub fn prometheus_client(&self) -> Option<&PrometheusClient> {
627        self.prometheus_client.as_ref()
628    }
629
630    /// Get the Prometheus selector string.
631    pub fn prometheus_selector(&self) -> &str {
632        &self.prometheus_selector
633    }
634
635    pub fn server_address(&self) -> &HostAddr {
636        &self.server_addr
637    }
638
639    pub fn client_pool(&self) -> ComputeClientPoolRef {
640        self.client_pool.clone()
641    }
642
643    pub fn frontend_client_pool(&self) -> FrontendClientPoolRef {
644        self.frontend_client_pool.clone()
645    }
646
647    pub fn batch_config(&self) -> &BatchConfig {
648        &self.batch_config
649    }
650
651    pub fn frontend_config(&self) -> &FrontendConfig {
652        &self.frontend_config
653    }
654
655    pub fn streaming_config(&self) -> &StreamingConfig {
656        &self.streaming_config
657    }
658
659    pub fn udf_config(&self) -> &UdfConfig {
660        &self.udf_config
661    }
662
663    pub fn source_metrics(&self) -> Arc<SourceMetrics> {
664        self.source_metrics.clone()
665    }
666
667    pub fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
668        self.spill_metrics.clone()
669    }
670
671    pub fn creating_streaming_job_tracker(&self) -> &StreamingJobTrackerRef {
672        &self.creating_streaming_job_tracker
673    }
674
675    pub fn sessions_map(&self) -> &SessionMapRef {
676        &self.sessions_map
677    }
678
679    pub fn compute_runtime(&self) -> Arc<BackgroundShutdownRuntime> {
680        self.compute_runtime.clone()
681    }
682
683    /// Cancel queries (i.e. batch queries) in session.
684    /// If the session exists return true, otherwise, return false.
685    pub fn cancel_queries_in_session(&self, session_id: SessionId) -> bool {
686        cancel_queries_in_session(session_id, self.sessions_map.clone())
687    }
688
689    /// Cancel creating jobs (i.e. streaming queries) in session.
690    /// If the session exists return true, otherwise, return false.
691    pub fn cancel_creating_jobs_in_session(&self, session_id: SessionId) -> bool {
692        cancel_creating_jobs_in_session(session_id, self.sessions_map.clone())
693    }
694
695    pub fn mem_context(&self) -> MemoryContext {
696        self.mem_context.clone()
697    }
698}
699
700#[derive(Clone)]
701pub struct AuthContext {
702    pub database: String,
703    pub user_name: String,
704    pub user_id: UserId,
705}
706
707impl AuthContext {
708    pub fn new(database: String, user_name: String, user_id: UserId) -> Self {
709        Self {
710            database,
711            user_name,
712            user_id,
713        }
714    }
715}
716pub struct SessionImpl {
717    env: FrontendEnv,
718    auth_context: Arc<RwLock<AuthContext>>,
719    /// Used for user authentication.
720    user_authenticator: UserAuthenticator,
721    /// Stores the value of configurations.
722    config_map: Arc<RwLock<SessionConfig>>,
723
724    /// Channel sender for frontend handler to send notices.
725    notice_tx: UnboundedSender<String>,
726    /// Channel receiver for pgwire to take notices and send to clients.
727    notice_rx: Mutex<UnboundedReceiver<String>>,
728
729    /// Identified by `process_id`, `secret_key`. Corresponds to `SessionManager`.
730    id: (i32, i32),
731
732    /// Client address
733    peer_addr: AddressRef,
734
735    /// Transaction state.
736    /// TODO: get rid of the `Mutex` here as a workaround if the `Send` requirement of
737    /// async functions, there should actually be no contention.
738    txn: Arc<Mutex<transaction::State>>,
739
740    /// Query cancel flag.
741    /// This flag is set only when current query is executed in local mode, and used to cancel
742    /// local query.
743    current_query_cancel_flag: Mutex<Option<ShutdownSender>>,
744
745    /// execution context represents the lifetime of a running SQL in the current session
746    exec_context: Mutex<Option<Weak<ExecContext>>>,
747
748    /// Last idle instant
749    last_idle_instant: Arc<Mutex<Option<Instant>>>,
750
751    cursor_manager: Arc<CursorManager>,
752
753    /// temporary sources for the current session
754    temporary_source_manager: Arc<Mutex<TemporarySourceManager>>,
755
756    /// staging catalogs for the current session
757    staging_catalog_manager: Arc<Mutex<StagingCatalogManager>>,
758}
759
760/// If TEMPORARY or TEMP is specified, the source is created as a temporary source.
761/// Temporary sources are automatically dropped at the end of a session
762/// Temporary sources are expected to be selected by batch queries, not streaming queries.
763/// Temporary sources currently are only used by cloud portal to preview the data during table and
764/// source creation, so it is a internal feature and not exposed to users.
765/// The current PR supports temporary source with minimum effort,
766/// so we don't care about the database name and schema name, but only care about the source name.
767/// Temporary sources can only be shown via `show sources` command but not other system tables.
768#[derive(Default, Clone)]
769pub struct TemporarySourceManager {
770    sources: HashMap<String, SourceCatalog>,
771}
772
773impl TemporarySourceManager {
774    pub fn new() -> Self {
775        Self {
776            sources: HashMap::new(),
777        }
778    }
779
780    pub fn create_source(&mut self, name: String, source: SourceCatalog) {
781        self.sources.insert(name, source);
782    }
783
784    pub fn drop_source(&mut self, name: &str) {
785        self.sources.remove(name);
786    }
787
788    pub fn get_source(&self, name: &str) -> Option<&SourceCatalog> {
789        self.sources.get(name)
790    }
791
792    pub fn keys(&self) -> Vec<String> {
793        self.sources.keys().cloned().collect()
794    }
795}
796
797/// Staging catalog manager is used to manage the tables creating in the current session.
798#[derive(Default, Clone)]
799pub struct StagingCatalogManager {
800    // staging tables creating in the current session.
801    tables: HashMap<String, TableCatalog>,
802}
803
804impl StagingCatalogManager {
805    pub fn new() -> Self {
806        Self {
807            tables: HashMap::new(),
808        }
809    }
810
811    pub fn create_table(&mut self, name: String, table: TableCatalog) {
812        self.tables.insert(name, table);
813    }
814
815    pub fn drop_table(&mut self, name: &str) {
816        self.tables.remove(name);
817    }
818
819    pub fn get_table(&self, name: &str) -> Option<&TableCatalog> {
820        self.tables.get(name)
821    }
822}
823
824#[derive(Error, Debug)]
825pub enum CheckRelationError {
826    #[error("{0}")]
827    Resolve(#[from] ResolveQualifiedNameError),
828    #[error("{0}")]
829    Catalog(#[from] CatalogError),
830}
831
832impl From<CheckRelationError> for RwError {
833    fn from(e: CheckRelationError) -> Self {
834        match e {
835            CheckRelationError::Resolve(e) => e.into(),
836            CheckRelationError::Catalog(e) => e.into(),
837        }
838    }
839}
840
841impl SessionImpl {
842    pub(crate) fn new(
843        env: FrontendEnv,
844        auth_context: AuthContext,
845        user_authenticator: UserAuthenticator,
846        id: SessionId,
847        peer_addr: AddressRef,
848        session_config: SessionConfig,
849    ) -> Self {
850        let cursor_metrics = env.cursor_metrics.clone();
851        let (notice_tx, notice_rx) = mpsc::unbounded_channel();
852
853        Self {
854            env,
855            auth_context: Arc::new(RwLock::new(auth_context)),
856            user_authenticator,
857            config_map: Arc::new(RwLock::new(session_config)),
858            id,
859            peer_addr,
860            txn: Default::default(),
861            current_query_cancel_flag: Mutex::new(None),
862            notice_tx,
863            notice_rx: Mutex::new(notice_rx),
864            exec_context: Mutex::new(None),
865            last_idle_instant: Default::default(),
866            cursor_manager: Arc::new(CursorManager::new(cursor_metrics)),
867            temporary_source_manager: Default::default(),
868            staging_catalog_manager: Default::default(),
869        }
870    }
871
872    #[cfg(test)]
873    pub fn mock() -> Self {
874        let env = FrontendEnv::mock();
875        let (notice_tx, notice_rx) = mpsc::unbounded_channel();
876
877        Self {
878            env: FrontendEnv::mock(),
879            auth_context: Arc::new(RwLock::new(AuthContext::new(
880                DEFAULT_DATABASE_NAME.to_owned(),
881                DEFAULT_SUPER_USER.to_owned(),
882                DEFAULT_SUPER_USER_ID,
883            ))),
884            user_authenticator: UserAuthenticator::None,
885            config_map: Default::default(),
886            // Mock session use non-sense id.
887            id: (0, 0),
888            txn: Default::default(),
889            current_query_cancel_flag: Mutex::new(None),
890            notice_tx,
891            notice_rx: Mutex::new(notice_rx),
892            exec_context: Mutex::new(None),
893            peer_addr: Address::Tcp(SocketAddr::new(
894                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
895                8080,
896            ))
897            .into(),
898            last_idle_instant: Default::default(),
899            cursor_manager: Arc::new(CursorManager::new(env.cursor_metrics)),
900            temporary_source_manager: Default::default(),
901            staging_catalog_manager: Default::default(),
902        }
903    }
904
905    pub(crate) fn env(&self) -> &FrontendEnv {
906        &self.env
907    }
908
909    pub fn auth_context(&self) -> Arc<AuthContext> {
910        let ctx = self.auth_context.read();
911        Arc::new(ctx.clone())
912    }
913
914    pub fn database(&self) -> String {
915        self.auth_context.read().database.clone()
916    }
917
918    pub fn database_id(&self) -> DatabaseId {
919        let db_name = self.database();
920        self.env
921            .catalog_reader()
922            .read_guard()
923            .get_database_by_name(&db_name)
924            .map(|db| db.id())
925            .expect("session database not found")
926    }
927
928    pub fn user_name(&self) -> String {
929        self.auth_context.read().user_name.clone()
930    }
931
932    pub fn user_id(&self) -> UserId {
933        self.auth_context.read().user_id
934    }
935
936    pub fn update_database(&self, database: String) {
937        self.auth_context.write().database = database;
938    }
939
940    pub fn shared_config(&self) -> Arc<RwLock<SessionConfig>> {
941        Arc::clone(&self.config_map)
942    }
943
944    pub fn config(&self) -> RwLockReadGuard<'_, SessionConfig> {
945        self.config_map.read()
946    }
947
948    pub fn set_config(&self, key: &str, value: String) -> Result<String> {
949        self.config_map
950            .write()
951            .set(key, value, &mut ())
952            .map_err(Into::into)
953    }
954
955    pub fn reset_config(&self, key: &str) -> Result<String> {
956        self.config_map
957            .write()
958            .reset(key, &mut ())
959            .map_err(Into::into)
960    }
961
962    pub fn set_config_report(
963        &self,
964        key: &str,
965        value: Option<String>,
966        mut reporter: impl ConfigReporter,
967    ) -> Result<String> {
968        if let Some(value) = value {
969            self.config_map
970                .write()
971                .set(key, value, &mut reporter)
972                .map_err(Into::into)
973        } else {
974            self.config_map
975                .write()
976                .reset(key, &mut reporter)
977                .map_err(Into::into)
978        }
979    }
980
981    pub fn session_id(&self) -> SessionId {
982        self.id
983    }
984
985    pub fn running_sql(&self) -> Option<Arc<str>> {
986        self.exec_context
987            .lock()
988            .as_ref()
989            .and_then(|weak| weak.upgrade())
990            .map(|context| context.running_sql.clone())
991    }
992
993    pub fn get_cursor_manager(&self) -> Arc<CursorManager> {
994        self.cursor_manager.clone()
995    }
996
997    pub fn peer_addr(&self) -> &Address {
998        &self.peer_addr
999    }
1000
1001    pub fn elapse_since_running_sql(&self) -> Option<u128> {
1002        self.exec_context
1003            .lock()
1004            .as_ref()
1005            .and_then(|weak| weak.upgrade())
1006            .map(|context| context.last_instant.elapsed().as_millis())
1007    }
1008
1009    pub fn elapse_since_last_idle_instant(&self) -> Option<u128> {
1010        self.last_idle_instant
1011            .lock()
1012            .as_ref()
1013            .map(|x| x.elapsed().as_millis())
1014    }
1015
1016    pub fn check_relation_name_duplicated(
1017        &self,
1018        name: ObjectName,
1019        stmt_type: StatementType,
1020        if_not_exists: bool,
1021    ) -> std::result::Result<Either<(), RwPgResponse>, CheckRelationError> {
1022        let db_name = &self.database();
1023        let catalog_reader = self.env().catalog_reader().read_guard();
1024        let (schema_name, relation_name) = {
1025            let (schema_name, relation_name) =
1026                Binder::resolve_schema_qualified_name(db_name, &name)?;
1027            let search_path = self.config().search_path();
1028            let user_name = &self.user_name();
1029            let schema_name = match schema_name {
1030                Some(schema_name) => schema_name,
1031                None => catalog_reader
1032                    .first_valid_schema(db_name, &search_path, user_name)?
1033                    .name(),
1034            };
1035            (schema_name, relation_name)
1036        };
1037        match catalog_reader.check_relation_name_duplicated(db_name, &schema_name, &relation_name) {
1038            Err(e) if if_not_exists => {
1039                if let CatalogErrorInner::Duplicated {
1040                    name,
1041                    under_creation,
1042                    ..
1043                } = e.inner()
1044                {
1045                    // If relation is created, return directly.
1046                    // Otherwise, the job status is `is_creating`. Since frontend receives the catalog asynchronously, we can't
1047                    // determine the real status of the meta at this time. We regard it as `not_exists` and delay the check to meta.
1048                    // Only the type in StreamingJob (defined in streaming_job.rs) and Subscription may be `is_creating`.
1049                    if !*under_creation {
1050                        Ok(Either::Right(
1051                            PgResponse::builder(stmt_type)
1052                                .notice(format!("relation \"{}\" already exists, skipping", name))
1053                                .into(),
1054                        ))
1055                    } else if stmt_type == StatementType::CREATE_SUBSCRIPTION {
1056                        // For now, when a Subscription is creating, we return directly with an additional message.
1057                        // TODO: Subscription should also be processed in the same way as StreamingJob.
1058                        Ok(Either::Right(
1059                            PgResponse::builder(stmt_type)
1060                                .notice(format!(
1061                                    "relation \"{}\" already exists but still creating, skipping",
1062                                    name
1063                                ))
1064                                .into(),
1065                        ))
1066                    } else {
1067                        Ok(Either::Left(()))
1068                    }
1069                } else {
1070                    Err(e.into())
1071                }
1072            }
1073            Err(e) => Err(e.into()),
1074            Ok(_) => Ok(Either::Left(())),
1075        }
1076    }
1077
1078    pub fn check_secret_name_duplicated(&self, name: ObjectName) -> Result<()> {
1079        let db_name = &self.database();
1080        let catalog_reader = self.env().catalog_reader().read_guard();
1081        let (schema_name, secret_name) = {
1082            let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, &name)?;
1083            let search_path = self.config().search_path();
1084            let user_name = &self.user_name();
1085            let schema_name = match schema_name {
1086                Some(schema_name) => schema_name,
1087                None => catalog_reader
1088                    .first_valid_schema(db_name, &search_path, user_name)?
1089                    .name(),
1090            };
1091            (schema_name, secret_name)
1092        };
1093        catalog_reader
1094            .check_secret_name_duplicated(db_name, &schema_name, &secret_name)
1095            .map_err(RwError::from)
1096    }
1097
1098    pub fn check_connection_name_duplicated(&self, name: ObjectName) -> Result<()> {
1099        let db_name = &self.database();
1100        let catalog_reader = self.env().catalog_reader().read_guard();
1101        let (schema_name, connection_name) = {
1102            let (schema_name, connection_name) =
1103                Binder::resolve_schema_qualified_name(db_name, &name)?;
1104            let search_path = self.config().search_path();
1105            let user_name = &self.user_name();
1106            let schema_name = match schema_name {
1107                Some(schema_name) => schema_name,
1108                None => catalog_reader
1109                    .first_valid_schema(db_name, &search_path, user_name)?
1110                    .name(),
1111            };
1112            (schema_name, connection_name)
1113        };
1114        catalog_reader
1115            .check_connection_name_duplicated(db_name, &schema_name, &connection_name)
1116            .map_err(RwError::from)
1117    }
1118
1119    pub fn check_function_name_duplicated(
1120        &self,
1121        stmt_type: StatementType,
1122        name: ObjectName,
1123        arg_types: &[DataType],
1124        if_not_exists: bool,
1125    ) -> Result<Either<(), RwPgResponse>> {
1126        let db_name = &self.database();
1127        let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, &name)?;
1128        let (database_id, schema_id) = self.get_database_and_schema_id_for_create(schema_name)?;
1129
1130        let catalog_reader = self.env().catalog_reader().read_guard();
1131        if catalog_reader
1132            .get_schema_by_id(database_id, schema_id)?
1133            .get_function_by_name_args(&function_name, arg_types)
1134            .is_some()
1135        {
1136            let full_name = format!(
1137                "{function_name}({})",
1138                arg_types.iter().map(|t| t.to_string()).join(",")
1139            );
1140            if if_not_exists {
1141                Ok(Either::Right(
1142                    PgResponse::builder(stmt_type)
1143                        .notice(format!(
1144                            "function \"{}\" already exists, skipping",
1145                            full_name
1146                        ))
1147                        .into(),
1148                ))
1149            } else {
1150                Err(CatalogError::duplicated("function", full_name).into())
1151            }
1152        } else {
1153            Ok(Either::Left(()))
1154        }
1155    }
1156
1157    /// Also check if the user has the privilege to create in the schema.
1158    pub fn get_database_and_schema_id_for_create(
1159        &self,
1160        schema_name: Option<String>,
1161    ) -> Result<(DatabaseId, SchemaId)> {
1162        let db_name = &self.database();
1163
1164        let search_path = self.config().search_path();
1165        let user_name = &self.user_name();
1166
1167        let catalog_reader = self.env().catalog_reader().read_guard();
1168        let schema = match schema_name {
1169            Some(schema_name) => catalog_reader.get_schema_by_name(db_name, &schema_name)?,
1170            None => catalog_reader.first_valid_schema(db_name, &search_path, user_name)?,
1171        };
1172        let schema_name = schema.name();
1173
1174        check_schema_writable(&schema_name)?;
1175        self.check_privileges(&[ObjectCheckItem::new(
1176            schema.owner(),
1177            AclMode::Create,
1178            schema_name,
1179            schema.id(),
1180        )])?;
1181
1182        let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1183        Ok((db_id, schema.id()))
1184    }
1185
1186    pub fn get_connection_by_name(
1187        &self,
1188        schema_name: Option<String>,
1189        connection_name: &str,
1190    ) -> Result<Arc<ConnectionCatalog>> {
1191        let db_name = &self.database();
1192        let search_path = self.config().search_path();
1193        let user_name = &self.user_name();
1194
1195        let catalog_reader = self.env().catalog_reader().read_guard();
1196        let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1197        let (connection, _) =
1198            catalog_reader.get_connection_by_name(db_name, schema_path, connection_name)?;
1199
1200        self.check_privileges(&[ObjectCheckItem::new(
1201            connection.owner(),
1202            AclMode::Usage,
1203            connection.name.clone(),
1204            connection.id,
1205        )])?;
1206
1207        Ok(connection.clone())
1208    }
1209
1210    pub fn get_subscription_by_schema_id_name(
1211        &self,
1212        schema_id: SchemaId,
1213        subscription_name: &str,
1214    ) -> Result<Arc<SubscriptionCatalog>> {
1215        let db_name = &self.database();
1216
1217        let catalog_reader = self.env().catalog_reader().read_guard();
1218        let db_id = catalog_reader.get_database_by_name(db_name)?.id();
1219        let schema = catalog_reader.get_schema_by_id(db_id, schema_id)?;
1220        let subscription = schema
1221            .get_subscription_by_name(subscription_name)
1222            .ok_or_else(|| {
1223                RwError::from(ErrorCode::ItemNotFound(format!(
1224                    "subscription {} not found",
1225                    subscription_name
1226                )))
1227            })?;
1228        Ok(subscription.clone())
1229    }
1230
1231    pub fn get_subscription_by_name(
1232        &self,
1233        schema_name: Option<String>,
1234        subscription_name: &str,
1235    ) -> Result<Arc<SubscriptionCatalog>> {
1236        let db_name = &self.database();
1237        let search_path = self.config().search_path();
1238        let user_name = &self.user_name();
1239
1240        let catalog_reader = self.env().catalog_reader().read_guard();
1241        let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1242        let (subscription, _) =
1243            catalog_reader.get_subscription_by_name(db_name, schema_path, subscription_name)?;
1244        Ok(subscription.clone())
1245    }
1246
1247    pub fn get_table_by_id(&self, table_id: TableId) -> Result<Arc<TableCatalog>> {
1248        let catalog_reader = self.env().catalog_reader().read_guard();
1249        Ok(catalog_reader.get_any_table_by_id(table_id)?.clone())
1250    }
1251
1252    pub fn get_table_by_name(
1253        &self,
1254        table_name: &str,
1255        db_id: DatabaseId,
1256        schema_id: SchemaId,
1257    ) -> Result<Arc<TableCatalog>> {
1258        let catalog_reader = self.env().catalog_reader().read_guard();
1259        let table = catalog_reader
1260            .get_schema_by_id(db_id, schema_id)?
1261            .get_created_table_by_name(table_name)
1262            .ok_or_else(|| {
1263                Error::new(
1264                    ErrorKind::InvalidInput,
1265                    format!("table \"{}\" does not exist", table_name),
1266                )
1267            })?;
1268
1269        self.check_privileges(&[ObjectCheckItem::new(
1270            table.owner(),
1271            AclMode::Select,
1272            table_name.to_owned(),
1273            table.id,
1274        )])?;
1275
1276        Ok(table.clone())
1277    }
1278
1279    pub fn get_secret_by_name(
1280        &self,
1281        schema_name: Option<String>,
1282        secret_name: &str,
1283    ) -> Result<Arc<SecretCatalog>> {
1284        let db_name = &self.database();
1285        let search_path = self.config().search_path();
1286        let user_name = &self.user_name();
1287
1288        let catalog_reader = self.env().catalog_reader().read_guard();
1289        let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
1290        let (secret, _) = catalog_reader.get_secret_by_name(db_name, schema_path, secret_name)?;
1291
1292        self.check_privileges(&[ObjectCheckItem::new(
1293            secret.owner(),
1294            AclMode::Usage,
1295            secret.name.clone(),
1296            secret.id,
1297        )])?;
1298
1299        Ok(secret.clone())
1300    }
1301
1302    pub fn list_change_log_epochs(
1303        &self,
1304        table_id: TableId,
1305        min_epoch: u64,
1306        max_count: u32,
1307    ) -> Result<Vec<u64>> {
1308        Ok(self
1309            .env
1310            .hummock_snapshot_manager()
1311            .acquire()
1312            .list_change_log_epochs(table_id, min_epoch, max_count))
1313    }
1314
1315    pub fn clear_cancel_query_flag(&self) {
1316        let mut flag = self.current_query_cancel_flag.lock();
1317        *flag = None;
1318    }
1319
1320    pub fn reset_cancel_query_flag(&self) -> ShutdownToken {
1321        let mut flag = self.current_query_cancel_flag.lock();
1322        let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
1323        *flag = Some(shutdown_tx);
1324        shutdown_rx
1325    }
1326
1327    pub fn cancel_current_query(&self) {
1328        let mut flag_guard = self.current_query_cancel_flag.lock();
1329        if let Some(sender) = flag_guard.take() {
1330            info!("Trying to cancel query in local mode.");
1331            // Current running query is in local mode
1332            sender.cancel();
1333            info!("Cancel query request sent.");
1334        } else {
1335            info!("Trying to cancel query in distributed mode.");
1336            self.env.query_manager().cancel_queries_in_session(self.id)
1337        }
1338    }
1339
1340    pub fn cancel_current_creating_job(&self) {
1341        self.env.creating_streaming_job_tracker.abort_jobs(self.id);
1342    }
1343
1344    /// This function only used for test now.
1345    /// Maybe we can remove it in the future.
1346    pub async fn run_statement(
1347        self: Arc<Self>,
1348        sql: Arc<str>,
1349        formats: Vec<Format>,
1350    ) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
1351        // Parse sql.
1352        let mut stmts = Parser::parse_sql(&sql)?;
1353        if stmts.is_empty() {
1354            return Ok(PgResponse::empty_result(
1355                pgwire::pg_response::StatementType::EMPTY,
1356            ));
1357        }
1358        if stmts.len() > 1 {
1359            return Ok(
1360                PgResponse::builder(pgwire::pg_response::StatementType::EMPTY)
1361                    .notice("cannot insert multiple commands into statement")
1362                    .into(),
1363            );
1364        }
1365        let stmt = stmts.swap_remove(0);
1366        let rsp = handle(self, stmt, sql.clone(), formats).await?;
1367        Ok(rsp)
1368    }
1369
1370    pub fn notice_to_user(&self, str: impl Into<String>) {
1371        let notice = str.into();
1372        tracing::trace!(notice, "notice to user");
1373        self.notice_tx
1374            .send(notice)
1375            .expect("notice channel should not be closed");
1376    }
1377
1378    pub fn is_barrier_read(&self) -> bool {
1379        match self.config().visibility_mode() {
1380            VisibilityMode::Default => self.env.batch_config.enable_barrier_read,
1381            VisibilityMode::All => true,
1382            VisibilityMode::Checkpoint => false,
1383        }
1384    }
1385
1386    pub fn statement_timeout(&self) -> Duration {
1387        if self.config().statement_timeout() == 0 {
1388            Duration::from_secs(self.env.batch_config.statement_timeout_in_sec as u64)
1389        } else {
1390            Duration::from_secs(self.config().statement_timeout() as u64)
1391        }
1392    }
1393
1394    pub fn create_temporary_source(&self, source: SourceCatalog) {
1395        self.temporary_source_manager
1396            .lock()
1397            .create_source(source.name.clone(), source);
1398    }
1399
1400    pub fn get_temporary_source(&self, name: &str) -> Option<SourceCatalog> {
1401        self.temporary_source_manager
1402            .lock()
1403            .get_source(name)
1404            .cloned()
1405    }
1406
1407    pub fn drop_temporary_source(&self, name: &str) {
1408        self.temporary_source_manager.lock().drop_source(name);
1409    }
1410
1411    pub fn temporary_source_manager(&self) -> TemporarySourceManager {
1412        self.temporary_source_manager.lock().clone()
1413    }
1414
1415    pub fn create_staging_table(&self, table: TableCatalog) {
1416        self.staging_catalog_manager
1417            .lock()
1418            .create_table(table.name.clone(), table);
1419    }
1420
1421    pub fn drop_staging_table(&self, name: &str) {
1422        self.staging_catalog_manager.lock().drop_table(name);
1423    }
1424
1425    pub fn staging_catalog_manager(&self) -> StagingCatalogManager {
1426        self.staging_catalog_manager.lock().clone()
1427    }
1428
1429    pub async fn check_cluster_limits(&self) -> Result<()> {
1430        if self.config().bypass_cluster_limits() {
1431            return Ok(());
1432        }
1433
1434        let gen_message = |ActorCountPerParallelism {
1435                               worker_id_to_actor_count,
1436                               hard_limit,
1437                               soft_limit,
1438                           }: ActorCountPerParallelism,
1439                           exceed_hard_limit: bool|
1440         -> String {
1441            let (limit_type, action) = if exceed_hard_limit {
1442                ("critical", "Scale the cluster immediately to proceed.")
1443            } else {
1444                (
1445                    "recommended",
1446                    "Consider scaling the cluster for optimal performance.",
1447                )
1448            };
1449            format!(
1450                r#"Actor count per parallelism exceeds the {limit_type} limit.
1451
1452Depending on your workload, this may overload the cluster and cause performance/stability issues. {action}
1453
1454HINT:
1455- For best practices on managing streaming jobs: https://docs.risingwave.com/operate/manage-a-large-number-of-streaming-jobs
1456- To bypass the check (if the cluster load is acceptable): `[ALTER SYSTEM] SET bypass_cluster_limits TO true`.
1457  See https://docs.risingwave.com/operate/view-configure-runtime-parameters#how-to-configure-runtime-parameters
1458- Contact us via slack or https://risingwave.com/contact-us/ for further enquiry.
1459
1460DETAILS:
1461- hard limit: {hard_limit}
1462- soft limit: {soft_limit}
1463- worker_id_to_actor_count: {worker_id_to_actor_count:?}"#,
1464            )
1465        };
1466
1467        let limits = self.env().meta_client().get_cluster_limits().await?;
1468        for limit in limits {
1469            match limit {
1470                cluster_limit::ClusterLimit::ActorCount(l) => {
1471                    if l.exceed_hard_limit() {
1472                        return Err(RwError::from(ErrorCode::ProtocolError(gen_message(
1473                            l, true,
1474                        ))));
1475                    } else if l.exceed_soft_limit() {
1476                        self.notice_to_user(gen_message(l, false));
1477                    }
1478                }
1479            }
1480        }
1481        Ok(())
1482    }
1483}
1484
1485pub static SESSION_MANAGER: std::sync::OnceLock<Arc<SessionManagerImpl>> =
1486    std::sync::OnceLock::new();
1487
1488pub struct SessionManagerImpl {
1489    env: FrontendEnv,
1490    _join_handles: Vec<JoinHandle<()>>,
1491    _shutdown_senders: Vec<Sender<()>>,
1492    number: AtomicI32,
1493}
1494
1495impl SessionManager for SessionManagerImpl {
1496    type Error = RwError;
1497    type Session = SessionImpl;
1498
1499    fn create_dummy_session(
1500        &self,
1501        database_id: DatabaseId,
1502        user_id: u32,
1503    ) -> Result<Arc<Self::Session>> {
1504        let dummy_addr = Address::Tcp(SocketAddr::new(
1505            IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
1506            5691, // port of meta
1507        ));
1508        let user_reader = self.env.user_info_reader();
1509        let reader = user_reader.read_guard();
1510        if let Some(user_name) = reader.get_user_name_by_id(user_id) {
1511            self.connect_inner(database_id, user_name.as_str(), Arc::new(dummy_addr))
1512        } else {
1513            bail_catalog_error!("Role id {} does not exist", user_id)
1514        }
1515    }
1516
1517    fn connect(
1518        &self,
1519        database: &str,
1520        user_name: &str,
1521        peer_addr: AddressRef,
1522    ) -> Result<Arc<Self::Session>> {
1523        let catalog_reader = self.env.catalog_reader();
1524        let reader = catalog_reader.read_guard();
1525        let database_id = reader.get_database_by_name(database)?.id();
1526
1527        self.connect_inner(database_id, user_name, peer_addr)
1528    }
1529
1530    /// Used when cancel request happened.
1531    fn cancel_queries_in_session(&self, session_id: SessionId) {
1532        self.env.cancel_queries_in_session(session_id);
1533    }
1534
1535    fn cancel_creating_jobs_in_session(&self, session_id: SessionId) {
1536        self.env.cancel_creating_jobs_in_session(session_id);
1537    }
1538
1539    fn end_session(&self, session: &Self::Session) {
1540        self.delete_session(&session.session_id());
1541    }
1542
1543    async fn shutdown(&self) {
1544        // Clean up the session map.
1545        self.env.sessions_map().write().clear();
1546        // Unregister from the meta service.
1547        self.env.meta_client().try_unregister().await;
1548    }
1549}
1550
1551impl SessionManagerImpl {
1552    pub async fn new(opts: FrontendOpts) -> Result<Self> {
1553        // TODO(shutdown): only save join handles that **need** to be shutdown
1554        let (env, join_handles, shutdown_senders) = FrontendEnv::init(opts).await?;
1555        Ok(Self {
1556            env,
1557            _join_handles: join_handles,
1558            _shutdown_senders: shutdown_senders,
1559            number: AtomicI32::new(0),
1560        })
1561    }
1562
1563    pub(crate) fn env(&self) -> &FrontendEnv {
1564        &self.env
1565    }
1566
1567    fn insert_session(&self, session: Arc<SessionImpl>) {
1568        let active_sessions = {
1569            let mut write_guard = self.env.sessions_map.write();
1570            write_guard.insert(session.id(), session);
1571            write_guard.len()
1572        };
1573        self.env
1574            .frontend_metrics
1575            .active_sessions
1576            .set(active_sessions as i64);
1577    }
1578
1579    fn delete_session(&self, session_id: &SessionId) {
1580        let active_sessions = {
1581            let mut write_guard = self.env.sessions_map.write();
1582            write_guard.remove(session_id);
1583            write_guard.len()
1584        };
1585        self.env
1586            .frontend_metrics
1587            .active_sessions
1588            .set(active_sessions as i64);
1589    }
1590
1591    fn connect_inner(
1592        &self,
1593        database_id: DatabaseId,
1594        user_name: &str,
1595        peer_addr: AddressRef,
1596    ) -> Result<Arc<SessionImpl>> {
1597        let catalog_reader = self.env.catalog_reader();
1598        let reader = catalog_reader.read_guard();
1599        let (database_name, database_owner) = {
1600            let db = reader.get_database_by_id(database_id)?;
1601            (db.name(), db.owner())
1602        };
1603
1604        let user_reader = self.env.user_info_reader();
1605        let reader = user_reader.read_guard();
1606        if let Some(user) = reader.get_user_by_name(user_name) {
1607            if !user.can_login {
1608                bail_permission_denied!("User {} is not allowed to login", user_name);
1609            }
1610            let has_privilege = user.has_privilege(database_id, AclMode::Connect);
1611            if !user.is_super && database_owner != user.id && !has_privilege {
1612                bail_permission_denied!("User does not have CONNECT privilege.");
1613            }
1614
1615            // Check HBA configuration for LDAP authentication
1616            let (connection_type, client_addr) = match peer_addr.as_ref() {
1617                Address::Tcp(socket_addr) => (ConnectionType::Host, Some(&socket_addr.ip())),
1618                Address::Unix(_) => (ConnectionType::Local, None),
1619            };
1620            tracing::debug!(
1621                "receive connection: type={:?}, client_addr={:?}",
1622                connection_type,
1623                client_addr
1624            );
1625
1626            let hba_entry_opt = self.env.frontend_config().hba_config.find_matching_entry(
1627                &connection_type,
1628                database_name,
1629                user_name,
1630                client_addr,
1631            );
1632
1633            // TODO: adding `FATAL` message support for no matching HBA entry.
1634            let Some(hba_entry_opt) = hba_entry_opt else {
1635                bail_permission_denied!(
1636                    "no pg_hba.conf entry for host \"{peer_addr}\", user \"{user_name}\", database \"{database_name}\""
1637                );
1638            };
1639
1640            // Determine the user authenticator based on the user's auth info.
1641            let authenticator_by_info = || -> Result<UserAuthenticator> {
1642                let authenticator = match &user.auth_info {
1643                    None => UserAuthenticator::None,
1644                    Some(auth_info) => match auth_info.encryption_type() {
1645                        EncryptionType::Plaintext => {
1646                            UserAuthenticator::ClearText(auth_info.encrypted_value.clone())
1647                        }
1648                        EncryptionType::Md5 => {
1649                            let mut salt = [0; 4];
1650                            let mut rng = rand::rng();
1651                            rng.fill_bytes(&mut salt);
1652                            UserAuthenticator::Md5WithSalt {
1653                                encrypted_password: md5_hash_with_salt(
1654                                    &auth_info.encrypted_value,
1655                                    &salt,
1656                                ),
1657                                salt,
1658                            }
1659                        }
1660                        EncryptionType::Oauth => UserAuthenticator::OAuth {
1661                            metadata: auth_info.metadata.clone(),
1662                            cluster_id: self.env.meta_client().cluster_id().to_owned(),
1663                        },
1664                        _ => {
1665                            bail_protocol_error!(
1666                                "Unsupported auth type: {}",
1667                                auth_info.encryption_type().as_str_name()
1668                            );
1669                        }
1670                    },
1671                };
1672                Ok(authenticator)
1673            };
1674
1675            let user_authenticator = match (&hba_entry_opt.auth_method, &user.auth_info) {
1676                (AuthMethod::Trust, _) => UserAuthenticator::None,
1677                // For backward compatibility, we allow password auth method to work with any auth info.
1678                (AuthMethod::Password, _) => authenticator_by_info()?,
1679                (AuthMethod::Md5, Some(auth_info))
1680                    if auth_info.encryption_type() == EncryptionType::Md5 =>
1681                {
1682                    authenticator_by_info()?
1683                }
1684                (AuthMethod::OAuth, Some(auth_info))
1685                    if auth_info.encryption_type() == EncryptionType::Oauth =>
1686                {
1687                    authenticator_by_info()?
1688                }
1689                (AuthMethod::Ldap, _) => {
1690                    UserAuthenticator::Ldap(user_name.to_owned(), hba_entry_opt.clone())
1691                }
1692                _ => {
1693                    bail_permission_denied!(
1694                        "password authentication failed for user \"{user_name}\""
1695                    );
1696                }
1697            };
1698
1699            // Assign a session id and insert into sessions map (for cancel request).
1700            let secret_key = self.number.fetch_add(1, Ordering::Relaxed);
1701            // Use a trivial strategy: process_id and secret_key are equal.
1702            let id = (secret_key, secret_key);
1703            // Read session params snapshot from frontend env.
1704            let session_config = self.env.session_params_snapshot();
1705
1706            let session_impl: Arc<SessionImpl> = SessionImpl::new(
1707                self.env.clone(),
1708                AuthContext::new(database_name.to_owned(), user_name.to_owned(), user.id),
1709                user_authenticator,
1710                id,
1711                peer_addr,
1712                session_config,
1713            )
1714            .into();
1715            self.insert_session(session_impl.clone());
1716
1717            Ok(session_impl)
1718        } else {
1719            bail_catalog_error!("Role {} does not exist", user_name);
1720        }
1721    }
1722}
1723
1724impl Session for SessionImpl {
1725    type Error = RwError;
1726    type Portal = Portal;
1727    type PreparedStatement = PrepareStatement;
1728    type ValuesStream = PgResponseStream;
1729
1730    /// A copy of `run_statement` but exclude the parser part so each run must be at most one
1731    /// statement. The str sql use the `to_string` of AST. Consider Reuse later.
1732    async fn run_one_query(
1733        self: Arc<Self>,
1734        stmt: Statement,
1735        format: Format,
1736    ) -> Result<PgResponse<PgResponseStream>> {
1737        let string = stmt.to_string();
1738        let sql_str = string.as_str();
1739        let sql: Arc<str> = Arc::from(sql_str);
1740        // The handle can be slow. Release potential large String early.
1741        drop(string);
1742        let rsp = handle(self, stmt, sql, vec![format]).await?;
1743        Ok(rsp)
1744    }
1745
1746    fn user_authenticator(&self) -> &UserAuthenticator {
1747        &self.user_authenticator
1748    }
1749
1750    fn id(&self) -> SessionId {
1751        self.id
1752    }
1753
1754    async fn parse(
1755        self: Arc<Self>,
1756        statement: Option<Statement>,
1757        params_types: Vec<Option<DataType>>,
1758    ) -> Result<PrepareStatement> {
1759        Ok(if let Some(statement) = statement {
1760            handle_parse(self, statement, params_types).await?
1761        } else {
1762            PrepareStatement::Empty
1763        })
1764    }
1765
1766    fn bind(
1767        self: Arc<Self>,
1768        prepare_statement: PrepareStatement,
1769        params: Vec<Option<Bytes>>,
1770        param_formats: Vec<Format>,
1771        result_formats: Vec<Format>,
1772    ) -> Result<Portal> {
1773        handle_bind(prepare_statement, params, param_formats, result_formats)
1774    }
1775
1776    async fn execute(self: Arc<Self>, portal: Portal) -> Result<PgResponse<PgResponseStream>> {
1777        let rsp = handle_execute(self, portal).await?;
1778        Ok(rsp)
1779    }
1780
1781    fn describe_statement(
1782        self: Arc<Self>,
1783        prepare_statement: PrepareStatement,
1784    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>)> {
1785        Ok(match prepare_statement {
1786            PrepareStatement::Empty => (vec![], vec![]),
1787            PrepareStatement::Prepared(prepare_statement) => (
1788                prepare_statement.bound_result.param_types,
1789                infer(
1790                    Some(prepare_statement.bound_result.bound),
1791                    prepare_statement.statement,
1792                )?,
1793            ),
1794            PrepareStatement::PureStatement(statement) => (vec![], infer(None, statement)?),
1795        })
1796    }
1797
1798    fn describe_portal(self: Arc<Self>, portal: Portal) -> Result<Vec<PgFieldDescriptor>> {
1799        match portal {
1800            Portal::Empty => Ok(vec![]),
1801            Portal::Portal(portal) => {
1802                let mut columns = infer(Some(portal.bound_result.bound), portal.statement)?;
1803                let formats = FormatIterator::new(&portal.result_formats, columns.len())
1804                    .map_err(|e| RwError::from(ErrorCode::ProtocolError(e)))?;
1805                columns.iter_mut().zip_eq_fast(formats).for_each(|(c, f)| {
1806                    if f == Format::Binary {
1807                        c.set_to_binary()
1808                    }
1809                });
1810                Ok(columns)
1811            }
1812            Portal::PureStatement(statement) => Ok(infer(None, statement)?),
1813        }
1814    }
1815
1816    fn get_config(&self, key: &str) -> Result<String> {
1817        self.config().get(key).map_err(Into::into)
1818    }
1819
1820    fn set_config(&self, key: &str, value: String) -> Result<String> {
1821        Self::set_config(self, key, value)
1822    }
1823
1824    async fn next_notice(self: &Arc<Self>) -> String {
1825        std::future::poll_fn(|cx| self.clone().notice_rx.lock().poll_recv(cx))
1826            .await
1827            .expect("notice channel should not be closed")
1828    }
1829
1830    fn transaction_status(&self) -> TransactionStatus {
1831        match &*self.txn.lock() {
1832            transaction::State::Initial | transaction::State::Implicit(_) => {
1833                TransactionStatus::Idle
1834            }
1835            transaction::State::Explicit(_) => TransactionStatus::InTransaction,
1836            // TODO: failed transaction
1837        }
1838    }
1839
1840    /// Init and return an `ExecContextGuard` which could be used as a guard to represent the execution flow.
1841    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
1842        let exec_context = Arc::new(ExecContext {
1843            running_sql: sql,
1844            last_instant: Instant::now(),
1845            last_idle_instant: self.last_idle_instant.clone(),
1846        });
1847        *self.exec_context.lock() = Some(Arc::downgrade(&exec_context));
1848        // unset idle state, since there is a sql running
1849        *self.last_idle_instant.lock() = None;
1850        ExecContextGuard::new(exec_context)
1851    }
1852
1853    /// Check whether idle transaction timeout.
1854    /// If yes, unpin snapshot and return an `IdleInTxnTimeout` error.
1855    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
1856        // In transaction.
1857        if matches!(self.transaction_status(), TransactionStatus::InTransaction) {
1858            let idle_in_transaction_session_timeout =
1859                self.config().idle_in_transaction_session_timeout() as u128;
1860            // Idle transaction timeout has been enabled.
1861            if idle_in_transaction_session_timeout != 0 {
1862                // Hold the `exec_context` lock to ensure no new sql coming when unpin_snapshot.
1863                let guard = self.exec_context.lock();
1864                // No running sql i.e. idle
1865                if guard.as_ref().and_then(|weak| weak.upgrade()).is_none() {
1866                    // Idle timeout.
1867                    if let Some(elapse_since_last_idle_instant) =
1868                        self.elapse_since_last_idle_instant()
1869                        && elapse_since_last_idle_instant > idle_in_transaction_session_timeout
1870                    {
1871                        return Err(PsqlError::IdleInTxnTimeout);
1872                    }
1873                }
1874            }
1875        }
1876        Ok(())
1877    }
1878}
1879
1880/// Returns row description of the statement
1881fn infer(bound: Option<BoundStatement>, stmt: Statement) -> Result<Vec<PgFieldDescriptor>> {
1882    match stmt {
1883        Statement::Query(_)
1884        | Statement::Insert { .. }
1885        | Statement::Delete { .. }
1886        | Statement::Update { .. }
1887        | Statement::FetchCursor { .. } => Ok(bound
1888            .unwrap()
1889            .output_fields()
1890            .iter()
1891            .map(to_pg_field)
1892            .collect()),
1893        Statement::ShowObjects {
1894            object: show_object,
1895            ..
1896        } => Ok(infer_show_object(&show_object)),
1897        Statement::ShowCreateObject { .. } => Ok(infer_show_create_object()),
1898        Statement::ShowTransactionIsolationLevel => {
1899            let name = "transaction_isolation";
1900            Ok(infer_show_variable(name))
1901        }
1902        Statement::ShowVariable { variable } => {
1903            let name = &variable[0].real_value().to_lowercase();
1904            Ok(infer_show_variable(name))
1905        }
1906        Statement::Describe { name: _, kind } => Ok(infer_describe(&kind)),
1907        Statement::Explain { .. } => Ok(vec![PgFieldDescriptor::new(
1908            "QUERY PLAN".to_owned(),
1909            DataType::Varchar.to_oid(),
1910            DataType::Varchar.type_len(),
1911        )]),
1912        _ => Ok(vec![]),
1913    }
1914}
1915
1916pub struct WorkerProcessId {
1917    pub worker_id: WorkerId,
1918    pub process_id: i32,
1919}
1920
1921impl WorkerProcessId {
1922    pub fn new(worker_id: WorkerId, process_id: i32) -> Self {
1923        Self {
1924            worker_id,
1925            process_id,
1926        }
1927    }
1928}
1929
1930impl Display for WorkerProcessId {
1931    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1932        write!(f, "{}:{}", self.worker_id, self.process_id)
1933    }
1934}
1935
1936impl TryFrom<String> for WorkerProcessId {
1937    type Error = String;
1938
1939    fn try_from(worker_process_id: String) -> std::result::Result<Self, Self::Error> {
1940        const INVALID: &str = "invalid WorkerProcessId";
1941        let splits: Vec<&str> = worker_process_id.split(":").collect();
1942        if splits.len() != 2 {
1943            return Err(INVALID.to_owned());
1944        }
1945        let Ok(worker_id) = splits[0].parse::<u32>() else {
1946            return Err(INVALID.to_owned());
1947        };
1948        let Ok(process_id) = splits[1].parse::<i32>() else {
1949            return Err(INVALID.to_owned());
1950        };
1951        Ok(WorkerProcessId::new(worker_id.into(), process_id))
1952    }
1953}
1954
1955pub fn cancel_queries_in_session(session_id: SessionId, sessions_map: SessionMapRef) -> bool {
1956    let guard = sessions_map.read();
1957    if let Some(session) = guard.get(&session_id) {
1958        session.cancel_current_query();
1959        true
1960    } else {
1961        info!("Current session finished, ignoring cancel query request");
1962        false
1963    }
1964}
1965
1966pub fn cancel_creating_jobs_in_session(session_id: SessionId, sessions_map: SessionMapRef) -> bool {
1967    let guard = sessions_map.read();
1968    if let Some(session) = guard.get(&session_id) {
1969        session.cancel_current_creating_job();
1970        true
1971    } else {
1972        info!("Current session finished, ignoring cancel creating request");
1973        false
1974    }
1975}