risingwave_simulation/
client.rs1use std::time::Duration;
16
17use indexmap::IndexMap;
18use itertools::Itertools;
19use risingwave_sqlparser::ast::Statement;
20use risingwave_sqlparser::parser::Parser;
21use shell_words::split;
22use sqllogictest::{DBOutput, DefaultColumnType};
23
24use crate::ctl_ext::start_ctl;
25
26pub struct RisingWave {
28 client: tokio_postgres::Client,
29 task: tokio::task::JoinHandle<()>,
30 host: String,
31 dbname: String,
32 set_stmts: SetStmts,
35}
36
37#[derive(Default)]
40pub struct SetStmts {
41 stmts: IndexMap<String, String>,
43}
44
45impl SetStmts {
46 fn push(&mut self, sql: &str) {
47 let ast = Parser::parse_sql(sql).expect("a set statement should be parsed successfully");
48 match ast
49 .into_iter()
50 .exactly_one()
51 .expect("should contain only one statement")
52 {
53 Statement::SetVariable {
55 local: _,
56 variable,
57 value: _,
58 } => {
59 let key = variable.real_value().to_lowercase();
60 self.stmts.insert(key, sql.to_owned());
62 }
63 _ => unreachable!(),
64 }
65 }
66
67 fn replay_iter(&self) -> impl Iterator<Item = &str> + '_ {
68 self.stmts.values().map(|s| s.as_str())
69 }
70}
71
72impl RisingWave {
73 pub async fn connect(
74 host: String,
75 dbname: String,
76 ) -> Result<Self, tokio_postgres::error::Error> {
77 let set_stmts = SetStmts::default();
78 let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
79 Ok(Self {
80 client,
81 task,
82 host,
83 dbname,
84 set_stmts,
85 })
86 }
87
88 pub async fn connect_inner(
89 host: &str,
90 dbname: &str,
91 set_stmts: &SetStmts,
92 ) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
93 {
94 let (client, connection) = tokio_postgres::Config::new()
95 .host(host)
96 .port(4566)
97 .dbname(dbname)
98 .user("root")
99 .connect_timeout(Duration::from_secs(5))
100 .connect(tokio_postgres::NoTls)
101 .await?;
102 let task = tokio::spawn(async move {
103 if let Err(e) = connection.await {
104 tracing::error!("postgres connection error: {e}");
105 }
106 });
107 for stmt in set_stmts.replay_iter() {
109 client.simple_query(stmt).await?;
110 }
111 Ok((client, task))
112 }
113
114 pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
115 let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
116 self.client = client;
117 self.task = task;
118 Ok(())
119 }
120
121 pub fn pg_client(&self) -> &tokio_postgres::Client {
123 &self.client
124 }
125}
126
127impl Drop for RisingWave {
128 fn drop(&mut self) {
129 self.task.abort();
130 }
131}
132
133fn parse_risedev_ctl_args(command: &std::process::Command) -> Option<Vec<String>> {
134 let program = command.get_program().to_str()?;
135 if program != "bash" && !program.ends_with("/bash") {
136 return None;
137 }
138
139 let mut args = command.get_args();
140 if args.next()?.to_str()? != "-c" {
141 return None;
142 }
143 let script = args.next()?.to_str()?;
144 if args.next().is_some() {
145 return None;
146 }
147
148 let mut parts = split(script).ok()?;
149 if parts.len() < 2 || parts[0] != "./risedev" || parts[1] != "ctl" {
150 return None;
151 }
152 Some(parts.split_off(2))
153}
154
155fn command_output(exit_code: i32, stderr: Vec<u8>) -> std::process::Output {
156 #[cfg(unix)]
157 {
158 use std::os::unix::process::ExitStatusExt;
159
160 std::process::Output {
161 status: std::process::ExitStatus::from_raw(exit_code << 8),
162 stdout: vec![],
163 stderr,
164 }
165 }
166
167 #[cfg(not(unix))]
168 {
169 unimplemented!("simulation mode does not support non-unix platforms")
170 }
171}
172
173#[async_trait::async_trait]
174impl sqllogictest::AsyncDB for RisingWave {
175 type ColumnType = DefaultColumnType;
176 type Error = tokio_postgres::error::Error;
177
178 async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
179 use sqllogictest::DBOutput;
180
181 if self.client.is_closed() {
182 self.reconnect().await?;
184 }
185
186 if sql.trim_start().to_lowercase().starts_with("set") {
187 self.set_stmts.push(sql);
188 }
189
190 let mut output = vec![];
191
192 let rows = self.client.simple_query(sql).await?;
193 let mut cnt = 0;
194 for row in rows {
195 let mut row_vec = vec![];
196 match row {
197 tokio_postgres::SimpleQueryMessage::Row(row) => {
198 for i in 0..row.len() {
199 match row.get(i) {
200 Some(v) => {
201 if v.is_empty() {
202 row_vec.push("(empty)".to_owned());
203 } else {
204 row_vec.push(v.to_owned());
205 }
206 }
207 None => row_vec.push("NULL".to_owned()),
208 }
209 }
210 }
211 tokio_postgres::SimpleQueryMessage::CommandComplete(cnt_) => {
212 cnt = cnt_;
213 break;
214 }
215 _ => unreachable!(),
216 }
217 output.push(row_vec);
218 }
219
220 if output.is_empty() {
221 Ok(DBOutput::StatementComplete(cnt))
222 } else {
223 Ok(DBOutput::Rows {
224 types: vec![DefaultColumnType::Any; output[0].len()],
225 rows: output,
226 })
227 }
228 }
229
230 async fn shutdown(&mut self) {}
231
232 fn engine_name(&self) -> &str {
233 "risingwave"
234 }
235
236 async fn sleep(dur: Duration) {
237 tokio::time::sleep(dur).await
238 }
239
240 async fn run_command(command: std::process::Command) -> std::io::Result<std::process::Output> {
241 if let Some(ctl_args) = parse_risedev_ctl_args(&command) {
242 let output = match start_ctl(ctl_args).await {
243 Ok(()) => command_output(0, vec![]),
244 Err(err) => command_output(1, format!("{err:#}\n").into_bytes()),
245 };
246 return Ok(output);
247 }
248 unimplemented!("spawning process is not supported in simulation mode")
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn command_output_uses_exit_code_and_stderr() {
258 let output = command_output(1, b"ctl failed\n".to_vec());
259
260 assert!(!output.status.success());
261 assert_eq!(output.status.code(), Some(1));
262 assert_eq!(output.stderr, b"ctl failed\n");
263 }
264}