risingwave_frontend/binder/mod.rs
1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18use itertools::Itertools;
19use parking_lot::RwLock;
20use risingwave_common::catalog::FunctionId;
21use risingwave_common::session_config::{SearchPath, SessionConfig};
22use risingwave_common::types::DataType;
23use risingwave_common::util::iter_util::ZipEqDebug;
24use risingwave_sqlparser::ast::Statement;
25
26use crate::error::Result;
27
28mod bind_context;
29mod bind_param;
30mod create;
31mod create_view;
32mod declare_cursor;
33mod delete;
34mod expr;
35pub mod fetch_cursor;
36mod for_system;
37mod gap_fill_binder;
38mod insert;
39mod query;
40mod relation;
41mod select;
42mod set_expr;
43mod statement;
44mod struct_field;
45mod update;
46mod values;
47
48pub use bind_context::{BindContext, Clause, LateralBindContext};
49pub use create_view::BoundCreateView;
50pub use delete::BoundDelete;
51pub use expr::bind_data_type;
52pub use gap_fill_binder::BoundFillStrategy;
53pub use insert::BoundInsert;
54use pgwire::pg_server::{Session, SessionId};
55pub use query::BoundQuery;
56pub use relation::{
57 BoundBaseTable, BoundGapFill, BoundJoin, BoundShare, BoundShareInput, BoundSource,
58 BoundSystemTable, BoundWatermark, BoundWindowTableFunction, Relation,
59 ResolveQualifiedNameError, WindowTableFunctionKind,
60};
61// Re-export common types
62pub use risingwave_common::gap_fill::FillStrategy;
63use risingwave_common::id::ObjectId;
64pub use select::{BoundDistinct, BoundSelect};
65pub use set_expr::*;
66pub use statement::BoundStatement;
67pub use update::{BoundUpdate, UpdateProject};
68pub use values::BoundValues;
69
70use crate::catalog::catalog_service::CatalogReadGuard;
71use crate::catalog::root_catalog::SchemaPath;
72use crate::catalog::schema_catalog::SchemaCatalog;
73use crate::catalog::{CatalogResult, DatabaseId, SecretId, ViewId};
74use crate::error::ErrorCode;
75use crate::session::{AuthContext, SessionImpl, StagingCatalogManager, TemporarySourceManager};
76use crate::user::user_service::UserInfoReadGuard;
77
78pub type ShareId = usize;
79
80/// The type of binding statement.
81enum BindFor {
82 /// Binding MV/SINK
83 Stream,
84 /// Binding a batch query
85 Batch,
86 /// Binding a DDL (e.g. CREATE TABLE/SOURCE)
87 Ddl,
88 /// Binding a system query (e.g. SHOW)
89 System,
90}
91
92/// `Binder` binds the identifiers in AST to columns in relations
93pub struct Binder {
94 // TODO: maybe we can only lock the database, but not the whole catalog.
95 catalog: CatalogReadGuard,
96 user: UserInfoReadGuard,
97 db_name: String,
98 database_id: DatabaseId,
99 session_id: SessionId,
100 context: BindContext,
101 auth_context: Arc<AuthContext>,
102 /// A stack holding contexts of outer queries when binding a subquery.
103 /// It also holds all of the lateral contexts for each respective
104 /// subquery.
105 ///
106 /// See [`Binder::bind_subquery_expr`] for details.
107 upper_subquery_contexts: Vec<(BindContext, Vec<LateralBindContext>)>,
108
109 /// A stack holding contexts of left-lateral `TableFactor`s.
110 ///
111 /// We need a separate stack as `CorrelatedInputRef` depth is
112 /// determined by the upper subquery context depth, not the lateral context stack depth.
113 lateral_contexts: Vec<LateralBindContext>,
114
115 next_subquery_id: usize,
116 next_values_id: usize,
117 /// The `ShareId` is used to identify the share relation which could be a CTE, a source, a view
118 /// and so on.
119 next_share_id: ShareId,
120
121 session_config: Arc<RwLock<SessionConfig>>,
122
123 search_path: SearchPath,
124 /// The type of binding statement.
125 bind_for: BindFor,
126
127 /// `ShareId`s identifying shared views.
128 shared_views: HashMap<ViewId, ShareId>,
129
130 /// The included relations while binding a query.
131 included_relations: HashSet<ObjectId>,
132
133 /// The included user-defined functions while binding a query.
134 included_udfs: HashSet<FunctionId>,
135
136 /// The included secrets while binding a query (e.g., secret refs in UDF arguments).
137 included_secrets: HashSet<SecretId>,
138
139 param_types: ParameterTypes,
140
141 /// The temporary sources that will be used during binding phase
142 temporary_source_manager: TemporarySourceManager,
143
144 /// The staging catalogs that will be used during binding phase
145 staging_catalog_manager: StagingCatalogManager,
146
147 /// Information for `secure_compare` function. It's ONLY available when binding the
148 /// `VALIDATE` clause of Webhook source i.e. `VALIDATE SECRET ... AS SECURE_COMPARE(...)`.
149 secure_compare_context: Option<SecureCompareContext>,
150}
151
152pub const WEBHOOK_PAYLOAD_FIELD_NAME: &str = "payload";
153
154// There are hidden names reserved for webhook validation expressions:
155// - `headers`, whose type is `JSONB`
156// - `payload`, whose type is `BYTEA`
157#[derive(Default, Clone, Debug)]
158pub struct SecureCompareContext {
159 /// The identifier used to reference the raw webhook payload during validation.
160 pub payload_name: String,
161 /// The secret (usually a token provided by the webhook source user) to validate the calls
162 pub secret_name: Option<String>,
163}
164
165/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments.
166/// It works by following the rules:
167/// 1. At the beginning, it contains the user specified parameters type.
168/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
169/// if it didn't exist in `ParameterTypes`.
170/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
171/// will record the target type as infer type for that parameter(call `record_infer_type`). If the
172/// parameter has been inferred, the cast function will act as a normal cast.
173/// 4. After bind finished:
174/// (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
175/// occur in the query. `export` will return error if there is a kind of
176/// parameter. This rule is compatible with PostgreSQL
177/// (b) parameter is None means that it's a unknown type. The user didn't specify it
178/// and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
179/// compatible with PostgreSQL.
180/// (c) parameter is Some means that it's a known type.
181#[derive(Clone, Debug)]
182pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);
183
184impl ParameterTypes {
185 pub fn new(specified_param_types: Vec<Option<DataType>>) -> Self {
186 let map = specified_param_types
187 .into_iter()
188 .enumerate()
189 .map(|(index, data_type)| ((index + 1) as u64, data_type))
190 .collect::<HashMap<u64, Option<DataType>>>();
191 Self(Arc::new(RwLock::new(map)))
192 }
193
194 pub fn has_infer(&self, index: u64) -> bool {
195 self.0.read().get(&index).unwrap().is_some()
196 }
197
198 pub fn read_type(&self, index: u64) -> Option<DataType> {
199 self.0.read().get(&index).unwrap().clone()
200 }
201
202 pub fn record_new_param(&mut self, index: u64) {
203 self.0.write().entry(index).or_insert(None);
204 }
205
206 pub fn record_infer_type(&mut self, index: u64, data_type: &DataType) {
207 assert!(
208 !self.has_infer(index),
209 "The parameter has been inferred, should not be inferred again."
210 );
211 self.0
212 .write()
213 .get_mut(&index)
214 .unwrap()
215 .replace(data_type.clone());
216 }
217
218 pub fn export(&self) -> Result<Vec<DataType>> {
219 let types = self
220 .0
221 .read()
222 .clone()
223 .into_iter()
224 .sorted_by_key(|(index, _)| *index)
225 .collect::<Vec<_>>();
226
227 // Check if all the parameters have been inferred.
228 for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
229 if *index != expect_index {
230 return Err(ErrorCode::InvalidInputSyntax(format!(
231 "Cannot infer the type of the parameter {}.",
232 expect_index
233 ))
234 .into());
235 }
236 }
237
238 Ok(types
239 .into_iter()
240 .map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
241 .collect::<Vec<_>>())
242 }
243}
244
245impl Binder {
246 fn new(session: &SessionImpl, bind_for: BindFor) -> Binder {
247 Binder {
248 catalog: session.env().catalog_reader().read_guard(),
249 user: session.env().user_info_reader().read_guard(),
250 db_name: session.database(),
251 database_id: session.database_id(),
252 session_id: session.id(),
253 context: BindContext::new(),
254 auth_context: session.auth_context(),
255 upper_subquery_contexts: vec![],
256 lateral_contexts: vec![],
257 next_subquery_id: 0,
258 next_values_id: 0,
259 next_share_id: 0,
260 session_config: session.shared_config(),
261 search_path: session.config().search_path(),
262 bind_for,
263 shared_views: HashMap::new(),
264 included_relations: HashSet::new(),
265 included_udfs: HashSet::new(),
266 included_secrets: HashSet::new(),
267 param_types: ParameterTypes::new(vec![]),
268 temporary_source_manager: session.temporary_source_manager(),
269 staging_catalog_manager: session.staging_catalog_manager(),
270 secure_compare_context: None,
271 }
272 }
273
274 pub fn new_for_batch(session: &SessionImpl) -> Binder {
275 Self::new(session, BindFor::Batch)
276 }
277
278 pub fn new_for_stream(session: &SessionImpl) -> Binder {
279 Self::new(session, BindFor::Stream)
280 }
281
282 pub fn new_for_ddl(session: &SessionImpl) -> Binder {
283 Self::new(session, BindFor::Ddl)
284 }
285
286 pub fn new_for_system(session: &SessionImpl) -> Binder {
287 Self::new(session, BindFor::System)
288 }
289
290 /// Set the specified parameter types.
291 pub fn with_specified_params_types(mut self, param_types: Vec<Option<DataType>>) -> Self {
292 self.param_types = ParameterTypes::new(param_types);
293 self
294 }
295
296 /// Set the secure compare context.
297 pub fn with_secure_compare(mut self, ctx: SecureCompareContext) -> Self {
298 self.secure_compare_context = Some(ctx);
299 self
300 }
301
302 fn is_for_stream(&self) -> bool {
303 matches!(self.bind_for, BindFor::Stream)
304 }
305
306 #[expect(dead_code)]
307 fn is_for_batch(&self) -> bool {
308 matches!(self.bind_for, BindFor::Batch)
309 }
310
311 fn is_for_ddl(&self) -> bool {
312 matches!(self.bind_for, BindFor::Ddl)
313 }
314
315 /// Bind a [`Statement`].
316 pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
317 self.bind_statement(stmt)
318 }
319
320 pub fn export_param_types(&self) -> Result<Vec<DataType>> {
321 self.param_types.export()
322 }
323
324 /// Get included relations in the query after binding. This is used for resolving relation
325 /// dependencies. Note that it only contains referenced relations discovered during binding.
326 /// After the plan is built, the referenced relations may be changed. We cannot rely on the
327 /// collection result of plan, because we still need to record the dependencies that have been
328 /// optimised away.
329 pub fn included_relations(&self) -> &HashSet<ObjectId> {
330 &self.included_relations
331 }
332
333 /// Get included user-defined functions in the query after binding.
334 pub fn included_udfs(&self) -> &HashSet<FunctionId> {
335 &self.included_udfs
336 }
337
338 /// Get included secrets in the query after binding (e.g., secret refs in UDF arguments).
339 pub fn included_secrets(&self) -> &HashSet<SecretId> {
340 &self.included_secrets
341 }
342
343 fn push_context(&mut self) {
344 let new_context = std::mem::take(&mut self.context);
345 self.context
346 .cte_to_relation
347 .clone_from(&new_context.cte_to_relation);
348 self.context.disable_security_invoker = new_context.disable_security_invoker;
349 let new_lateral_contexts = std::mem::take(&mut self.lateral_contexts);
350 self.upper_subquery_contexts
351 .push((new_context, new_lateral_contexts));
352 }
353
354 fn pop_context(&mut self) -> Result<()> {
355 let (old_context, old_lateral_contexts) = self
356 .upper_subquery_contexts
357 .pop()
358 .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?;
359 self.context = old_context;
360 self.lateral_contexts = old_lateral_contexts;
361 Ok(())
362 }
363
364 fn push_lateral_context(&mut self) {
365 let new_context = std::mem::take(&mut self.context);
366 self.context
367 .cte_to_relation
368 .clone_from(&new_context.cte_to_relation);
369 self.context.disable_security_invoker = new_context.disable_security_invoker;
370 self.lateral_contexts.push(LateralBindContext {
371 is_visible: false,
372 context: new_context,
373 });
374 }
375
376 fn pop_and_merge_lateral_context(&mut self) -> Result<()> {
377 let mut old_context = self
378 .lateral_contexts
379 .pop()
380 .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?
381 .context;
382 old_context.merge_context(self.context.clone())?;
383 self.context = old_context;
384 Ok(())
385 }
386
387 fn try_mark_lateral_as_visible(&mut self) {
388 if let Some(mut ctx) = self.lateral_contexts.pop() {
389 ctx.is_visible = true;
390 self.lateral_contexts.push(ctx);
391 }
392 }
393
394 fn try_mark_lateral_as_invisible(&mut self) {
395 if let Some(mut ctx) = self.lateral_contexts.pop() {
396 ctx.is_visible = false;
397 self.lateral_contexts.push(ctx);
398 }
399 }
400
401 /// Returns a reverse iterator over the upper subquery contexts that are visible to the current
402 /// context. Not to be confused with `is_visible` in [`LateralBindContext`].
403 ///
404 /// In most cases, this should include all the upper subquery contexts. However, when binding
405 /// SQL UDFs, we should avoid resolving the context outside the UDF for hygiene.
406 fn visible_upper_subquery_contexts_rev(
407 &self,
408 ) -> impl Iterator<Item = &(BindContext, Vec<LateralBindContext>)> + '_ {
409 self.upper_subquery_contexts
410 .iter()
411 .rev()
412 .take_while(|(context, _)| context.sql_udf_arguments.is_none())
413 }
414
415 fn next_subquery_id(&mut self) -> usize {
416 let id = self.next_subquery_id;
417 self.next_subquery_id += 1;
418 id
419 }
420
421 fn next_values_id(&mut self) -> usize {
422 let id = self.next_values_id;
423 self.next_values_id += 1;
424 id
425 }
426
427 fn next_share_id(&mut self) -> ShareId {
428 let id = self.next_share_id;
429 self.next_share_id += 1;
430 id
431 }
432
433 fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> {
434 self.catalog.first_valid_schema(
435 &self.db_name,
436 &self.search_path,
437 &self.auth_context.user_name,
438 )
439 }
440
441 fn bind_schema_path<'a>(&'a self, schema_name: Option<&'a str>) -> SchemaPath<'a> {
442 SchemaPath::new(schema_name, &self.search_path, &self.auth_context.user_name)
443 }
444
445 pub fn set_clause(&mut self, clause: Option<Clause>) {
446 self.context.clause = clause;
447 }
448}
449
450/// The column name stored in [`BindContext`] for a column without an alias.
451pub const UNNAMED_COLUMN: &str = "?column?";
452/// The table name stored in [`BindContext`] for a subquery without an alias.
453const UNNAMED_SUBQUERY: &str = "?subquery?";
454/// The table name stored in [`BindContext`] for a column group.
455const COLUMN_GROUP_PREFIX: &str = "?column_group_id?";
456
457#[cfg(test)]
458pub mod test_utils {
459 use risingwave_common::types::DataType;
460
461 use super::Binder;
462 use crate::session::SessionImpl;
463
464 pub fn mock_binder() -> Binder {
465 mock_binder_with_param_types(vec![])
466 }
467
468 pub fn mock_binder_with_param_types(param_types: Vec<Option<DataType>>) -> Binder {
469 Binder::new_for_batch(&SessionImpl::mock()).with_specified_params_types(param_types)
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use expect_test::expect;
476
477 use super::test_utils::*;
478
479 #[tokio::test]
480 async fn test_bind_approx_percentile() {
481 let stmt = risingwave_sqlparser::parser::Parser::parse_sql(
482 "SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)",
483 ).unwrap().into_iter().next().unwrap();
484 let parse_expected = expect![[r#"
485 Query(
486 Query {
487 with: None,
488 body: Select(
489 Select {
490 distinct: All,
491 projection: [
492 UnnamedExpr(
493 Function(
494 Function {
495 scalar_as_agg: false,
496 name: ObjectName(
497 [
498 Ident {
499 value: "approx_percentile",
500 quote_style: None,
501 },
502 ],
503 ),
504 arg_list: FunctionArgList {
505 distinct: false,
506 args: [
507 Unnamed(
508 Expr(
509 Value(
510 Number(
511 "0.5",
512 ),
513 ),
514 ),
515 ),
516 Unnamed(
517 Expr(
518 Value(
519 Number(
520 "0.01",
521 ),
522 ),
523 ),
524 ),
525 ],
526 variadic: false,
527 order_by: [],
528 ignore_nulls: false,
529 },
530 within_group: Some(
531 OrderByExpr {
532 expr: Identifier(
533 Ident {
534 value: "generate_series",
535 quote_style: None,
536 },
537 ),
538 asc: None,
539 nulls_first: None,
540 },
541 ),
542 filter: None,
543 over: None,
544 },
545 ),
546 ),
547 ],
548 from: [
549 TableWithJoins {
550 relation: TableFunction {
551 name: ObjectName(
552 [
553 Ident {
554 value: "generate_series",
555 quote_style: None,
556 },
557 ],
558 ),
559 alias: None,
560 args: [
561 Unnamed(
562 Expr(
563 Value(
564 Number(
565 "1",
566 ),
567 ),
568 ),
569 ),
570 Unnamed(
571 Expr(
572 Value(
573 Number(
574 "100",
575 ),
576 ),
577 ),
578 ),
579 ],
580 with_ordinality: false,
581 },
582 joins: [],
583 },
584 ],
585 lateral_views: [],
586 selection: None,
587 group_by: [],
588 having: None,
589 window: [],
590 },
591 ),
592 order_by: [],
593 limit: None,
594 offset: None,
595 fetch: None,
596 },
597 )"#]];
598 parse_expected.assert_eq(&format!("{:#?}", stmt));
599
600 let mut binder = mock_binder();
601 let bound = binder.bind(stmt).unwrap();
602
603 let expected = expect![[r#"
604 Query(
605 BoundQuery {
606 body: Select(
607 BoundSelect {
608 distinct: All,
609 select_items: [
610 AggCall(
611 AggCall {
612 agg_type: Builtin(
613 ApproxPercentile,
614 ),
615 return_type: Float64,
616 args: [
617 FunctionCall(
618 FunctionCall {
619 func_type: Cast,
620 return_type: Float64,
621 inputs: [
622 InputRef(
623 InputRef {
624 index: 0,
625 data_type: Int32,
626 },
627 ),
628 ],
629 },
630 ),
631 ],
632 filter: Condition {
633 conjunctions: [],
634 },
635 distinct: false,
636 order_by: OrderBy {
637 sort_exprs: [
638 OrderByExpr {
639 expr: InputRef(
640 InputRef {
641 index: 0,
642 data_type: Int32,
643 },
644 ),
645 order_type: OrderType {
646 direction: Ascending,
647 nulls_are: Largest,
648 },
649 },
650 ],
651 },
652 direct_args: [
653 Literal {
654 data: Some(
655 Float64(
656 0.5,
657 ),
658 ),
659 data_type: Some(
660 Float64,
661 ),
662 },
663 Literal {
664 data: Some(
665 Float64(
666 0.01,
667 ),
668 ),
669 data_type: Some(
670 Float64,
671 ),
672 },
673 ],
674 },
675 ),
676 ],
677 aliases: [
678 Some(
679 "approx_percentile",
680 ),
681 ],
682 from: Some(
683 TableFunction {
684 expr: TableFunction(
685 FunctionCall {
686 function_type: GenerateSeries,
687 return_type: Int32,
688 args: [
689 Literal(
690 Literal {
691 data: Some(
692 Int32(
693 1,
694 ),
695 ),
696 data_type: Some(
697 Int32,
698 ),
699 },
700 ),
701 Literal(
702 Literal {
703 data: Some(
704 Int32(
705 100,
706 ),
707 ),
708 data_type: Some(
709 Int32,
710 ),
711 },
712 ),
713 ],
714 },
715 ),
716 with_ordinality: false,
717 },
718 ),
719 where_clause: None,
720 group_by: GroupKey(
721 [],
722 ),
723 having: None,
724 window: {},
725 schema: Schema {
726 fields: [
727 approx_percentile:Float64,
728 ],
729 },
730 },
731 ),
732 order: [],
733 limit: None,
734 offset: None,
735 with_ties: false,
736 extra_order_exprs: [],
737 },
738 )"#]];
739
740 expected.assert_eq(&format!("{:#?}", bound));
741 }
742}