Skip to content

Commit 60dd177

Browse files
committed
Add regex module
1 parent 1d578bd commit 60dd177

File tree

5 files changed

+306
-0
lines changed

5 files changed

+306
-0
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@ send = ["mlua/send"]
1717
vendored = ["mlua/vendored"]
1818

1919
json = ["mlua/serde", "dep:ouroboros", "dep:serde", "dep:serde_json"]
20+
regex = ["dep:regex", "dep:ouroboros", "dep:quick_cache"]
2021

2122
[dependencies]
2223
mlua = { version = "0.11" }
2324
ouroboros = { version = "0.18", optional = true }
2425
serde = { version = "1.0", optional = true }
2526
serde_json = { version = "1.0", optional = true }
2627
owo-colors = "4"
28+
regex = { version = "1.0", optional = true }
29+
quick_cache = { version = "0.6", optional = true }

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ pub mod testing;
1515

1616
#[cfg(feature = "json")]
1717
pub mod json;
18+
#[cfg(feature = "regex")]
19+
pub mod regex;

src/regex.rs

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
use std::ops::Deref;
2+
use std::result::Result as StdResult;
3+
use std::sync::LazyLock;
4+
5+
use mlua::{Lua, MetaMethod, Result, String as LuaString, Table, UserData, UserDataMethods, Value, Variadic};
6+
use ouroboros::self_referencing;
7+
use quick_cache::sync::Cache;
8+
9+
// A reasonable cache size for regexes. This can be adjusted as needed.
10+
const REGEX_CACHE_SIZE: usize = 256;
11+
12+
#[derive(Clone, Debug)]
13+
pub struct Regex(regex::bytes::Regex);
14+
15+
impl Deref for Regex {
16+
type Target = regex::bytes::Regex;
17+
18+
#[inline]
19+
fn deref(&self) -> &Self::Target {
20+
&self.0
21+
}
22+
}
23+
24+
// Global cache for regexes shared across all Lua states.
25+
static CACHE: LazyLock<Cache<String, Regex>> = LazyLock::new(|| Cache::new(REGEX_CACHE_SIZE));
26+
27+
impl Regex {
28+
/// Creates a new cached regex or retrieves it from the cache if it already exists.
29+
pub fn new(_: &Lua, re: &str) -> StdResult<Self, regex::Error> {
30+
if let Some(re) = CACHE.get(re) {
31+
return Ok(re);
32+
}
33+
let regex = Self(regex::bytes::Regex::new(&re)?);
34+
CACHE.insert(re.to_string(), regex.clone());
35+
Ok(regex)
36+
}
37+
}
38+
39+
impl UserData for Regex {
40+
fn register(registry: &mut mlua::UserDataRegistry<Self>) {
41+
registry.add_method("is_match", |_, this, text: LuaString| {
42+
Ok(this.0.is_match(&text.as_bytes()))
43+
});
44+
45+
registry.add_method("match", |lua, this, text: LuaString| {
46+
let text = (*text.as_bytes()).into();
47+
let caps = Captures::try_new(text, |text| this.0.captures(text).ok_or(()));
48+
match caps {
49+
Ok(caps) => Ok(Value::UserData(lua.create_userdata(caps)?)),
50+
Err(_) => Ok(Value::Nil),
51+
}
52+
});
53+
54+
// Returns low level information about raw offsets of each submatch.
55+
registry.add_method("captures_read", |lua, this, text: LuaString| {
56+
let mut locs = this.capture_locations();
57+
match this.captures_read(&mut locs, &text.as_bytes()) {
58+
Some(_) => Ok(Value::UserData(lua.create_userdata(CaptureLocations(locs))?)),
59+
None => Ok(Value::Nil),
60+
}
61+
});
62+
63+
registry.add_method("split", |lua, this, text: LuaString| {
64+
lua.create_sequence_from(this.split(&text.as_bytes()).map(LuaString::wrap))
65+
});
66+
67+
registry.add_method("splitn", |lua, this, (text, limit): (LuaString, usize)| {
68+
lua.create_sequence_from(this.splitn(&text.as_bytes(), limit).map(LuaString::wrap))
69+
});
70+
71+
registry.add_method("replace", |lua, this, (text, rep): (LuaString, LuaString)| {
72+
lua.create_string(this.replace(&text.as_bytes(), &*rep.as_bytes()))
73+
});
74+
}
75+
}
76+
77+
#[self_referencing]
78+
struct Captures {
79+
text: Box<[u8]>,
80+
81+
#[borrows(text)]
82+
#[covariant]
83+
caps: regex::bytes::Captures<'this>,
84+
}
85+
86+
impl UserData for Captures {
87+
fn register(registry: &mut mlua::UserDataRegistry<Self>) {
88+
registry.add_meta_method(MetaMethod::Index, |lua, this, key: Value| match key {
89+
Value::String(s) => {
90+
let name = s.to_string_lossy();
91+
this.borrow_caps()
92+
.name(&name)
93+
.map(|v| lua.create_string(v.as_bytes()))
94+
.transpose()
95+
}
96+
Value::Integer(i) => this
97+
.borrow_caps()
98+
.get(i as usize)
99+
.map(|v| lua.create_string(v.as_bytes()))
100+
.transpose(),
101+
_ => Ok(None),
102+
})
103+
}
104+
}
105+
106+
struct CaptureLocations(regex::bytes::CaptureLocations);
107+
108+
impl UserData for CaptureLocations {
109+
fn register(registry: &mut mlua::UserDataRegistry<Self>) {
110+
// Returns the total number of capture groups.
111+
registry.add_method("len", |_, this, ()| Ok(this.0.len()));
112+
113+
// Returns the start and end positions of the Nth capture group.
114+
registry.add_method("get", |_, this, i: usize| match this.0.get(i) {
115+
// We add 1 to the start position because Lua is 1-indexed.
116+
// End position is non-inclusive, so we don't need to add 1.
117+
Some((start, end)) => Ok(Variadic::from_iter([start + 1, end])),
118+
None => Ok(Variadic::new()),
119+
});
120+
}
121+
}
122+
123+
struct RegexSet(regex::bytes::RegexSet);
124+
125+
impl Deref for RegexSet {
126+
type Target = regex::bytes::RegexSet;
127+
128+
#[inline]
129+
fn deref(&self) -> &Self::Target {
130+
&self.0
131+
}
132+
}
133+
134+
impl UserData for RegexSet {
135+
fn register(registry: &mut mlua::UserDataRegistry<Self>) {
136+
registry.add_function("new", |_, patterns: Vec<String>| {
137+
let set = lua_try!(regex::bytes::RegexSet::new(patterns).map(RegexSet));
138+
Ok(Ok(set))
139+
});
140+
141+
registry.add_method("is_match", |_, this, text: LuaString| {
142+
Ok(this.is_match(&text.as_bytes()))
143+
});
144+
145+
registry.add_method("len", |_, this, ()| Ok(this.len()));
146+
147+
registry.add_method("matches", |_, this, text: LuaString| {
148+
Ok(this
149+
.matches(&text.as_bytes())
150+
.iter()
151+
.map(|i| i + 1)
152+
.collect::<Vec<_>>())
153+
});
154+
}
155+
}
156+
157+
/// Compiles a regular expression.
158+
///
159+
/// Once compiled, it can be used repeatedly to search, split or replace substrings in a text.
160+
fn regex_new(lua: &Lua, re: LuaString) -> Result<StdResult<Regex, String>> {
161+
let re = re.to_str()?;
162+
Ok(Ok(lua_try!(Regex::new(lua, &re))))
163+
}
164+
165+
/// Escapes a string so that it can be used as a literal in a regular expression.
166+
fn regex_escape(_: &Lua, text: LuaString) -> Result<String> {
167+
Ok(regex::escape(&text.to_str()?))
168+
}
169+
170+
/// Returns true if there is a match for the regex anywhere in the given text.
171+
fn regex_is_match(lua: &Lua, (re, text): (LuaString, LuaString)) -> Result<StdResult<bool, String>> {
172+
let re = re.to_str()?;
173+
let re = lua_try!(Regex::new(lua, &re));
174+
Ok(Ok(re.is_match(&text.as_bytes())))
175+
}
176+
177+
/// Returns all matches of the regex in the given text or nil if there is no match.
178+
fn regex_match(lua: &Lua, (re, text): (LuaString, LuaString)) -> Result<StdResult<Value, String>> {
179+
let re = re.to_str()?;
180+
let re = lua_try!(Regex::new(lua, &re));
181+
match re.captures(&text.as_bytes()) {
182+
Some(caps) => {
183+
let mut it = caps.iter().map(|om| om.map(|m| LuaString::wrap(m.as_bytes())));
184+
let first = it.next().unwrap();
185+
let table = lua.create_sequence_from(it)?;
186+
table.raw_set(0, first)?;
187+
Ok(Ok(Value::Table(table)))
188+
}
189+
None => Ok(Ok(Value::Nil)),
190+
}
191+
}
192+
193+
/// A loader for the `regex` module.
194+
fn loader(lua: &Lua) -> Result<Table> {
195+
let t = lua.create_table()?;
196+
t.set("new", lua.create_function(regex_new)?)?;
197+
t.set("escape", lua.create_function(regex_escape)?)?;
198+
t.set("is_match", lua.create_function(regex_is_match)?)?;
199+
t.set("match", lua.create_function(regex_match)?)?;
200+
t.set("RegexSet", lua.create_proxy::<RegexSet>()?)?;
201+
Ok(t)
202+
}
203+
204+
/// Registers the `regex` module in the given Lua state.
205+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
206+
let name = name.unwrap_or("@regex");
207+
let value = loader(lua)?;
208+
lua.register_module(name, &value)?;
209+
Ok(value)
210+
}

