Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .lightning/workflows/pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ run: |

echo "Install package"
extra=$(python -c "print({'lightning': 'pytorch-'}.get('${PACKAGE_NAME}', ''))")
uv pip install -e ".[${extra}dev]" --upgrade

# Use find-links to prefer CUDA-specific packages from PyTorch index
uv pip install -e ".[${extra}dev]" --upgrade \
--find-links="https://download.pytorch.org/whl/${UV_TORCH_BACKEND}"
uv pip list

echo "Ensure only a single package is installed"
if [ "${PACKAGE_NAME}" == "pytorch" ]; then
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <2.9.0
torch >=2.1.0, <2.10.0
fsspec[http] >=2022.5.0, <2025.11.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <2.9.0
torch >=2.1.0, <2.10.0
tqdm >=4.57.0, <4.68.0
PyYAML >5.4, <6.1.0
fsspec[http] >=2022.5.0, <2025.11.0
Expand Down
3 changes: 2 additions & 1 deletion requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in

tensorboard >=2.11, <2.21.0 # for `TensorBoardLogger`

torch-tensorrt; platform_system == "Linux" and python_version >= "3.12"
# TODO: resolve GPU test failures for TensorRT due to defaulting to cu13 installations
torch-tensorrt<2.9.0; platform_system == "Linux" and python_version >= "3.12"
huggingface-hub
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mypy==1.18.2
torch==2.8.0
torch==2.9.0

types-Markdown
types-PyYAML
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/utilities/spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None:
raise TrainingSpikeException(batch_idx=batch_idx)

def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol))
return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol)) # type: ignore

def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b))
return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b)) # type: ignore

def _is_better(self, diff_val: torch.Tensor) -> bool:
if self.mode == "min":
Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def thread_police_duuu_daaa_duuu_daaa():
sys.version_info >= (3, 9)
and isinstance(thread, _ExecutorManagerThread)
or "ThreadPoolExecutor-" in thread.name
or thread.name == "InductorSubproc" # torch.compile
):
# probably `torch.compile`, can't narrow it down further
continue
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def thread_police_duuu_daaa_duuu_daaa():
sys.version_info >= (3, 9)
and isinstance(thread, _ExecutorManagerThread)
or "ThreadPoolExecutor-" in thread.name
or thread.name == "InductorSubproc" # torch.compile
):
# probably `torch.compile`, can't narrow it down further
continue
Expand Down
Loading