Skip to content

Commit bbbe5a0

Browse files
committed
Add simple testing framework
1 parent 7c00ec7 commit bbbe5a0

File tree

10 files changed

+504
-120
lines changed

10 files changed

+504
-120
lines changed

Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,4 @@ luau = ["mlua/luau"]
1616

1717
[dependencies]
1818
mlua = { version = "0.11" }
19-
20-
[dev-dependencies]
21-
mlua = { version = "0.11", features = ["macros"] }
19+
owo-colors = "4"

lua/assertions.lua

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
local opts = ... or {}
12
local assertions = {}
23

34
function assertions.assert_eq(left, right, message)
@@ -7,7 +8,8 @@ function assertions.assert_eq(left, right, message)
78
else
89
message = "assertion `left == right` failed!"
910
end
10-
error(string.format("%s\n left: %s\n right: %s", message, tostring(left), tostring(right)))
11+
local frame_level = opts.level or 2
12+
error(string.format("%s\n left: %s\n right: %s", message, tostring(left), tostring(right)), frame_level)
1113
end
1214
end
1315

@@ -18,7 +20,8 @@ function assertions.assert_ne(left, right, message)
1820
else
1921
message = "assertion `left ~= right` failed!"
2022
end
21-
error(string.format("%s\n left: %s\n right: %s", message, tostring(left), tostring(right)))
23+
local frame_level = opts.level or 2
24+
error(string.format("%s\n left: %s\n right: %s", message, tostring(left), tostring(right)), frame_level)
2225
end
2326
end
2427

@@ -80,9 +83,10 @@ function assertions.assert_same(left, right, message)
8083
else
8184
message = "assertion `left ~ right` failed!"
8285
end
83-
error(
86+
local error_msg =
8487
string.format("%s\n left%s: %s\n right%s: %s", message, level, tostring(left_v), level, tostring(right_v))
85-
)
88+
local frame_level = opts.level or 2
89+
error(error_msg, frame_level)
8690
end
8791
end
8892

