risingwave_frontend/binder/mod.rs
1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18use itertools::Itertools;
19use parking_lot::RwLock;
20use risingwave_common::catalog::FunctionId;
21use risingwave_common::session_config::{SearchPath, SessionConfig};
22use risingwave_common::types::DataType;
23use risingwave_common::util::iter_util::ZipEqDebug;
24use risingwave_sqlparser::ast::Statement;
25
26use crate::error::Result;
27
28mod bind_context;
29mod bind_param;
30mod create;
31mod create_view;
32mod declare_cursor;
33mod delete;
34mod expr;
35pub mod fetch_cursor;
36mod for_system;
37mod gap_fill_binder;
38mod insert;
39mod query;
40mod relation;
41mod select;
42mod set_expr;
43mod statement;
44mod struct_field;
45mod update;
46mod values;
47
48pub use bind_context::{BindContext, Clause, LateralBindContext};
49pub use create_view::BoundCreateView;
50pub use delete::BoundDelete;
51pub use expr::bind_data_type;
52pub use gap_fill_binder::BoundFillStrategy;
53pub use insert::BoundInsert;
54use pgwire::pg_server::{Session, SessionId};
55pub use query::BoundQuery;
56pub use relation::{
57 BoundBaseTable, BoundGapFill, BoundJoin, BoundShare, BoundShareInput, BoundSource,
58 BoundSystemTable, BoundWatermark, BoundWindowTableFunction, Relation,
59 ResolveQualifiedNameError, WindowTableFunctionKind,
60};
61// Re-export common types
62pub use risingwave_common::gap_fill::FillStrategy;
63use risingwave_common::id::ObjectId;
64pub use select::{BoundDistinct, BoundSelect};
65pub use set_expr::*;
66pub use statement::BoundStatement;
67pub use update::{BoundUpdate, UpdateProject};
68pub use values::BoundValues;
69
70use crate::catalog::catalog_service::CatalogReadGuard;
71use crate::catalog::root_catalog::SchemaPath;
72use crate::catalog::schema_catalog::SchemaCatalog;
73use crate::catalog::{CatalogResult, DatabaseId, SecretId, ViewId};
74use crate::error::ErrorCode;
75use crate::session::{AuthContext, SessionImpl, StagingCatalogManager, TemporarySourceManager};
76use crate::user::user_service::UserInfoReadGuard;
77
78pub type ShareId = usize;
79
80/// The type of binding statement.
81enum BindFor {
82 /// Binding MV/SINK
83 Stream,
84 /// Binding a batch query
85 Batch,
86 /// Binding a DDL (e.g. CREATE TABLE/SOURCE)
87 Ddl,
88 /// Binding a system query (e.g. SHOW)
89 System,
90}
91
92/// `Binder` binds the identifiers in AST to columns in relations
93pub struct Binder {
94 // TODO: maybe we can only lock the database, but not the whole catalog.
95 catalog: CatalogReadGuard,
96 user: UserInfoReadGuard,
97 db_name: String,
98 database_id: DatabaseId,
99 session_id: SessionId,
100 context: BindContext,
101 auth_context: Arc<AuthContext>,
102 /// A stack holding contexts of outer queries when binding a subquery.
103 /// It also holds all of the lateral contexts for each respective
104 /// subquery.
105 ///
106 /// See [`Binder::bind_subquery_expr`] for details.
107 upper_subquery_contexts: Vec<(BindContext, Vec<LateralBindContext>)>,
108
109 /// A stack holding contexts of left-lateral `TableFactor`s.
110 ///
111 /// We need a separate stack as `CorrelatedInputRef` depth is
112 /// determined by the upper subquery context depth, not the lateral context stack depth.
113 lateral_contexts: Vec<LateralBindContext>,
114
115 next_subquery_id: usize,
116 next_values_id: usize,
117 /// The `ShareId` is used to identify the share relation which could be a CTE, a source, a view
118 /// and so on.
119 next_share_id: ShareId,
120
121 session_config: Arc<RwLock<SessionConfig>>,
122
123 search_path: SearchPath,
124 /// The type of binding statement.
125 bind_for: BindFor,
126
127 /// `ShareId`s identifying shared views.
128 shared_views: HashMap<ViewId, ShareId>,
129
130 /// The included relations while binding a query.
131 included_relations: HashSet<ObjectId>,
132
133 /// The included user-defined functions while binding a query.
134 included_udfs: HashSet<FunctionId>,
135
136 /// The included secrets while binding a query (e.g., secret refs in UDF arguments).
137 included_secrets: HashSet<SecretId>,
138
139 param_types: ParameterTypes,
140
141 /// The temporary sources that will be used during binding phase
142 temporary_source_manager: TemporarySourceManager,
143
144 /// The staging catalogs that will be used during binding phase
145 staging_catalog_manager: StagingCatalogManager,
146
147 /// Information for `secure_compare` function. It's ONLY available when binding the
148 /// `VALIDATE` clause of Webhook source i.e. `VALIDATE SECRET ... AS SECURE_COMPARE(...)`.
149 secure_compare_context: Option<SecureCompareContext>,
150}
151
152// There's one more hidden name, `HEADERS`, which is a reserved identifier for HTTP headers. Its type is `JSONB`.
153#[derive(Default, Clone, Debug)]
154pub struct SecureCompareContext {
155 /// The column name to store the whole payload in `JSONB`, but during validation it will be used as `bytea`
156 pub column_name: String,
157 /// The secret (usually a token provided by the webhook source user) to validate the calls
158 pub secret_name: Option<String>,
159}
160
161/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments.
162/// It works by following the rules:
163/// 1. At the beginning, it contains the user specified parameters type.
164/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
165/// if it didn't exist in `ParameterTypes`.
166/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
167/// will record the target type as infer type for that parameter(call `record_infer_type`). If the
168/// parameter has been inferred, the cast function will act as a normal cast.
169/// 4. After bind finished:
170/// (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
171/// occur in the query. `export` will return error if there is a kind of
172/// parameter. This rule is compatible with PostgreSQL
173/// (b) parameter is None means that it's a unknown type. The user didn't specify it
174/// and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
175/// compatible with PostgreSQL.
176/// (c) parameter is Some means that it's a known type.
177#[derive(Clone, Debug)]
178pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);
179
180impl ParameterTypes {
181 pub fn new(specified_param_types: Vec<Option<DataType>>) -> Self {
182 let map = specified_param_types
183 .into_iter()
184 .enumerate()
185 .map(|(index, data_type)| ((index + 1) as u64, data_type))
186 .collect::<HashMap<u64, Option<DataType>>>();
187 Self(Arc::new(RwLock::new(map)))
188 }
189
190 pub fn has_infer(&self, index: u64) -> bool {
191 self.0.read().get(&index).unwrap().is_some()
192 }
193
194 pub fn read_type(&self, index: u64) -> Option<DataType> {
195 self.0.read().get(&index).unwrap().clone()
196 }
197
198 pub fn record_new_param(&mut self, index: u64) {
199 self.0.write().entry(index).or_insert(None);
200 }
201
202 pub fn record_infer_type(&mut self, index: u64, data_type: &DataType) {
203 assert!(
204 !self.has_infer(index),
205 "The parameter has been inferred, should not be inferred again."
206 );
207 self.0
208 .write()
209 .get_mut(&index)
210 .unwrap()
211 .replace(data_type.clone());
212 }
213
214 pub fn export(&self) -> Result<Vec<DataType>> {
215 let types = self
216 .0
217 .read()
218 .clone()
219 .into_iter()
220 .sorted_by_key(|(index, _)| *index)
221 .collect::<Vec<_>>();
222
223 // Check if all the parameters have been inferred.
224 for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
225 if *index != expect_index {
226 return Err(ErrorCode::InvalidInputSyntax(format!(
227 "Cannot infer the type of the parameter {}.",
228 expect_index
229 ))
230 .into());
231 }
232 }
233
234 Ok(types
235 .into_iter()
236 .map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
237 .collect::<Vec<_>>())
238 }
239}
240
241impl Binder {
242 fn new(session: &SessionImpl, bind_for: BindFor) -> Binder {
243 Binder {
244 catalog: session.env().catalog_reader().read_guard(),
245 user: session.env().user_info_reader().read_guard(),
246 db_name: session.database(),
247 database_id: session.database_id(),
248 session_id: session.id(),
249 context: BindContext::new(),
250 auth_context: session.auth_context(),
251 upper_subquery_contexts: vec![],
252 lateral_contexts: vec![],
253 next_subquery_id: 0,
254 next_values_id: 0,
255 next_share_id: 0,
256 session_config: session.shared_config(),
257 search_path: session.config().search_path(),
258 bind_for,
259 shared_views: HashMap::new(),
260 included_relations: HashSet::new(),
261 included_udfs: HashSet::new(),
262 included_secrets: HashSet::new(),
263 param_types: ParameterTypes::new(vec![]),
264 temporary_source_manager: session.temporary_source_manager(),
265 staging_catalog_manager: session.staging_catalog_manager(),
266 secure_compare_context: None,
267 }
268 }
269
270 pub fn new_for_batch(session: &SessionImpl) -> Binder {
271 Self::new(session, BindFor::Batch)
272 }
273
274 pub fn new_for_stream(session: &SessionImpl) -> Binder {
275 Self::new(session, BindFor::Stream)
276 }
277
278 pub fn new_for_ddl(session: &SessionImpl) -> Binder {
279 Self::new(session, BindFor::Ddl)
280 }
281
282 pub fn new_for_system(session: &SessionImpl) -> Binder {
283 Self::new(session, BindFor::System)
284 }
285
286 /// Set the specified parameter types.
287 pub fn with_specified_params_types(mut self, param_types: Vec<Option<DataType>>) -> Self {
288 self.param_types = ParameterTypes::new(param_types);
289 self
290 }
291
292 /// Set the secure compare context.
293 pub fn with_secure_compare(mut self, ctx: SecureCompareContext) -> Self {
294 self.secure_compare_context = Some(ctx);
295 self
296 }
297
298 fn is_for_stream(&self) -> bool {
299 matches!(self.bind_for, BindFor::Stream)
300 }
301
302 #[allow(dead_code)]
303 fn is_for_batch(&self) -> bool {
304 matches!(self.bind_for, BindFor::Batch)
305 }
306
307 fn is_for_ddl(&self) -> bool {
308 matches!(self.bind_for, BindFor::Ddl)
309 }
310
311 /// Bind a [`Statement`].
312 pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
313 self.bind_statement(stmt)
314 }
315
316 pub fn export_param_types(&self) -> Result<Vec<DataType>> {
317 self.param_types.export()
318 }
319
320 /// Get included relations in the query after binding. This is used for resolving relation
321 /// dependencies. Note that it only contains referenced relations discovered during binding.
322 /// After the plan is built, the referenced relations may be changed. We cannot rely on the
323 /// collection result of plan, because we still need to record the dependencies that have been
324 /// optimised away.
325 pub fn included_relations(&self) -> &HashSet<ObjectId> {
326 &self.included_relations
327 }
328
329 /// Get included user-defined functions in the query after binding.
330 pub fn included_udfs(&self) -> &HashSet<FunctionId> {
331 &self.included_udfs
332 }
333
334 /// Get included secrets in the query after binding (e.g., secret refs in UDF arguments).
335 pub fn included_secrets(&self) -> &HashSet<SecretId> {
336 &self.included_secrets
337 }
338
339 fn push_context(&mut self) {
340 let new_context = std::mem::take(&mut self.context);
341 self.context
342 .cte_to_relation
343 .clone_from(&new_context.cte_to_relation);
344 self.context.disable_security_invoker = new_context.disable_security_invoker;
345 let new_lateral_contexts = std::mem::take(&mut self.lateral_contexts);
346 self.upper_subquery_contexts
347 .push((new_context, new_lateral_contexts));
348 }
349
350 fn pop_context(&mut self) -> Result<()> {
351 let (old_context, old_lateral_contexts) = self
352 .upper_subquery_contexts
353 .pop()
354 .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?;
355 self.context = old_context;
356 self.lateral_contexts = old_lateral_contexts;
357 Ok(())
358 }
359
360 fn push_lateral_context(&mut self) {
361 let new_context = std::mem::take(&mut self.context);
362 self.context
363 .cte_to_relation
364 .clone_from(&new_context.cte_to_relation);
365 self.context.disable_security_invoker = new_context.disable_security_invoker;
366 self.lateral_contexts.push(LateralBindContext {
367 is_visible: false,
368 context: new_context,
369 });
370 }
371
372 fn pop_and_merge_lateral_context(&mut self) -> Result<()> {
373 let mut old_context = self
374 .lateral_contexts
375 .pop()
376 .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?
377 .context;
378 old_context.merge_context(self.context.clone())?;
379 self.context = old_context;
380 Ok(())
381 }
382
383 fn try_mark_lateral_as_visible(&mut self) {
384 if let Some(mut ctx) = self.lateral_contexts.pop() {
385 ctx.is_visible = true;
386 self.lateral_contexts.push(ctx);
387 }
388 }
389
390 fn try_mark_lateral_as_invisible(&mut self) {
391 if let Some(mut ctx) = self.lateral_contexts.pop() {
392 ctx.is_visible = false;
393 self.lateral_contexts.push(ctx);
394 }
395 }
396
397 /// Returns a reverse iterator over the upper subquery contexts that are visible to the current
398 /// context. Not to be confused with `is_visible` in [`LateralBindContext`].
399 ///
400 /// In most cases, this should include all the upper subquery contexts. However, when binding
401 /// SQL UDFs, we should avoid resolving the context outside the UDF for hygiene.
402 fn visible_upper_subquery_contexts_rev(
403 &self,
404 ) -> impl Iterator<Item = &(BindContext, Vec<LateralBindContext>)> + '_ {
405 self.upper_subquery_contexts
406 .iter()
407 .rev()
408 .take_while(|(context, _)| context.sql_udf_arguments.is_none())
409 }
410
411 fn next_subquery_id(&mut self) -> usize {
412 let id = self.next_subquery_id;
413 self.next_subquery_id += 1;
414 id
415 }
416
417 fn next_values_id(&mut self) -> usize {
418 let id = self.next_values_id;
419 self.next_values_id += 1;
420 id
421 }
422
423 fn next_share_id(&mut self) -> ShareId {
424 let id = self.next_share_id;
425 self.next_share_id += 1;
426 id
427 }
428
429 fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> {
430 self.catalog.first_valid_schema(
431 &self.db_name,
432 &self.search_path,
433 &self.auth_context.user_name,
434 )
435 }
436
437 fn bind_schema_path<'a>(&'a self, schema_name: Option<&'a str>) -> SchemaPath<'a> {
438 SchemaPath::new(schema_name, &self.search_path, &self.auth_context.user_name)
439 }
440
441 pub fn set_clause(&mut self, clause: Option<Clause>) {
442 self.context.clause = clause;
443 }
444}
445
446/// The column name stored in [`BindContext`] for a column without an alias.
447pub const UNNAMED_COLUMN: &str = "?column?";
448/// The table name stored in [`BindContext`] for a subquery without an alias.
449const UNNAMED_SUBQUERY: &str = "?subquery?";
450/// The table name stored in [`BindContext`] for a column group.
451const COLUMN_GROUP_PREFIX: &str = "?column_group_id?";
452
453#[cfg(test)]
454pub mod test_utils {
455 use risingwave_common::types::DataType;
456
457 use super::Binder;
458 use crate::session::SessionImpl;
459
460 pub fn mock_binder() -> Binder {
461 mock_binder_with_param_types(vec![])
462 }
463
464 pub fn mock_binder_with_param_types(param_types: Vec<Option<DataType>>) -> Binder {
465 Binder::new_for_batch(&SessionImpl::mock()).with_specified_params_types(param_types)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use expect_test::expect;
472
473 use super::test_utils::*;
474
475 #[tokio::test]
476 async fn test_bind_approx_percentile() {
477 let stmt = risingwave_sqlparser::parser::Parser::parse_sql(
478 "SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)",
479 ).unwrap().into_iter().next().unwrap();
480 let parse_expected = expect![[r#"
481 Query(
482 Query {
483 with: None,
484 body: Select(
485 Select {
486 distinct: All,
487 projection: [
488 UnnamedExpr(
489 Function(
490 Function {
491 scalar_as_agg: false,
492 name: ObjectName(
493 [
494 Ident {
495 value: "approx_percentile",
496 quote_style: None,
497 },
498 ],
499 ),
500 arg_list: FunctionArgList {
501 distinct: false,
502 args: [
503 Unnamed(
504 Expr(
505 Value(
506 Number(
507 "0.5",
508 ),
509 ),
510 ),
511 ),
512 Unnamed(
513 Expr(
514 Value(
515 Number(
516 "0.01",
517 ),
518 ),
519 ),
520 ),
521 ],
522 variadic: false,
523 order_by: [],
524 ignore_nulls: false,
525 },
526 within_group: Some(
527 OrderByExpr {
528 expr: Identifier(
529 Ident {
530 value: "generate_series",
531 quote_style: None,
532 },
533 ),
534 asc: None,
535 nulls_first: None,
536 },
537 ),
538 filter: None,
539 over: None,
540 },
541 ),
542 ),
543 ],
544 from: [
545 TableWithJoins {
546 relation: TableFunction {
547 name: ObjectName(
548 [
549 Ident {
550 value: "generate_series",
551 quote_style: None,
552 },
553 ],
554 ),
555 alias: None,
556 args: [
557 Unnamed(
558 Expr(
559 Value(
560 Number(
561 "1",
562 ),
563 ),
564 ),
565 ),
566 Unnamed(
567 Expr(
568 Value(
569 Number(
570 "100",
571 ),
572 ),
573 ),
574 ),
575 ],
576 with_ordinality: false,
577 },
578 joins: [],
579 },
580 ],
581 lateral_views: [],
582 selection: None,
583 group_by: [],
584 having: None,
585 window: [],
586 },
587 ),
588 order_by: [],
589 limit: None,
590 offset: None,
591 fetch: None,
592 },
593 )"#]];
594 parse_expected.assert_eq(&format!("{:#?}", stmt));
595
596 let mut binder = mock_binder();
597 let bound = binder.bind(stmt).unwrap();
598
599 let expected = expect![[r#"
600 Query(
601 BoundQuery {
602 body: Select(
603 BoundSelect {
604 distinct: All,
605 select_items: [
606 AggCall(
607 AggCall {
608 agg_type: Builtin(
609 ApproxPercentile,
610 ),
611 return_type: Float64,
612 args: [
613 FunctionCall(
614 FunctionCall {
615 func_type: Cast,
616 return_type: Float64,
617 inputs: [
618 InputRef(
619 InputRef {
620 index: 0,
621 data_type: Int32,
622 },
623 ),
624 ],
625 },
626 ),
627 ],
628 filter: Condition {
629 conjunctions: [],
630 },
631 distinct: false,
632 order_by: OrderBy {
633 sort_exprs: [
634 OrderByExpr {
635 expr: InputRef(
636 InputRef {
637 index: 0,
638 data_type: Int32,
639 },
640 ),
641 order_type: OrderType {
642 direction: Ascending,
643 nulls_are: Largest,
644 },
645 },
646 ],
647 },
648 direct_args: [
649 Literal {
650 data: Some(
651 Float64(
652 0.5,
653 ),
654 ),
655 data_type: Some(
656 Float64,
657 ),
658 },
659 Literal {
660 data: Some(
661 Float64(
662 0.01,
663 ),
664 ),
665 data_type: Some(
666 Float64,
667 ),
668 },
669 ],
670 },
671 ),
672 ],
673 aliases: [
674 Some(
675 "approx_percentile",
676 ),
677 ],
678 from: Some(
679 TableFunction {
680 expr: TableFunction(
681 FunctionCall {
682 function_type: GenerateSeries,
683 return_type: Int32,
684 args: [
685 Literal(
686 Literal {
687 data: Some(
688 Int32(
689 1,
690 ),
691 ),
692 data_type: Some(
693 Int32,
694 ),
695 },
696 ),
697 Literal(
698 Literal {
699 data: Some(
700 Int32(
701 100,
702 ),
703 ),
704 data_type: Some(
705 Int32,
706 ),
707 },
708 ),
709 ],
710 },
711 ),
712 with_ordinality: false,
713 },
714 ),
715 where_clause: None,
716 group_by: GroupKey(
717 [],
718 ),
719 having: None,
720 window: {},
721 schema: Schema {
722 fields: [
723 approx_percentile:Float64,
724 ],
725 },
726 },
727 ),
728 order: [],
729 limit: None,
730 offset: None,
731 with_ties: false,
732 extra_order_exprs: [],
733 },
734 )"#]];
735
736 expected.assert_eq(&format!("{:#?}", bound));
737 }
738}