diff --git a/.lightning/workflows/pytorch.yml b/.lightning/workflows/pytorch.yml index 15dfc4a1f9064..e6db4eeb33cd4 100644 --- a/.lightning/workflows/pytorch.yml +++ b/.lightning/workflows/pytorch.yml @@ -121,7 +121,11 @@ run: | echo "Install package" extra=$(python -c "print({'lightning': 'pytorch-'}.get('${PACKAGE_NAME}', ''))") - uv pip install -e ".[${extra}dev]" --upgrade + + # Use find-links to prefer CUDA-specific packages from PyTorch index + uv pip install -e ".[${extra}dev]" --upgrade \ + --find-links="https://download.pytorch.org/whl/${UV_TORCH_BACKEND}" + uv pip list echo "Ensure only a single package is installed" if [ "${PACKAGE_NAME}" == "pytorch" ]; then diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index fdc814c7f7fa1..b31f3dacf5407 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.9.0 +torch >=2.1.0, <2.10.0 fsspec[http] >=2022.5.0, <2025.11.0 packaging >=20.0, <=25.0 typing-extensions >4.5.0, <4.16.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 014b223b1f012..c44e348ffa6cf 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.9.0 +torch >=2.1.0, <2.10.0 tqdm >=4.57.0, <4.68.0 PyYAML >5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.11.0 diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 9a315c25bfa21..b22b2a3679946 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -21,5 +21,6 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.11, <2.21.0 # for `TensorBoardLogger` -torch-tensorrt; platform_system == "Linux" and python_version >= "3.12" +# TODO: resolve GPU test failures for TensorRT due to defaulting to cu13 installations +torch-tensorrt<2.9.0; platform_system == "Linux" and python_version >= "3.12" huggingface-hub diff --git a/requirements/typing.txt b/requirements/typing.txt index dc848c55e583d..8c5ad38fb7825 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.18.2 -torch==2.8.0 +torch==2.9.0 types-Markdown types-PyYAML diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index 9c1b0a2a00572..cd2e05309e087 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -126,10 +126,10 @@ def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None: raise TrainingSpikeException(batch_idx=batch_idx) def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool: - return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol)) + return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol)) # type: ignore def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool: - return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b)) + return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b)) # type: ignore def _is_better(self, diff_val: torch.Tensor) -> bool: if self.mode == "min": diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 9d4a0b9462f2e..a11a9f93e569d 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -111,6 +111,7 @@ def thread_police_duuu_daaa_duuu_daaa(): sys.version_info >= (3, 9) and isinstance(thread, _ExecutorManagerThread) or "ThreadPoolExecutor-" in thread.name + or thread.name == "InductorSubproc" # torch.compile ): # probably `torch.compile`, can't narrow it down further continue diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 878298c6bfd94..da48878c7f670 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -170,6 +170,7 @@ def thread_police_duuu_daaa_duuu_daaa(): sys.version_info >= (3, 9) and isinstance(thread, _ExecutorManagerThread) or "ThreadPoolExecutor-" in thread.name + or thread.name == "InductorSubproc" # torch.compile ): # probably `torch.compile`, can't narrow it down further continue