sqlsmith/
main.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(register_tool)]
16#![register_tool(rw)]
17#![allow(rw::format_error)] // test code
18
19use core::panic;
20use std::time::Duration;
21
22use clap::Parser as ClapParser;
23use risingwave_sqlsmith::config::Configuration;
24use risingwave_sqlsmith::print_function_table;
25use risingwave_sqlsmith::test_runners::{generate, run, run_differential_testing};
26use tokio_postgres::NoTls;
27
28#[derive(ClapParser, Debug, Clone)]
29#[clap(about, version, author)]
30struct Opt {
31    #[clap(subcommand)]
32    command: Commands,
33}
34
35#[derive(clap::Args, Clone, Debug)]
36struct TestOptions {
37    /// The database server host.
38    #[clap(long, default_value = "localhost")]
39    host: String,
40
41    /// The database server port.
42    #[clap(short, long, default_value = "4566")]
43    port: u16,
44
45    /// The database name to connect.
46    #[clap(short, long, default_value = "dev")]
47    db: String,
48
49    /// The database username.
50    #[clap(short, long, default_value = "root")]
51    user: String,
52
53    /// The database password.
54    #[clap(short = 'w', long, default_value = "")]
55    pass: String,
56
57    /// Path to the testing data files.
58    #[clap(short, long)]
59    testdata: String,
60
61    /// The number of test cases to generate.
62    #[clap(long, default_value = "100")]
63    count: usize,
64
65    /// Output directory - only applicable if we are generating
66    /// query while testing.
67    #[clap(long)]
68    generate: Option<String>,
69
70    /// Whether to run differential testing mode.
71    #[clap(long)]
72    differential_testing: bool,
73
74    /// Path to weight configuration file.
75    #[clap(long, default_value = "src/tests/sqlsmith/config.yml")]
76    weight_config_path: String,
77
78    /// Features to enable (e.g. eowc).
79    #[clap(long = "enable", value_delimiter = ',', action = clap::ArgAction::Append)]
80    enabled_features: Vec<String>,
81}
82
83#[derive(clap::Subcommand, Clone, Debug)]
84enum Commands {
85    /// Prints the currently supported function/operator table.
86    #[clap(name = "print-function-table")]
87    PrintFunctionTable,
88
89    /// Run testing.
90    Test(TestOptions),
91}
92
93#[tokio::main(flavor = "multi_thread", worker_threads = 5)]
94async fn main() {
95    tracing_subscriber::fmt::init();
96
97    let opt = Opt::parse();
98    let command = opt.command;
99    let opt = match command {
100        Commands::PrintFunctionTable => {
101            println!("{}", print_function_table());
102            return;
103        }
104        Commands::Test(test_opts) => test_opts,
105    };
106    let (client, connection) = tokio_postgres::Config::new()
107        .host(&opt.host)
108        .port(opt.port)
109        .dbname(&opt.db)
110        .user(&opt.user)
111        .password(&opt.pass)
112        .connect_timeout(Duration::from_secs(5))
113        .connect(NoTls)
114        .await
115        .unwrap_or_else(|e| panic!("Failed to connect to database: {}", e));
116    tokio::spawn(async move {
117        if let Err(e) = connection.await {
118            tracing::error!("Postgres connection error: {:?}", e);
119        }
120    });
121    let mut config = Configuration::new(&opt.weight_config_path);
122    config.enable_features_from_args(&opt.enabled_features);
123    if opt.differential_testing {
124        return run_differential_testing(&client, &opt.testdata, opt.count, &config, None)
125            .await
126            .unwrap();
127    }
128    if let Some(outdir) = opt.generate {
129        generate(&client, &opt.testdata, opt.count, &outdir, &config, None).await;
130    } else {
131        run(&client, &opt.testdata, opt.count, &config, None).await;
132    }
133}