risingwave_state_cleaning_test/
main.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(register_tool)]
16#![register_tool(rw)]
17#![allow(rw::format_error)] // test code
18
19use std::collections::HashSet;
20use std::path::PathBuf;
21use std::str::FromStr;
22use std::time::{Duration, SystemTime, UNIX_EPOCH};
23
24use clap::Parser;
25use futures::{StreamExt, TryStreamExt};
26use regex::Regex;
27use serde::Deserialize;
28use serde_with::{OneOrMany, serde_as};
29use tokio::fs;
30use tokio_postgres::{NoTls, SimpleQueryMessage};
31use tokio_stream::wrappers::ReadDirStream;
32use tracing::{debug, error, info};
33
34#[derive(clap::Parser, Clone, Debug)]
35struct TestOptions {
36    /// The database server host.
37    #[clap(long, default_value = "localhost")]
38    host: String,
39
40    /// The database server port.
41    #[clap(short, long, default_value = "4566")]
42    port: u16,
43
44    /// The database name to connect.
45    #[clap(short, long, default_value = "dev")]
46    db: String,
47
48    /// The database username.
49    #[clap(short, long, default_value = "root")]
50    user: String,
51
52    /// The database password.
53    #[clap(short = 'w', long, default_value = "")]
54    pass: String,
55}
56
57#[derive(Debug, Clone, Deserialize)]
58struct BoundTable {
59    pattern: String,
60    limit: usize,
61}
62
63#[serde_as]
64#[derive(Debug, Clone, Deserialize)]
65struct TestCase {
66    name: String,
67    init_sqls: Vec<String>,
68    #[serde_as(deserialize_as = "OneOrMany<_>")]
69    bound_tables: Vec<BoundTable>,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct TestFile {
74    test: Vec<TestCase>,
75}
76
77async fn validate_case(
78    client: &tokio_postgres::Client,
79    TestCase {
80        name,
81        init_sqls,
82        bound_tables,
83    }: TestCase,
84) -> anyhow::Result<()> {
85    info!(%name, "validating");
86
87    for sql in init_sqls {
88        client.simple_query(&sql).await?;
89    }
90
91    let msgs = client.simple_query("SHOW INTERNAL TABLES").await?;
92    let internal_tables: HashSet<String> = msgs
93        .into_iter()
94        .filter_map(|msg| {
95            let SimpleQueryMessage::Row(row) = msg else {
96                return None;
97            };
98            Some(row.get("Name").unwrap().to_owned())
99        })
100        .collect();
101    info!(?internal_tables, "found tables");
102
103    #[derive(Debug)]
104    struct ProcessedBoundTable {
105        interested_tables: Vec<String>,
106        limit: usize,
107    }
108
109    let tables: Vec<_> = bound_tables
110        .into_iter()
111        .map(|t| {
112            let pattern = Regex::new(&t.pattern).unwrap();
113            let interested_tables = internal_tables
114                .iter()
115                .filter(|t| pattern.is_match(t))
116                .cloned()
117                .collect::<Vec<_>>();
118            ProcessedBoundTable {
119                interested_tables,
120                limit: t.limit,
121            }
122        })
123        .collect();
124
125    info!(?tables, "start checking");
126
127    const CHECK_COUNT: usize = 100;
128    const CHECK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
129
130    for i in 0..CHECK_COUNT {
131        for ProcessedBoundTable {
132            interested_tables,
133            limit,
134        } in &tables
135        {
136            for table in interested_tables {
137                let sql = format!("SELECT COUNT(*) FROM {}", table);
138                let res = client.query_one(&sql, &[]).await?;
139                let cnt: i64 = res.get(0);
140                debug!(iter=i, %table, %cnt, "checking");
141                if cnt > *limit as i64 {
142                    anyhow::bail!(
143                        "Table {} has {} rows, which is more than limit {}",
144                        table,
145                        cnt,
146                        limit
147                    );
148                }
149            }
150        }
151
152        tokio::time::sleep(CHECK_INTERVAL).await;
153    }
154
155    Ok(())
156}
157
158#[tokio::main]
159async fn main() -> anyhow::Result<()> {
160    risingwave_rt::init_risingwave_logger(risingwave_rt::LoggerSettings::default());
161
162    let opt = TestOptions::parse();
163
164    let conn_builder = tokio_postgres::Config::new()
165        .host(&opt.host)
166        .port(opt.port)
167        .user(&opt.user)
168        .password(&opt.pass)
169        .connect_timeout(Duration::from_secs(5))
170        .clone();
171
172    let (main_client, connection) = conn_builder
173        .clone()
174        .dbname(&opt.db)
175        .connect(NoTls)
176        .await
177        .unwrap_or_else(|e| panic!("Failed to connect to database: {}", e));
178
179    tokio::spawn(async move {
180        if let Err(e) = connection.await {
181            error!(?e, "connection error");
182        }
183    });
184
185    let now = SystemTime::now()
186        .duration_since(UNIX_EPOCH)
187        .expect("Time went backwards")
188        .as_secs();
189
190    let manifest = env!("CARGO_MANIFEST_DIR");
191
192    let data_dir = PathBuf::from_str(manifest).unwrap().join("data");
193
194    ReadDirStream::new(fs::read_dir(data_dir).await?)
195        .map(|path| async {
196            let path = path?.path();
197            let content = tokio::fs::read_to_string(&path).await?;
198            let test_file: TestFile = toml::from_str(&content)?;
199            let cases = test_file.test;
200
201            let test_name = path.file_stem().unwrap().to_string_lossy();
202
203            let cur_db_name = format!("state_cleaning_test_{}_{}", test_name, now);
204
205            main_client
206                .simple_query(&format!("CREATE DATABASE {}", cur_db_name))
207                .await?;
208
209            let (client, connection) = conn_builder
210                .clone()
211                .dbname(&cur_db_name)
212                .connect(NoTls)
213                .await?;
214
215            info!(%test_name, %cur_db_name, "run test in new database");
216
217            tokio::spawn(async move {
218                if let Err(e) = connection.await {
219                    error!(?e, "connection error");
220                }
221            });
222
223            for case in cases {
224                validate_case(&client, case).await?;
225            }
226
227            Ok::<_, anyhow::Error>(())
228        })
229        .buffer_unordered(16)
230        .try_collect::<()>()
231        .await?;
232
233    Ok(())
234}