diff --git a/pyproject.toml b/pyproject.toml index 83950d7..9e3216c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,9 @@ dependencies = [ "jax>=0.6.0", ] +[project.entry-points.jax_plugins] +mpibackend4jax = "mpibackend4jax.plugin" + [tool.hatch.build.targets.wheel] packages = ["src/mpibackend4jax"] diff --git a/src/mpibackend4jax/__init__.py b/src/mpibackend4jax/__init__.py index 20973bd..4c5652e 100644 --- a/src/mpibackend4jax/__init__.py +++ b/src/mpibackend4jax/__init__.py @@ -6,28 +6,12 @@ """ import os -from pathlib import Path # Import the cluster to register it automatically from .mpitrampoline_cluster import MPITrampolineLocalCluster __version__ = "0.1.0" -# Get the package installation directory -_package_dir = Path(__file__).parent -_mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so" - -# Set environment variables for MPITrampoline -if _mpiwrapper_lib.exists(): - os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute()) - os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" - - print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}") - print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") -else: - print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}") - print("Please ensure the package was installed correctly.") - # Convenience function to check if MPITrampoline is properly configured def is_configured(): diff --git a/src/mpibackend4jax/plugin.py b/src/mpibackend4jax/plugin.py new file mode 100644 index 0000000..c4501bb --- /dev/null +++ b/src/mpibackend4jax/plugin.py @@ -0,0 +1,20 @@ +import os +from pathlib import Path + + +def initialize(): + # Get the package installation directory + _package_dir = Path(__file__).parent + _mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so" + + # Set environment variables for MPITrampoline + if _mpiwrapper_lib.exists(): + if "MPITRAMPOLINE_LIB" not in os.environ.keys(): + os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute()) + print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}") + if "JAX_CPU_COLLECTIVES_IMPLEMENTATION" not in os.environ.keys(): + os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" + print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") + else: + print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}") + print("Please ensure the package was installed correctly.")