risingwave_frontend/binder/
mod.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18use itertools::Itertools;
19use parking_lot::RwLock;
20use risingwave_common::catalog::FunctionId;
21use risingwave_common::session_config::{SearchPath, SessionConfig};
22use risingwave_common::types::DataType;
23use risingwave_common::util::iter_util::ZipEqDebug;
24use risingwave_sqlparser::ast::Statement;
25
26use crate::error::Result;
27
28mod bind_context;
29mod bind_param;
30mod create;
31mod create_view;
32mod declare_cursor;
33mod delete;
34mod expr;
35pub mod fetch_cursor;
36mod for_system;
37mod gap_fill_binder;
38mod insert;
39mod query;
40mod relation;
41mod select;
42mod set_expr;
43mod statement;
44mod struct_field;
45mod update;
46mod values;
47
48pub use bind_context::{BindContext, Clause, LateralBindContext};
49pub use create_view::BoundCreateView;
50pub use delete::BoundDelete;
51pub use expr::bind_data_type;
52pub use gap_fill_binder::BoundFillStrategy;
53pub use insert::BoundInsert;
54use pgwire::pg_server::{Session, SessionId};
55pub use query::BoundQuery;
56pub use relation::{
57    BoundBaseTable, BoundGapFill, BoundJoin, BoundShare, BoundShareInput, BoundSource,
58    BoundSystemTable, BoundWatermark, BoundWindowTableFunction, Relation,
59    ResolveQualifiedNameError, WindowTableFunctionKind,
60};
61// Re-export common types
62pub use risingwave_common::gap_fill::FillStrategy;
63use risingwave_common::id::ObjectId;
64pub use select::{BoundDistinct, BoundSelect};
65pub use set_expr::*;
66pub use statement::BoundStatement;
67pub use update::{BoundUpdate, UpdateProject};
68pub use values::BoundValues;
69
70use crate::catalog::catalog_service::CatalogReadGuard;
71use crate::catalog::root_catalog::SchemaPath;
72use crate::catalog::schema_catalog::SchemaCatalog;
73use crate::catalog::{CatalogResult, DatabaseId, SecretId, ViewId};
74use crate::error::ErrorCode;
75use crate::session::{AuthContext, SessionImpl, StagingCatalogManager, TemporarySourceManager};
76use crate::user::user_service::UserInfoReadGuard;
77
78pub type ShareId = usize;
79
80/// The type of binding statement.
81enum BindFor {
82    /// Binding MV/SINK
83    Stream,
84    /// Binding a batch query
85    Batch,
86    /// Binding a DDL (e.g. CREATE TABLE/SOURCE)
87    Ddl,
88    /// Binding a system query (e.g. SHOW)
89    System,
90}
91
92/// `Binder` binds the identifiers in AST to columns in relations
93pub struct Binder {
94    // TODO: maybe we can only lock the database, but not the whole catalog.
95    catalog: CatalogReadGuard,
96    user: UserInfoReadGuard,
97    db_name: String,
98    database_id: DatabaseId,
99    session_id: SessionId,
100    context: BindContext,
101    auth_context: Arc<AuthContext>,
102    /// A stack holding contexts of outer queries when binding a subquery.
103    /// It also holds all of the lateral contexts for each respective
104    /// subquery.
105    ///
106    /// See [`Binder::bind_subquery_expr`] for details.
107    upper_subquery_contexts: Vec<(BindContext, Vec<LateralBindContext>)>,
108
109    /// A stack holding contexts of left-lateral `TableFactor`s.
110    ///
111    /// We need a separate stack as `CorrelatedInputRef` depth is
112    /// determined by the upper subquery context depth, not the lateral context stack depth.
113    lateral_contexts: Vec<LateralBindContext>,
114
115    next_subquery_id: usize,
116    next_values_id: usize,
117    /// The `ShareId` is used to identify the share relation which could be a CTE, a source, a view
118    /// and so on.
119    next_share_id: ShareId,
120
121    session_config: Arc<RwLock<SessionConfig>>,
122
123    search_path: SearchPath,
124    /// The type of binding statement.
125    bind_for: BindFor,
126
127    /// `ShareId`s identifying shared views.
128    shared_views: HashMap<ViewId, ShareId>,
129
130    /// The included relations while binding a query.
131    included_relations: HashSet<ObjectId>,
132
133    /// The included user-defined functions while binding a query.
134    included_udfs: HashSet<FunctionId>,
135
136    /// The included secrets while binding a query (e.g., secret refs in UDF arguments).
137    included_secrets: HashSet<SecretId>,
138
139    param_types: ParameterTypes,
140
141    /// The temporary sources that will be used during binding phase
142    temporary_source_manager: TemporarySourceManager,
143
144    /// The staging catalogs that will be used during binding phase
145    staging_catalog_manager: StagingCatalogManager,
146
147    /// Information for `secure_compare` function. It's ONLY available when binding the
148    /// `VALIDATE` clause of Webhook source i.e. `VALIDATE SECRET ... AS SECURE_COMPARE(...)`.
149    secure_compare_context: Option<SecureCompareContext>,
150}
151
152pub const WEBHOOK_PAYLOAD_FIELD_NAME: &str = "payload";
153
154// There are hidden names reserved for webhook validation expressions:
155// - `headers`, whose type is `JSONB`
156// - `payload`, whose type is `BYTEA`
157#[derive(Default, Clone, Debug)]
158pub struct SecureCompareContext {
159    /// The identifier used to reference the raw webhook payload during validation.
160    pub payload_name: String,
161    /// The secret (usually a token provided by the webhook source user) to validate the calls
162    pub secret_name: Option<String>,
163}
164
165/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments.
166/// It works by following the rules:
167/// 1. At the beginning, it contains the user specified parameters type.
168/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
169///    if it didn't exist in `ParameterTypes`.
170/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
171///    will record the target type as infer type for that parameter(call `record_infer_type`). If the
172///    parameter has been inferred, the cast function will act as a normal cast.
173/// 4. After bind finished:
174///    (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
175///    occur in the query. `export` will return error if there is a kind of
176///    parameter. This rule is compatible with PostgreSQL
177///    (b) parameter is None means that it's a unknown type. The user didn't specify it
178///    and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
179///    compatible with PostgreSQL.
180///    (c) parameter is Some means that it's a known type.
181#[derive(Clone, Debug)]
182pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);
183
184impl ParameterTypes {
185    pub fn new(specified_param_types: Vec<Option<DataType>>) -> Self {
186        let map = specified_param_types
187            .into_iter()
188            .enumerate()
189            .map(|(index, data_type)| ((index + 1) as u64, data_type))
190            .collect::<HashMap<u64, Option<DataType>>>();
191        Self(Arc::new(RwLock::new(map)))
192    }
193
194    pub fn has_infer(&self, index: u64) -> bool {
195        self.0.read().get(&index).unwrap().is_some()
196    }
197
198    pub fn read_type(&self, index: u64) -> Option<DataType> {
199        self.0.read().get(&index).unwrap().clone()
200    }
201
202    pub fn record_new_param(&mut self, index: u64) {
203        self.0.write().entry(index).or_insert(None);
204    }
205
206    pub fn record_infer_type(&mut self, index: u64, data_type: &DataType) {
207        assert!(
208            !self.has_infer(index),
209            "The parameter has been inferred, should not be inferred again."
210        );
211        self.0
212            .write()
213            .get_mut(&index)
214            .unwrap()
215            .replace(data_type.clone());
216    }
217
218    pub fn export(&self) -> Result<Vec<DataType>> {
219        let types = self
220            .0
221            .read()
222            .clone()
223            .into_iter()
224            .sorted_by_key(|(index, _)| *index)
225            .collect::<Vec<_>>();
226
227        // Check if all the parameters have been inferred.
228        for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
229            if *index != expect_index {
230                return Err(ErrorCode::InvalidInputSyntax(format!(
231                    "Cannot infer the type of the parameter {}.",
232                    expect_index
233                ))
234                .into());
235            }
236        }
237
238        Ok(types
239            .into_iter()
240            .map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
241            .collect::<Vec<_>>())
242    }
243}
244
245impl Binder {
246    fn new(session: &SessionImpl, bind_for: BindFor) -> Binder {
247        Binder {
248            catalog: session.env().catalog_reader().read_guard(),
249            user: session.env().user_info_reader().read_guard(),
250            db_name: session.database(),
251            database_id: session.database_id(),
252            session_id: session.id(),
253            context: BindContext::new(),
254            auth_context: session.auth_context(),
255            upper_subquery_contexts: vec![],
256            lateral_contexts: vec![],
257            next_subquery_id: 0,
258            next_values_id: 0,
259            next_share_id: 0,
260            session_config: session.shared_config(),
261            search_path: session.config().search_path(),
262            bind_for,
263            shared_views: HashMap::new(),
264            included_relations: HashSet::new(),
265            included_udfs: HashSet::new(),
266            included_secrets: HashSet::new(),
267            param_types: ParameterTypes::new(vec![]),
268            temporary_source_manager: session.temporary_source_manager(),
269            staging_catalog_manager: session.staging_catalog_manager(),
270            secure_compare_context: None,
271        }
272    }
273
274    pub fn new_for_batch(session: &SessionImpl) -> Binder {
275        Self::new(session, BindFor::Batch)
276    }
277
278    pub fn new_for_stream(session: &SessionImpl) -> Binder {
279        Self::new(session, BindFor::Stream)
280    }
281
282    pub fn new_for_ddl(session: &SessionImpl) -> Binder {
283        Self::new(session, BindFor::Ddl)
284    }
285
286    pub fn new_for_system(session: &SessionImpl) -> Binder {
287        Self::new(session, BindFor::System)
288    }
289
290    /// Set the specified parameter types.
291    pub fn with_specified_params_types(mut self, param_types: Vec<Option<DataType>>) -> Self {
292        self.param_types = ParameterTypes::new(param_types);
293        self
294    }
295
296    /// Set the secure compare context.
297    pub fn with_secure_compare(mut self, ctx: SecureCompareContext) -> Self {
298        self.secure_compare_context = Some(ctx);
299        self
300    }
301
302    fn is_for_stream(&self) -> bool {
303        matches!(self.bind_for, BindFor::Stream)
304    }
305
306    #[expect(dead_code)]
307    fn is_for_batch(&self) -> bool {
308        matches!(self.bind_for, BindFor::Batch)
309    }
310
311    fn is_for_ddl(&self) -> bool {
312        matches!(self.bind_for, BindFor::Ddl)
313    }
314
315    /// Bind a [`Statement`].
316    pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
317        self.bind_statement(stmt)
318    }
319
320    pub fn export_param_types(&self) -> Result<Vec<DataType>> {
321        self.param_types.export()
322    }
323
324    /// Get included relations in the query after binding. This is used for resolving relation
325    /// dependencies. Note that it only contains referenced relations discovered during binding.
326    /// After the plan is built, the referenced relations may be changed. We cannot rely on the
327    /// collection result of plan, because we still need to record the dependencies that have been
328    /// optimised away.
329    pub fn included_relations(&self) -> &HashSet<ObjectId> {
330        &self.included_relations
331    }
332
333    /// Get included user-defined functions in the query after binding.
334    pub fn included_udfs(&self) -> &HashSet<FunctionId> {
335        &self.included_udfs
336    }
337
338    /// Get included secrets in the query after binding (e.g., secret refs in UDF arguments).
339    pub fn included_secrets(&self) -> &HashSet<SecretId> {
340        &self.included_secrets
341    }
342
343    fn push_context(&mut self) {
344        let new_context = std::mem::take(&mut self.context);
345        self.context
346            .cte_to_relation
347            .clone_from(&new_context.cte_to_relation);
348        self.context.disable_security_invoker = new_context.disable_security_invoker;
349        let new_lateral_contexts = std::mem::take(&mut self.lateral_contexts);
350        self.upper_subquery_contexts
351            .push((new_context, new_lateral_contexts));
352    }
353
354    fn pop_context(&mut self) -> Result<()> {
355        let (old_context, old_lateral_contexts) = self
356            .upper_subquery_contexts
357            .pop()
358            .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?;
359        self.context = old_context;
360        self.lateral_contexts = old_lateral_contexts;
361        Ok(())
362    }
363
364    fn push_lateral_context(&mut self) {
365        let new_context = std::mem::take(&mut self.context);
366        self.context
367            .cte_to_relation
368            .clone_from(&new_context.cte_to_relation);
369        self.context.disable_security_invoker = new_context.disable_security_invoker;
370        self.lateral_contexts.push(LateralBindContext {
371            is_visible: false,
372            context: new_context,
373        });
374    }
375
376    fn pop_and_merge_lateral_context(&mut self) -> Result<()> {
377        let mut old_context = self
378            .lateral_contexts
379            .pop()
380            .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?
381            .context;
382        old_context.merge_context(self.context.clone())?;
383        self.context = old_context;
384        Ok(())
385    }
386
387    fn try_mark_lateral_as_visible(&mut self) {
388        if let Some(mut ctx) = self.lateral_contexts.pop() {
389            ctx.is_visible = true;
390            self.lateral_contexts.push(ctx);
391        }
392    }
393
394    fn try_mark_lateral_as_invisible(&mut self) {
395        if let Some(mut ctx) = self.lateral_contexts.pop() {
396            ctx.is_visible = false;
397            self.lateral_contexts.push(ctx);
398        }
399    }
400
401    /// Returns a reverse iterator over the upper subquery contexts that are visible to the current
402    /// context. Not to be confused with `is_visible` in [`LateralBindContext`].
403    ///
404    /// In most cases, this should include all the upper subquery contexts. However, when binding
405    /// SQL UDFs, we should avoid resolving the context outside the UDF for hygiene.
406    fn visible_upper_subquery_contexts_rev(
407        &self,
408    ) -> impl Iterator<Item = &(BindContext, Vec<LateralBindContext>)> + '_ {
409        self.upper_subquery_contexts
410            .iter()
411            .rev()
412            .take_while(|(context, _)| context.sql_udf_arguments.is_none())
413    }
414
415    fn next_subquery_id(&mut self) -> usize {
416        let id = self.next_subquery_id;
417        self.next_subquery_id += 1;
418        id
419    }
420
421    fn next_values_id(&mut self) -> usize {
422        let id = self.next_values_id;
423        self.next_values_id += 1;
424        id
425    }
426
427    fn next_share_id(&mut self) -> ShareId {
428        let id = self.next_share_id;
429        self.next_share_id += 1;
430        id
431    }
432
433    fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> {
434        self.catalog.first_valid_schema(
435            &self.db_name,
436            &self.search_path,
437            &self.auth_context.user_name,
438        )
439    }
440
441    fn bind_schema_path<'a>(&'a self, schema_name: Option<&'a str>) -> SchemaPath<'a> {
442        SchemaPath::new(schema_name, &self.search_path, &self.auth_context.user_name)
443    }
444
445    pub fn set_clause(&mut self, clause: Option<Clause>) {
446        self.context.clause = clause;
447    }
448}
449
450/// The column name stored in [`BindContext`] for a column without an alias.
451pub const UNNAMED_COLUMN: &str = "?column?";
452/// The table name stored in [`BindContext`] for a subquery without an alias.
453const UNNAMED_SUBQUERY: &str = "?subquery?";
454/// The table name stored in [`BindContext`] for a column group.
455const COLUMN_GROUP_PREFIX: &str = "?column_group_id?";
456
457#[cfg(test)]
458pub mod test_utils {
459    use risingwave_common::types::DataType;
460
461    use super::Binder;
462    use crate::session::SessionImpl;
463
464    pub fn mock_binder() -> Binder {
465        mock_binder_with_param_types(vec![])
466    }
467
468    pub fn mock_binder_with_param_types(param_types: Vec<Option<DataType>>) -> Binder {
469        Binder::new_for_batch(&SessionImpl::mock()).with_specified_params_types(param_types)
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use expect_test::expect;
476
477    use super::test_utils::*;
478
479    #[tokio::test]
480    async fn test_bind_approx_percentile() {
481        let stmt = risingwave_sqlparser::parser::Parser::parse_sql(
482            "SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)",
483        ).unwrap().into_iter().next().unwrap();
484        let parse_expected = expect![[r#"
485            Query(
486                Query {
487                    with: None,
488                    body: Select(
489                        Select {
490                            distinct: All,
491                            projection: [
492                                UnnamedExpr(
493                                    Function(
494                                        Function {
495                                            scalar_as_agg: false,
496                                            name: ObjectName(
497                                                [
498                                                    Ident {
499                                                        value: "approx_percentile",
500                                                        quote_style: None,
501                                                    },
502                                                ],
503                                            ),
504                                            arg_list: FunctionArgList {
505                                                distinct: false,
506                                                args: [
507                                                    Unnamed(
508                                                        Expr(
509                                                            Value(
510                                                                Number(
511                                                                    "0.5",
512                                                                ),
513                                                            ),
514                                                        ),
515                                                    ),
516                                                    Unnamed(
517                                                        Expr(
518                                                            Value(
519                                                                Number(
520                                                                    "0.01",
521                                                                ),
522                                                            ),
523                                                        ),
524                                                    ),
525                                                ],
526                                                variadic: false,
527                                                order_by: [],
528                                                ignore_nulls: false,
529                                            },
530                                            within_group: Some(
531                                                OrderByExpr {
532                                                    expr: Identifier(
533                                                        Ident {
534                                                            value: "generate_series",
535                                                            quote_style: None,
536                                                        },
537                                                    ),
538                                                    asc: None,
539                                                    nulls_first: None,
540                                                },
541                                            ),
542                                            filter: None,
543                                            over: None,
544                                        },
545                                    ),
546                                ),
547                            ],
548                            from: [
549                                TableWithJoins {
550                                    relation: TableFunction {
551                                        name: ObjectName(
552                                            [
553                                                Ident {
554                                                    value: "generate_series",
555                                                    quote_style: None,
556                                                },
557                                            ],
558                                        ),
559                                        alias: None,
560                                        args: [
561                                            Unnamed(
562                                                Expr(
563                                                    Value(
564                                                        Number(
565                                                            "1",
566                                                        ),
567                                                    ),
568                                                ),
569                                            ),
570                                            Unnamed(
571                                                Expr(
572                                                    Value(
573                                                        Number(
574                                                            "100",
575                                                        ),
576                                                    ),
577                                                ),
578                                            ),
579                                        ],
580                                        with_ordinality: false,
581                                    },
582                                    joins: [],
583                                },
584                            ],
585                            lateral_views: [],
586                            selection: None,
587                            group_by: [],
588                            having: None,
589                            window: [],
590                        },
591                    ),
592                    order_by: [],
593                    limit: None,
594                    offset: None,
595                    fetch: None,
596                },
597            )"#]];
598        parse_expected.assert_eq(&format!("{:#?}", stmt));
599
600        let mut binder = mock_binder();
601        let bound = binder.bind(stmt).unwrap();
602
603        let expected = expect![[r#"
604            Query(
605                BoundQuery {
606                    body: Select(
607                        BoundSelect {
608                            distinct: All,
609                            select_items: [
610                                AggCall(
611                                    AggCall {
612                                        agg_type: Builtin(
613                                            ApproxPercentile,
614                                        ),
615                                        return_type: Float64,
616                                        args: [
617                                            FunctionCall(
618                                                FunctionCall {
619                                                    func_type: Cast,
620                                                    return_type: Float64,
621                                                    inputs: [
622                                                        InputRef(
623                                                            InputRef {
624                                                                index: 0,
625                                                                data_type: Int32,
626                                                            },
627                                                        ),
628                                                    ],
629                                                },
630                                            ),
631                                        ],
632                                        filter: Condition {
633                                            conjunctions: [],
634                                        },
635                                        distinct: false,
636                                        order_by: OrderBy {
637                                            sort_exprs: [
638                                                OrderByExpr {
639                                                    expr: InputRef(
640                                                        InputRef {
641                                                            index: 0,
642                                                            data_type: Int32,
643                                                        },
644                                                    ),
645                                                    order_type: OrderType {
646                                                        direction: Ascending,
647                                                        nulls_are: Largest,
648                                                    },
649                                                },
650                                            ],
651                                        },
652                                        direct_args: [
653                                            Literal {
654                                                data: Some(
655                                                    Float64(
656                                                        0.5,
657                                                    ),
658                                                ),
659                                                data_type: Some(
660                                                    Float64,
661                                                ),
662                                            },
663                                            Literal {
664                                                data: Some(
665                                                    Float64(
666                                                        0.01,
667                                                    ),
668                                                ),
669                                                data_type: Some(
670                                                    Float64,
671                                                ),
672                                            },
673                                        ],
674                                    },
675                                ),
676                            ],
677                            aliases: [
678                                Some(
679                                    "approx_percentile",
680                                ),
681                            ],
682                            from: Some(
683                                TableFunction {
684                                    expr: TableFunction(
685                                        FunctionCall {
686                                            function_type: GenerateSeries,
687                                            return_type: Int32,
688                                            args: [
689                                                Literal(
690                                                    Literal {
691                                                        data: Some(
692                                                            Int32(
693                                                                1,
694                                                            ),
695                                                        ),
696                                                        data_type: Some(
697                                                            Int32,
698                                                        ),
699                                                    },
700                                                ),
701                                                Literal(
702                                                    Literal {
703                                                        data: Some(
704                                                            Int32(
705                                                                100,
706                                                            ),
707                                                        ),
708                                                        data_type: Some(
709                                                            Int32,
710                                                        ),
711                                                    },
712                                                ),
713                                            ],
714                                        },
715                                    ),
716                                    with_ordinality: false,
717                                },
718                            ),
719                            where_clause: None,
720                            group_by: GroupKey(
721                                [],
722                            ),
723                            having: None,
724                            window: {},
725                            schema: Schema {
726                                fields: [
727                                    approx_percentile:Float64,
728                                ],
729                            },
730                        },
731                    ),
732                    order: [],
733                    limit: None,
734                    offset: None,
735                    with_ties: false,
736                    extra_order_exprs: [],
737                },
738            )"#]];
739
740        expected.assert_eq(&format!("{:#?}", bound));
741    }
742}