risingwave_sqlsmith/sqlreduce/
checker.rs1use std::process::Command;
16use std::time::Duration;
17
18use risingwave_sqlparser::ast::Statement;
19use tokio_postgres::{Client, NoTls};
20
21pub struct Checker {
23 pub client: Client,
24 pub setup_stmts: Vec<Statement>,
25 restore_cmd: String,
26}
27
28impl Checker {
29 pub fn new(client: Client, setup_stmts: Vec<Statement>, restore_cmd: String) -> Self {
30 Self {
31 client,
32 setup_stmts,
33 restore_cmd,
34 }
35 }
36
37 pub async fn prepare_schema(&self) {
42 let _ = self
43 .client
44 .simple_query("CREATE SCHEMA IF NOT EXISTS sqlsmith_reducer;")
45 .await;
46 let _ = self
47 .client
48 .simple_query("SET search_path TO sqlsmith_reducer;")
49 .await;
50 }
51
52 pub async fn drop_schema(&self) {
56 let _ = self
57 .client
58 .simple_query("DROP SCHEMA IF EXISTS sqlsmith_reducer CASCADE;")
59 .await;
60 }
61
62 pub async fn is_failure_preserved(&mut self, old: &str, new: &str) -> bool {
66 self.reset_schema().await;
67 self.replay_setup().await;
68 let old_result = run_query(&mut self.client, old, &self.restore_cmd).await;
69
70 self.reset_schema().await;
71 self.replay_setup().await;
72 let new_result = run_query(&mut self.client, new, &self.restore_cmd).await;
73
74 tracing::info!("old_result: {:?}", old_result);
75 tracing::info!("new_result: {:?}", new_result);
76
77 old_result == new_result
78 }
79
80 async fn reset_schema(&self) {
82 let _ = self
83 .client
84 .simple_query("DROP SCHEMA IF EXISTS sqlsmith_reducer CASCADE;")
85 .await;
86 let _ = self
87 .client
88 .simple_query("CREATE SCHEMA sqlsmith_reducer;")
89 .await;
90 let _ = self
91 .client
92 .simple_query("SET search_path TO sqlsmith_reducer;")
93 .await;
94 }
95
96 async fn replay_setup(&self) {
98 for stmt in &self.setup_stmts {
99 let _ = self.client.simple_query(&stmt.to_string()).await;
100 }
101 }
102}
103
104pub async fn run_query(client: &mut Client, query: &str, restore_cmd: &str) -> (bool, String) {
106 match client.simple_query(query).await {
107 Ok(_) => (true, String::new()),
108 Err(e) => {
109 if e.is_closed() {
110 tracing::error!("Frontend panic detected, restoring with `{restore_cmd}`...");
111
112 let status = Command::new("sh").arg("-c").arg(restore_cmd).status();
113 match status {
114 Ok(s) if s.success() => tracing::info!("restore cmd executed successfully"),
115 Ok(s) => tracing::error!("restore cmd failed with status: {s}"),
116 Err(err) => tracing::error!("failed to execute restore cmd: {err}"),
117 }
118
119 match tokio_postgres::Config::new()
125 .host("localhost")
126 .port(4566)
127 .dbname("dev")
128 .user("root")
129 .password("")
130 .connect_timeout(Duration::from_secs(5))
131 .connect(NoTls)
132 .await
133 {
134 Ok((new_client, connection)) => {
135 tokio::spawn(async move {
136 if let Err(e) = connection.await {
137 tracing::error!("connection error: {}", e);
138 }
139 });
140 *client = new_client;
141 tracing::info!("Reconnected to Frontend after panic");
142
143 if let Err(err) = wait_for_recovery(client).await {
144 tracing::error!("RW failed to recover after frontend panic: {:?}", err);
145 } else {
146 tracing::info!("RW recovery complete (frontend case)");
147 }
148 }
149 Err(err) => {
150 tracing::error!("Failed to reconnect frontend: {}", err);
151 }
152 }
153 } else if e.as_db_error().is_some() {
154 tracing::error!("Compute panic detected, waiting for recovery...");
155 if let Err(err) = wait_for_recovery(client).await {
156 tracing::error!("RW failed to recover after compute panic: {:?}", err);
157 } else {
158 tracing::info!("RW recovery complete (compute case)");
159 }
160 } else {
161 tracing::error!("Other panics detected...");
162 }
163
164 (false, e.to_string())
165 }
166 }
167}
168
169pub async fn wait_for_recovery(client: &Client) -> anyhow::Result<()> {
171 let timeout = Duration::from_secs(300);
172 let mut interval = tokio::time::interval(Duration::from_millis(100));
173
174 let res: Result<(), anyhow::Error> = tokio::time::timeout(timeout, async {
175 loop {
176 let query_res = client.simple_query("select rw_recovery_status();").await;
177 if let Ok(messages) = query_res {
178 for msg in messages {
179 if let tokio_postgres::SimpleQueryMessage::Row(row) = msg
180 && let Some(status) = row.get(0)
181 && status == "RUNNING"
182 {
183 return Ok(());
184 }
185 }
186 }
187 interval.tick().await;
188 }
189 })
190 .await
191 .map_err(|_| anyhow::anyhow!("timed out waiting for recovery"))?;
192
193 res
194}