risingwave_simulation/
client.rs1use std::time::Duration;
16
17use indexmap::IndexMap;
18use itertools::Itertools;
19use risingwave_sqlparser::ast::Statement;
20use risingwave_sqlparser::parser::Parser;
21use shell_words::split;
22use sqllogictest::{DBOutput, DefaultColumnType};
23
24use crate::ctl_ext::start_ctl;
25
26pub struct RisingWave {
28 client: tokio_postgres::Client,
29 task: tokio::task::JoinHandle<()>,
30 host: String,
31 dbname: String,
32 set_stmts: SetStmts,
35}
36
37#[derive(Default)]
40pub struct SetStmts {
41 stmts: IndexMap<String, String>,
43}
44
45impl SetStmts {
46 fn push(&mut self, sql: &str) {
47 let ast = Parser::parse_sql(sql).expect("a set statement should be parsed successfully");
48 match Itertools::exactly_one(ast.into_iter()).expect("should contain only one statement") {
49 Statement::SetVariable {
51 local: _,
52 variable,
53 value: _,
54 } => {
55 let key = variable.real_value().to_lowercase();
56 self.stmts.insert(key, sql.to_owned());
58 }
59 _ => unreachable!(),
60 }
61 }
62
63 fn replay_iter(&self) -> impl Iterator<Item = &str> + '_ {
64 self.stmts.values().map(|s| s.as_str())
65 }
66}
67
68impl RisingWave {
69 pub async fn connect(
70 host: String,
71 dbname: String,
72 ) -> Result<Self, tokio_postgres::error::Error> {
73 let set_stmts = SetStmts::default();
74 let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
75 Ok(Self {
76 client,
77 task,
78 host,
79 dbname,
80 set_stmts,
81 })
82 }
83
84 pub async fn connect_inner(
85 host: &str,
86 dbname: &str,
87 set_stmts: &SetStmts,
88 ) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
89 {
90 let (client, connection) = tokio_postgres::Config::new()
91 .host(host)
92 .port(4566)
93 .dbname(dbname)
94 .user("root")
95 .connect_timeout(Duration::from_secs(5))
96 .connect(tokio_postgres::NoTls)
97 .await?;
98 let task = tokio::spawn(async move {
99 if let Err(e) = connection.await {
100 tracing::error!("postgres connection error: {e}");
101 }
102 });
103 for stmt in set_stmts.replay_iter() {
105 client.simple_query(stmt).await?;
106 }
107 Ok((client, task))
108 }
109
110 pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
111 let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
112 self.client = client;
113 self.task = task;
114 Ok(())
115 }
116
117 pub fn pg_client(&self) -> &tokio_postgres::Client {
119 &self.client
120 }
121}
122
123impl Drop for RisingWave {
124 fn drop(&mut self) {
125 self.task.abort();
126 }
127}
128
129fn parse_risedev_ctl_args(command: &std::process::Command) -> Option<Vec<String>> {
130 let program = command.get_program().to_str()?;
131 if program != "bash" && !program.ends_with("/bash") {
132 return None;
133 }
134
135 let mut args = command.get_args();
136 if args.next()?.to_str()? != "-c" {
137 return None;
138 }
139 let script = args.next()?.to_str()?;
140 if args.next().is_some() {
141 return None;
142 }
143
144 let mut parts = split(script).ok()?;
145 if parts.len() < 2 || parts[0] != "./risedev" || parts[1] != "ctl" {
146 return None;
147 }
148 Some(parts.split_off(2))
149}
150
151fn command_output(exit_code: i32, stderr: Vec<u8>) -> std::process::Output {
152 #[cfg(unix)]
153 {
154 use std::os::unix::process::ExitStatusExt;
155
156 std::process::Output {
157 status: std::process::ExitStatus::from_raw(exit_code << 8),
158 stdout: vec![],
159 stderr,
160 }
161 }
162
163 #[cfg(not(unix))]
164 {
165 unimplemented!("simulation mode does not support non-unix platforms")
166 }
167}
168
169#[async_trait::async_trait]
170impl sqllogictest::AsyncDB for RisingWave {
171 type ColumnType = DefaultColumnType;
172 type Error = tokio_postgres::error::Error;
173
174 async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
175 use sqllogictest::DBOutput;
176
177 if self.client.is_closed() {
178 self.reconnect().await?;
180 }
181
182 if sql.trim_start().to_lowercase().starts_with("set") {
183 self.set_stmts.push(sql);
184 }
185
186 let mut output = vec![];
187
188 let rows = self.client.simple_query(sql).await?;
189 let mut cnt = 0;
190 for row in rows {
191 let mut row_vec = vec![];
192 match row {
193 tokio_postgres::SimpleQueryMessage::Row(row) => {
194 for i in 0..row.len() {
195 match row.get(i) {
196 Some(v) => {
197 if v.is_empty() {
198 row_vec.push("(empty)".to_owned());
199 } else {
200 row_vec.push(v.to_owned());
201 }
202 }
203 None => row_vec.push("NULL".to_owned()),
204 }
205 }
206 }
207 tokio_postgres::SimpleQueryMessage::CommandComplete(cnt_) => {
208 cnt = cnt_;
209 break;
210 }
211 _ => unreachable!(),
212 }
213 output.push(row_vec);
214 }
215
216 if output.is_empty() {
217 Ok(DBOutput::StatementComplete(cnt))
218 } else {
219 Ok(DBOutput::Rows {
220 types: vec![DefaultColumnType::Any; output[0].len()],
221 rows: output,
222 })
223 }
224 }
225
226 async fn shutdown(&mut self) {}
227
228 fn engine_name(&self) -> &str {
229 "risingwave"
230 }
231
232 async fn sleep(dur: Duration) {
233 tokio::time::sleep(dur).await
234 }
235
236 async fn run_command(command: std::process::Command) -> std::io::Result<std::process::Output> {
237 if let Some(ctl_args) = parse_risedev_ctl_args(&command) {
238 let output = match start_ctl(ctl_args).await {
239 Ok(()) => command_output(0, vec![]),
240 Err(err) => command_output(1, format!("{err:#}\n").into_bytes()),
241 };
242 return Ok(output);
243 }
244 unimplemented!("spawning process is not supported in simulation mode")
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn command_output_uses_exit_code_and_stderr() {
254 let output = command_output(1, b"ctl failed\n".to_vec());
255
256 assert!(!output.status.success());
257 assert_eq!(output.status.code(), Some(1));
258 assert_eq!(output.stderr, b"ctl failed\n");
259 }
260}