risingwave_frontend/binder/
mod.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18use itertools::Itertools;
19use parking_lot::RwLock;
20use risingwave_common::catalog::FunctionId;
21use risingwave_common::session_config::{SearchPath, SessionConfig};
22use risingwave_common::types::DataType;
23use risingwave_common::util::iter_util::ZipEqDebug;
24use risingwave_sqlparser::ast::Statement;
25
26use crate::error::Result;
27
28mod bind_context;
29mod bind_param;
30mod create;
31mod create_view;
32mod declare_cursor;
33mod delete;
34mod expr;
35pub mod fetch_cursor;
36mod for_system;
37mod gap_fill_binder;
38mod insert;
39mod query;
40mod relation;
41mod select;
42mod set_expr;
43mod statement;
44mod struct_field;
45mod update;
46mod values;
47
48pub use bind_context::{BindContext, Clause, LateralBindContext};
49pub use create_view::BoundCreateView;
50pub use delete::BoundDelete;
51pub use expr::bind_data_type;
52pub use gap_fill_binder::BoundFillStrategy;
53pub use insert::BoundInsert;
54use pgwire::pg_server::{Session, SessionId};
55pub use query::BoundQuery;
56pub use relation::{
57    BoundBaseTable, BoundGapFill, BoundJoin, BoundShare, BoundShareInput, BoundSource,
58    BoundSystemTable, BoundWatermark, BoundWindowTableFunction, Relation,
59    ResolveQualifiedNameError, WindowTableFunctionKind,
60};
61// Re-export common types
62pub use risingwave_common::gap_fill::FillStrategy;
63use risingwave_common::id::ObjectId;
64pub use select::{BoundDistinct, BoundSelect};
65pub use set_expr::*;
66pub use statement::BoundStatement;
67pub use update::{BoundUpdate, UpdateProject};
68pub use values::BoundValues;
69
70use crate::catalog::catalog_service::CatalogReadGuard;
71use crate::catalog::root_catalog::SchemaPath;
72use crate::catalog::schema_catalog::SchemaCatalog;
73use crate::catalog::{CatalogResult, DatabaseId, ViewId};
74use crate::error::ErrorCode;
75use crate::session::{AuthContext, SessionImpl, StagingCatalogManager, TemporarySourceManager};
76use crate::user::user_service::UserInfoReadGuard;
77
78pub type ShareId = usize;
79
80/// The type of binding statement.
81enum BindFor {
82    /// Binding MV/SINK
83    Stream,
84    /// Binding a batch query
85    Batch,
86    /// Binding a DDL (e.g. CREATE TABLE/SOURCE)
87    Ddl,
88    /// Binding a system query (e.g. SHOW)
89    System,
90}
91
92/// `Binder` binds the identifiers in AST to columns in relations
93pub struct Binder {
94    // TODO: maybe we can only lock the database, but not the whole catalog.
95    catalog: CatalogReadGuard,
96    user: UserInfoReadGuard,
97    db_name: String,
98    database_id: DatabaseId,
99    session_id: SessionId,
100    context: BindContext,
101    auth_context: Arc<AuthContext>,
102    /// A stack holding contexts of outer queries when binding a subquery.
103    /// It also holds all of the lateral contexts for each respective
104    /// subquery.
105    ///
106    /// See [`Binder::bind_subquery_expr`] for details.
107    upper_subquery_contexts: Vec<(BindContext, Vec<LateralBindContext>)>,
108
109    /// A stack holding contexts of left-lateral `TableFactor`s.
110    ///
111    /// We need a separate stack as `CorrelatedInputRef` depth is
112    /// determined by the upper subquery context depth, not the lateral context stack depth.
113    lateral_contexts: Vec<LateralBindContext>,
114
115    next_subquery_id: usize,
116    next_values_id: usize,
117    /// The `ShareId` is used to identify the share relation which could be a CTE, a source, a view
118    /// and so on.
119    next_share_id: ShareId,
120
121    session_config: Arc<RwLock<SessionConfig>>,
122
123    search_path: SearchPath,
124    /// The type of binding statement.
125    bind_for: BindFor,
126
127    /// `ShareId`s identifying shared views.
128    shared_views: HashMap<ViewId, ShareId>,
129
130    /// The included relations while binding a query.
131    included_relations: HashSet<ObjectId>,
132
133    /// The included user-defined functions while binding a query.
134    included_udfs: HashSet<FunctionId>,
135
136    param_types: ParameterTypes,
137
138    /// The temporary sources that will be used during binding phase
139    temporary_source_manager: TemporarySourceManager,
140
141    /// The staging catalogs that will be used during binding phase
142    staging_catalog_manager: StagingCatalogManager,
143
144    /// Information for `secure_compare` function. It's ONLY available when binding the
145    /// `VALIDATE` clause of Webhook source i.e. `VALIDATE SECRET ... AS SECURE_COMPARE(...)`.
146    secure_compare_context: Option<SecureCompareContext>,
147}
148
149// There's one more hidden name, `HEADERS`, which is a reserved identifier for HTTP headers. Its type is `JSONB`.
150#[derive(Default, Clone, Debug)]
151pub struct SecureCompareContext {
152    /// The column name to store the whole payload in `JSONB`, but during validation it will be used as `bytea`
153    pub column_name: String,
154    /// The secret (usually a token provided by the webhook source user) to validate the calls
155    pub secret_name: Option<String>,
156}
157
158/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments.
159/// It works by following the rules:
160/// 1. At the beginning, it contains the user specified parameters type.
161/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
162///    if it didn't exist in `ParameterTypes`.
163/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
164///    will record the target type as infer type for that parameter(call `record_infer_type`). If the
165///    parameter has been inferred, the cast function will act as a normal cast.
166/// 4. After bind finished:
167///    (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
168///    occur in the query. `export` will return error if there is a kind of
169///    parameter. This rule is compatible with PostgreSQL
170///    (b) parameter is None means that it's a unknown type. The user didn't specify it
171///    and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
172///    compatible with PostgreSQL.
173///    (c) parameter is Some means that it's a known type.
174#[derive(Clone, Debug)]
175pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);
176
177impl ParameterTypes {
178    pub fn new(specified_param_types: Vec<Option<DataType>>) -> Self {
179        let map = specified_param_types
180            .into_iter()
181            .enumerate()
182            .map(|(index, data_type)| ((index + 1) as u64, data_type))
183            .collect::<HashMap<u64, Option<DataType>>>();
184        Self(Arc::new(RwLock::new(map)))
185    }
186
187    pub fn has_infer(&self, index: u64) -> bool {
188        self.0.read().get(&index).unwrap().is_some()
189    }
190
191    pub fn read_type(&self, index: u64) -> Option<DataType> {
192        self.0.read().get(&index).unwrap().clone()
193    }
194
195    pub fn record_new_param(&mut self, index: u64) {
196        self.0.write().entry(index).or_insert(None);
197    }
198
199    pub fn record_infer_type(&mut self, index: u64, data_type: &DataType) {
200        assert!(
201            !self.has_infer(index),
202            "The parameter has been inferred, should not be inferred again."
203        );
204        self.0
205            .write()
206            .get_mut(&index)
207            .unwrap()
208            .replace(data_type.clone());
209    }
210
211    pub fn export(&self) -> Result<Vec<DataType>> {
212        let types = self
213            .0
214            .read()
215            .clone()
216            .into_iter()
217            .sorted_by_key(|(index, _)| *index)
218            .collect::<Vec<_>>();
219
220        // Check if all the parameters have been inferred.
221        for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
222            if *index != expect_index {
223                return Err(ErrorCode::InvalidInputSyntax(format!(
224                    "Cannot infer the type of the parameter {}.",
225                    expect_index
226                ))
227                .into());
228            }
229        }
230
231        Ok(types
232            .into_iter()
233            .map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
234            .collect::<Vec<_>>())
235    }
236}
237
238impl Binder {
239    fn new(session: &SessionImpl, bind_for: BindFor) -> Binder {
240        Binder {
241            catalog: session.env().catalog_reader().read_guard(),
242            user: session.env().user_info_reader().read_guard(),
243            db_name: session.database(),
244            database_id: session.database_id(),
245            session_id: session.id(),
246            context: BindContext::new(),
247            auth_context: session.auth_context(),
248            upper_subquery_contexts: vec![],
249            lateral_contexts: vec![],
250            next_subquery_id: 0,
251            next_values_id: 0,
252            next_share_id: 0,
253            session_config: session.shared_config(),
254            search_path: session.config().search_path(),
255            bind_for,
256            shared_views: HashMap::new(),
257            included_relations: HashSet::new(),
258            included_udfs: HashSet::new(),
259            param_types: ParameterTypes::new(vec![]),
260            temporary_source_manager: session.temporary_source_manager(),
261            staging_catalog_manager: session.staging_catalog_manager(),
262            secure_compare_context: None,
263        }
264    }
265
266    pub fn new_for_batch(session: &SessionImpl) -> Binder {
267        Self::new(session, BindFor::Batch)
268    }
269
270    pub fn new_for_stream(session: &SessionImpl) -> Binder {
271        Self::new(session, BindFor::Stream)
272    }
273
274    pub fn new_for_ddl(session: &SessionImpl) -> Binder {
275        Self::new(session, BindFor::Ddl)
276    }
277
278    pub fn new_for_system(session: &SessionImpl) -> Binder {
279        Self::new(session, BindFor::System)
280    }
281
282    /// Set the specified parameter types.
283    pub fn with_specified_params_types(mut self, param_types: Vec<Option<DataType>>) -> Self {
284        self.param_types = ParameterTypes::new(param_types);
285        self
286    }
287
288    /// Set the secure compare context.
289    pub fn with_secure_compare(mut self, ctx: SecureCompareContext) -> Self {
290        self.secure_compare_context = Some(ctx);
291        self
292    }
293
294    fn is_for_stream(&self) -> bool {
295        matches!(self.bind_for, BindFor::Stream)
296    }
297
298    #[allow(dead_code)]
299    fn is_for_batch(&self) -> bool {
300        matches!(self.bind_for, BindFor::Batch)
301    }
302
303    fn is_for_ddl(&self) -> bool {
304        matches!(self.bind_for, BindFor::Ddl)
305    }
306
307    /// Bind a [`Statement`].
308    pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
309        self.bind_statement(stmt)
310    }
311
312    pub fn export_param_types(&self) -> Result<Vec<DataType>> {
313        self.param_types.export()
314    }
315
316    /// Get included relations in the query after binding. This is used for resolving relation
317    /// dependencies. Note that it only contains referenced relations discovered during binding.
318    /// After the plan is built, the referenced relations may be changed. We cannot rely on the
319    /// collection result of plan, because we still need to record the dependencies that have been
320    /// optimised away.
321    pub fn included_relations(&self) -> &HashSet<ObjectId> {
322        &self.included_relations
323    }
324
325    /// Get included user-defined functions in the query after binding.
326    pub fn included_udfs(&self) -> &HashSet<FunctionId> {
327        &self.included_udfs
328    }
329
330    fn push_context(&mut self) {
331        let new_context = std::mem::take(&mut self.context);
332        self.context
333            .cte_to_relation
334            .clone_from(&new_context.cte_to_relation);
335        self.context.disable_security_invoker = new_context.disable_security_invoker;
336        let new_lateral_contexts = std::mem::take(&mut self.lateral_contexts);
337        self.upper_subquery_contexts
338            .push((new_context, new_lateral_contexts));
339    }
340
341    fn pop_context(&mut self) -> Result<()> {
342        let (old_context, old_lateral_contexts) = self
343            .upper_subquery_contexts
344            .pop()
345            .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?;
346        self.context = old_context;
347        self.lateral_contexts = old_lateral_contexts;
348        Ok(())
349    }
350
351    fn push_lateral_context(&mut self) {
352        let new_context = std::mem::take(&mut self.context);
353        self.context
354            .cte_to_relation
355            .clone_from(&new_context.cte_to_relation);
356        self.context.disable_security_invoker = new_context.disable_security_invoker;
357        self.lateral_contexts.push(LateralBindContext {
358            is_visible: false,
359            context: new_context,
360        });
361    }
362
363    fn pop_and_merge_lateral_context(&mut self) -> Result<()> {
364        let mut old_context = self
365            .lateral_contexts
366            .pop()
367            .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?
368            .context;
369        old_context.merge_context(self.context.clone())?;
370        self.context = old_context;
371        Ok(())
372    }
373
374    fn try_mark_lateral_as_visible(&mut self) {
375        if let Some(mut ctx) = self.lateral_contexts.pop() {
376            ctx.is_visible = true;
377            self.lateral_contexts.push(ctx);
378        }
379    }
380
381    fn try_mark_lateral_as_invisible(&mut self) {
382        if let Some(mut ctx) = self.lateral_contexts.pop() {
383            ctx.is_visible = false;
384            self.lateral_contexts.push(ctx);
385        }
386    }
387
388    /// Returns a reverse iterator over the upper subquery contexts that are visible to the current
389    /// context. Not to be confused with `is_visible` in [`LateralBindContext`].
390    ///
391    /// In most cases, this should include all the upper subquery contexts. However, when binding
392    /// SQL UDFs, we should avoid resolving the context outside the UDF for hygiene.
393    fn visible_upper_subquery_contexts_rev(
394        &self,
395    ) -> impl Iterator<Item = &(BindContext, Vec<LateralBindContext>)> + '_ {
396        self.upper_subquery_contexts
397            .iter()
398            .rev()
399            .take_while(|(context, _)| context.sql_udf_arguments.is_none())
400    }
401
402    fn next_subquery_id(&mut self) -> usize {
403        let id = self.next_subquery_id;
404        self.next_subquery_id += 1;
405        id
406    }
407
408    fn next_values_id(&mut self) -> usize {
409        let id = self.next_values_id;
410        self.next_values_id += 1;
411        id
412    }
413
414    fn next_share_id(&mut self) -> ShareId {
415        let id = self.next_share_id;
416        self.next_share_id += 1;
417        id
418    }
419
420    fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> {
421        self.catalog.first_valid_schema(
422            &self.db_name,
423            &self.search_path,
424            &self.auth_context.user_name,
425        )
426    }
427
428    fn bind_schema_path<'a>(&'a self, schema_name: Option<&'a str>) -> SchemaPath<'a> {
429        SchemaPath::new(schema_name, &self.search_path, &self.auth_context.user_name)
430    }
431
432    pub fn set_clause(&mut self, clause: Option<Clause>) {
433        self.context.clause = clause;
434    }
435}
436
437/// The column name stored in [`BindContext`] for a column without an alias.
438pub const UNNAMED_COLUMN: &str = "?column?";
439/// The table name stored in [`BindContext`] for a subquery without an alias.
440const UNNAMED_SUBQUERY: &str = "?subquery?";
441/// The table name stored in [`BindContext`] for a column group.
442const COLUMN_GROUP_PREFIX: &str = "?column_group_id?";
443
444#[cfg(test)]
445pub mod test_utils {
446    use risingwave_common::types::DataType;
447
448    use super::Binder;
449    use crate::session::SessionImpl;
450
451    pub fn mock_binder() -> Binder {
452        mock_binder_with_param_types(vec![])
453    }
454
455    pub fn mock_binder_with_param_types(param_types: Vec<Option<DataType>>) -> Binder {
456        Binder::new_for_batch(&SessionImpl::mock()).with_specified_params_types(param_types)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use expect_test::expect;
463
464    use super::test_utils::*;
465
466    #[tokio::test]
467    async fn test_bind_approx_percentile() {
468        let stmt = risingwave_sqlparser::parser::Parser::parse_sql(
469            "SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)",
470        ).unwrap().into_iter().next().unwrap();
471        let parse_expected = expect![[r#"
472            Query(
473                Query {
474                    with: None,
475                    body: Select(
476                        Select {
477                            distinct: All,
478                            projection: [
479                                UnnamedExpr(
480                                    Function(
481                                        Function {
482                                            scalar_as_agg: false,
483                                            name: ObjectName(
484                                                [
485                                                    Ident {
486                                                        value: "approx_percentile",
487                                                        quote_style: None,
488                                                    },
489                                                ],
490                                            ),
491                                            arg_list: FunctionArgList {
492                                                distinct: false,
493                                                args: [
494                                                    Unnamed(
495                                                        Expr(
496                                                            Value(
497                                                                Number(
498                                                                    "0.5",
499                                                                ),
500                                                            ),
501                                                        ),
502                                                    ),
503                                                    Unnamed(
504                                                        Expr(
505                                                            Value(
506                                                                Number(
507                                                                    "0.01",
508                                                                ),
509                                                            ),
510                                                        ),
511                                                    ),
512                                                ],
513                                                variadic: false,
514                                                order_by: [],
515                                                ignore_nulls: false,
516                                            },
517                                            within_group: Some(
518                                                OrderByExpr {
519                                                    expr: Identifier(
520                                                        Ident {
521                                                            value: "generate_series",
522                                                            quote_style: None,
523                                                        },
524                                                    ),
525                                                    asc: None,
526                                                    nulls_first: None,
527                                                },
528                                            ),
529                                            filter: None,
530                                            over: None,
531                                        },
532                                    ),
533                                ),
534                            ],
535                            from: [
536                                TableWithJoins {
537                                    relation: TableFunction {
538                                        name: ObjectName(
539                                            [
540                                                Ident {
541                                                    value: "generate_series",
542                                                    quote_style: None,
543                                                },
544                                            ],
545                                        ),
546                                        alias: None,
547                                        args: [
548                                            Unnamed(
549                                                Expr(
550                                                    Value(
551                                                        Number(
552                                                            "1",
553                                                        ),
554                                                    ),
555                                                ),
556                                            ),
557                                            Unnamed(
558                                                Expr(
559                                                    Value(
560                                                        Number(
561                                                            "100",
562                                                        ),
563                                                    ),
564                                                ),
565                                            ),
566                                        ],
567                                        with_ordinality: false,
568                                    },
569                                    joins: [],
570                                },
571                            ],
572                            lateral_views: [],
573                            selection: None,
574                            group_by: [],
575                            having: None,
576                            window: [],
577                        },
578                    ),
579                    order_by: [],
580                    limit: None,
581                    offset: None,
582                    fetch: None,
583                },
584            )"#]];
585        parse_expected.assert_eq(&format!("{:#?}", stmt));
586
587        let mut binder = mock_binder();
588        let bound = binder.bind(stmt).unwrap();
589
590        let expected = expect![[r#"
591            Query(
592                BoundQuery {
593                    body: Select(
594                        BoundSelect {
595                            distinct: All,
596                            select_items: [
597                                AggCall(
598                                    AggCall {
599                                        agg_type: Builtin(
600                                            ApproxPercentile,
601                                        ),
602                                        return_type: Float64,
603                                        args: [
604                                            FunctionCall(
605                                                FunctionCall {
606                                                    func_type: Cast,
607                                                    return_type: Float64,
608                                                    inputs: [
609                                                        InputRef(
610                                                            InputRef {
611                                                                index: 0,
612                                                                data_type: Int32,
613                                                            },
614                                                        ),
615                                                    ],
616                                                },
617                                            ),
618                                        ],
619                                        filter: Condition {
620                                            conjunctions: [],
621                                        },
622                                        distinct: false,
623                                        order_by: OrderBy {
624                                            sort_exprs: [
625                                                OrderByExpr {
626                                                    expr: InputRef(
627                                                        InputRef {
628                                                            index: 0,
629                                                            data_type: Int32,
630                                                        },
631                                                    ),
632                                                    order_type: OrderType {
633                                                        direction: Ascending,
634                                                        nulls_are: Largest,
635                                                    },
636                                                },
637                                            ],
638                                        },
639                                        direct_args: [
640                                            Literal {
641                                                data: Some(
642                                                    Float64(
643                                                        0.5,
644                                                    ),
645                                                ),
646                                                data_type: Some(
647                                                    Float64,
648                                                ),
649                                            },
650                                            Literal {
651                                                data: Some(
652                                                    Float64(
653                                                        0.01,
654                                                    ),
655                                                ),
656                                                data_type: Some(
657                                                    Float64,
658                                                ),
659                                            },
660                                        ],
661                                    },
662                                ),
663                            ],
664                            aliases: [
665                                Some(
666                                    "approx_percentile",
667                                ),
668                            ],
669                            from: Some(
670                                TableFunction {
671                                    expr: TableFunction(
672                                        FunctionCall {
673                                            function_type: GenerateSeries,
674                                            return_type: Int32,
675                                            args: [
676                                                Literal(
677                                                    Literal {
678                                                        data: Some(
679                                                            Int32(
680                                                                1,
681                                                            ),
682                                                        ),
683                                                        data_type: Some(
684                                                            Int32,
685                                                        ),
686                                                    },
687                                                ),
688                                                Literal(
689                                                    Literal {
690                                                        data: Some(
691                                                            Int32(
692                                                                100,
693                                                            ),
694                                                        ),
695                                                        data_type: Some(
696                                                            Int32,
697                                                        ),
698                                                    },
699                                                ),
700                                            ],
701                                        },
702                                    ),
703                                    with_ordinality: false,
704                                },
705                            ),
706                            where_clause: None,
707                            group_by: GroupKey(
708                                [],
709                            ),
710                            having: None,
711                            window: {},
712                            schema: Schema {
713                                fields: [
714                                    approx_percentile:Float64,
715                                ],
716                            },
717                        },
718                    ),
719                    order: [],
720                    limit: None,
721                    offset: None,
722                    with_ties: false,
723                    extra_order_exprs: [],
724                },
725            )"#]];
726
727        expected.assert_eq(&format!("{:#?}", bound));
728    }
729}