risingwave_simulation/
client.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use indexmap::IndexMap;
18use itertools::Itertools;
19use risingwave_sqlparser::ast::Statement;
20use risingwave_sqlparser::parser::Parser;
21use shell_words::split;
22use sqllogictest::{DBOutput, DefaultColumnType};
23
24use crate::ctl_ext::start_ctl;
25
26/// A RisingWave client.
27pub struct RisingWave {
28    client: tokio_postgres::Client,
29    task: tokio::task::JoinHandle<()>,
30    host: String,
31    dbname: String,
32    /// The `SET` statements that have been executed on this client.
33    /// We need to replay them when reconnecting.
34    set_stmts: SetStmts,
35}
36
37/// `SetStmts` stores and compacts all `SET` statements that have been executed in the client
38/// history.
39#[derive(Default)]
40pub struct SetStmts {
41    // variable name -> last set statement
42    stmts: IndexMap<String, String>,
43}
44
45impl SetStmts {
46    fn push(&mut self, sql: &str) {
47        let ast = Parser::parse_sql(sql).expect("a set statement should be parsed successfully");
48        match ast
49            .into_iter()
50            .exactly_one()
51            .expect("should contain only one statement")
52        {
53            // record `local` for variable and `SetTransaction` if supported in the future.
54            Statement::SetVariable {
55                local: _,
56                variable,
57                value: _,
58            } => {
59                let key = variable.real_value().to_lowercase();
60                // store complete sql as value.
61                self.stmts.insert(key, sql.to_owned());
62            }
63            _ => unreachable!(),
64        }
65    }
66
67    fn replay_iter(&self) -> impl Iterator<Item = &str> + '_ {
68        self.stmts.values().map(|s| s.as_str())
69    }
70}
71
72impl RisingWave {
73    pub async fn connect(
74        host: String,
75        dbname: String,
76    ) -> Result<Self, tokio_postgres::error::Error> {
77        let set_stmts = SetStmts::default();
78        let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
79        Ok(Self {
80            client,
81            task,
82            host,
83            dbname,
84            set_stmts,
85        })
86    }
87
88    pub async fn connect_inner(
89        host: &str,
90        dbname: &str,
91        set_stmts: &SetStmts,
92    ) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
93    {
94        let (client, connection) = tokio_postgres::Config::new()
95            .host(host)
96            .port(4566)
97            .dbname(dbname)
98            .user("root")
99            .connect_timeout(Duration::from_secs(5))
100            .connect(tokio_postgres::NoTls)
101            .await?;
102        let task = tokio::spawn(async move {
103            if let Err(e) = connection.await {
104                tracing::error!("postgres connection error: {e}");
105            }
106        });
107        // replay all SET statements
108        for stmt in set_stmts.replay_iter() {
109            client.simple_query(stmt).await?;
110        }
111        Ok((client, task))
112    }
113
114    pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
115        let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
116        self.client = client;
117        self.task = task;
118        Ok(())
119    }
120
121    /// Returns a reference of the inner Postgres client.
122    pub fn pg_client(&self) -> &tokio_postgres::Client {
123        &self.client
124    }
125}
126
127impl Drop for RisingWave {
128    fn drop(&mut self) {
129        self.task.abort();
130    }
131}
132
133fn parse_risedev_ctl_args(command: &std::process::Command) -> Option<Vec<String>> {
134    let program = command.get_program().to_str()?;
135    if program != "bash" && !program.ends_with("/bash") {
136        return None;
137    }
138
139    let mut args = command.get_args();
140    if args.next()?.to_str()? != "-c" {
141        return None;
142    }
143    let script = args.next()?.to_str()?;
144    if args.next().is_some() {
145        return None;
146    }
147
148    let mut parts = split(script).ok()?;
149    if parts.len() < 2 || parts[0] != "./risedev" || parts[1] != "ctl" {
150        return None;
151    }
152    Some(parts.split_off(2))
153}
154
155fn command_output(exit_code: i32, stderr: Vec<u8>) -> std::process::Output {
156    #[cfg(unix)]
157    {
158        use std::os::unix::process::ExitStatusExt;
159
160        std::process::Output {
161            status: std::process::ExitStatus::from_raw(exit_code << 8),
162            stdout: vec![],
163            stderr,
164        }
165    }
166
167    #[cfg(not(unix))]
168    {
169        unimplemented!("simulation mode does not support non-unix platforms")
170    }
171}
172
173#[async_trait::async_trait]
174impl sqllogictest::AsyncDB for RisingWave {
175    type ColumnType = DefaultColumnType;
176    type Error = tokio_postgres::error::Error;
177
178    async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
179        use sqllogictest::DBOutput;
180
181        if self.client.is_closed() {
182            // connection error, reset the client
183            self.reconnect().await?;
184        }
185
186        if sql.trim_start().to_lowercase().starts_with("set") {
187            self.set_stmts.push(sql);
188        }
189
190        let mut output = vec![];
191
192        let rows = self.client.simple_query(sql).await?;
193        let mut cnt = 0;
194        for row in rows {
195            let mut row_vec = vec![];
196            match row {
197                tokio_postgres::SimpleQueryMessage::Row(row) => {
198                    for i in 0..row.len() {
199                        match row.get(i) {
200                            Some(v) => {
201                                if v.is_empty() {
202                                    row_vec.push("(empty)".to_owned());
203                                } else {
204                                    row_vec.push(v.to_owned());
205                                }
206                            }
207                            None => row_vec.push("NULL".to_owned()),
208                        }
209                    }
210                }
211                tokio_postgres::SimpleQueryMessage::CommandComplete(cnt_) => {
212                    cnt = cnt_;
213                    break;
214                }
215                _ => unreachable!(),
216            }
217            output.push(row_vec);
218        }
219
220        if output.is_empty() {
221            Ok(DBOutput::StatementComplete(cnt))
222        } else {
223            Ok(DBOutput::Rows {
224                types: vec![DefaultColumnType::Any; output[0].len()],
225                rows: output,
226            })
227        }
228    }
229
230    async fn shutdown(&mut self) {}
231
232    fn engine_name(&self) -> &str {
233        "risingwave"
234    }
235
236    async fn sleep(dur: Duration) {
237        tokio::time::sleep(dur).await
238    }
239
240    async fn run_command(command: std::process::Command) -> std::io::Result<std::process::Output> {
241        if let Some(ctl_args) = parse_risedev_ctl_args(&command) {
242            let output = match start_ctl(ctl_args).await {
243                Ok(()) => command_output(0, vec![]),
244                Err(err) => command_output(1, format!("{err:#}\n").into_bytes()),
245            };
246            return Ok(output);
247        }
248        unimplemented!("spawning process is not supported in simulation mode")
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn command_output_uses_exit_code_and_stderr() {
258        let output = command_output(1, b"ctl failed\n".to_vec());
259
260        assert!(!output.status.success());
261        assert_eq!(output.status.code(), Some(1));
262        assert_eq!(output.stderr, b"ctl failed\n");
263    }
264}