risingwave_test_runner/
test_runner.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Copyright 2016 TiKV Project Authors. Licensed under Apache-2.0.
16use std::cell::RefCell;
17use std::env;
18
19use crate::test::*;
20/// general tests function
21pub fn run_general_test(cases: &[&TestDescAndFn]) {
22    run_test_inner(cases, TestGeneralStates)
23}
24#[derive(Clone)]
25struct TestGeneralStates;
26
27pub trait TestHook {
28    fn setup(&mut self);
29    fn teardown(&mut self);
30}
31
32impl TestHook for TestGeneralStates {
33    fn setup(&mut self) {}
34
35    fn teardown(&mut self) {}
36}
37
38struct TestWatcher<T: TestHook> {
39    name: String,
40    hook: T,
41}
42
43impl<H: TestHook + 'static> TestWatcher<H> {
44    fn new(name: String, mut hook: H) -> TestWatcher<H> {
45        println!("test is runner,{}", name);
46        hook.setup();
47        TestWatcher { name, hook }
48    }
49}
50
51impl<H: TestHook> Drop for TestWatcher<H> {
52    fn drop(&mut self) {
53        self.hook.teardown();
54        println!("test is drop,{}", self.name);
55    }
56}
57
58pub fn run_test_inner(cases: &[&TestDescAndFn], hook: impl TestHook + 'static + Clone + Send) {
59    let cases = cases
60        .iter()
61        .map(|case| {
62            let name = case.desc.name.as_slice().to_owned();
63            let h = hook.clone();
64            let f = match case.testfn {
65                TestFn::StaticTestFn(f) => TestFn::DynTestFn(Box::new(move || {
66                    let _watcher = TestWatcher::new(name, h);
67                    f()
68                })),
69                TestFn::StaticBenchFn(f) => TestFn::DynTestFn(Box::new(move || {
70                    let _watcher = TestWatcher::new(name, h);
71                    bench::run_once(f)
72                })),
73                ref f => panic!("unexpected testfn {:?}", f),
74            };
75            TestDescAndFn {
76                desc: case.desc.clone(),
77                testfn: f,
78            }
79        })
80        .collect();
81    let args = env::args().collect::<Vec<_>>();
82    test_main(&args, cases, None)
83}
84
85thread_local!(static FS: RefCell<Option<fail::FailScenario<'static>>> = const { RefCell::new(None) });
86#[derive(Clone)]
87struct FailPointHook;
88
89impl TestHook for FailPointHook {
90    fn setup(&mut self) {
91        FS.with(|s| {
92            s.borrow_mut().take();
93            *s.borrow_mut() = Some(fail::FailScenario::setup())
94        })
95    }
96
97    fn teardown(&mut self) {
98        FS.with(|s| {
99            s.borrow_mut().take();
100        })
101    }
102}
103
104#[derive(Clone)]
105struct SyncPointHook;
106
107impl TestHook for SyncPointHook {
108    fn setup(&mut self) {
109        sync_point::reset();
110    }
111
112    fn teardown(&mut self) {
113        sync_point::reset();
114    }
115}
116
117// End Copyright 2016 TiKV Project Authors. Licensed under Apache-2.0.
118pub fn run_failpont_tests(cases: &[&TestDescAndFn]) {
119    let mut cases1 = vec![];
120    let mut cases2 = vec![];
121    let mut cases3 = vec![];
122    cases.iter().for_each(|case| {
123        if case.desc.name.as_slice().contains("test_syncpoints") {
124            // sync_point tests should specify #[serial], because sync_point lib doesn't implement
125            // an implicit global lock to order tests like fail-rs.
126            cases1.push(*case);
127        } else if case.desc.name.as_slice().contains("test_failpoints") {
128            cases2.push(*case);
129        } else {
130            cases3.push(*case);
131        }
132    });
133    if !cases1.is_empty() {
134        run_test_inner(cases1.as_slice(), SyncPointHook);
135    }
136    if !cases2.is_empty() {
137        run_test_inner(cases2.as_slice(), FailPointHook);
138    }
139    if !cases3.is_empty() {
140        run_general_test(cases3.as_slice());
141    }
142}