risingwave_frontend/binder/mod.rs
1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18use itertools::Itertools;
19use parking_lot::RwLock;
20use risingwave_common::catalog::FunctionId;
21use risingwave_common::session_config::{SearchPath, SessionConfig};
22use risingwave_common::types::DataType;
23use risingwave_common::util::iter_util::ZipEqDebug;
24use risingwave_sqlparser::ast::Statement;
25
26use crate::error::Result;
27
28mod bind_context;
29mod bind_param;
30mod create;
31mod create_view;
32mod declare_cursor;
33mod delete;
34mod expr;
35pub mod fetch_cursor;
36mod for_system;
37mod gap_fill_binder;
38mod insert;
39mod query;
40mod relation;
41mod select;
42mod set_expr;
43mod statement;
44mod struct_field;
45mod update;
46mod values;
47
48pub use bind_context::{BindContext, Clause, LateralBindContext};
49pub use create_view::BoundCreateView;
50pub use delete::BoundDelete;
51pub use expr::bind_data_type;
52pub use gap_fill_binder::BoundFillStrategy;
53pub use insert::BoundInsert;
54use pgwire::pg_server::{Session, SessionId};
55pub use query::BoundQuery;
56pub use relation::{
57 BoundBaseTable, BoundGapFill, BoundJoin, BoundShare, BoundShareInput, BoundSource,
58 BoundSystemTable, BoundWatermark, BoundWindowTableFunction, Relation,
59 ResolveQualifiedNameError, WindowTableFunctionKind,
60};
61// Re-export common types
62pub use risingwave_common::gap_fill::FillStrategy;
63use risingwave_common::id::ObjectId;
64pub use select::{BoundDistinct, BoundSelect};
65pub use set_expr::*;
66pub use statement::BoundStatement;
67pub use update::{BoundUpdate, UpdateProject};
68pub use values::BoundValues;
69
70use crate::catalog::catalog_service::CatalogReadGuard;
71use crate::catalog::root_catalog::SchemaPath;
72use crate::catalog::schema_catalog::SchemaCatalog;
73use crate::catalog::{CatalogResult, DatabaseId, ViewId};
74use crate::error::ErrorCode;
75use crate::session::{AuthContext, SessionImpl, StagingCatalogManager, TemporarySourceManager};
76use crate::user::user_service::UserInfoReadGuard;
77
78pub type ShareId = usize;
79
80/// The type of binding statement.
81enum BindFor {
82 /// Binding MV/SINK
83 Stream,
84 /// Binding a batch query
85 Batch,
86 /// Binding a DDL (e.g. CREATE TABLE/SOURCE)
87 Ddl,
88 /// Binding a system query (e.g. SHOW)
89 System,
90}
91
92/// `Binder` binds the identifiers in AST to columns in relations
93pub struct Binder {
94 // TODO: maybe we can only lock the database, but not the whole catalog.
95 catalog: CatalogReadGuard,
96 user: UserInfoReadGuard,
97 db_name: String,
98 database_id: DatabaseId,
99 session_id: SessionId,
100 context: BindContext,
101 auth_context: Arc<AuthContext>,
102 /// A stack holding contexts of outer queries when binding a subquery.
103 /// It also holds all of the lateral contexts for each respective
104 /// subquery.
105 ///
106 /// See [`Binder::bind_subquery_expr`] for details.
107 upper_subquery_contexts: Vec<(BindContext, Vec<LateralBindContext>)>,
108
109 /// A stack holding contexts of left-lateral `TableFactor`s.
110 ///
111 /// We need a separate stack as `CorrelatedInputRef` depth is
112 /// determined by the upper subquery context depth, not the lateral context stack depth.
113 lateral_contexts: Vec<LateralBindContext>,
114
115 next_subquery_id: usize,
116 next_values_id: usize,
117 /// The `ShareId` is used to identify the share relation which could be a CTE, a source, a view
118 /// and so on.
119 next_share_id: ShareId,
120
121 session_config: Arc<RwLock<SessionConfig>>,
122
123 search_path: SearchPath,
124 /// The type of binding statement.
125 bind_for: BindFor,
126
127 /// `ShareId`s identifying shared views.
128 shared_views: HashMap<ViewId, ShareId>,
129
130 /// The included relations while binding a query.
131 included_relations: HashSet<ObjectId>,
132
133 /// The included user-defined functions while binding a query.
134 included_udfs: HashSet<FunctionId>,
135
136 param_types: ParameterTypes,
137
138 /// The temporary sources that will be used during binding phase
139 temporary_source_manager: TemporarySourceManager,
140
141 /// The staging catalogs that will be used during binding phase
142 staging_catalog_manager: StagingCatalogManager,
143
144 /// Information for `secure_compare` function. It's ONLY available when binding the
145 /// `VALIDATE` clause of Webhook source i.e. `VALIDATE SECRET ... AS SECURE_COMPARE(...)`.
146 secure_compare_context: Option<SecureCompareContext>,
147}
148
149// There's one more hidden name, `HEADERS`, which is a reserved identifier for HTTP headers. Its type is `JSONB`.
150#[derive(Default, Clone, Debug)]
151pub struct SecureCompareContext {
152 /// The column name to store the whole payload in `JSONB`, but during validation it will be used as `bytea`
153 pub column_name: String,
154 /// The secret (usually a token provided by the webhook source user) to validate the calls
155 pub secret_name: Option<String>,
156}
157
158/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments.
159/// It works by following the rules:
160/// 1. At the beginning, it contains the user specified parameters type.
161/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
162/// if it didn't exist in `ParameterTypes`.
163/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
164/// will record the target type as infer type for that parameter(call `record_infer_type`). If the
165/// parameter has been inferred, the cast function will act as a normal cast.
166/// 4. After bind finished:
167/// (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
168/// occur in the query. `export` will return error if there is a kind of
169/// parameter. This rule is compatible with PostgreSQL
170/// (b) parameter is None means that it's a unknown type. The user didn't specify it
171/// and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
172/// compatible with PostgreSQL.
173/// (c) parameter is Some means that it's a known type.
174#[derive(Clone, Debug)]
175pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);
176
177impl ParameterTypes {
178 pub fn new(specified_param_types: Vec<Option<DataType>>) -> Self {
179 let map = specified_param_types
180 .into_iter()
181 .enumerate()
182 .map(|(index, data_type)| ((index + 1) as u64, data_type))
183 .collect::<HashMap<u64, Option<DataType>>>();
184 Self(Arc::new(RwLock::new(map)))
185 }
186
187 pub fn has_infer(&self, index: u64) -> bool {
188 self.0.read().get(&index).unwrap().is_some()
189 }
190
191 pub fn read_type(&self, index: u64) -> Option<DataType> {
192 self.0.read().get(&index).unwrap().clone()
193 }
194
195 pub fn record_new_param(&mut self, index: u64) {
196 self.0.write().entry(index).or_insert(None);
197 }
198
199 pub fn record_infer_type(&mut self, index: u64, data_type: &DataType) {
200 assert!(
201 !self.has_infer(index),
202 "The parameter has been inferred, should not be inferred again."
203 );
204 self.0
205 .write()
206 .get_mut(&index)
207 .unwrap()
208 .replace(data_type.clone());
209 }
210
211 pub fn export(&self) -> Result<Vec<DataType>> {
212 let types = self
213 .0
214 .read()
215 .clone()
216 .into_iter()
217 .sorted_by_key(|(index, _)| *index)
218 .collect::<Vec<_>>();
219
220 // Check if all the parameters have been inferred.
221 for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
222 if *index != expect_index {
223 return Err(ErrorCode::InvalidInputSyntax(format!(
224 "Cannot infer the type of the parameter {}.",
225 expect_index
226 ))
227 .into());
228 }
229 }
230
231 Ok(types
232 .into_iter()
233 .map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
234 .collect::<Vec<_>>())
235 }
236}
237
238impl Binder {
239 fn new(session: &SessionImpl, bind_for: BindFor) -> Binder {
240 Binder {
241 catalog: session.env().catalog_reader().read_guard(),
242 user: session.env().user_info_reader().read_guard(),
243 db_name: session.database(),
244 database_id: session.database_id(),
245 session_id: session.id(),
246 context: BindContext::new(),
247 auth_context: session.auth_context(),
248 upper_subquery_contexts: vec![],
249 lateral_contexts: vec![],
250 next_subquery_id: 0,
251 next_values_id: 0,
252 next_share_id: 0,
253 session_config: session.shared_config(),
254 search_path: session.config().search_path(),
255 bind_for,
256 shared_views: HashMap::new(),
257 included_relations: HashSet::new(),
258 included_udfs: HashSet::new(),
259 param_types: ParameterTypes::new(vec![]),
260 temporary_source_manager: session.temporary_source_manager(),
261 staging_catalog_manager: session.staging_catalog_manager(),
262 secure_compare_context: None,
263 }
264 }
265
266 pub fn new_for_batch(session: &SessionImpl) -> Binder {
267 Self::new(session, BindFor::Batch)
268 }
269
270 pub fn new_for_stream(session: &SessionImpl) -> Binder {
271 Self::new(session, BindFor::Stream)
272 }
273
274 pub fn new_for_ddl(session: &SessionImpl) -> Binder {
275 Self::new(session, BindFor::Ddl)
276 }
277
278 pub fn new_for_system(session: &SessionImpl) -> Binder {
279 Self::new(session, BindFor::System)
280 }
281
282 /// Set the specified parameter types.
283 pub fn with_specified_params_types(mut self, param_types: Vec<Option<DataType>>) -> Self {
284 self.param_types = ParameterTypes::new(param_types);
285 self
286 }
287
288 /// Set the secure compare context.
289 pub fn with_secure_compare(mut self, ctx: SecureCompareContext) -> Self {
290 self.secure_compare_context = Some(ctx);
291 self
292 }
293
294 fn is_for_stream(&self) -> bool {
295 matches!(self.bind_for, BindFor::Stream)
296 }
297
298 #[allow(dead_code)]
299 fn is_for_batch(&self) -> bool {
300 matches!(self.bind_for, BindFor::Batch)
301 }
302
303 fn is_for_ddl(&self) -> bool {
304 matches!(self.bind_for, BindFor::Ddl)
305 }
306
307 /// Bind a [`Statement`].
308 pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
309 self.bind_statement(stmt)
310 }
311
312 pub fn export_param_types(&self) -> Result<Vec<DataType>> {
313 self.param_types.export()
314 }
315
316 /// Get included relations in the query after binding. This is used for resolving relation
317 /// dependencies. Note that it only contains referenced relations discovered during binding.
318 /// After the plan is built, the referenced relations may be changed. We cannot rely on the
319 /// collection result of plan, because we still need to record the dependencies that have been
320 /// optimised away.
321 pub fn included_relations(&self) -> &HashSet<ObjectId> {
322 &self.included_relations
323 }
324
325 /// Get included user-defined functions in the query after binding.
326 pub fn included_udfs(&self) -> &HashSet<FunctionId> {
327 &self.included_udfs
328 }
329
330 fn push_context(&mut self) {
331 let new_context = std::mem::take(&mut self.context);
332 self.context
333 .cte_to_relation
334 .clone_from(&new_context.cte_to_relation);
335 self.context.disable_security_invoker = new_context.disable_security_invoker;
336 let new_lateral_contexts = std::mem::take(&mut self.lateral_contexts);
337 self.upper_subquery_contexts
338 .push((new_context, new_lateral_contexts));
339 }
340
341 fn pop_context(&mut self) -> Result<()> {
342 let (old_context, old_lateral_contexts) = self
343 .upper_subquery_contexts
344 .pop()
345 .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?;
346 self.context = old_context;
347 self.lateral_contexts = old_lateral_contexts;
348 Ok(())
349 }
350
351 fn push_lateral_context(&mut self) {
352 let new_context = std::mem::take(&mut self.context);
353 self.context
354 .cte_to_relation
355 .clone_from(&new_context.cte_to_relation);
356 self.context.disable_security_invoker = new_context.disable_security_invoker;
357 self.lateral_contexts.push(LateralBindContext {
358 is_visible: false,
359 context: new_context,
360 });
361 }
362
363 fn pop_and_merge_lateral_context(&mut self) -> Result<()> {
364 let mut old_context = self
365 .lateral_contexts
366 .pop()
367 .ok_or_else(|| ErrorCode::InternalError("Popping non-existent context".to_owned()))?
368 .context;
369 old_context.merge_context(self.context.clone())?;
370 self.context = old_context;
371 Ok(())
372 }
373
374 fn try_mark_lateral_as_visible(&mut self) {
375 if let Some(mut ctx) = self.lateral_contexts.pop() {
376 ctx.is_visible = true;
377 self.lateral_contexts.push(ctx);
378 }
379 }
380
381 fn try_mark_lateral_as_invisible(&mut self) {
382 if let Some(mut ctx) = self.lateral_contexts.pop() {
383 ctx.is_visible = false;
384 self.lateral_contexts.push(ctx);
385 }
386 }
387
388 /// Returns a reverse iterator over the upper subquery contexts that are visible to the current
389 /// context. Not to be confused with `is_visible` in [`LateralBindContext`].
390 ///
391 /// In most cases, this should include all the upper subquery contexts. However, when binding
392 /// SQL UDFs, we should avoid resolving the context outside the UDF for hygiene.
393 fn visible_upper_subquery_contexts_rev(
394 &self,
395 ) -> impl Iterator<Item = &(BindContext, Vec<LateralBindContext>)> + '_ {
396 self.upper_subquery_contexts
397 .iter()
398 .rev()
399 .take_while(|(context, _)| context.sql_udf_arguments.is_none())
400 }
401
402 fn next_subquery_id(&mut self) -> usize {
403 let id = self.next_subquery_id;
404 self.next_subquery_id += 1;
405 id
406 }
407
408 fn next_values_id(&mut self) -> usize {
409 let id = self.next_values_id;
410 self.next_values_id += 1;
411 id
412 }
413
414 fn next_share_id(&mut self) -> ShareId {
415 let id = self.next_share_id;
416 self.next_share_id += 1;
417 id
418 }
419
420 fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> {
421 self.catalog.first_valid_schema(
422 &self.db_name,
423 &self.search_path,
424 &self.auth_context.user_name,
425 )
426 }
427
428 fn bind_schema_path<'a>(&'a self, schema_name: Option<&'a str>) -> SchemaPath<'a> {
429 SchemaPath::new(schema_name, &self.search_path, &self.auth_context.user_name)
430 }
431
432 pub fn set_clause(&mut self, clause: Option<Clause>) {
433 self.context.clause = clause;
434 }
435}
436
437/// The column name stored in [`BindContext`] for a column without an alias.
438pub const UNNAMED_COLUMN: &str = "?column?";
439/// The table name stored in [`BindContext`] for a subquery without an alias.
440const UNNAMED_SUBQUERY: &str = "?subquery?";
441/// The table name stored in [`BindContext`] for a column group.
442const COLUMN_GROUP_PREFIX: &str = "?column_group_id?";
443
444#[cfg(test)]
445pub mod test_utils {
446 use risingwave_common::types::DataType;
447
448 use super::Binder;
449 use crate::session::SessionImpl;
450
451 pub fn mock_binder() -> Binder {
452 mock_binder_with_param_types(vec![])
453 }
454
455 pub fn mock_binder_with_param_types(param_types: Vec<Option<DataType>>) -> Binder {
456 Binder::new_for_batch(&SessionImpl::mock()).with_specified_params_types(param_types)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use expect_test::expect;
463
464 use super::test_utils::*;
465
466 #[tokio::test]
467 async fn test_bind_approx_percentile() {
468 let stmt = risingwave_sqlparser::parser::Parser::parse_sql(
469 "SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)",
470 ).unwrap().into_iter().next().unwrap();
471 let parse_expected = expect![[r#"
472 Query(
473 Query {
474 with: None,
475 body: Select(
476 Select {
477 distinct: All,
478 projection: [
479 UnnamedExpr(
480 Function(
481 Function {
482 scalar_as_agg: false,
483 name: ObjectName(
484 [
485 Ident {
486 value: "approx_percentile",
487 quote_style: None,
488 },
489 ],
490 ),
491 arg_list: FunctionArgList {
492 distinct: false,
493 args: [
494 Unnamed(
495 Expr(
496 Value(
497 Number(
498 "0.5",
499 ),
500 ),
501 ),
502 ),
503 Unnamed(
504 Expr(
505 Value(
506 Number(
507 "0.01",
508 ),
509 ),
510 ),
511 ),
512 ],
513 variadic: false,
514 order_by: [],
515 ignore_nulls: false,
516 },
517 within_group: Some(
518 OrderByExpr {
519 expr: Identifier(
520 Ident {
521 value: "generate_series",
522 quote_style: None,
523 },
524 ),
525 asc: None,
526 nulls_first: None,
527 },
528 ),
529 filter: None,
530 over: None,
531 },
532 ),
533 ),
534 ],
535 from: [
536 TableWithJoins {
537 relation: TableFunction {
538 name: ObjectName(
539 [
540 Ident {
541 value: "generate_series",
542 quote_style: None,
543 },
544 ],
545 ),
546 alias: None,
547 args: [
548 Unnamed(
549 Expr(
550 Value(
551 Number(
552 "1",
553 ),
554 ),
555 ),
556 ),
557 Unnamed(
558 Expr(
559 Value(
560 Number(
561 "100",
562 ),
563 ),
564 ),
565 ),
566 ],
567 with_ordinality: false,
568 },
569 joins: [],
570 },
571 ],
572 lateral_views: [],
573 selection: None,
574 group_by: [],
575 having: None,
576 window: [],
577 },
578 ),
579 order_by: [],
580 limit: None,
581 offset: None,
582 fetch: None,
583 },
584 )"#]];
585 parse_expected.assert_eq(&format!("{:#?}", stmt));
586
587 let mut binder = mock_binder();
588 let bound = binder.bind(stmt).unwrap();
589
590 let expected = expect![[r#"
591 Query(
592 BoundQuery {
593 body: Select(
594 BoundSelect {
595 distinct: All,
596 select_items: [
597 AggCall(
598 AggCall {
599 agg_type: Builtin(
600 ApproxPercentile,
601 ),
602 return_type: Float64,
603 args: [
604 FunctionCall(
605 FunctionCall {
606 func_type: Cast,
607 return_type: Float64,
608 inputs: [
609 InputRef(
610 InputRef {
611 index: 0,
612 data_type: Int32,
613 },
614 ),
615 ],
616 },
617 ),
618 ],
619 filter: Condition {
620 conjunctions: [],
621 },
622 distinct: false,
623 order_by: OrderBy {
624 sort_exprs: [
625 OrderByExpr {
626 expr: InputRef(
627 InputRef {
628 index: 0,
629 data_type: Int32,
630 },
631 ),
632 order_type: OrderType {
633 direction: Ascending,
634 nulls_are: Largest,
635 },
636 },
637 ],
638 },
639 direct_args: [
640 Literal {
641 data: Some(
642 Float64(
643 0.5,
644 ),
645 ),
646 data_type: Some(
647 Float64,
648 ),
649 },
650 Literal {
651 data: Some(
652 Float64(
653 0.01,
654 ),
655 ),
656 data_type: Some(
657 Float64,
658 ),
659 },
660 ],
661 },
662 ),
663 ],
664 aliases: [
665 Some(
666 "approx_percentile",
667 ),
668 ],
669 from: Some(
670 TableFunction {
671 expr: TableFunction(
672 FunctionCall {
673 function_type: GenerateSeries,
674 return_type: Int32,
675 args: [
676 Literal(
677 Literal {
678 data: Some(
679 Int32(
680 1,
681 ),
682 ),
683 data_type: Some(
684 Int32,
685 ),
686 },
687 ),
688 Literal(
689 Literal {
690 data: Some(
691 Int32(
692 100,
693 ),
694 ),
695 data_type: Some(
696 Int32,
697 ),
698 },
699 ),
700 ],
701 },
702 ),
703 with_ordinality: false,
704 },
705 ),
706 where_clause: None,
707 group_by: GroupKey(
708 [],
709 ),
710 having: None,
711 window: {},
712 schema: Schema {
713 fields: [
714 approx_percentile:Float64,
715 ],
716 },
717 },
718 ),
719 order: [],
720 limit: None,
721 offset: None,
722 with_ties: false,
723 extra_order_exprs: [],
724 },
725 )"#]];
726
727 expected.assert_eq(&format!("{:#?}", bound));
728 }
729}