Skip to main content

risingwave_simulation/
client.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use indexmap::IndexMap;
18use itertools::Itertools;
19use risingwave_sqlparser::ast::Statement;
20use risingwave_sqlparser::parser::Parser;
21use shell_words::split;
22use sqllogictest::{DBOutput, DefaultColumnType};
23
24use crate::ctl_ext::start_ctl;
25
26/// A RisingWave client.
27pub struct RisingWave {
28    client: tokio_postgres::Client,
29    task: tokio::task::JoinHandle<()>,
30    host: String,
31    dbname: String,
32    /// The `SET` statements that have been executed on this client.
33    /// We need to replay them when reconnecting.
34    set_stmts: SetStmts,
35}
36
37/// `SetStmts` stores and compacts all `SET` statements that have been executed in the client
38/// history.
39#[derive(Default)]
40pub struct SetStmts {
41    // variable name -> last set statement
42    stmts: IndexMap<String, String>,
43}
44
45impl SetStmts {
46    fn push(&mut self, sql: &str) {
47        let ast = Parser::parse_sql(sql).expect("a set statement should be parsed successfully");
48        match Itertools::exactly_one(ast.into_iter()).expect("should contain only one statement") {
49            // record `local` for variable and `SetTransaction` if supported in the future.
50            Statement::SetVariable {
51                local: _,
52                variable,
53                value: _,
54            } => {
55                let key = variable.real_value().to_lowercase();
56                // store complete sql as value.
57                self.stmts.insert(key, sql.to_owned());
58            }
59            _ => unreachable!(),
60        }
61    }
62
63    fn replay_iter(&self) -> impl Iterator<Item = &str> + '_ {
64        self.stmts.values().map(|s| s.as_str())
65    }
66}
67
68impl RisingWave {
69    pub async fn connect(
70        host: String,
71        dbname: String,
72    ) -> Result<Self, tokio_postgres::error::Error> {
73        let set_stmts = SetStmts::default();
74        let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
75        Ok(Self {
76            client,
77            task,
78            host,
79            dbname,
80            set_stmts,
81        })
82    }
83
84    pub async fn connect_inner(
85        host: &str,
86        dbname: &str,
87        set_stmts: &SetStmts,
88    ) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
89    {
90        let (client, connection) = tokio_postgres::Config::new()
91            .host(host)
92            .port(4566)
93            .dbname(dbname)
94            .user("root")
95            .connect_timeout(Duration::from_secs(5))
96            .connect(tokio_postgres::NoTls)
97            .await?;
98        let task = tokio::spawn(async move {
99            if let Err(e) = connection.await {
100                tracing::error!("postgres connection error: {e}");
101            }
102        });
103        // replay all SET statements
104        for stmt in set_stmts.replay_iter() {
105            client.simple_query(stmt).await?;
106        }
107        Ok((client, task))
108    }
109
110    pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
111        let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
112        self.client = client;
113        self.task = task;
114        Ok(())
115    }
116
117    /// Returns a reference of the inner Postgres client.
118    pub fn pg_client(&self) -> &tokio_postgres::Client {
119        &self.client
120    }
121}
122
123impl Drop for RisingWave {
124    fn drop(&mut self) {
125        self.task.abort();
126    }
127}
128
129fn parse_risedev_ctl_args(command: &std::process::Command) -> Option<Vec<String>> {
130    let program = command.get_program().to_str()?;
131    if program != "bash" && !program.ends_with("/bash") {
132        return None;
133    }
134
135    let mut args = command.get_args();
136    if args.next()?.to_str()? != "-c" {
137        return None;
138    }
139    let script = args.next()?.to_str()?;
140    if args.next().is_some() {
141        return None;
142    }
143
144    let mut parts = split(script).ok()?;
145    if parts.len() < 2 || parts[0] != "./risedev" || parts[1] != "ctl" {
146        return None;
147    }
148    Some(parts.split_off(2))
149}
150
151fn command_output(exit_code: i32, stderr: Vec<u8>) -> std::process::Output {
152    #[cfg(unix)]
153    {
154        use std::os::unix::process::ExitStatusExt;
155
156        std::process::Output {
157            status: std::process::ExitStatus::from_raw(exit_code << 8),
158            stdout: vec![],
159            stderr,
160        }
161    }
162
163    #[cfg(not(unix))]
164    {
165        unimplemented!("simulation mode does not support non-unix platforms")
166    }
167}
168
169#[async_trait::async_trait]
170impl sqllogictest::AsyncDB for RisingWave {
171    type ColumnType = DefaultColumnType;
172    type Error = tokio_postgres::error::Error;
173
174    async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
175        use sqllogictest::DBOutput;
176
177        if self.client.is_closed() {
178            // connection error, reset the client
179            self.reconnect().await?;
180        }
181
182        if sql.trim_start().to_lowercase().starts_with("set") {
183            self.set_stmts.push(sql);
184        }
185
186        let mut output = vec![];
187
188        let rows = self.client.simple_query(sql).await?;
189        let mut cnt = 0;
190        for row in rows {
191            let mut row_vec = vec![];
192            match row {
193                tokio_postgres::SimpleQueryMessage::Row(row) => {
194                    for i in 0..row.len() {
195                        match row.get(i) {
196                            Some(v) => {
197                                if v.is_empty() {
198                                    row_vec.push("(empty)".to_owned());
199                                } else {
200                                    row_vec.push(v.to_owned());
201                                }
202                            }
203                            None => row_vec.push("NULL".to_owned()),
204                        }
205                    }
206                }
207                tokio_postgres::SimpleQueryMessage::CommandComplete(cnt_) => {
208                    cnt = cnt_;
209                    break;
210                }
211                _ => unreachable!(),
212            }
213            output.push(row_vec);
214        }
215
216        if output.is_empty() {
217            Ok(DBOutput::StatementComplete(cnt))
218        } else {
219            Ok(DBOutput::Rows {
220                types: vec![DefaultColumnType::Any; output[0].len()],
221                rows: output,
222            })
223        }
224    }
225
226    async fn shutdown(&mut self) {}
227
228    fn engine_name(&self) -> &str {
229        "risingwave"
230    }
231
232    async fn sleep(dur: Duration) {
233        tokio::time::sleep(dur).await
234    }
235
236    async fn run_command(command: std::process::Command) -> std::io::Result<std::process::Output> {
237        if let Some(ctl_args) = parse_risedev_ctl_args(&command) {
238            let output = match start_ctl(ctl_args).await {
239                Ok(()) => command_output(0, vec![]),
240                Err(err) => command_output(1, format!("{err:#}\n").into_bytes()),
241            };
242            return Ok(output);
243        }
244        unimplemented!("spawning process is not supported in simulation mode")
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn command_output_uses_exit_code_and_stderr() {
254        let output = command_output(1, b"ctl failed\n".to_vec());
255
256        assert!(!output.status.success());
257        assert_eq!(output.status.code(), Some(1));
258        assert_eq!(output.stderr, b"ctl failed\n");
259    }
260}