risingwave_state_cleaning_test/
main.rs1#![feature(register_tool)]
16#![register_tool(rw)]
17#![allow(rw::format_error)] use std::collections::HashSet;
20use std::path::PathBuf;
21use std::str::FromStr;
22use std::time::{Duration, SystemTime, UNIX_EPOCH};
23
24use clap::Parser;
25use futures::{StreamExt, TryStreamExt};
26use regex::Regex;
27use serde::Deserialize;
28use serde_with::{OneOrMany, serde_as};
29use tokio::fs;
30use tokio_postgres::{NoTls, SimpleQueryMessage};
31use tokio_stream::wrappers::ReadDirStream;
32use tracing::{debug, error, info};
33
34#[derive(clap::Parser, Clone, Debug)]
35struct TestOptions {
36 #[clap(long, default_value = "localhost")]
38 host: String,
39
40 #[clap(short, long, default_value = "4566")]
42 port: u16,
43
44 #[clap(short, long, default_value = "dev")]
46 db: String,
47
48 #[clap(short, long, default_value = "root")]
50 user: String,
51
52 #[clap(short = 'w', long, default_value = "")]
54 pass: String,
55}
56
57#[derive(Debug, Clone, Deserialize)]
58struct BoundTable {
59 pattern: String,
60 limit: usize,
61}
62
63#[serde_as]
64#[derive(Debug, Clone, Deserialize)]
65struct TestCase {
66 name: String,
67 init_sqls: Vec<String>,
68 #[serde_as(deserialize_as = "OneOrMany<_>")]
69 bound_tables: Vec<BoundTable>,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct TestFile {
74 test: Vec<TestCase>,
75}
76
77async fn validate_case(
78 client: &tokio_postgres::Client,
79 TestCase {
80 name,
81 init_sqls,
82 bound_tables,
83 }: TestCase,
84) -> anyhow::Result<()> {
85 info!(%name, "validating");
86
87 for sql in init_sqls {
88 client.simple_query(&sql).await?;
89 }
90
91 let msgs = client.simple_query("SHOW INTERNAL TABLES").await?;
92 let internal_tables: HashSet<String> = msgs
93 .into_iter()
94 .filter_map(|msg| {
95 let SimpleQueryMessage::Row(row) = msg else {
96 return None;
97 };
98 Some(row.get("Name").unwrap().to_owned())
99 })
100 .collect();
101 info!(?internal_tables, "found tables");
102
103 #[derive(Debug)]
104 struct ProcessedBoundTable {
105 interested_tables: Vec<String>,
106 limit: usize,
107 }
108
109 let tables: Vec<_> = bound_tables
110 .into_iter()
111 .map(|t| {
112 let pattern = Regex::new(&t.pattern).unwrap();
113 let interested_tables = internal_tables
114 .iter()
115 .filter(|t| pattern.is_match(t))
116 .cloned()
117 .collect::<Vec<_>>();
118 ProcessedBoundTable {
119 interested_tables,
120 limit: t.limit,
121 }
122 })
123 .collect();
124
125 info!(?tables, "start checking");
126
127 const CHECK_COUNT: usize = 100;
128 const CHECK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
129
130 for i in 0..CHECK_COUNT {
131 for ProcessedBoundTable {
132 interested_tables,
133 limit,
134 } in &tables
135 {
136 for table in interested_tables {
137 let sql = format!("SELECT COUNT(*) FROM {}", table);
138 let res = client.query_one(&sql, &[]).await?;
139 let cnt: i64 = res.get(0);
140 debug!(iter=i, %table, %cnt, "checking");
141 if cnt > *limit as i64 {
142 anyhow::bail!(
143 "Table {} has {} rows, which is more than limit {}",
144 table,
145 cnt,
146 limit
147 );
148 }
149 }
150 }
151
152 tokio::time::sleep(CHECK_INTERVAL).await;
153 }
154
155 Ok(())
156}
157
158#[tokio::main]
159async fn main() -> anyhow::Result<()> {
160 risingwave_rt::init_risingwave_logger(risingwave_rt::LoggerSettings::default());
161
162 let opt = TestOptions::parse();
163
164 let conn_builder = tokio_postgres::Config::new()
165 .host(&opt.host)
166 .port(opt.port)
167 .user(&opt.user)
168 .password(&opt.pass)
169 .connect_timeout(Duration::from_secs(5))
170 .clone();
171
172 let (main_client, connection) = conn_builder
173 .clone()
174 .dbname(&opt.db)
175 .connect(NoTls)
176 .await
177 .unwrap_or_else(|e| panic!("Failed to connect to database: {}", e));
178
179 tokio::spawn(async move {
180 if let Err(e) = connection.await {
181 error!(?e, "connection error");
182 }
183 });
184
185 let now = SystemTime::now()
186 .duration_since(UNIX_EPOCH)
187 .expect("Time went backwards")
188 .as_secs();
189
190 let manifest = env!("CARGO_MANIFEST_DIR");
191
192 let data_dir = PathBuf::from_str(manifest).unwrap().join("data");
193
194 ReadDirStream::new(fs::read_dir(data_dir).await?)
195 .map(|path| async {
196 let path = path?.path();
197 let content = tokio::fs::read_to_string(&path).await?;
198 let test_file: TestFile = toml::from_str(&content)?;
199 let cases = test_file.test;
200
201 let test_name = path.file_stem().unwrap().to_string_lossy();
202
203 let cur_db_name = format!("state_cleaning_test_{}_{}", test_name, now);
204
205 main_client
206 .simple_query(&format!("CREATE DATABASE {}", cur_db_name))
207 .await?;
208
209 let (client, connection) = conn_builder
210 .clone()
211 .dbname(&cur_db_name)
212 .connect(NoTls)
213 .await?;
214
215 info!(%test_name, %cur_db_name, "run test in new database");
216
217 tokio::spawn(async move {
218 if let Err(e) = connection.await {
219 error!(?e, "connection error");
220 }
221 });
222
223 for case in cases {
224 validate_case(&client, case).await?;
225 }
226
227 Ok::<_, anyhow::Error>(())
228 })
229 .buffer_unordered(16)
230 .try_collect::<()>()
231 .await?;
232
233 Ok(())
234}