sqlsmith/
main.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(register_tool)]
16#![register_tool(rw)]
17#![allow(rw::format_error)] // test code
18
19use core::panic;
20use std::time::Duration;
21
22use clap::Parser as ClapParser;
23use risingwave_sqlsmith::print_function_table;
24use risingwave_sqlsmith::test_runners::{generate, run, run_differential_testing};
25use tokio_postgres::NoTls;
26
27#[derive(ClapParser, Debug, Clone)]
28#[clap(about, version, author)]
29struct Opt {
30    #[clap(subcommand)]
31    command: Commands,
32}
33
34#[derive(clap::Args, Clone, Debug)]
35struct TestOptions {
36    /// The database server host.
37    #[clap(long, default_value = "localhost")]
38    host: String,
39
40    /// The database server port.
41    #[clap(short, long, default_value = "4566")]
42    port: u16,
43
44    /// The database name to connect.
45    #[clap(short, long, default_value = "dev")]
46    db: String,
47
48    /// The database username.
49    #[clap(short, long, default_value = "root")]
50    user: String,
51
52    /// The database password.
53    #[clap(short = 'w', long, default_value = "")]
54    pass: String,
55
56    /// Path to the testing data files.
57    #[clap(short, long)]
58    testdata: String,
59
60    /// The number of test cases to generate.
61    #[clap(long, default_value = "100")]
62    count: usize,
63
64    /// Output directory - only applicable if we are generating
65    /// query while testing.
66    #[clap(long)]
67    generate: Option<String>,
68
69    /// Whether to run differential testing mode.
70    #[clap(long)]
71    differential_testing: bool,
72}
73
74#[derive(clap::Subcommand, Clone, Debug)]
75enum Commands {
76    /// Prints the currently supported function/operator table.
77    #[clap(name = "print-function-table")]
78    PrintFunctionTable,
79
80    /// Run testing.
81    Test(TestOptions),
82}
83
84#[tokio::main(flavor = "multi_thread", worker_threads = 5)]
85async fn main() {
86    tracing_subscriber::fmt::init();
87
88    let opt = Opt::parse();
89    let command = opt.command;
90    let opt = match command {
91        Commands::PrintFunctionTable => {
92            println!("{}", print_function_table());
93            return;
94        }
95        Commands::Test(test_opts) => test_opts,
96    };
97    let (client, connection) = tokio_postgres::Config::new()
98        .host(&opt.host)
99        .port(opt.port)
100        .dbname(&opt.db)
101        .user(&opt.user)
102        .password(&opt.pass)
103        .connect_timeout(Duration::from_secs(5))
104        .connect(NoTls)
105        .await
106        .unwrap_or_else(|e| panic!("Failed to connect to database: {}", e));
107    tokio::spawn(async move {
108        if let Err(e) = connection.await {
109            tracing::error!("Postgres connection error: {:?}", e);
110        }
111    });
112    if opt.differential_testing {
113        return run_differential_testing(&client, &opt.testdata, opt.count, None)
114            .await
115            .unwrap();
116    }
117    if let Some(outdir) = opt.generate {
118        generate(&client, &opt.testdata, opt.count, &outdir, None).await;
119    } else {
120        run(&client, &opt.testdata, opt.count, None).await;
121    }
122}