lua/testing.lua

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
local deps = ...
2+
local assertions = deps.assertions
3+
4+
local println, style = deps.println, deps.style
5+
local instant = deps.instant
6+
7+
local Testing = {}
8+
Testing.__index = Testing
9+
10+
function Testing.new(name)
11+
local self = setmetatable({}, Testing)
12+
self._name = name
13+
self._tests = {}
14+
self._hooks = {
15+
before_all = {},
16+
after_all = {},
17+
before_each = {},
18+
after_each = {},
19+
}
20+
self._results = {}
21+
return self
22+
end
23+
24+
-- Test context passed to each test function
25+
local TestContext = {}
26+
TestContext.__index = TestContext
27+
28+
function TestContext.new(name)
29+
local self = setmetatable({}, TestContext)
30+
self.name = name
31+
return self
32+
end
33+
34+
-- Forward assertion methods to the assertions module
35+
function TestContext.assert_eq(a, b, msg)
36+
assertions.assert_eq(a, b, msg)
37+
end
38+
39+
function TestContext.assert_same(a, b, msg)
40+
assertions.assert_same(a, b, msg)
41+
end
42+
43+
function TestContext.assert(cond, msg)
44+
if not cond then
45+
if msg ~= nil then
46+
error("assertion failed: " .. tostring(msg), 2)
47+
else
48+
error("assertion failed!", 2)
49+
end
50+
end
51+
end
52+
53+
-- Add convenience methods
54+
function TestContext.skip(reason)
55+
error("__SKIP__: " .. (reason or "skipped"), 0)
56+
end
57+
58+
-- Hooks registration
59+
function Testing:before_all(func)
60+
table.insert(self._hooks.before_all, func)
61+
end
62+
63+
function Testing:after_all(func)
64+
table.insert(self._hooks.after_all, func)
65+
end
66+
67+
function Testing:before_each(func)
68+
table.insert(self._hooks.before_each, func)
69+
end
70+
71+
function Testing:after_each(func)
72+
table.insert(self._hooks.after_each, func)
73+
end
74+
75+
-- Tests registration
76+
function Testing:test(name, func)
77+
table.insert(self._tests, { name = name, func = func })
78+
end
79+
80+
-- Run a single test
81+
function Testing:_run_single_test(test)
82+
local ctx = TestContext.new(test.name)
83+
local start_time = instant()
84+
local success, err = true, nil
85+
86+
-- Run before_each hooks
87+
for _, func in ipairs(self._hooks.before_each) do
88+
local ok, hook_err = pcall(func)
89+
if not ok then
90+
return {
91+
name = test.name,
92+
passed = false,
93+
skipped = false,
94+
error = "before_each failed: " .. tostring(hook_err),
95+
duration = start_time:elapsed(),
96+
}
97+
end
98+
end
99+
100+
-- Run the test
101+
local test_ok, test_err = pcall(test.func, ctx)
102+
if not test_ok then
103+
if test_err and test_err:match("^__SKIP__:") then
104+
success, err = "skip", test_err:match("^__SKIP__: (.*)")
105+
else
106+
success, err = false, test_err
107+
end
108+
end
109+
110+
-- Run after_each hooks (even if test failed)
111+
for _, func in ipairs(self._hooks.after_each) do
112+
local ok, hook_err = pcall(func)
113+
if not ok then
114+
return {
115+
name = test.name,
116+
passed = false,
117+
skipped = false,
118+
error = "after_each failed: " .. tostring(hook_err),
119+
duration = start_time:elapsed(),
120+
}
121+
end
122+
end
123+
124+
return {
125+
name = test.name,
126+
passed = success == true,
127+
skipped = success == "skip",
128+
error = err,
129+
duration = start_time:elapsed(),
130+
}
131+
end
132+
133+
-- Run all tests
134+
function Testing:run(opts)
135+
opts = opts or {}
136+
local pattern = opts.pattern
137+
self._results = {}
138+
local start_time = instant()
139+
140+
-- Run before_all hooks
141+
for _, func in ipairs(self._hooks.before_all) do
142+
func()
143+
end
144+
145+
-- Run tests
146+
for _, test in ipairs(self._tests) do
147+
if not pattern or test.name:find(pattern) then
148+
local result = self:_run_single_test(test)
149+
table.insert(self._results, result)
150+
end
151+
end
152+
153+
-- Run after_all hooks
154+
for _, func in ipairs(self._hooks.after_all) do
155+
func()
156+
end
157+
158+
self._results.duration = start_time:elapsed()
159+
160+
-- Print results unless quiet
161+
if not opts.quiet then
162+
self:_print_results()
163+
end
164+
165+
-- Return success status
166+
local failed = 0
167+
for _, result in ipairs(self._results) do
168+
if not result.passed and not result.skipped then
169+
failed = failed + 1
170+
end
171+
end
172+
173+
return failed == 0, self._results
174+
end
175+
176+
function Testing:_print_results()
177+
local passed, failed, skipped = 0, 0, 0
178+
179+
for _, result in ipairs(self._results) do
180+
local status = style(result.passed and "" or (result.skipped and "" or ""))
181+
status:color(result.passed and "green" or (result.skipped and "yellow" or "red"))
182+
183+
println(status, result.name)
184+
if result.error then
185+
println(tostring(result.error))
186+
end
187+
188+
if result.passed then
189+
passed = passed + 1
190+
elseif result.skipped then
191+
skipped = skipped + 1
192+
else
193+
failed = failed + 1
194+
end
195+
end
196+
197+
local total = passed + failed + skipped
198+
if total == 0 then
199+
-- No tests were run
200+
return
201+
end
202+
203+
local prefix = "test results:"
204+
if self._name then
205+
prefix = string.format("`%s` %s", self._name, prefix)
206+
end
207+
local duration = self._results.duration
208+
local stats = string.format(
209+
"%d passed, %d failed, %d skipped (%d total finished in %s)",
210+
passed,
211+
failed,
212+
skipped,
213+
total,
214+
duration
215+
)
216+
println()
217+
println(prefix, stats)
218+
end
219+
220+
-- Get results for the Rust integration
221+
function Testing:results()
222+
return self._results
223+
end
224+
225+
return Testing

src/assertions.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
use mlua::{Lua, Result, Table};
22

