diff --git a/monarch_extension/Cargo.toml b/monarch_extension/Cargo.toml index 7f7740d93..f22221829 100644 --- a/monarch_extension/Cargo.toml +++ b/monarch_extension/Cargo.toml @@ -34,7 +34,7 @@ monarch_tensor_worker = { version = "0.0.0", path = "../monarch_tensor_worker", monarch_types = { version = "0.0.0", path = "../monarch_types" } nccl-sys = { path = "../nccl-sys", optional = true } ndslice = { version = "0.0.0", path = "../ndslice" } -pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods", "py-clone"] } +pyo3 = { version = "0.24", features = ["anyhow", "extension-module", "multiple-pymethods", "py-clone"] } rdmaxcel-sys = { path = "../rdmaxcel-sys", optional = true } serde = { version = "1.0.219", features = ["derive", "rc"] } tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] } diff --git a/monarch_types/Cargo.toml b/monarch_types/Cargo.toml index 01be830bb..862c7ea09 100644 --- a/monarch_types/Cargo.toml +++ b/monarch_types/Cargo.toml @@ -10,7 +10,7 @@ license = "BSD-3-Clause" [dependencies] derive_more = { version = "1.0.0", features = ["full"] } hyperactor = { version = "0.0.0", path = "../hyperactor" } -pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods", "py-clone"] } +pyo3 = { version = "0.24", features = ["anyhow", "extension-module", "multiple-pymethods", "py-clone"] } serde = { version = "1.0.219", features = ["derive", "rc"] } serde_bytes = "0.11" diff --git a/setup.py b/setup.py index 4bd29a4e9..d0d8dfcbb 100644 --- a/setup.py +++ b/setup.py @@ -121,22 +121,10 @@ def run(self): readme = f.read() if sys.platform.startswith("linux"): - # Always include the active env's lib (Conda-safe) + # With extension-module, we don't link libpython, but we still need + # RPATH for finding libtorch and other PyTorch libraries conda_lib = os.path.join(sys.prefix, "lib") - # Only use LIBDIR if it actually contains the current libpython - ldlib = sysconfig.get_config_var("LDLIBRARY") or "" - libdir = sysconfig.get_config_var("LIBDIR") or "" - py_lib = "" - if libdir and ldlib: - cand = os.path.join(libdir, ldlib) - if os.path.exists(cand) and os.path.realpath(libdir) != os.path.realpath( - conda_lib - ): - py_lib = libdir - - # Prefer sidecar .so next to the extension; then the conda env; - # then (optionally) py_lib flags = [ "-C", "link-arg=-Wl,--enable-new-dtags", @@ -145,14 +133,8 @@ def run(self): "-C", "link-arg=-Wl,-rpath,$ORIGIN", "-C", - "link-arg=-Wl,-rpath,$ORIGIN/..", - "-C", - "link-arg=-Wl,-rpath,$ORIGIN/../../..", - "-C", - "link-arg=-Wl,-rpath," + conda_lib, + "link-arg=-Wl,-rpath," + conda_lib, # For libtorch ] - if py_lib: - flags += ["-C", "link-arg=-Wl,-rpath," + py_lib] cur = os.environ.get("RUSTFLAGS", "") os.environ["RUSTFLAGS"] = (cur + " " + " ".join(flags)).strip() diff --git a/torch-sys/Cargo.toml b/torch-sys/Cargo.toml index 1b3f27265..dd06b2ce4 100644 --- a/torch-sys/Cargo.toml +++ b/torch-sys/Cargo.toml @@ -19,7 +19,7 @@ hyperactor = { version = "0.0.0", path = "../hyperactor" } monarch_types = { version = "0.0.0", path = "../monarch_types" } nccl-sys = { path = "../nccl-sys", optional = true } paste = "1.0.14" -pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods", "py-clone"] } +pyo3 = { version = "0.24", features = ["anyhow", "extension-module", "multiple-pymethods", "py-clone"] } regex = "1.11.1" serde = { version = "1.0.219", features = ["derive", "rc"] } thiserror = "2.0.12"