1#![feature(register_tool)]
16#![register_tool(rw)]
17#![allow(rw::format_error)] use core::panic;
20use std::time::Duration;
21
22use clap::Parser as ClapParser;
23use risingwave_sqlsmith::config::Configuration;
24use risingwave_sqlsmith::print_function_table;
25use risingwave_sqlsmith::test_runners::{generate, run, run_differential_testing};
26use tokio_postgres::NoTls;
27
28#[derive(ClapParser, Debug, Clone)]
29#[clap(about, version, author)]
30struct Opt {
31    #[clap(subcommand)]
32    command: Commands,
33}
34
35#[derive(clap::Args, Clone, Debug)]
36struct TestOptions {
37    #[clap(long, default_value = "localhost")]
39    host: String,
40
41    #[clap(short, long, default_value = "4566")]
43    port: u16,
44
45    #[clap(short, long, default_value = "dev")]
47    db: String,
48
49    #[clap(short, long, default_value = "root")]
51    user: String,
52
53    #[clap(short = 'w', long, default_value = "")]
55    pass: String,
56
57    #[clap(short, long)]
59    testdata: String,
60
61    #[clap(long, default_value = "100")]
63    count: usize,
64
65    #[clap(long)]
68    generate: Option<String>,
69
70    #[clap(long)]
72    differential_testing: bool,
73
74    #[clap(long, default_value = "src/tests/sqlsmith/config.yml")]
76    weight_config_path: String,
77
78    #[clap(long = "enable", value_delimiter = ',', action = clap::ArgAction::Append)]
80    enabled_features: Vec<String>,
81}
82
83#[derive(clap::Subcommand, Clone, Debug)]
84enum Commands {
85    #[clap(name = "print-function-table")]
87    PrintFunctionTable,
88
89    Test(TestOptions),
91}
92
93#[tokio::main(flavor = "multi_thread", worker_threads = 5)]
94async fn main() {
95    tracing_subscriber::fmt::init();
96
97    let opt = Opt::parse();
98    let command = opt.command;
99    let opt = match command {
100        Commands::PrintFunctionTable => {
101            println!("{}", print_function_table());
102            return;
103        }
104        Commands::Test(test_opts) => test_opts,
105    };
106    let (client, connection) = tokio_postgres::Config::new()
107        .host(&opt.host)
108        .port(opt.port)
109        .dbname(&opt.db)
110        .user(&opt.user)
111        .password(&opt.pass)
112        .connect_timeout(Duration::from_secs(5))
113        .connect(NoTls)
114        .await
115        .unwrap_or_else(|e| panic!("Failed to connect to database: {}", e));
116    tokio::spawn(async move {
117        if let Err(e) = connection.await {
118            tracing::error!("Postgres connection error: {:?}", e);
119        }
120    });
121    let mut config = Configuration::new(&opt.weight_config_path);
122    config.enable_features_from_args(&opt.enabled_features);
123    if opt.differential_testing {
124        return run_differential_testing(&client, &opt.testdata, opt.count, &config, None)
125            .await
126            .unwrap();
127    }
128    if let Some(outdir) = opt.generate {
129        generate(&client, &opt.testdata, opt.count, &outdir, &config, None).await;
130    } else {
131        run(&client, &opt.testdata, opt.count, &config, None).await;
132    }
133}