risingwave_simulation/
client.rs1use std::time::Duration;
16
17use indexmap::IndexMap;
18use itertools::Itertools;
19use risingwave_sqlparser::ast::Statement;
20use risingwave_sqlparser::parser::Parser;
21use sqllogictest::{DBOutput, DefaultColumnType};
22
23pub struct RisingWave {
25 client: tokio_postgres::Client,
26 task: tokio::task::JoinHandle<()>,
27 host: String,
28 dbname: String,
29 set_stmts: SetStmts,
32}
33
34#[derive(Default)]
37pub struct SetStmts {
38 stmts: IndexMap<String, String>,
40}
41
42impl SetStmts {
43 fn push(&mut self, sql: &str) {
44 let ast = Parser::parse_sql(sql).expect("a set statement should be parsed successfully");
45 match ast
46 .into_iter()
47 .exactly_one()
48 .expect("should contain only one statement")
49 {
50 Statement::SetVariable {
52 local: _,
53 variable,
54 value: _,
55 } => {
56 let key = variable.real_value().to_lowercase();
57 self.stmts.insert(key, sql.to_owned());
59 }
60 _ => unreachable!(),
61 }
62 }
63
64 fn replay_iter(&self) -> impl Iterator<Item = &str> + '_ {
65 self.stmts.values().map(|s| s.as_str())
66 }
67}
68
69impl RisingWave {
70 pub async fn connect(
71 host: String,
72 dbname: String,
73 ) -> Result<Self, tokio_postgres::error::Error> {
74 let set_stmts = SetStmts::default();
75 let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
76 Ok(Self {
77 client,
78 task,
79 host,
80 dbname,
81 set_stmts,
82 })
83 }
84
85 pub async fn connect_inner(
86 host: &str,
87 dbname: &str,
88 set_stmts: &SetStmts,
89 ) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
90 {
91 let (client, connection) = tokio_postgres::Config::new()
92 .host(host)
93 .port(4566)
94 .dbname(dbname)
95 .user("root")
96 .connect_timeout(Duration::from_secs(5))
97 .connect(tokio_postgres::NoTls)
98 .await?;
99 let task = tokio::spawn(async move {
100 if let Err(e) = connection.await {
101 tracing::error!("postgres connection error: {e}");
102 }
103 });
104 for stmt in set_stmts.replay_iter() {
106 client.simple_query(stmt).await?;
107 }
108 Ok((client, task))
109 }
110
111 pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
112 let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
113 self.client = client;
114 self.task = task;
115 Ok(())
116 }
117
118 pub fn pg_client(&self) -> &tokio_postgres::Client {
120 &self.client
121 }
122}
123
124impl Drop for RisingWave {
125 fn drop(&mut self) {
126 self.task.abort();
127 }
128}
129
130#[async_trait::async_trait]
131impl sqllogictest::AsyncDB for RisingWave {
132 type ColumnType = DefaultColumnType;
133 type Error = tokio_postgres::error::Error;
134
135 async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
136 use sqllogictest::DBOutput;
137
138 if self.client.is_closed() {
139 self.reconnect().await?;
141 }
142
143 if sql.trim_start().to_lowercase().starts_with("set") {
144 self.set_stmts.push(sql);
145 }
146
147 let mut output = vec![];
148
149 let rows = self.client.simple_query(sql).await?;
150 let mut cnt = 0;
151 for row in rows {
152 let mut row_vec = vec![];
153 match row {
154 tokio_postgres::SimpleQueryMessage::Row(row) => {
155 for i in 0..row.len() {
156 match row.get(i) {
157 Some(v) => {
158 if v.is_empty() {
159 row_vec.push("(empty)".to_owned());
160 } else {
161 row_vec.push(v.to_owned());
162 }
163 }
164 None => row_vec.push("NULL".to_owned()),
165 }
166 }
167 }
168 tokio_postgres::SimpleQueryMessage::CommandComplete(cnt_) => {
169 cnt = cnt_;
170 break;
171 }
172 _ => unreachable!(),
173 }
174 output.push(row_vec);
175 }
176
177 if output.is_empty() {
178 Ok(DBOutput::StatementComplete(cnt))
179 } else {
180 Ok(DBOutput::Rows {
181 types: vec![DefaultColumnType::Any; output[0].len()],
182 rows: output,
183 })
184 }
185 }
186
187 async fn shutdown(&mut self) {}
188
189 fn engine_name(&self) -> &str {
190 "risingwave"
191 }
192
193 async fn sleep(dur: Duration) {
194 tokio::time::sleep(dur).await
195 }
196
197 async fn run_command(_command: std::process::Command) -> std::io::Result<std::process::Output> {
198 unimplemented!("spawning process is not supported in simulation mode")
199 }
200}