From 70cfb7ca4fad4d1d4789c3b0efc60ef0743388cf Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Fri, 25 Jul 2025 17:36:05 +0200 Subject: [PATCH 1/5] load it automatically, no import needed --- pyproject.toml | 3 ++ src/mpibackend4jax/__init__.py | 67 +++++++++++++++++----------------- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 83950d7..41b5641 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,9 @@ dependencies = [ "jax>=0.6.0", ] +[project.entry-points.jax_plugins] +mpibackend4jax = "mpibackend4jax" + [tool.hatch.build.targets.wheel] packages = ["src/mpibackend4jax"] diff --git a/src/mpibackend4jax/__init__.py b/src/mpibackend4jax/__init__.py index 20973bd..e10eb3a 100644 --- a/src/mpibackend4jax/__init__.py +++ b/src/mpibackend4jax/__init__.py @@ -9,39 +9,40 @@ from pathlib import Path # Import the cluster to register it automatically -from .mpitrampoline_cluster import MPITrampolineLocalCluster +# from .mpitrampoline_cluster import MPITrampolineLocalCluster __version__ = "0.1.0" -# Get the package installation directory -_package_dir = Path(__file__).parent -_mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so" - -# Set environment variables for MPITrampoline -if _mpiwrapper_lib.exists(): - os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute()) - os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" - - print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}") - print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") -else: - print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}") - print("Please ensure the package was installed correctly.") - - -# Convenience function to check if MPITrampoline is properly configured -def is_configured(): - """Check if MPITrampoline is properly configured for JAX""" - return ( - "MPITRAMPOLINE_LIB" in os.environ - and os.environ.get("JAX_CPU_COLLECTIVES_IMPLEMENTATION") == "mpi" - and Path(os.environ["MPITRAMPOLINE_LIB"]).exists() - ) - - -def get_library_path(): - """Get the path to the MPIWrapper library""" - return os.environ.get("MPITRAMPOLINE_LIB") - - -__all__ = ["is_configured", "get_library_path", "MPITrampolineLocalCluster"] +def initialize(): + # Get the package installation directory + _package_dir = Path(__file__).parent + _mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so" + + # Set environment variables for MPITrampoline + if _mpiwrapper_lib.exists(): + os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute()) + os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" + + print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}") + print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") + else: + print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}") + print("Please ensure the package was installed correctly.") + + +# # Convenience function to check if MPITrampoline is properly configured +# def is_configured(): +# """Check if MPITrampoline is properly configured for JAX""" +# return ( +# "MPITRAMPOLINE_LIB" in os.environ +# and os.environ.get("JAX_CPU_COLLECTIVES_IMPLEMENTATION") == "mpi" +# and Path(os.environ["MPITRAMPOLINE_LIB"]).exists() +# ) +# +# +# def get_library_path(): +# """Get the path to the MPIWrapper library""" +# return os.environ.get("MPITRAMPOLINE_LIB") + + +__all__ = ["is_configured", "get_library_path"]#, "MPITrampolineLocalCluster"] From 2345543bf355564f8f0bdf19e0ed5fd85b604f10 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Fri, 25 Jul 2025 17:42:08 +0200 Subject: [PATCH 2/5] Update src/mpibackend4jax/__init__.py --- src/mpibackend4jax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mpibackend4jax/__init__.py b/src/mpibackend4jax/__init__.py index e10eb3a..bbe6b1e 100644 --- a/src/mpibackend4jax/__init__.py +++ b/src/mpibackend4jax/__init__.py @@ -45,4 +45,4 @@ def initialize(): # return os.environ.get("MPITRAMPOLINE_LIB") -__all__ = ["is_configured", "get_library_path"]#, "MPITrampolineLocalCluster"] +#__all__ = ["is_configured", "get_library_path"], "MPITrampolineLocalCluster"] From ba75fe4bb049194dcc4701029747d8f12f132651 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Wed, 30 Jul 2025 15:00:47 +0200 Subject: [PATCH 3/5] only set variables if not set previously This way its easier to override by hand --- src/mpibackend4jax/__init__.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/mpibackend4jax/__init__.py b/src/mpibackend4jax/__init__.py index bbe6b1e..2586f7c 100644 --- a/src/mpibackend4jax/__init__.py +++ b/src/mpibackend4jax/__init__.py @@ -20,11 +20,12 @@ def initialize(): # Set environment variables for MPITrampoline if _mpiwrapper_lib.exists(): - os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute()) - os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" - - print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}") - print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") + if "MPITRAMPOLINE_LIB" not in os.environ.keys(): + os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute()) + print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}") + if "JAX_CPU_COLLECTIVES_IMPLEMENTATION" not in os.environ.keys() + os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" + print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") else: print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}") print("Please ensure the package was installed correctly.") From 988e8954972d0d85044cff1f5de0fb12d9ec0be3 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Wed, 30 Jul 2025 15:02:34 +0200 Subject: [PATCH 4/5] move plugin to its own file --- pyproject.toml | 2 +- src/mpibackend4jax/__init__.py | 19 ------------------- src/mpibackend4jax/plugin.py | 20 ++++++++++++++++++++ 3 files changed, 21 insertions(+), 20 deletions(-) create mode 100644 src/mpibackend4jax/plugin.py diff --git a/pyproject.toml b/pyproject.toml index 41b5641..9e3216c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ ] [project.entry-points.jax_plugins] -mpibackend4jax = "mpibackend4jax" +mpibackend4jax = "mpibackend4jax.plugin" [tool.hatch.build.targets.wheel] packages = ["src/mpibackend4jax"] diff --git a/src/mpibackend4jax/__init__.py b/src/mpibackend4jax/__init__.py index 2586f7c..f9e2526 100644 --- a/src/mpibackend4jax/__init__.py +++ b/src/mpibackend4jax/__init__.py @@ -6,31 +6,12 @@ """ import os -from pathlib import Path # Import the cluster to register it automatically # from .mpitrampoline_cluster import MPITrampolineLocalCluster __version__ = "0.1.0" -def initialize(): - # Get the package installation directory - _package_dir = Path(__file__).parent - _mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so" - - # Set environment variables for MPITrampoline - if _mpiwrapper_lib.exists(): - if "MPITRAMPOLINE_LIB" not in os.environ.keys(): - os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute()) - print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}") - if "JAX_CPU_COLLECTIVES_IMPLEMENTATION" not in os.environ.keys() - os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" - print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") - else: - print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}") - print("Please ensure the package was installed correctly.") - - # # Convenience function to check if MPITrampoline is properly configured # def is_configured(): # """Check if MPITrampoline is properly configured for JAX""" diff --git a/src/mpibackend4jax/plugin.py b/src/mpibackend4jax/plugin.py new file mode 100644 index 0000000..c4501bb --- /dev/null +++ b/src/mpibackend4jax/plugin.py @@ -0,0 +1,20 @@ +import os +from pathlib import Path + + +def initialize(): + # Get the package installation directory + _package_dir = Path(__file__).parent + _mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so" + + # Set environment variables for MPITrampoline + if _mpiwrapper_lib.exists(): + if "MPITRAMPOLINE_LIB" not in os.environ.keys(): + os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute()) + print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}") + if "JAX_CPU_COLLECTIVES_IMPLEMENTATION" not in os.environ.keys(): + os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" + print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") + else: + print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}") + print("Please ensure the package was installed correctly.") From 57a85cac06890e9913f33977160b4572e8449891 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Wed, 30 Jul 2025 15:05:51 +0200 Subject: [PATCH 5/5] put back original functions --- src/mpibackend4jax/__init__.py | 35 +++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/mpibackend4jax/__init__.py b/src/mpibackend4jax/__init__.py index f9e2526..4c5652e 100644 --- a/src/mpibackend4jax/__init__.py +++ b/src/mpibackend4jax/__init__.py @@ -8,23 +8,24 @@ import os # Import the cluster to register it automatically -# from .mpitrampoline_cluster import MPITrampolineLocalCluster +from .mpitrampoline_cluster import MPITrampolineLocalCluster __version__ = "0.1.0" -# # Convenience function to check if MPITrampoline is properly configured -# def is_configured(): -# """Check if MPITrampoline is properly configured for JAX""" -# return ( -# "MPITRAMPOLINE_LIB" in os.environ -# and os.environ.get("JAX_CPU_COLLECTIVES_IMPLEMENTATION") == "mpi" -# and Path(os.environ["MPITRAMPOLINE_LIB"]).exists() -# ) -# -# -# def get_library_path(): -# """Get the path to the MPIWrapper library""" -# return os.environ.get("MPITRAMPOLINE_LIB") - - -#__all__ = ["is_configured", "get_library_path"], "MPITrampolineLocalCluster"] + +# Convenience function to check if MPITrampoline is properly configured +def is_configured(): + """Check if MPITrampoline is properly configured for JAX""" + return ( + "MPITRAMPOLINE_LIB" in os.environ + and os.environ.get("JAX_CPU_COLLECTIVES_IMPLEMENTATION") == "mpi" + and Path(os.environ["MPITRAMPOLINE_LIB"]).exists() + ) + + +def get_library_path(): + """Get the path to the MPIWrapper library""" + return os.environ.get("MPITRAMPOLINE_LIB") + + +__all__ = ["is_configured", "get_library_path", "MPITrampolineLocalCluster"]