risingwave_frontend/binder/
mod.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18use itertools::Itertools;
19use parking_lot::RwLock;
20use risingwave_common::catalog::FunctionId;
21use risingwave_common::session_config::{SearchPath, SessionConfig};
22use risingwave_common::types::DataType;
23use risingwave_common::util::iter_util::ZipEqDebug;
24use risingwave_sqlparser::ast::Statement;
25
26use crate::error::Result;
27
28mod bind_context;
29mod bind_param;
30mod create;
31mod create_view;
32mod declare_cursor;
33mod delete;
34mod expr;
35pub mod fetch_cursor;
36mod for_system;
37mod gap_fill_binder;
38mod insert;
39mod query;
40mod relation;
41mod select;
42mod set_expr;
43mod statement;
44mod struct_field;
45mod update;
46mod values;
47
48pub use bind_context::{BindContext, Clause, LateralBindContext};
49pub use create_view::BoundCreateView;
50pub use delete::BoundDelete;
51pub use expr::bind_data_type;
52pub use gap_fill_binder::BoundFillStrategy;
53pub use insert::BoundInsert;
54use pgwire::pg_server::{Session, SessionId};
55pub use query::BoundQuery;
56pub use relation::{
57    BoundBaseTable, BoundGapFill, BoundJoin, BoundShare, BoundShareInput, BoundSource,
58    BoundSystemTable, BoundWatermark, BoundWindowTableFunction, Relation,
59    ResolveQualifiedNameError, WindowTableFunctionKind,
60};
61// Re-export common types
62pub use risingwave_common::gap_fill::FillStrategy;
63use risingwave_common::id::ObjectId;
64pub use select::{BoundDistinct, BoundSelect};
65pub use set_expr::*;
66pub use statement::BoundStatement;
67pub use update::{BoundUpdate, UpdateProject};
68pub use values::BoundValues;
69
70use crate::catalog::catalog_service::CatalogReadGuard;
71use crate::catalog::root_catalog::SchemaPath;
72use crate::catalog::schema_catalog::SchemaCatalog;
73use crate::catalog::{CatalogResult, DatabaseId, SecretId, ViewId};
74use crate::error::ErrorCode;
75use crate::session::{AuthContext, SessionImpl, StagingCatalogManager, TemporarySourceManager};
76use crate::user::user_service::UserInfoReadGuard;
77
78pub type ShareId = usize;
79
80/// The type of binding statement.
81enum BindFor {
82    /// Binding MV/SINK
83    Stream,
84    /// Binding a batch query
85    Batch,
86    /// Binding a DDL (e.g. CREATE TABLE/SOURCE)
87    Ddl,
88    /// Binding a system query (e.g. SHOW)
89    System,
90}
91
92/// `Binder` binds the identifiers in AST to columns in relations
93pub struct Binder {
94    // TODO: maybe we can only lock the database, but not the whole catalog.
95    catalog: CatalogReadGuard,
96    user: UserInfoReadGuard,
97    db_name: String,
98    database_id: DatabaseId,
99    session_id: SessionId,
100    context: BindContext,
101    auth_context: Arc<AuthContext>,
102    /// A stack holding contexts of outer queries when binding a subquery.
103    /// It also holds all of the lateral contexts for each respective
104    /// subquery.
105    ///
106    /// See [`Binder::bind_subquery_expr`] for details.
107    upper_subquery_contexts: Vec<(BindContext, Vec<LateralBindContext>)>,
108
109    /// A stack holding contexts of left-lateral `TableFactor`s.
110    ///
111    /// We need a separate stack as `CorrelatedInputRef` depth is
112    /// determined by the upper subquery context depth, not the lateral context stack depth.
113    lateral_contexts: Vec<LateralBindContext>,
114
115    next_subquery_id: usize,
116    next_values_id: usize,
117    /// The `ShareId` is used to identify the share relation which could be a CTE, a source, a view
118    /// and so on.
119    next_share_id: ShareId,
120
121    session_config: Arc<RwLock<SessionConfig>>,
122
123    search_path: SearchPath,
124    /// The type of binding statement.
125    bind_for: BindFor,
126
127    /// `ShareId`s identifying shared views.
128    shared_views: HashMap<ViewId, ShareId>,
129
130    /// The included relations while binding a query.
131    included_relations: HashSet<ObjectId>,
132
133    /// The included user-defined functions while binding a query.
134    included_udfs: HashSet<FunctionId>,
135
136    /// The included secrets while binding a query (e.g., secret refs in UDF arguments).
137    included_secrets: HashSet<SecretId>,
138
139    param_types: ParameterTypes,
140
141    /// The temporary sources that will be used during binding phase
142    temporary_source_manager: TemporarySourceManager,
143
144    /// The staging catalogs that will be used during binding phase
145    staging_catalog_manager: StagingCatalogManager,
146
147    /// Information for `secure_compare` function. It's ONLY available when binding the
148    /// `VALIDATE` clause of Webhook source i.e. `VALIDATE SECRET ... AS SECURE_COMPARE(...)`.
149    secure_compare_context: Option<SecureCompareContext>,
150}
151
152// There's one more hidden name, `HEADERS`, which is a reserved identifier for HTTP headers. Its type is `JSONB`.
153#[derive(Default, Clone, Debug)]
154pub struct SecureCompareContext {
155    /// The column name to store the whole payload in `JSONB`, but during validation it will be used as `bytea`
156    pub column_name: String,
157    /// The secret (usually a token provided by the webhook source user) to validate the calls
158    pub secret_name: Option<String>,
159}
160
161/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments.
162/// It works by following the rules:
163/// 1. At the beginning, it contains the user specified parameters type.
164/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
165///    if it didn't exist in `ParameterTypes`.
166/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
167///    will record the target type as infer type for that parameter(call `record_infer_type`). If the
168///    parameter has been inferred, the cast function will act as a normal cast.
169/// 4. After bind finished:
170///    (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
171///    occur in the query. `export` will return error if there is a kind of
172///    parameter. This rule is compatible with PostgreSQL
173///    (b) parameter is None means that it's a unknown type. The user didn't specify it
174///    and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
175///    compatible with PostgreSQL.
176///    (c) parameter is Some means that it's a known type.
177#[derive(Clone, Debug)]
178pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);
179
180impl ParameterTypes {
181    pub fn new(specified_param_types: Vec<Option<DataType>>) -> Self {
182        let map = specified_param_types
183            .into_iter()
184            .enumerate()
185            .map(|(index, data_type)| ((index + 1) as u64, data_type))
186            .collect::<HashMap<u64, Option<DataType>>>();
187        Self(Arc::new(RwLock::new(map)))
188    }
189
190    pub fn has_infer(&self, index: u64) -> bool {
191        self.0.read().get(&index).unwrap().is_some()
192    }
193
194    pub fn read_type(&self, index: u64) -> Option<DataType> {
195        self.0.read().get(&index).unwrap().clone()
196    }
197
198    pub fn record_new_param(&mut self, index: u64) {
199        self.0.write().entry(index).or_insert(None);
200    }
201
202    pub fn record_infer_type(&mut self, index: u64, data_type: &DataType) {
203        assert!(
204            !self.has_infer(index),
205            "The parameter has been inferred, should not be inferred again."
206        );
207        self.0
208            .write()
209            .get_mut(&index)
210            .unwrap()
211            .replace(data_type.clone());
212    }
213
214    pub fn export(&self) -> Result<Vec<DataType>> {
215        let types = self
216            .0
217            .read()
218            .clone()
219            .into_iter()
220            .sorted_by_key(|(index, _)| *index)
221            .collect::<Vec<_>>();
222
223        // Check if all the parameters have been inferred.
224        for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
225            if *index != expect_index {
226                return Err(ErrorCode::InvalidInputSyntax(format!(
227                    "Cannot infer the type of the parameter {}.",
228                    expect_index
229                ))
230                .into());
231            }
232        }
233
234        Ok(types
235            .into_iter()
236            .map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
237            .collect::<Vec<_>>())
238    }
239}
240
241impl Binder {
242    fn new(session: &SessionImpl, bind_for: BindFor) -> Binder {
243        Binder {
244            catalog: session.env().catalog_reader().read_guard(),
245            user: session.env().user_info_reader().read_guard(),
246            db_name: session.database(),
247            database_id: session.database_id(),
248            session_id: session.id(),
249            context: BindContext::new(),
250            auth_context: session.auth_context(),
251            upper_subquery_contexts: vec![],
252            lateral_contexts: vec![],
253            next_subquery_id: 0,
254            next_values_id: 0,
255            next_share_id: 0,
256            session_config: session.shared_config(),
257            search_path: session.config().search_path(),
258            bind_for,
259            shared_views: HashMap::new(),
260            included_relations: HashSet::new(),
261            included_udfs: HashSet::new(),
262            included_secrets: HashSet::new(),
263            param_types: ParameterTypes::new(vec![]),
264            temporary_source_manager: session.temporary_source_manager(),
265            staging_catalog_manager: session.staging_catalog_manager(),
266            secure_compare_context: None,
267        }
268    }
269
270    pub fn new_for_batch(session: &SessionImpl) -> Binder {
271        Self::new(session, BindFor::Batch)
272    }
273
274    pub fn new_for_stream(session: &SessionImpl) -> Binder {
275        Self::new(session, BindFor::Stream)
276    }
277
278    pub fn new_for_ddl(session: &SessionImpl) -> Binder {
279        Self::new(session, BindFor::Ddl)
280    }
281
282    pub fn new_for_system(session: &SessionImpl) -> Binder {
283        Self::new(session, BindFor::System)
284    }
285
286    /// Set the specified parameter types.
287    pub fn with_specified_params_types(mut self, param_types: Vec<Option<DataType>>) -> Self {
288        self.param_types = ParameterTypes::new(param_types);
289        self
290    }
291
292    /// Set the secure compare context.
293    pub fn with_secure_compare(mut self, ctx: SecureCompareContext) -> Self {
294        self.secure_compare_context = Some(ctx);
295        self
296    }
297
298    fn is_for_stream(&self) -> bool {
299        matches!(self.bind_for, BindFor::Stream)
300    }
301
302    #[allow(dead_code)]
303    fn is_for_batch(&self) -> bool {
304        matches!(self.bind_for, BindFor::Batch)
305    }
306
307    fn is_for_ddl(&self) -> bool {
308        matches!(self.bind_for, BindFor::Ddl)
309    }
310
311    /// Bind a [`Statement`].
312    pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
313        self.bind_statement(stmt)
314    }
315
316    pub fn export_param_types(&self) -> Result<Vec<DataType>> {
317        self.param_types.export()
318    }
319
320    /// Get included relations in the query after binding. This is used for resolving relation
321    /// dependencies. Note that it only contains referenced relations discovered during binding.
322    /// After the plan is built, the referenced relations may be changed. We cannot rely on the
323    /// collection result of plan, because we still need to record the dependencies that have been
324    /// optimised away.
325    pub fn included_relations(&self) -> &HashSet<ObjectId> {
326        &self.included_relations
327    }
328
329    /// Get included user-defined functions in the query after binding.
330    pub fn included_udfs(&self) -> &HashSet<FunctionId> {
331        &self.included_udfs
332    }
333
334    /// Get included secrets in the query after binding (e.g., secret refs in UDF arguments).
335    pub fn included_secrets(&self) -> &HashSet<SecretId> {
336        &self.included_secrets
337    }
338
339    fn push_context(&mut self) {
340        let new_context = std::mem::take(&mut self.context);
341        self.context
342            .cte_to_relation
343            .clone_from(&new_context.cte_to_relation);
344        self.context.disable_security_invoker = new_context.disable_security_invoker;
345        let new_lateral_contexts = std::mem::take(&mut self.lateral_contexts);
346        self.upper_subquery_contexts
347            .push((new_context, new_lateral_contexts));
348    }
349
350    fn pop_context(&mut self) -> Result<()> {
351        let (old_context, old_lateral_contexts) = self
352            .upper_subquery_contexts
353            .pop()
354            .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?;
355        self.context = old_context;
356        self.lateral_contexts = old_lateral_contexts;
357        Ok(())
358    }
359
360    fn push_lateral_context(&mut self) {
361        let new_context = std::mem::take(&mut self.context);
362        self.context
363            .cte_to_relation
364            .clone_from(&new_context.cte_to_relation);
365        self.context.disable_security_invoker = new_context.disable_security_invoker;
366        self.lateral_contexts.push(LateralBindContext {
367            is_visible: false,
368            context: new_context,
369        });
370    }
371
372    fn pop_and_merge_lateral_context(&mut self) -> Result<()> {
373        let mut old_context = self
374            .lateral_contexts
375            .pop()
376            .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?
377            .context;
378        old_context.merge_context(self.context.clone())?;
379        self.context = old_context;
380        Ok(())
381    }
382
383    fn try_mark_lateral_as_visible(&mut self) {
384        if let Some(mut ctx) = self.lateral_contexts.pop() {
385            ctx.is_visible = true;
386            self.lateral_contexts.push(ctx);
387        }
388    }
389
390    fn try_mark_lateral_as_invisible(&mut self) {
391        if let Some(mut ctx) = self.lateral_contexts.pop() {
392            ctx.is_visible = false;
393            self.lateral_contexts.push(ctx);
394        }
395    }
396
397    /// Returns a reverse iterator over the upper subquery contexts that are visible to the current
398    /// context. Not to be confused with `is_visible` in [`LateralBindContext`].
399    ///
400    /// In most cases, this should include all the upper subquery contexts. However, when binding
401    /// SQL UDFs, we should avoid resolving the context outside the UDF for hygiene.
402    fn visible_upper_subquery_contexts_rev(
403        &self,
404    ) -> impl Iterator<Item = &(BindContext, Vec<LateralBindContext>)> + '_ {
405        self.upper_subquery_contexts
406            .iter()
407            .rev()
408            .take_while(|(context, _)| context.sql_udf_arguments.is_none())
409    }
410
411    fn next_subquery_id(&mut self) -> usize {
412        let id = self.next_subquery_id;
413        self.next_subquery_id += 1;
414        id
415    }
416
417    fn next_values_id(&mut self) -> usize {
418        let id = self.next_values_id;
419        self.next_values_id += 1;
420        id
421    }
422
423    fn next_share_id(&mut self) -> ShareId {
424        let id = self.next_share_id;
425        self.next_share_id += 1;
426        id
427    }
428
429    fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> {
430        self.catalog.first_valid_schema(
431            &self.db_name,
432            &self.search_path,
433            &self.auth_context.user_name,
434        )
435    }
436
437    fn bind_schema_path<'a>(&'a self, schema_name: Option<&'a str>) -> SchemaPath<'a> {
438        SchemaPath::new(schema_name, &self.search_path, &self.auth_context.user_name)
439    }
440
441    pub fn set_clause(&mut self, clause: Option<Clause>) {
442        self.context.clause = clause;
443    }
444}
445
446/// The column name stored in [`BindContext`] for a column without an alias.
447pub const UNNAMED_COLUMN: &str = "?column?";
448/// The table name stored in [`BindContext`] for a subquery without an alias.
449const UNNAMED_SUBQUERY: &str = "?subquery?";
450/// The table name stored in [`BindContext`] for a column group.
451const COLUMN_GROUP_PREFIX: &str = "?column_group_id?";
452
453#[cfg(test)]
454pub mod test_utils {
455    use risingwave_common::types::DataType;
456
457    use super::Binder;
458    use crate::session::SessionImpl;
459
460    pub fn mock_binder() -> Binder {
461        mock_binder_with_param_types(vec![])
462    }
463
464    pub fn mock_binder_with_param_types(param_types: Vec<Option<DataType>>) -> Binder {
465        Binder::new_for_batch(&SessionImpl::mock()).with_specified_params_types(param_types)
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use expect_test::expect;
472
473    use super::test_utils::*;
474
475    #[tokio::test]
476    async fn test_bind_approx_percentile() {
477        let stmt = risingwave_sqlparser::parser::Parser::parse_sql(
478            "SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)",
479        ).unwrap().into_iter().next().unwrap();
480        let parse_expected = expect![[r#"
481            Query(
482                Query {
483                    with: None,
484                    body: Select(
485                        Select {
486                            distinct: All,
487                            projection: [
488                                UnnamedExpr(
489                                    Function(
490                                        Function {
491                                            scalar_as_agg: false,
492                                            name: ObjectName(
493                                                [
494                                                    Ident {
495                                                        value: "approx_percentile",
496                                                        quote_style: None,
497                                                    },
498                                                ],
499                                            ),
500                                            arg_list: FunctionArgList {
501                                                distinct: false,
502                                                args: [
503                                                    Unnamed(
504                                                        Expr(
505                                                            Value(
506                                                                Number(
507                                                                    "0.5",
508                                                                ),
509                                                            ),
510                                                        ),
511                                                    ),
512                                                    Unnamed(
513                                                        Expr(
514                                                            Value(
515                                                                Number(
516                                                                    "0.01",
517                                                                ),
518                                                            ),
519                                                        ),
520                                                    ),
521                                                ],
522                                                variadic: false,
523                                                order_by: [],
524                                                ignore_nulls: false,
525                                            },
526                                            within_group: Some(
527                                                OrderByExpr {
528                                                    expr: Identifier(
529                                                        Ident {
530                                                            value: "generate_series",
531                                                            quote_style: None,
532                                                        },
533                                                    ),
534                                                    asc: None,
535                                                    nulls_first: None,
536                                                },
537                                            ),
538                                            filter: None,
539                                            over: None,
540                                        },
541                                    ),
542                                ),
543                            ],
544                            from: [
545                                TableWithJoins {
546                                    relation: TableFunction {
547                                        name: ObjectName(
548                                            [
549                                                Ident {
550                                                    value: "generate_series",
551                                                    quote_style: None,
552                                                },
553                                            ],
554                                        ),
555                                        alias: None,
556                                        args: [
557                                            Unnamed(
558                                                Expr(
559                                                    Value(
560                                                        Number(
561                                                            "1",
562                                                        ),
563                                                    ),
564                                                ),
565                                            ),
566                                            Unnamed(
567                                                Expr(
568                                                    Value(
569                                                        Number(
570                                                            "100",
571                                                        ),
572                                                    ),
573                                                ),
574                                            ),
575                                        ],
576                                        with_ordinality: false,
577                                    },
578                                    joins: [],
579                                },
580                            ],
581                            lateral_views: [],
582                            selection: None,
583                            group_by: [],
584                            having: None,
585                            window: [],
586                        },
587                    ),
588                    order_by: [],
589                    limit: None,
590                    offset: None,
591                    fetch: None,
592                },
593            )"#]];
594        parse_expected.assert_eq(&format!("{:#?}", stmt));
595
596        let mut binder = mock_binder();
597        let bound = binder.bind(stmt).unwrap();
598
599        let expected = expect![[r#"
600            Query(
601                BoundQuery {
602                    body: Select(
603                        BoundSelect {
604                            distinct: All,
605                            select_items: [
606                                AggCall(
607                                    AggCall {
608                                        agg_type: Builtin(
609                                            ApproxPercentile,
610                                        ),
611                                        return_type: Float64,
612                                        args: [
613                                            FunctionCall(
614                                                FunctionCall {
615                                                    func_type: Cast,
616                                                    return_type: Float64,
617                                                    inputs: [
618                                                        InputRef(
619                                                            InputRef {
620                                                                index: 0,
621                                                                data_type: Int32,
622                                                            },
623                                                        ),
624                                                    ],
625                                                },
626                                            ),
627                                        ],
628                                        filter: Condition {
629                                            conjunctions: [],
630                                        },
631                                        distinct: false,
632                                        order_by: OrderBy {
633                                            sort_exprs: [
634                                                OrderByExpr {
635                                                    expr: InputRef(
636                                                        InputRef {
637                                                            index: 0,
638                                                            data_type: Int32,
639                                                        },
640                                                    ),
641                                                    order_type: OrderType {
642                                                        direction: Ascending,
643                                                        nulls_are: Largest,
644                                                    },
645                                                },
646                                            ],
647                                        },
648                                        direct_args: [
649                                            Literal {
650                                                data: Some(
651                                                    Float64(
652                                                        0.5,
653                                                    ),
654                                                ),
655                                                data_type: Some(
656                                                    Float64,
657                                                ),
658                                            },
659                                            Literal {
660                                                data: Some(
661                                                    Float64(
662                                                        0.01,
663                                                    ),
664                                                ),
665                                                data_type: Some(
666                                                    Float64,
667                                                ),
668                                            },
669                                        ],
670                                    },
671                                ),
672                            ],
673                            aliases: [
674                                Some(
675                                    "approx_percentile",
676                                ),
677                            ],
678                            from: Some(
679                                TableFunction {
680                                    expr: TableFunction(
681                                        FunctionCall {
682                                            function_type: GenerateSeries,
683                                            return_type: Int32,
684                                            args: [
685                                                Literal(
686                                                    Literal {
687                                                        data: Some(
688                                                            Int32(
689                                                                1,
690                                                            ),
691                                                        ),
692                                                        data_type: Some(
693                                                            Int32,
694                                                        ),
695                                                    },
696                                                ),
697                                                Literal(
698                                                    Literal {
699                                                        data: Some(
700                                                            Int32(
701                                                                100,
702                                                            ),
703                                                        ),
704                                                        data_type: Some(
705                                                            Int32,
706                                                        ),
707                                                    },
708                                                ),
709                                            ],
710                                        },
711                                    ),
712                                    with_ordinality: false,
713                                },
714                            ),
715                            where_clause: None,
716                            group_by: GroupKey(
717                                [],
718                            ),
719                            having: None,
720                            window: {},
721                            schema: Schema {
722                                fields: [
723                                    approx_percentile:Float64,
724                                ],
725                            },
726                        },
727                    ),
728                    order: [],
729                    limit: None,
730                    offset: None,
731                    with_ties: false,
732                    extra_order_exprs: [],
733                },
734            )"#]];
735
736        expected.assert_eq(&format!("{:#?}", bound));
737    }
738}