tests/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ fn run_file(modname: &str) -> Result<()> {
1313

1414
#[cfg(feature = "json")]
1515
mlua_stdlib::json::register(&lua, None)?;
16+
#[cfg(feature = "regex")]
17+
mlua_stdlib::regex::register(&lua, None)?;
1618

1719
// Add `testing` global variable (an instance of the testing framework)
1820
let testing = testing.call_function::<Table>("new", modname)?;
@@ -50,3 +52,5 @@ include_tests! {
5052

5153
#[cfg(feature = "json")]
5254
include_tests!(json);
55+
#[cfg(feature = "regex")]
56+
include_tests!(regex);

tests/lua/regex_tests.lua

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
local regex = require("@regex")
2+
3+
-- Test basic regex functionality
4+
testing:test("regex_basic", function(t)
5+
local re = regex.new(".*(?P<gr1>abc)")
6+
7+
t.assert(re:is_match("123abc321"), "is_match() should have matches")
8+
t.assert(not re:is_match("123"), "is_match() should not have matches")
9+
10+
local matches = re:match("123abc321")
11+
t.assert_eq(matches[0], "123abc", "zero capture group should match the whole text")
12+
t.assert_eq(matches[1], "abc", "first capture group should match `abc`")
13+
t.assert_eq(matches["gr1"], "abc", "named capture group should match `abc`")
14+
t.assert_eq(matches[true], nil, "bad key should have no match")
15+
16+
-- Test split
17+
local re_split = regex.new("[,.]")
18+
local vec = re_split:split("abc.qwe,rty.asd")
19+
t.assert_eq(#vec, 4, "vec len should be 4")
20+
t.assert(
21+
vec[1] == "abc" and vec[2] == "qwe" and vec[3] == "rty" and vec[4] == "asd",
22+
"vec must be 'abc', 'qwe', 'rty', 'asd'"
23+
)
24+
25+
vec = re_split:splitn("abc,bcd,cde", 2)
26+
t.assert_eq(#vec, 2, "vec len should be 2")
27+
t.assert(vec[1] == "abc" and vec[2] == "bcd,cde", "vec must be 'abc', 'bcd,cde'")
28+
29+
-- Test invalid regex
30+
local re_invalid, err = regex.new("(")
31+
t.assert_eq(re_invalid, nil, "re is not nil")
32+
t.assert(string.find(err, "regex parse error"), "err must contain 'regex parse error'")
33+
34+
-- Test replace
35+
local re_replace = regex.new("(?P<last>[^,\\s]+),\\s+(?P<first>\\S+)")
36+
local str = re_replace:replace("Smith, John", "$first $last")
37+
t.assert_eq(str, "John Smith", "str must be 'John Smith'")
38+
end)
39+
40+
-- Test regex shortcuts (escape, is_match, match functions)
41+
testing:test("regex_shortcuts", function(t)
42+
-- Test escape
43+
t.assert_eq(regex.escape("a*b"), "a\\*b", "escaped regex must be 'a\\*b'")
44+
45+
-- Test "is_match"
46+
t.assert(regex.is_match("\\b\\w{13}\\b", "I categorically deny having ..."), "is_match should have matches")
47+
t.assert(not regex.is_match("abc", "bca"), "is_match should not have matches")
48+
local is_match, err = regex.is_match("(", "")
49+
t.assert(is_match == nil and string.find(err, "regex parse error") ~= nil, "is_match should return error")
50+
51+
-- Test "match"
52+
local matches = regex.match("^(\\d{4})-(\\d{2})-(\\d{2})$", "2014-05-01")
53+
t.assert_eq(matches[0], "2014-05-01", "zero capture group should match the whole text")
54+
t.assert_eq(matches[1], "2014", "first capture group should match year")
55+
t.assert_eq(matches[2], "05", "second capture group should match month")
56+
t.assert_eq(matches[3], "01", "third capture group should match day")
57+
matches, err = regex.match("(", "")
58+
t.assert(matches == nil and string.find(err, "regex parse error") ~= nil, "match should return error")
59+
end)
60+
61+
-- Test RegexSet functionality
62+
testing:test("regex_set", function(t)
63+
local set = regex.RegexSet.new({ "\\w+", "\\d+", "\\pL+", "foo", "bar", "barfoo", "foobar" })
64+
t.assert_eq(set:len(), 7, "len should be 7")
65+
t.assert(set:is_match("foobar"), "is_match should have matches")
66+
t.assert_eq(table.concat(set:matches("foobar"), ","), "1,3,4,5,7", "matches should return 1,3,4,5,7")
67+
end)
68+
69+
-- Test capture locations
70+
testing:test("capture_locations", function(t)
71+
local re = regex.new("\\d+(abc)\\d+")
72+
73+
local str = "123abc321"
74+
local locs = re:captures_read(str)
75+
t.assert(locs, "locs is nil")
76+
t.assert_eq(locs:len(), 2, "locs len is not 2")
77+
local i, j = locs:get(0)
78+
t.assert(i == 1 and j == 9, "locs:get(0) is not 1, 9")
79+
i, j = locs:get(1)
80+
t.assert(i == 4 and j == 6, "locs:get(1) is not 4, 6")
81+
t.assert_eq(str:sub(i, j), "abc", "str:sub(i, j) is not 'abc'")
82+
t.assert_eq(locs:get(2), nil, "locs:get(2) is nil")
83+
84+
-- Test no match
85+
locs = re:captures_read("123")
86+
t.assert_eq(locs, nil, "locs is not nil")
87+
end)

0 commit comments

Comments
 (0)