33
/// A loader for the `assertions` module.
4-
pub fn loader(lua: &Lua, name: String) -> Result<Table> {
4+
fn loader(lua: &Lua) -> Result<Table> {
55
lua.load(include_str!("../lua/assertions.lua"))
6-
.set_name(format!("={name}"))
6+
.set_name(format!("@mlua-stdlib/assertions.lua"))
77
.call(())
88
}
99

1010
/// Registers the `assertions` module in the given Lua state.
11-
pub fn register(lua: &Lua, name: Option<&str>) -> Result<()> {
11+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
1212
let name = name.unwrap_or("@assertions");
13-
lua.register_module(name, loader(lua, name.to_string())?)
13+
let value = loader(lua)?;
14+
lua.register_module(name, &value)?;
15+
Ok(value)
1416
}

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
pub(crate) mod terminal;
2+
pub(crate) mod time;
3+
14
pub mod assertions;
5+
pub mod testing;

src/terminal.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use std::fmt;
2+
3+
use mlua::{
4+
AnyUserData, Lua, MetaMethod, MultiValue, Result, UserData, UserDataMethods, UserDataRegistry, Value,
5+
};
6+
use owo_colors::{AnsiColors, DynColor};
7+
8+
pub(crate) struct Style {
9+
text: String,
10+
style: owo_colors::Style,
11+
}
12+
13+
impl UserData for Style {
14+
fn register(registry: &mut UserDataRegistry<Self>) {
15+
// Sets the color for the text
16+
registry.add_function("color", |_, (ud, color): (AnyUserData, String)| {
17+
let mut this = ud.borrow_mut::<Self>()?;
18+
this.style = this.style.color(str2color(&color));
19+
Ok(ud)
20+
});
21+
22+
// Sets the background color for the text
23+
registry.add_function("on", |_, (ud, color): (AnyUserData, String)| {
24+
let mut this = ud.borrow_mut::<Self>()?;
25+
this.style = this.style.on_color(str2color(&color));
26+
Ok(ud)
27+
});
28+
29+
registry.add_meta_method(MetaMethod::ToString, |_, this, ()| Ok(format!("{this}")));
30+
}
31+
}
32+
33+
impl fmt::Display for Style {
34+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35+
self.style.style(&self.text).fmt(f)
36+
}
37+
}
38+
39+
pub(crate) fn style(_: &Lua, value: Value) -> Result<Style> {
40+
let text = value.to_string()?;
41+
Ok(Style {
42+
text,
43+
style: owo_colors::Style::new(),
44+
})
45+
}
46+
47+
pub(crate) fn print(_: &Lua, values: MultiValue) -> Result<()> {
48+
let mut first = true;
49+
for value in values {
50+
if !first {
51+
print!(" ");
52+
}
53+
first = false;
54+
print!("{}", value.to_string()?);
55+
}
56+
Ok(())
57+
}
58+
59+
pub(crate) fn println(lua: &Lua, values: MultiValue) -> Result<()> {
60+
print(lua, values)?;
61+
println!();
62+
Ok(())
63+
}
64+
65+
fn str2color(s: &str) -> impl DynColor {
66+
match s.to_ascii_lowercase().as_str() {
67+
"black" => AnsiColors::Black,
68+
"red" => AnsiColors::Red,
69+
"green" => AnsiColors::Green,
70+
"yellow" => AnsiColors::Yellow,
71+
"blue" => AnsiColors::Blue,
72+
"magenta" => AnsiColors::Magenta,
73+
"cyan" => AnsiColors::Cyan,
74+
"white" => AnsiColors::White,
75+
"bright_black" => AnsiColors::BrightBlack,
76+
"bright_red" => AnsiColors::BrightRed,
77+
"bright_green" => AnsiColors::BrightGreen,
78+
"bright_yellow" => AnsiColors::BrightYellow,
79+
"bright_blue" => AnsiColors::BrightBlue,
80+
"bright_magenta" => AnsiColors::BrightMagenta,
81+
"bright_cyan" => AnsiColors::BrightCyan,
82+
"bright_white" => AnsiColors::BrightWhite,
83+
_ => AnsiColors::Default,
84+
}
85+
}

0 commit comments

Comments
 (0)