risingwave_simulation/
client.rs1use std::time::Duration;
16
17use itertools::Itertools;
18use lru::{Iter, LruCache};
19use risingwave_sqlparser::ast::Statement;
20use risingwave_sqlparser::parser::Parser;
21use sqllogictest::{DBOutput, DefaultColumnType};
22
23pub struct RisingWave {
25 client: tokio_postgres::Client,
26 task: tokio::task::JoinHandle<()>,
27 host: String,
28 dbname: String,
29 set_stmts: SetStmts,
32}
33
34pub struct SetStmts {
37 stmts_cache: LruCache<String, String>,
38}
39
40impl Default for SetStmts {
41 fn default() -> Self {
42 Self {
43 stmts_cache: LruCache::unbounded(),
44 }
45 }
46}
47
48struct SetStmtsIterator<'a, 'b>
49where
50 'a: 'b,
51{
52 _stmts: &'a SetStmts,
53 stmts_iter: core::iter::Rev<Iter<'b, String, String>>,
54}
55
56impl<'a> SetStmtsIterator<'a, '_> {
57 fn new(stmts: &'a SetStmts) -> Self {
58 Self {
59 _stmts: stmts,
60 stmts_iter: stmts.stmts_cache.iter().rev(),
61 }
62 }
63}
64
65impl SetStmts {
66 fn push(&mut self, sql: &str) {
67 let ast = Parser::parse_sql(sql).expect("a set statement should be parsed successfully");
68 match ast
69 .into_iter()
70 .exactly_one()
71 .expect("should contain only one statement")
72 {
73 Statement::SetVariable {
75 local: _,
76 variable,
77 value: _,
78 } => {
79 let key = variable.real_value().to_lowercase();
80 self.stmts_cache.put(key, sql.to_owned());
82 }
83 _ => unreachable!(),
84 }
85 }
86}
87
88impl Iterator for SetStmtsIterator<'_, '_> {
89 type Item = String;
90
91 fn next(&mut self) -> Option<Self::Item> {
92 let (_, stmt) = self.stmts_iter.next()?;
93 Some(stmt.clone())
94 }
95}
96
97impl RisingWave {
98 pub async fn connect(
99 host: String,
100 dbname: String,
101 ) -> Result<Self, tokio_postgres::error::Error> {
102 let set_stmts = SetStmts::default();
103 let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
104 Ok(Self {
105 client,
106 task,
107 host,
108 dbname,
109 set_stmts,
110 })
111 }
112
113 pub async fn connect_inner(
114 host: &str,
115 dbname: &str,
116 set_stmts: &SetStmts,
117 ) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
118 {
119 let (client, connection) = tokio_postgres::Config::new()
120 .host(host)
121 .port(4566)
122 .dbname(dbname)
123 .user("root")
124 .connect_timeout(Duration::from_secs(5))
125 .connect(tokio_postgres::NoTls)
126 .await?;
127 let task = tokio::spawn(async move {
128 if let Err(e) = connection.await {
129 tracing::error!("postgres connection error: {e}");
130 }
131 });
132 for stmt in SetStmtsIterator::new(set_stmts) {
134 client.simple_query(&stmt).await?;
135 }
136 Ok((client, task))
137 }
138
139 pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
140 let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
141 self.client = client;
142 self.task = task;
143 Ok(())
144 }
145
146 pub fn pg_client(&self) -> &tokio_postgres::Client {
148 &self.client
149 }
150}
151
152impl Drop for RisingWave {
153 fn drop(&mut self) {
154 self.task.abort();
155 }
156}
157
158#[async_trait::async_trait]
159impl sqllogictest::AsyncDB for RisingWave {
160 type ColumnType = DefaultColumnType;
161 type Error = tokio_postgres::error::Error;
162
163 async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
164 use sqllogictest::DBOutput;
165
166 if self.client.is_closed() {
167 self.reconnect().await?;
169 }
170
171 if sql.trim_start().to_lowercase().starts_with("set") {
172 self.set_stmts.push(sql);
173 }
174
175 let mut output = vec![];
176
177 let rows = self.client.simple_query(sql).await?;
178 let mut cnt = 0;
179 for row in rows {
180 let mut row_vec = vec![];
181 match row {
182 tokio_postgres::SimpleQueryMessage::Row(row) => {
183 for i in 0..row.len() {
184 match row.get(i) {
185 Some(v) => {
186 if v.is_empty() {
187 row_vec.push("(empty)".to_owned());
188 } else {
189 row_vec.push(v.to_owned());
190 }
191 }
192 None => row_vec.push("NULL".to_owned()),
193 }
194 }
195 }
196 tokio_postgres::SimpleQueryMessage::CommandComplete(cnt_) => {
197 cnt = cnt_;
198 break;
199 }
200 _ => unreachable!(),
201 }
202 output.push(row_vec);
203 }
204
205 if output.is_empty() {
206 Ok(DBOutput::StatementComplete(cnt))
207 } else {
208 Ok(DBOutput::Rows {
209 types: vec![DefaultColumnType::Any; output[0].len()],
210 rows: output,
211 })
212 }
213 }
214
215 async fn shutdown(&mut self) {}
216
217 fn engine_name(&self) -> &str {
218 "risingwave"
219 }
220
221 async fn sleep(dur: Duration) {
222 tokio::time::sleep(dur).await
223 }
224
225 async fn run_command(_command: std::process::Command) -> std::io::Result<std::process::Output> {
226 unimplemented!("spawning process is not supported in simulation mode")
227 }
228}