Skip to content

Commit cb153a5

Browse files
committed
Make Luau registered aliases case-insensitive
Executing `require("@my_module")` or `require("@My_Module")` should give the same result and use case-insensitive name. See #620 for details
1 parent b1c69d3 commit cb153a5

File tree

5 files changed

+33
-3
lines changed

5 files changed

+33
-3
lines changed

src/luau/require.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -567,10 +567,20 @@ pub(super) fn create_require_function<R: Require + MaybeSend + 'static>(
567567
1
568568
}
569569

570-
let (error, r#type) = unsafe {
571-
lua.exec_raw::<(Function, Function)>((), move |state| {
570+
unsafe extern "C-unwind" fn to_lowercase(state: *mut ffi::lua_State) -> c_int {
571+
let s = ffi::luaL_checkstring(state, 1);
572+
let s = CStr::from_ptr(s);
573+
callback_error_ext(state, ptr::null_mut(), true, |extra, _| {
574+
let s = s.to_string_lossy().to_lowercase();
575+
(*extra).raw_lua().push(s).map(|_| 1)
576+
})
577+
}
578+
579+
let (error, r#type, to_lowercase) = unsafe {
580+
lua.exec_raw::<(Function, Function, Function)>((), move |state| {
572581
ffi::lua_pushcfunctiond(state, error, cstr!("error"));
573582
ffi::lua_pushcfunctiond(state, r#type, cstr!("type"));
583+
ffi::lua_pushcfunctiond(state, to_lowercase, cstr!("to_lowercase"));
574584
})
575585
}?;
576586

@@ -583,6 +593,7 @@ pub(super) fn create_require_function<R: Require + MaybeSend + 'static>(
583593
env.raw_set("LOADER_CACHE", loader_cache)?;
584594
env.raw_set("error", error)?;
585595
env.raw_set("type", r#type)?;
596+
env.raw_set("to_lowercase", to_lowercase)?;
586597

587598
lua.load(
588599
r#"
@@ -592,7 +603,7 @@ pub(super) fn create_require_function<R: Require + MaybeSend + 'static>(
592603
end
593604
594605
-- Check if the module (path) is explicitly registered
595-
local maybe_result = REGISTERED_MODULES[path]
606+
local maybe_result = REGISTERED_MODULES[to_lowercase(path)]
596607
if maybe_result ~= nil then
597608
return maybe_result
598609
end

src/state.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ impl Lua {
358358
if cfg!(feature = "luau") && !modname.starts_with('@') {
359359
return Err(Error::runtime("module name must begin with '@'"));
360360
}
361+
#[cfg(feature = "luau")]
362+
let modname = modname.to_lowercase();
361363
unsafe {
362364
self.exec_raw::<()>(value, |state| {
363365
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, LOADED_MODULES_KEY);

tests/luau/require.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ fn test_require_with_config() {
179179
let res = run_require(&lua, "./tests/luau/require/with_config/src/alias_requirer").unwrap();
180180
assert_eq!("result from dependency", get_str(&res, 1));
181181

182+
// RequirePathWithAlias (case-insensitive)
183+
let res2 = run_require(&lua, "./tests/luau/require/with_config/src/alias_requirer_uc").unwrap();
184+
assert_eq!("result from dependency", get_str(&res2, 1));
185+
assert_eq!(res.to_pointer(), res2.to_pointer());
186+
182187
// RequirePathWithParentAlias
183188
let res = run_require(&lua, "./tests/luau/require/with_config/src/parent_alias_requirer").unwrap();
184189
assert_eq!("result from other_dependency", get_str(&res, 1));
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
return require("@DeP")

tests/tests.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,17 @@ fn test_register_module() -> Result<()> {
12421242
res.unwrap_err().to_string(),
12431243
"runtime error: module name must begin with '@'"
12441244
);
1245+
1246+
// Luau registered modules (aliases) are case-insensitive
1247+
let res = lua.register_module("@My_Module", &t);
1248+
assert!(res.is_ok());
1249+
lua.load(
1250+
r#"
1251+
local my_module = require("@MY_MODule")
1252+
assert(my_module.name == "my_module")
1253+
"#,
1254+
)
1255+
.exec()?;
12451256
}
12461257

12471258
Ok(())

0 commit comments

Comments
 (0)