risingwave_simulation/
client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use itertools::Itertools;
18use lru::{Iter, LruCache};
19use risingwave_sqlparser::ast::Statement;
20use risingwave_sqlparser::parser::Parser;
21use sqllogictest::{DBOutput, DefaultColumnType};
22
23/// A RisingWave client.
24pub struct RisingWave {
25    client: tokio_postgres::Client,
26    task: tokio::task::JoinHandle<()>,
27    host: String,
28    dbname: String,
29    /// The `SET` statements that have been executed on this client.
30    /// We need to replay them when reconnecting.
31    set_stmts: SetStmts,
32}
33
34/// `SetStmts` stores and compacts all `SET` statements that have been executed in the client
35/// history.
36pub struct SetStmts {
37    stmts_cache: LruCache<String, String>,
38}
39
40impl Default for SetStmts {
41    fn default() -> Self {
42        Self {
43            stmts_cache: LruCache::unbounded(),
44        }
45    }
46}
47
48struct SetStmtsIterator<'a, 'b>
49where
50    'a: 'b,
51{
52    _stmts: &'a SetStmts,
53    stmts_iter: core::iter::Rev<Iter<'b, String, String>>,
54}
55
56impl<'a> SetStmtsIterator<'a, '_> {
57    fn new(stmts: &'a SetStmts) -> Self {
58        Self {
59            _stmts: stmts,
60            stmts_iter: stmts.stmts_cache.iter().rev(),
61        }
62    }
63}
64
65impl SetStmts {
66    fn push(&mut self, sql: &str) {
67        let ast = Parser::parse_sql(sql).expect("a set statement should be parsed successfully");
68        match ast
69            .into_iter()
70            .exactly_one()
71            .expect("should contain only one statement")
72        {
73            // record `local` for variable and `SetTransaction` if supported in the future.
74            Statement::SetVariable {
75                local: _,
76                variable,
77                value: _,
78            } => {
79                let key = variable.real_value().to_lowercase();
80                // store complete sql as value.
81                self.stmts_cache.put(key, sql.to_owned());
82            }
83            _ => unreachable!(),
84        }
85    }
86}
87
88impl Iterator for SetStmtsIterator<'_, '_> {
89    type Item = String;
90
91    fn next(&mut self) -> Option<Self::Item> {
92        let (_, stmt) = self.stmts_iter.next()?;
93        Some(stmt.clone())
94    }
95}
96
97impl RisingWave {
98    pub async fn connect(
99        host: String,
100        dbname: String,
101    ) -> Result<Self, tokio_postgres::error::Error> {
102        let set_stmts = SetStmts::default();
103        let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
104        Ok(Self {
105            client,
106            task,
107            host,
108            dbname,
109            set_stmts,
110        })
111    }
112
113    pub async fn connect_inner(
114        host: &str,
115        dbname: &str,
116        set_stmts: &SetStmts,
117    ) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
118    {
119        let (client, connection) = tokio_postgres::Config::new()
120            .host(host)
121            .port(4566)
122            .dbname(dbname)
123            .user("root")
124            .connect_timeout(Duration::from_secs(5))
125            .connect(tokio_postgres::NoTls)
126            .await?;
127        let task = tokio::spawn(async move {
128            if let Err(e) = connection.await {
129                tracing::error!("postgres connection error: {e}");
130            }
131        });
132        // replay all SET statements
133        for stmt in SetStmtsIterator::new(set_stmts) {
134            client.simple_query(&stmt).await?;
135        }
136        Ok((client, task))
137    }
138
139    pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
140        let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
141        self.client = client;
142        self.task = task;
143        Ok(())
144    }
145
146    /// Returns a reference of the inner Postgres client.
147    pub fn pg_client(&self) -> &tokio_postgres::Client {
148        &self.client
149    }
150}
151
152impl Drop for RisingWave {
153    fn drop(&mut self) {
154        self.task.abort();
155    }
156}
157
158#[async_trait::async_trait]
159impl sqllogictest::AsyncDB for RisingWave {
160    type ColumnType = DefaultColumnType;
161    type Error = tokio_postgres::error::Error;
162
163    async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
164        use sqllogictest::DBOutput;
165
166        if self.client.is_closed() {
167            // connection error, reset the client
168            self.reconnect().await?;
169        }
170
171        if sql.trim_start().to_lowercase().starts_with("set") {
172            self.set_stmts.push(sql);
173        }
174
175        let mut output = vec![];
176
177        let rows = self.client.simple_query(sql).await?;
178        let mut cnt = 0;
179        for row in rows {
180            let mut row_vec = vec![];
181            match row {
182                tokio_postgres::SimpleQueryMessage::Row(row) => {
183                    for i in 0..row.len() {
184                        match row.get(i) {
185                            Some(v) => {
186                                if v.is_empty() {
187                                    row_vec.push("(empty)".to_owned());
188                                } else {
189                                    row_vec.push(v.to_owned());
190                                }
191                            }
192                            None => row_vec.push("NULL".to_owned()),
193                        }
194                    }
195                }
196                tokio_postgres::SimpleQueryMessage::CommandComplete(cnt_) => {
197                    cnt = cnt_;
198                    break;
199                }
200                _ => unreachable!(),
201            }
202            output.push(row_vec);
203        }
204
205        if output.is_empty() {
206            Ok(DBOutput::StatementComplete(cnt))
207        } else {
208            Ok(DBOutput::Rows {
209                types: vec![DefaultColumnType::Any; output[0].len()],
210                rows: output,
211            })
212        }
213    }
214
215    async fn shutdown(&mut self) {}
216
217    fn engine_name(&self) -> &str {
218        "risingwave"
219    }
220
221    async fn sleep(dur: Duration) {
222        tokio::time::sleep(dur).await
223    }
224
225    async fn run_command(_command: std::process::Command) -> std::io::Result<std::process::Output> {
226        unimplemented!("spawning process is not supported in simulation mode")
227    }
228}