risingwave_frontend/handler/
extended_handle.rsuse std::fmt;
use std::fmt::Formatter;
use std::sync::Arc;
use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::bail_not_implemented;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{CreateSink, DeclareCursor, Query, Statement};
use super::query::BoundResult;
use super::{fetch_cursor, handle, query, HandlerArgs, RwPgResponse};
use crate::error::Result;
use crate::session::SessionImpl;
#[derive(Clone)]
pub enum PrepareStatement {
Empty,
Prepared(PreparedResult),
PureStatement(Statement),
}
#[derive(Clone)]
pub struct PreparedResult {
pub statement: Statement,
pub bound_result: BoundResult,
}
#[expect(clippy::enum_variant_names)]
#[derive(Clone)]
pub enum Portal {
Empty,
Portal(PortalResult),
PureStatement(Statement),
}
impl std::fmt::Display for Portal {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match &self {
Portal::Empty => write!(f, "Empty"),
Portal::Portal(portal) => portal.fmt(f),
Portal::PureStatement(stmt) => write!(f, "{}", stmt),
}
}
}
#[derive(Clone)]
pub struct PortalResult {
pub statement: Statement,
pub bound_result: BoundResult,
pub result_formats: Vec<Format>,
}
impl std::fmt::Display for PortalResult {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"{}, params = {:?}",
self.statement, self.bound_result.parsed_params
)
}
}
pub async fn handle_parse(
session: Arc<SessionImpl>,
statement: Statement,
specific_param_types: Vec<Option<DataType>>,
) -> Result<PrepareStatement> {
session.clear_cancel_query_flag();
let sql: Arc<str> = Arc::from(statement.to_string());
let handler_args = HandlerArgs::new(session, &statement, sql)?;
match &statement {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => {
query::handle_parse(handler_args, statement, specific_param_types)
}
Statement::FetchCursor { .. } => {
fetch_cursor::handle_parse(handler_args, statement, specific_param_types).await
}
Statement::DeclareCursor { stmt } => {
if let DeclareCursor::Query(_) = stmt.declare_cursor {
query::handle_parse(handler_args, statement, specific_param_types)
} else {
bail_not_implemented!("DECLARE SUBSCRIPTION CURSOR with parameters");
}
}
Statement::CreateView {
query,
materialized,
..
} => {
if *materialized {
return query::handle_parse(handler_args, statement, specific_param_types);
}
if have_parameter_in_query(query) {
bail_not_implemented!("CREATE VIEW with parameters");
}
Ok(PrepareStatement::PureStatement(statement))
}
Statement::CreateTable { query, .. } => {
if let Some(query) = query
&& have_parameter_in_query(query)
{
bail_not_implemented!("CREATE TABLE AS SELECT with parameters");
} else {
Ok(PrepareStatement::PureStatement(statement))
}
}
Statement::CreateSink { stmt } => {
if let CreateSink::AsQuery(query) = &stmt.sink_from
&& have_parameter_in_query(query)
{
bail_not_implemented!("CREATE SINK AS SELECT with parameters");
} else {
Ok(PrepareStatement::PureStatement(statement))
}
}
_ => Ok(PrepareStatement::PureStatement(statement)),
}
}
pub fn handle_bind(
prepare_statement: PrepareStatement,
params: Vec<Option<Bytes>>,
param_formats: Vec<Format>,
result_formats: Vec<Format>,
) -> Result<Portal> {
match prepare_statement {
PrepareStatement::Empty => Ok(Portal::Empty),
PrepareStatement::Prepared(prepared_result) => {
let PreparedResult {
bound_result,
statement,
} = prepared_result;
let BoundResult {
stmt_type,
must_dist,
bound,
param_types,
dependent_relations,
dependent_udfs,
..
} = bound_result;
let (new_bound, parsed_params) = bound.bind_parameter(params, param_formats)?;
let new_bound_result = BoundResult {
stmt_type,
must_dist,
param_types,
parsed_params: Some(parsed_params),
dependent_relations,
dependent_udfs,
bound: new_bound,
};
Ok(Portal::Portal(PortalResult {
bound_result: new_bound_result,
result_formats,
statement,
}))
}
PrepareStatement::PureStatement(stmt) => {
Ok(Portal::PureStatement(stmt))
}
}
}
pub async fn handle_execute(session: Arc<SessionImpl>, portal: Portal) -> Result<RwPgResponse> {
match portal {
Portal::Empty => Ok(RwPgResponse::empty_result(
pgwire::pg_response::StatementType::EMPTY,
)),
Portal::Portal(portal) => {
session.clear_cancel_query_flag();
let _guard = session.txn_begin_implicit(); let sql: Arc<str> = Arc::from(portal.statement.to_string());
let handler_args = HandlerArgs::new(session, &portal.statement, sql)?;
if let Statement::FetchCursor { .. } = &portal.statement {
fetch_cursor::handle_fetch_cursor_execute(handler_args, portal).await
} else {
query::handle_execute(handler_args, portal).await
}
}
Portal::PureStatement(stmt) => {
let sql: Arc<str> = Arc::from(stmt.to_string());
handle(session, stmt, sql, vec![]).await
}
}
}
fn have_parameter_in_query(query: &Query) -> bool {
query.to_string().contains("$1")
}