sqlsmith_reducer/
reducer.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use clap::{Parser, ValueEnum};
18use risingwave_sqlsmith::reducer::shrink_file;
19use risingwave_sqlsmith::sqlreduce::Strategy;
20use thiserror_ext::AsReport;
21use tokio_postgres::NoTls;
22use tracing_subscriber::EnvFilter;
23
24#[derive(Debug, Clone, ValueEnum)]
25enum ReductionStrategy {
26    Single,
27    Aggressive,
28    Consecutive,
29}
30
31/// Reduce an sql query
32#[derive(Parser, Debug)]
33#[command(author, version, about, long_about = None)]
34struct Args {
35    /// Input file
36    #[arg(short, long)]
37    input_file: String,
38
39    /// Output file
40    #[arg(short, long)]
41    output_file: String,
42
43    /// Reducer strategy
44    #[arg(short, long, default_value = "single")]
45    strategy: ReductionStrategy,
46
47    /// For consecutive strategy, number of elements to reduce at once (used only when strategy = consecutive)
48    #[arg(short, long, default_value_t = 2)]
49    consecutive_k: usize,
50
51    /// Command to restore RW
52    #[clap(long)]
53    run_rw_cmd: String,
54}
55
56#[tokio::main(flavor = "multi_thread", worker_threads = 5)]
57async fn main() {
58    _ = tracing_subscriber::fmt()
59        .with_env_filter(EnvFilter::from_default_env())
60        .with_ansi(console::colors_enabled_stderr() && console::colors_enabled())
61        .with_writer(std::io::stderr)
62        .try_init();
63
64    let args = Args::parse();
65
66    let (client, connection) = tokio_postgres::Config::new()
67        .host("localhost")
68        .port(4566)
69        .dbname("dev")
70        .user("root")
71        .password("")
72        .connect_timeout(Duration::from_secs(5))
73        .connect(NoTls)
74        .await
75        .unwrap_or_else(|e| panic!("Failed to connect to database: {}", e.as_report()));
76
77    tokio::spawn(async move {
78        if let Err(e) = connection.await {
79            tracing::error!(error = %e.as_report(), "Postgres connection error");
80        }
81    });
82
83    let strategy = match args.strategy {
84        ReductionStrategy::Single => Strategy::Single,
85        ReductionStrategy::Aggressive => Strategy::Aggressive,
86        ReductionStrategy::Consecutive => Strategy::Consecutive(args.consecutive_k),
87    };
88
89    shrink_file(
90        &args.input_file,
91        &args.output_file,
92        strategy,
93        client,
94        &args.run_rw_cmd,
95    )
96    .await
97    .unwrap();
98}