Skip to content

Commit 916cbbd

Browse files
committed
Improve metallib fallback discovery
1 parent d17258e commit 916cbbd

File tree

1 file changed

+76
-27
lines changed

1 file changed

+76
-27
lines changed

.github/workflows/ci.yml

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ jobs:
217217
218218
search_dirs: list[pathlib.Path] = []
219219
package_dir: Optional[pathlib.Path] = None
220+
package_paths: list[pathlib.Path] = []
221+
222+
package_file = getattr(mlx, "__file__", None)
223+
if package_file:
224+
try:
225+
package_paths.append(pathlib.Path(package_file).resolve().parent)
226+
except (TypeError, OSError):
227+
pass
228+
229+
package_path_attr = getattr(mlx, "__path__", None)
230+
if package_path_attr:
231+
for entry in package_path_attr:
232+
try:
233+
package_paths.append(pathlib.Path(entry).resolve())
234+
except (TypeError, OSError):
235+
continue
220236
221237
try:
222238
spec = importlib.util.find_spec("mlx.backend.metal.kernels")
@@ -227,31 +243,56 @@ jobs:
227243
candidate = pathlib.Path(spec.origin).resolve().parent
228244
if candidate.exists():
229245
search_dirs.append(candidate)
246+
package_paths.append(candidate)
247+
248+
def append_resource_directory(module: str, *subpath: str) -> None:
249+
try:
250+
traversable = resources.files(module)
251+
except (ModuleNotFoundError, AttributeError):
252+
return
253+
254+
for segment in subpath:
255+
traversable = traversable / segment
230256
231-
try:
232-
resource = resources.files("mlx.backend.metal") / "kernels"
233-
except (ModuleNotFoundError, AttributeError):
234-
resource = None
235-
else:
236257
try:
237-
with resources.as_file(resource) as extracted:
258+
with resources.as_file(traversable) as extracted:
238259
if extracted:
239260
extracted_path = pathlib.Path(extracted).resolve()
240261
if extracted_path.exists():
241262
search_dirs.append(extracted_path)
263+
package_paths.append(extracted_path)
242264
except (FileNotFoundError, RuntimeError):
243265
pass
244266
245-
package_file = getattr(mlx, "__file__", None)
246-
if package_file:
247-
package_dir = pathlib.Path(package_file).resolve().parent
248-
search_dirs.extend(
249-
[
250-
package_dir / "backend" / "metal" / "kernels",
251-
package_dir / "backend" / "metal",
252-
package_dir,
253-
]
254-
)
267+
append_resource_directory("mlx.backend.metal", "kernels")
268+
append_resource_directory("mlx")
269+
270+
existing_package_paths: list[pathlib.Path] = []
271+
seen_package_paths: set[pathlib.Path] = set()
272+
for path in package_paths:
273+
if not path:
274+
continue
275+
try:
276+
resolved = path.resolve()
277+
except (OSError, RuntimeError):
278+
continue
279+
if not resolved.exists():
280+
continue
281+
if resolved in seen_package_paths:
282+
continue
283+
seen_package_paths.add(resolved)
284+
existing_package_paths.append(resolved)
285+
286+
if existing_package_paths:
287+
package_dir = existing_package_paths[0]
288+
for root in existing_package_paths:
289+
search_dirs.extend(
290+
[
291+
root / "backend" / "metal" / "kernels",
292+
root / "backend" / "metal",
293+
root,
294+
]
295+
)
255296
256297
ordered_dirs: list[pathlib.Path] = []
257298
seen: set[pathlib.Path] = set()
@@ -277,17 +318,25 @@ jobs:
277318
278319
src = next(iter_metallibs(ordered_dirs), None)
279320
280-
if src is None and package_dir and package_dir.exists():
281-
for candidate in package_dir.rglob("mlx.metallib"):
282-
src = candidate
283-
print(f"::warning::Resolved metallib via recursive search under {package_dir}")
284-
break
285-
286-
if src is None and package_dir and package_dir.exists():
287-
for candidate in sorted(package_dir.rglob("*.metallib")):
288-
src = candidate
289-
print(f"::warning::Using metallib {candidate.name} discovered via package-wide search")
290-
break
321+
package_roots = existing_package_paths if existing_package_paths else ([] if not package_dir else [package_dir])
322+
323+
if src is None:
324+
for root in package_roots:
325+
for candidate in root.rglob("mlx.metallib"):
326+
src = candidate
327+
print(f"::warning::Resolved metallib via recursive search under {root}")
328+
break
329+
if src is not None:
330+
break
331+
332+
if src is None:
333+
for root in package_roots:
334+
for candidate in sorted(root.rglob("*.metallib")):
335+
src = candidate
336+
print(f"::warning::Using metallib {candidate.name} discovered via package-wide search in {root}")
337+
break
338+
if src is not None:
339+
break
291340
292341
if src is None:
293342
print("::error::Could not locate any mlx.metallib artifacts within the installed mlx package.")

0 commit comments

Comments
 (0)