risingwave_simulation/
client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use indexmap::IndexMap;
18use itertools::Itertools;
19use risingwave_sqlparser::ast::Statement;
20use risingwave_sqlparser::parser::Parser;
21use sqllogictest::{DBOutput, DefaultColumnType};
22
23/// A RisingWave client.
24pub struct RisingWave {
25    client: tokio_postgres::Client,
26    task: tokio::task::JoinHandle<()>,
27    host: String,
28    dbname: String,
29    /// The `SET` statements that have been executed on this client.
30    /// We need to replay them when reconnecting.
31    set_stmts: SetStmts,
32}
33
34/// `SetStmts` stores and compacts all `SET` statements that have been executed in the client
35/// history.
36#[derive(Default)]
37pub struct SetStmts {
38    // variable name -> last set statement
39    stmts: IndexMap<String, String>,
40}
41
42impl SetStmts {
43    fn push(&mut self, sql: &str) {
44        let ast = Parser::parse_sql(sql).expect("a set statement should be parsed successfully");
45        match ast
46            .into_iter()
47            .exactly_one()
48            .expect("should contain only one statement")
49        {
50            // record `local` for variable and `SetTransaction` if supported in the future.
51            Statement::SetVariable {
52                local: _,
53                variable,
54                value: _,
55            } => {
56                let key = variable.real_value().to_lowercase();
57                // store complete sql as value.
58                self.stmts.insert(key, sql.to_owned());
59            }
60            _ => unreachable!(),
61        }
62    }
63
64    fn replay_iter(&self) -> impl Iterator<Item = &str> + '_ {
65        self.stmts.values().map(|s| s.as_str())
66    }
67}
68
69impl RisingWave {
70    pub async fn connect(
71        host: String,
72        dbname: String,
73    ) -> Result<Self, tokio_postgres::error::Error> {
74        let set_stmts = SetStmts::default();
75        let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
76        Ok(Self {
77            client,
78            task,
79            host,
80            dbname,
81            set_stmts,
82        })
83    }
84
85    pub async fn connect_inner(
86        host: &str,
87        dbname: &str,
88        set_stmts: &SetStmts,
89    ) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
90    {
91        let (client, connection) = tokio_postgres::Config::new()
92            .host(host)
93            .port(4566)
94            .dbname(dbname)
95            .user("root")
96            .connect_timeout(Duration::from_secs(5))
97            .connect(tokio_postgres::NoTls)
98            .await?;
99        let task = tokio::spawn(async move {
100            if let Err(e) = connection.await {
101                tracing::error!("postgres connection error: {e}");
102            }
103        });
104        // replay all SET statements
105        for stmt in set_stmts.replay_iter() {
106            client.simple_query(stmt).await?;
107        }
108        Ok((client, task))
109    }
110
111    pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
112        let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
113        self.client = client;
114        self.task = task;
115        Ok(())
116    }
117
118    /// Returns a reference of the inner Postgres client.
119    pub fn pg_client(&self) -> &tokio_postgres::Client {
120        &self.client
121    }
122}
123
124impl Drop for RisingWave {
125    fn drop(&mut self) {
126        self.task.abort();
127    }
128}
129
130#[async_trait::async_trait]
131impl sqllogictest::AsyncDB for RisingWave {
132    type ColumnType = DefaultColumnType;
133    type Error = tokio_postgres::error::Error;
134
135    async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
136        use sqllogictest::DBOutput;
137
138        if self.client.is_closed() {
139            // connection error, reset the client
140            self.reconnect().await?;
141        }
142
143        if sql.trim_start().to_lowercase().starts_with("set") {
144            self.set_stmts.push(sql);
145        }
146
147        let mut output = vec![];
148
149        let rows = self.client.simple_query(sql).await?;
150        let mut cnt = 0;
151        for row in rows {
152            let mut row_vec = vec![];
153            match row {
154                tokio_postgres::SimpleQueryMessage::Row(row) => {
155                    for i in 0..row.len() {
156                        match row.get(i) {
157                            Some(v) => {
158                                if v.is_empty() {
159                                    row_vec.push("(empty)".to_owned());
160                                } else {
161                                    row_vec.push(v.to_owned());
162                                }
163                            }
164                            None => row_vec.push("NULL".to_owned()),
165                        }
166                    }
167                }
168                tokio_postgres::SimpleQueryMessage::CommandComplete(cnt_) => {
169                    cnt = cnt_;
170                    break;
171                }
172                _ => unreachable!(),
173            }
174            output.push(row_vec);
175        }
176
177        if output.is_empty() {
178            Ok(DBOutput::StatementComplete(cnt))
179        } else {
180            Ok(DBOutput::Rows {
181                types: vec![DefaultColumnType::Any; output[0].len()],
182                rows: output,
183            })
184        }
185    }
186
187    async fn shutdown(&mut self) {}
188
189    fn engine_name(&self) -> &str {
190        "risingwave"
191    }
192
193    async fn sleep(dur: Duration) {
194        tokio::time::sleep(dur).await
195    }
196
197    async fn run_command(_command: std::process::Command) -> std::io::Result<std::process::Output> {
198        unimplemented!("spawning process is not supported in simulation mode")
199    }
200}