risingwave_sqlparser/
test_utils.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13/// This module contains internal utilities used for testing the library.
14/// While technically public, the library's users are not supposed to rely
15/// on this module, as it will change without notice.
16// Integration tests (i.e. everything under `tests/`) import this
17// via `tests/test_utils/executor`.
18
19#[cfg(not(feature = "std"))]
20use alloc::{
21    boxed::Box,
22    string::{String, ToString},
23    vec,
24    vec::Vec,
25};
26use core::fmt::Debug;
27
28use crate::ast::*;
29use crate::parser::{Parser, ParserError};
30use crate::tokenizer::Tokenizer;
31
32pub fn run_parser_method<F, T: Debug + PartialEq>(sql: &str, f: F) -> T
33where
34    F: Fn(&mut Parser<'_>) -> T,
35{
36    let mut tokenizer = Tokenizer::new(sql);
37    let tokens = tokenizer.tokenize_with_location().unwrap();
38    f(&mut Parser(&tokens))
39}
40
41pub fn parse_sql_statements(sql: &str) -> Result<Vec<Statement>, ParserError> {
42    Parser::parse_sql(sql)
43    // To fail the `ensure_multiple_dialects_are_tested` test:
44    // Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
45}
46
47/// Ensures that `sql` parses as a single statement and returns it.
48///
49/// If non-empty `canonical` SQL representation is provided,
50/// additionally asserts that parsing `sql` results in the same parse
51/// tree as parsing `canonical`, and that serializing it back to string
52/// results in the `canonical` representation.
53#[track_caller]
54pub fn one_statement_parses_to(sql: &str, canonical: &str) -> Statement {
55    let mut statements = parse_sql_statements(sql).unwrap();
56    assert_eq!(statements.len(), 1);
57
58    if !canonical.is_empty() && sql != canonical {
59        assert_eq!(parse_sql_statements(canonical).unwrap(), statements);
60    }
61
62    let only_statement = statements.pop().unwrap();
63    if !canonical.is_empty() {
64        assert_eq!(canonical, only_statement.to_string())
65    }
66    only_statement
67}
68
69/// Ensures that `sql` parses as a single [Statement], and is not modified
70/// after a serialization round-trip.
71#[track_caller]
72pub fn verified_stmt(query: &str) -> Statement {
73    one_statement_parses_to(query, query)
74}
75
76/// Ensures that `sql` parses as a single [Query], and is not modified
77/// after a serialization round-trip.
78#[track_caller]
79pub fn verified_query(sql: &str) -> Query {
80    match verified_stmt(sql) {
81        Statement::Query(query) => *query,
82        _ => panic!("Expected Query"),
83    }
84}
85
86#[track_caller]
87pub fn query(sql: &str, canonical: &str) -> Query {
88    match one_statement_parses_to(sql, canonical) {
89        Statement::Query(query) => *query,
90        _ => panic!("Expected Query"),
91    }
92}
93
94/// Ensures that `sql` parses as a single [Select], and is not modified
95/// after a serialization round-trip.
96#[track_caller]
97pub fn verified_only_select(query: &str) -> Select {
98    match verified_query(query).body {
99        SetExpr::Select(s) => *s,
100        _ => panic!("Expected SetExpr::Select"),
101    }
102}
103
104/// Ensures that `sql` parses as an expression, and is not modified
105/// after a serialization round-trip.
106pub fn verified_expr(sql: &str) -> Expr {
107    let ast = run_parser_method(sql, |parser| parser.parse_expr()).unwrap();
108    assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
109    ast
110}
111
112pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
113    let mut iter = v.into_iter();
114    match (iter.next(), iter.next()) {
115        (Some(item), None) => item,
116        _ => {
117            panic!("only called on collection without exactly one item")
118        }
119    }
120}
121
122pub fn expr_from_projection(item: &SelectItem) -> &Expr {
123    match item {
124        SelectItem::UnnamedExpr(expr) => expr,
125        _ => panic!("Expected UnnamedExpr"),
126    }
127}
128
129pub fn number(n: &'static str) -> Value {
130    Value::Number(n.parse().unwrap())
131}
132
133pub fn table_alias(name: impl Into<String>) -> Option<TableAlias> {
134    Some(TableAlias {
135        name: Ident::new_unchecked(name),
136        columns: vec![],
137    })
138}
139
140pub fn table(name: impl Into<String>) -> TableFactor {
141    TableFactor::Table {
142        name: ObjectName(vec![Ident::new_unchecked(name.into())]),
143        as_of: None,
144        alias: None,
145    }
146}
147
148pub fn join(relation: TableFactor) -> Join {
149    Join {
150        relation,
151        join_operator: JoinOperator::Inner(JoinConstraint::Natural),
152    }
153}