@@ -217,6 +217,22 @@ jobs:
217217
218218 search_dirs: list[pathlib.Path] = []
219219 package_dir: Optional[pathlib.Path] = None
220+ package_paths: list[pathlib.Path] = []
221+
222+ package_file = getattr(mlx, "__file__", None)
223+ if package_file:
224+ try:
225+ package_paths.append(pathlib.Path(package_file).resolve().parent)
226+ except (TypeError, OSError):
227+ pass
228+
229+ package_path_attr = getattr(mlx, "__path__", None)
230+ if package_path_attr:
231+ for entry in package_path_attr:
232+ try:
233+ package_paths.append(pathlib.Path(entry).resolve())
234+ except (TypeError, OSError):
235+ continue
220236
221237 try:
222238 spec = importlib.util.find_spec("mlx.backend.metal.kernels")
@@ -227,31 +243,56 @@ jobs:
227243 candidate = pathlib.Path(spec.origin).resolve().parent
228244 if candidate.exists():
229245 search_dirs.append(candidate)
246+ package_paths.append(candidate)
247+
248+ def append_resource_directory(module: str, *subpath: str) -> None:
249+ try:
250+ traversable = resources.files(module)
251+ except (ModuleNotFoundError, AttributeError):
252+ return
253+
254+ for segment in subpath:
255+ traversable = traversable / segment
230256
231- try:
232- resource = resources.files("mlx.backend.metal") / "kernels"
233- except (ModuleNotFoundError, AttributeError):
234- resource = None
235- else:
236257 try:
237- with resources.as_file(resource ) as extracted:
258+ with resources.as_file(traversable ) as extracted:
238259 if extracted:
239260 extracted_path = pathlib.Path(extracted).resolve()
240261 if extracted_path.exists():
241262 search_dirs.append(extracted_path)
263+ package_paths.append(extracted_path)
242264 except (FileNotFoundError, RuntimeError):
243265 pass
244266
245- package_file = getattr(mlx, "__file__", None)
246- if package_file:
247- package_dir = pathlib.Path(package_file).resolve().parent
248- search_dirs.extend(
249- [
250- package_dir / "backend" / "metal" / "kernels",
251- package_dir / "backend" / "metal",
252- package_dir,
253- ]
254- )
267+ append_resource_directory("mlx.backend.metal", "kernels")
268+ append_resource_directory("mlx")
269+
270+ existing_package_paths: list[pathlib.Path] = []
271+ seen_package_paths: set[pathlib.Path] = set()
272+ for path in package_paths:
273+ if not path:
274+ continue
275+ try:
276+ resolved = path.resolve()
277+ except (OSError, RuntimeError):
278+ continue
279+ if not resolved.exists():
280+ continue
281+ if resolved in seen_package_paths:
282+ continue
283+ seen_package_paths.add(resolved)
284+ existing_package_paths.append(resolved)
285+
286+ if existing_package_paths:
287+ package_dir = existing_package_paths[0]
288+ for root in existing_package_paths:
289+ search_dirs.extend(
290+ [
291+ root / "backend" / "metal" / "kernels",
292+ root / "backend" / "metal",
293+ root,
294+ ]
295+ )
255296
256297 ordered_dirs: list[pathlib.Path] = []
257298 seen: set[pathlib.Path] = set()
@@ -277,17 +318,25 @@ jobs:
277318
278319 src = next(iter_metallibs(ordered_dirs), None)
279320
280- if src is None and package_dir and package_dir.exists():
281- for candidate in package_dir.rglob("mlx.metallib"):
282- src = candidate
283- print(f"::warning::Resolved metallib via recursive search under {package_dir}")
284- break
285-
286- if src is None and package_dir and package_dir.exists():
287- for candidate in sorted(package_dir.rglob("*.metallib")):
288- src = candidate
289- print(f"::warning::Using metallib {candidate.name} discovered via package-wide search")
290- break
321+ package_roots = existing_package_paths if existing_package_paths else ([] if not package_dir else [package_dir])
322+
323+ if src is None:
324+ for root in package_roots:
325+ for candidate in root.rglob("mlx.metallib"):
326+ src = candidate
327+ print(f"::warning::Resolved metallib via recursive search under {root}")
328+ break
329+ if src is not None:
330+ break
331+
332+ if src is None:
333+ for root in package_roots:
334+ for candidate in sorted(root.rglob("*.metallib")):
335+ src = candidate
336+ print(f"::warning::Using metallib {candidate.name} discovered via package-wide search in {root}")
337+ break
338+ if src is not None:
339+ break
291340
292341 if src is None:
293342 print("::error::Could not locate any mlx.metallib artifacts within the installed mlx package.")
0 commit comments