diff --git a/.actions/assistant.py b/.actions/assistant.py index a278b1dfcd3b3..329d20c16e024 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -22,7 +22,7 @@ from itertools import chain from os.path import dirname, isfile from pathlib import Path -from typing import Any, Optional +from typing import Any from packaging.requirements import Requirement from packaging.version import Version @@ -48,7 +48,7 @@ class _RequirementWithComment(Requirement): strict_cmd = "strict" - def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None: + def __init__(self, *args: Any, comment: str = "", pip_argument: str | None = None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.comment = comment assert pip_argument is None or pip_argument # sanity check that it's not an empty str @@ -285,7 +285,7 @@ def copy_replace_imports( source_dir: str, source_imports: Sequence[str], target_imports: Sequence[str], - target_dir: Optional[str] = None, + target_dir: str | None = None, lightning_by: str = "", ) -> None: """Copy package content with import adjustments.""" @@ -347,7 +347,7 @@ def copy_replace_imports( source_dir: str, source_import: str, target_import: str, - target_dir: Optional[str] = None, + target_dir: str | None = None, lightning_by: str = "", ) -> None: """Copy package content with import adjustments.""" @@ -363,7 +363,7 @@ def pull_docs_files( target_dir: str = "docs/source-pytorch/XXX", checkout: str = "refs/tags/1.0.0", source_dir: str = "docs/source", - single_page: Optional[str] = None, + single_page: str | None = None, as_orphan: bool = False, ) -> None: """Pull docs pages from external source and append to local docs. diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index c9f0740152445..e4ea6a865f188 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -1,7 +1,7 @@ import os from collections.abc import Iterable, Mapping from functools import partial -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, cast import torch from lightning_utilities import apply_to_collection @@ -18,18 +18,18 @@ class MyCustomTrainer: def __init__( self, - accelerator: Union[str, Accelerator] = "auto", - strategy: Union[str, Strategy] = "auto", - devices: Union[list[int], str, int] = "auto", - precision: Union[str, int] = "32-true", - plugins: Optional[Union[str, Any]] = None, - callbacks: Optional[Union[list[Any], Any]] = None, - loggers: Optional[Union[Logger, list[Logger]]] = None, - max_epochs: Optional[int] = 1000, - max_steps: Optional[int] = None, + accelerator: str | Accelerator = "auto", + strategy: str | Strategy = "auto", + devices: list[int] | str | int = "auto", + precision: str | int = "32-true", + plugins: str | Any | None = None, + callbacks: list[Any] | Any | None = None, + loggers: Logger | list[Logger] | None = None, + max_epochs: int | None = 1000, + max_steps: int | None = None, grad_accum_steps: int = 1, - limit_train_batches: Union[int, float] = float("inf"), - limit_val_batches: Union[int, float] = float("inf"), + limit_train_batches: int | float = float("inf"), + limit_val_batches: int | float = float("inf"), validation_frequency: int = 1, use_distributed_sampler: bool = True, checkpoint_dir: str = "./checkpoints", @@ -115,8 +115,8 @@ def __init__( self.limit_val_batches = limit_val_batches self.validation_frequency = validation_frequency self.use_distributed_sampler = use_distributed_sampler - self._current_train_return: Union[torch.Tensor, Mapping[str, Any]] = {} - self._current_val_return: Optional[Union[torch.Tensor, Mapping[str, Any]]] = {} + self._current_train_return: torch.Tensor | Mapping[str, Any] = {} + self._current_val_return: torch.Tensor | Mapping[str, Any] | None = {} self.checkpoint_dir = checkpoint_dir self.checkpoint_frequency = checkpoint_frequency @@ -126,7 +126,7 @@ def fit( model: L.LightningModule, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, - ckpt_path: Optional[str] = None, + ckpt_path: str | None = None, ): """The main entrypoint of the trainer, triggering the actual training. @@ -196,8 +196,8 @@ def train_loop( model: L.LightningModule, optimizer: torch.optim.Optimizer, train_loader: torch.utils.data.DataLoader, - limit_batches: Union[int, float] = float("inf"), - scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, + limit_batches: int | float = float("inf"), + scheduler_cfg: Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None = None, ): """The training loop running a single training epoch. @@ -262,8 +262,8 @@ def train_loop( def val_loop( self, model: L.LightningModule, - val_loader: Optional[torch.utils.data.DataLoader], - limit_batches: Union[int, float] = float("inf"), + val_loader: torch.utils.data.DataLoader | None, + limit_batches: int | float = float("inf"), ): """The validation loop running a single validation epoch. @@ -331,7 +331,7 @@ def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> batch_idx: index of the current batch w.r.t the current epoch """ - outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx) + outputs: torch.Tensor | Mapping[str, Any] = model.training_step(batch, batch_idx=batch_idx) loss = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"] @@ -347,7 +347,7 @@ def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> def step_scheduler( self, model: L.LightningModule, - scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]], + scheduler_cfg: Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None, level: Literal["step", "epoch"], current_value: int, ) -> None: @@ -387,7 +387,7 @@ def step_scheduler( possible_monitor_vals.update({"val_" + k: v for k, v in self._current_val_return.items()}) try: - monitor = possible_monitor_vals[cast(Optional[str], scheduler_cfg["monitor"])] + monitor = possible_monitor_vals[cast(str | None, scheduler_cfg["monitor"])] except KeyError as ex: possible_keys = list(possible_monitor_vals.keys()) raise KeyError( @@ -414,7 +414,7 @@ def progbar_wrapper(self, iterable: Iterable, total: int, **kwargs: Any): return tqdm(iterable, total=total, **kwargs) return iterable - def load(self, state: Optional[Mapping], path: str) -> None: + def load(self, state: Mapping | None, path: str) -> None: """Loads a checkpoint from a given file into state. Args: @@ -432,7 +432,7 @@ def load(self, state: Optional[Mapping], path: str) -> None: if remainder: raise RuntimeError(f"Unused Checkpoint Values: {remainder}") - def save(self, state: Optional[Mapping]) -> None: + def save(self, state: Mapping | None) -> None: """Saves a checkpoint to the ``checkpoint_dir`` Args: @@ -447,7 +447,7 @@ def save(self, state: Optional[Mapping]) -> None: self.fabric.save(os.path.join(self.checkpoint_dir, f"epoch-{self.current_epoch:04d}.ckpt"), state) @staticmethod - def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: + def get_latest_checkpoint(checkpoint_dir: str) -> str | None: """Returns the latest checkpoint from the ``checkpoint_dir`` Args: @@ -467,8 +467,8 @@ def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: def _parse_optimizers_schedulers( self, configure_optim_output ) -> tuple[ - Optional[L.fabric.utilities.types.Optimizable], - Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]], + L.fabric.utilities.types.Optimizable | None, + Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None, ]: """Recursively parses the output of :meth:`lightning.pytorch.LightningModule.configure_optimizers`. @@ -521,7 +521,7 @@ def _parse_optimizers_schedulers( @staticmethod def _format_iterable( - prog_bar, candidates: Optional[Union[torch.Tensor, Mapping[str, Union[torch.Tensor, float, int]]]], prefix: str + prog_bar, candidates: torch.Tensor | Mapping[str, torch.Tensor | float | int] | None, prefix: str ): """Adds values as postfix string to progressbar. diff --git a/examples/fabric/reinforcement_learning/rl/utils.py b/examples/fabric/reinforcement_learning/rl/utils.py index 4c5a8066b359f..f5f0bfc021c85 100644 --- a/examples/fabric/reinforcement_learning/rl/utils.py +++ b/examples/fabric/reinforcement_learning/rl/utils.py @@ -1,7 +1,7 @@ import argparse import math import os -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Union import gymnasium as gym import torch @@ -160,7 +160,7 @@ def linear_annealing(optimizer: torch.optim.Optimizer, update: int, num_updates: pg["lr"] = lrnow -def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: Optional[str] = None, prefix: str = ""): +def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: str | None = None, prefix: str = ""): def thunk(): env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordEpisodeStatistics(env) diff --git a/examples/fabric/tensor_parallel/model.py b/examples/fabric/tensor_parallel/model.py index 71f2634867e9b..23b72de598832 100644 --- a/examples/fabric/tensor_parallel/model.py +++ b/examples/fabric/tensor_parallel/model.py @@ -9,7 +9,6 @@ from dataclasses import dataclass -from typing import Optional import torch import torch.nn.functional as F @@ -21,10 +20,10 @@ class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 - n_kv_heads: Optional[int] = None + n_kv_heads: int | None = None vocab_size: int = -1 # defined later by tokenizer multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 10000 @@ -248,7 +247,7 @@ def __init__( dim: int, hidden_dim: int, multiple_of: int, - ffn_dim_multiplier: Optional[float], + ffn_dim_multiplier: float | None, ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index 332c9a811e3e4..c7c5015f0ec71 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -18,7 +18,6 @@ """ from os import path -from typing import Optional import torch import torch.nn.functional as F @@ -46,7 +45,7 @@ def __init__( nrow: int = 8, padding: int = 2, normalize: bool = True, - value_range: Optional[tuple[int, int]] = None, + value_range: tuple[int, int] | None = None, scale_each: bool = False, pad_value: int = 0, ) -> None: diff --git a/examples/pytorch/basics/backbone_image_classifier.py b/examples/pytorch/basics/backbone_image_classifier.py index 965f636d7fc0b..fa69b9dc24caa 100644 --- a/examples/pytorch/basics/backbone_image_classifier.py +++ b/examples/pytorch/basics/backbone_image_classifier.py @@ -18,7 +18,6 @@ """ from os import path -from typing import Optional import torch from torch.nn import functional as F @@ -63,7 +62,7 @@ class LitClassifier(LightningModule): ) """ - def __init__(self, backbone: Optional[Backbone] = None, learning_rate: float = 0.0001): + def __init__(self, backbone: Backbone | None = None, learning_rate: float = 0.0001): super().__init__() self.save_hyperparameters(ignore=["backbone"]) if backbone is None: diff --git a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py index 69721214748ee..bec8ad1cf01fe 100644 --- a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py +++ b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py @@ -42,7 +42,6 @@ import logging from pathlib import Path -from typing import Union import torch import torch.nn.functional as F @@ -91,7 +90,7 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O class CatDogImageDataModule(LightningDataModule): - def __init__(self, dl_path: Union[str, Path] = "data", num_workers: int = 0, batch_size: int = 8): + def __init__(self, dl_path: str | Path = "data", num_workers: int = 0, batch_size: int = 8): """CatDogImageDataModule. Args: diff --git a/examples/pytorch/domain_templates/imagenet.py b/examples/pytorch/domain_templates/imagenet.py index fd2050e2ed38b..4610db889eadb 100644 --- a/examples/pytorch/domain_templates/imagenet.py +++ b/examples/pytorch/domain_templates/imagenet.py @@ -32,7 +32,6 @@ """ import os -from typing import Optional import torch import torch.nn.functional as F @@ -65,7 +64,7 @@ def __init__( self, data_path: str, arch: str = "resnet18", - weights: Optional[str] = None, + weights: str | None = None, lr: float = 0.1, momentum: float = 0.9, weight_decay: float = 1e-4, @@ -82,8 +81,8 @@ def __init__( self.batch_size = batch_size self.workers = workers self.model = get_torchvision_model(self.arch, weights=self.weights) - self.train_dataset: Optional[Dataset] = None - self.eval_dataset: Optional[Dataset] = None + self.train_dataset: Dataset | None = None + self.eval_dataset: Dataset | None = None # ToDo: this number of classes hall be parsed when the dataset is loaded from folder self.train_acc1 = Accuracy(task="multiclass", num_classes=1000, top_k=1) self.train_acc5 = Accuracy(task="multiclass", num_classes=1000, top_k=5) diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index 55581c1b68088..19572c3d4a4f5 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -30,8 +30,7 @@ """ import argparse -from collections.abc import Iterator -from typing import Callable +from collections.abc import Callable, Iterator import gym import torch diff --git a/examples/pytorch/servable_module/production.py b/examples/pytorch/servable_module/production.py index 854ff1176b619..3ed48cf9c091a 100644 --- a/examples/pytorch/servable_module/production.py +++ b/examples/pytorch/servable_module/production.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from io import BytesIO from os import path -from typing import Optional import numpy as np import torch @@ -56,8 +55,8 @@ def val_dataloader(self, *args, **kwargs): @dataclass(unsafe_hash=True) class Image: - height: Optional[int] = None - width: Optional[int] = None + height: int | None = None + width: int | None = None extension: str = "JPEG" mode: str = "RGB" channel_first: bool = False diff --git a/examples/pytorch/tensor_parallel/model.py b/examples/pytorch/tensor_parallel/model.py index 71f2634867e9b..23b72de598832 100644 --- a/examples/pytorch/tensor_parallel/model.py +++ b/examples/pytorch/tensor_parallel/model.py @@ -9,7 +9,6 @@ from dataclasses import dataclass -from typing import Optional import torch import torch.nn.functional as F @@ -21,10 +20,10 @@ class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 - n_kv_heads: Optional[int] = None + n_kv_heads: int | None = None vocab_size: int = -1 # defined later by tokenizer multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 10000 @@ -248,7 +247,7 @@ def __init__( dim: int, hidden_dim: int, multiple_of: int, - ffn_dim_multiplier: Optional[float], + ffn_dim_multiplier: float | None, ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) diff --git a/pyproject.toml b/pyproject.toml index 078738d21111d..a530070c29eb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ ignore-words-list = "te, compiletime" [tool.ruff] line-length = 120 -target-version = "py39" +target-version = "py310" # Exclude a variety of commonly ignored directories. exclude = [ ".git", diff --git a/setup.py b/setup.py index fffb38f9a578b..88925369dfbfb 100755 --- a/setup.py +++ b/setup.py @@ -48,7 +48,6 @@ from collections.abc import Generator, Mapping from importlib.util import module_from_spec, spec_from_file_location from types import ModuleType -from typing import Optional import setuptools import setuptools.command.egg_info @@ -76,7 +75,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: return py -def _named_temporary_file(directory: Optional[str] = None) -> str: +def _named_temporary_file(directory: str | None = None) -> str: # `tempfile.NamedTemporaryFile` has issues in Windows # https://github.com/deepchem/deepchem/issues/707#issuecomment-556002823 if directory is None: diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 155174313d45b..8df3db63683f4 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union import torch from typing_extensions import override @@ -39,13 +38,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str]) -> int: + def parse_devices(devices: int | str) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: + def get_parallel_devices(devices: int | str) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @@ -77,7 +76,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int: +def _parse_cpu_cores(cpu_cores: int | str) -> int: """Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 562dcfc9cd744..a98f60295567f 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import lru_cache -from typing import Optional, Union import torch from typing_extensions import override @@ -43,7 +42,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + def parse_devices(devices: int | str | list[int]) -> list[int] | None: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -156,7 +155,7 @@ def is_cuda_available() -> bool: return torch.cuda.is_available() -def _is_ampere_or_later(device: Optional[torch.device] = None) -> bool: +def _is_ampere_or_later(device: torch.device | None = None) -> bool: major, _ = torch.cuda.get_device_capability(device) return major >= 8 # Ampere and later leverage tensor cores, where this setting becomes useful diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index f3a9cf8b9d415..8f6e3dbd6f2e7 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -14,7 +14,6 @@ import os import platform from functools import lru_cache -from typing import Optional, Union import torch from typing_extensions import override @@ -46,7 +45,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + def parse_devices(devices: int | str | list[int]) -> list[int] | None: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -54,7 +53,7 @@ def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: @staticmethod @override - def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]: + def get_parallel_devices(devices: int | str | list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 539b7aa8a01dc..7fb73bf183778 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from typing_extensions import override @@ -47,7 +48,7 @@ def __init__(self, a, b): def register( self, name: str, - accelerator: Optional[Callable] = None, + accelerator: Callable | None = None, description: str = "", override: bool = False, **init_params: Any, @@ -85,7 +86,7 @@ def do_register(accelerator: Callable) -> Callable: return do_register @override - def get(self, name: str, default: Optional[Any] = None) -> Any: + def get(self, name: str, default: Any | None = None) -> Any: """Calls the registered accelerator with the required parameters and returns the accelerator object. Args: diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index db2cf2586e1ba..37fa2c79d2d51 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Union +from typing import Any import torch from lightning_utilities.core.imports import RequirementCache @@ -47,13 +47,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: + def parse_devices(devices: int | str | list[int]) -> int | list[int]: """Accelerator device parsing logic.""" return _parse_tpu_devices(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: + def get_parallel_devices(devices: int | list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_tpu_devices(devices) if isinstance(devices, int): @@ -131,7 +131,7 @@ def _using_pjrt() -> bool: return pjrt.using_pjrt() -def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: +def _parse_tpu_devices(devices: int | str | list[int]) -> int | list[int]: """Parses the TPU devices given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`. @@ -168,7 +168,7 @@ def _check_tpu_devices_valid(devices: object) -> None: ) -def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]: +def _parse_tpu_devices_str(devices: str) -> int | list[int]: devices = devices.strip() try: return int(devices) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 594bb46f4b362..004eabc89af88 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -15,11 +15,10 @@ import os import re from argparse import Namespace -from typing import Any, Optional +from typing import Any, get_args import torch from lightning_utilities.core.imports import RequirementCache -from typing_extensions import get_args from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS @@ -156,7 +155,7 @@ def _run(**kwargs: Any) -> None: " and a '.consolidated' suffix." ), ) - def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None: + def _consolidate(checkpoint_folder: str, output_file: str | None) -> None: """Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`. Only supports FSDP sharded checkpoints at the moment. @@ -229,7 +228,7 @@ def _torchrun_launch(args: Namespace, script_args: list[str]) -> None: torchrun.main(torchrun_args) -def main(args: Namespace, script_args: Optional[list[str]] = None) -> None: +def main(args: Namespace, script_args: list[str] | None = None) -> None: _set_env_variables(args) _torchrun_launch(args, script_args or []) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index b3289debbd522..7228b1cbc6f0e 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -14,10 +14,9 @@ import os from collections import Counter from collections.abc import Iterable -from typing import Any, Optional, Union, cast +from typing import Any, cast, get_args import torch -from typing_extensions import get_args from lightning.fabric.accelerators import ACCELERATOR_REGISTRY from lightning.fabric.accelerators.accelerator import Accelerator @@ -68,7 +67,7 @@ from lightning.fabric.utilities.device_parser import _determine_root_gpu_device from lightning.fabric.utilities.imports import _IS_INTERACTIVE -_PLUGIN_INPUT = Union[Precision, ClusterEnvironment, CheckpointIO] +_PLUGIN_INPUT = Precision | ClusterEnvironment | CheckpointIO class _Connector: @@ -98,12 +97,12 @@ class _Connector: def __init__( self, - accelerator: Union[str, Accelerator] = "auto", - strategy: Union[str, Strategy] = "auto", - devices: Union[list[int], str, int] = "auto", + accelerator: str | Accelerator = "auto", + strategy: str | Strategy = "auto", + devices: list[int] | str | int = "auto", num_nodes: int = 1, - precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None, + precision: _PRECISION_INPUT | None = None, + plugins: _PLUGIN_INPUT | Iterable[_PLUGIN_INPUT] | None = None, ) -> None: # These arguments can be set through environment variables set by the CLI accelerator = self._argument_from_env("accelerator", accelerator, default="auto") @@ -120,13 +119,13 @@ def __init__( # Raise an exception if there are conflicts between flags # Set each valid flag to `self._x_flag` after validation # For devices: Assign gpus, etc. to the accelerator flag and devices flag - self._strategy_flag: Union[Strategy, str] = "auto" - self._accelerator_flag: Union[Accelerator, str] = "auto" + self._strategy_flag: Strategy | str = "auto" + self._accelerator_flag: Accelerator | str = "auto" self._precision_input: _PRECISION_INPUT_STR = "32-true" - self._precision_instance: Optional[Precision] = None - self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: list[Union[int, torch.device, str]] = [] - self.checkpoint_io: Optional[CheckpointIO] = None + self._precision_instance: Precision | None = None + self._cluster_environment_flag: ClusterEnvironment | str | None = None + self._parallel_devices: list[int | torch.device | str] = [] + self.checkpoint_io: CheckpointIO | None = None self._check_config_and_set_final_flags( strategy=strategy, @@ -163,10 +162,10 @@ def __init__( def _check_config_and_set_final_flags( self, - strategy: Union[str, Strategy], - accelerator: Union[str, Accelerator], - precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]], + strategy: str | Strategy, + accelerator: str | Accelerator, + precision: _PRECISION_INPUT | None, + plugins: _PLUGIN_INPUT | Iterable[_PLUGIN_INPUT] | None, ) -> None: """This method checks: @@ -295,7 +294,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: + def _check_device_config_and_set_final_flags(self, devices: list[int] | str | int, num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") @@ -391,7 +390,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: return env_type() return LightningEnvironment() - def _choose_strategy(self) -> Union[Strategy, str]: + def _choose_strategy(self) -> Strategy | str: if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator): if self._parallel_devices and len(self._parallel_devices) > 1: return "xla" @@ -540,7 +539,7 @@ def _lazy_init_strategy(self) -> None: @staticmethod def _argument_from_env(name: str, current: Any, default: Any) -> Any: - env_value: Optional[str] = os.environ.get("LT_" + name.upper()) + env_value: str | None = os.environ.get("LT_" + name.upper()) if env_value is None: return current @@ -554,7 +553,7 @@ def _argument_from_env(name: str, current: Any, default: Any) -> Any: return env_value -def _convert_precision_to_unified_args(precision: Optional[_PRECISION_INPUT]) -> Optional[_PRECISION_INPUT_STR]: +def _convert_precision_to_unified_args(precision: _PRECISION_INPUT | None) -> _PRECISION_INPUT_STR | None: if precision is None: return None diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 288c355a4ebf2..d2c37d6b7bcd1 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -13,16 +13,14 @@ # limitations under the License. import inspect import os -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Callable, Generator, Mapping, Sequence from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, Optional, - Union, cast, overload, ) @@ -134,14 +132,14 @@ class Fabric: def __init__( self, *, - accelerator: Union[str, Accelerator] = "auto", - strategy: Union[str, Strategy] = "auto", - devices: Union[list[int], str, int] = "auto", + accelerator: str | Accelerator = "auto", + strategy: str | Strategy = "auto", + devices: list[int] | str | int = "auto", num_nodes: int = 1, - precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, - callbacks: Optional[Union[list[Any], Any]] = None, - loggers: Optional[Union[Logger, list[Logger]]] = None, + precision: _PRECISION_INPUT | None = None, + plugins: _PLUGIN_INPUT | list[_PLUGIN_INPUT] | None = None, + callbacks: list[Any] | Any | None = None, + loggers: Logger | list[Logger] | None = None, ) -> None: self._connector = _Connector( accelerator=accelerator, @@ -373,7 +371,7 @@ def setup_module( self._models_setup += 1 return module - def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, tuple[_FabricOptimizer, ...]]: + def setup_optimizers(self, *optimizers: Optimizer) -> _FabricOptimizer | tuple[_FabricOptimizer, ...]: r"""Set up one or more optimizers for accelerated training. Some strategies do not allow setting up model and optimizer independently. For them, you should call @@ -411,7 +409,7 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, tu def setup_dataloaders( self, *dataloaders: DataLoader, use_distributed_sampler: bool = True, move_to_device: bool = True - ) -> Union[DataLoader, list[DataLoader]]: + ) -> DataLoader | list[DataLoader]: r"""Set up one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -479,7 +477,7 @@ def _setup_dataloader( fabric_dataloader = cast(DataLoader, fabric_dataloader) return fabric_dataloader - def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] = None, **kwargs: Any) -> None: + def backward(self, tensor: Tensor, *args: Any, model: _FabricModule | None = None, **kwargs: Any) -> None: r"""Replaces ``loss.backward()`` in your training loop. Handles precision automatically for you. Args: @@ -526,13 +524,13 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] = def clip_gradients( self, - module: Union[torch.nn.Module, _FabricModule], - optimizer: Union[Optimizer, _FabricOptimizer], - clip_val: Optional[Union[float, int]] = None, - max_norm: Optional[Union[float, int]] = None, - norm_type: Union[float, int] = 2.0, + module: torch.nn.Module | _FabricModule, + optimizer: Optimizer | _FabricOptimizer, + clip_val: float | int | None = None, + max_norm: float | int | None = None, + norm_type: float | int = 2.0, error_if_nonfinite: bool = True, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """Clip the gradients of the model to a given max value or max norm. Args: @@ -598,7 +596,7 @@ def to_device(self, obj: Tensor) -> Tensor: ... @overload def to_device(self, obj: Any) -> Any: ... - def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: + def to_device(self, obj: nn.Module | Tensor | Any) -> nn.Module | Tensor | Any: r"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on that device. @@ -626,7 +624,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: if self.local_rank == 0: print(*args, **kwargs) - def barrier(self, name: Optional[str] = None) -> None: + def barrier(self, name: str | None = None) -> None: """Wait for all processes to enter this call. Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization @@ -655,8 +653,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return self._strategy.broadcast(obj, src=src) def all_gather( - self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, dict, list, tuple]: + self, data: Tensor | dict | list | tuple, group: Any | None = None, sync_grads: bool = False + ) -> Tensor | dict | list | tuple: """Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes and the tensors need to have the same shape across all @@ -680,10 +678,10 @@ def all_gather( def all_reduce( self, - data: Union[Tensor, dict, list, tuple], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ) -> Union[Tensor, dict, list, tuple]: + data: Tensor | dict | list | tuple, + group: Any | None = None, + reduce_op: ReduceOp | str | None = "mean", + ) -> Tensor | dict | list | tuple: """Reduce tensors or collections of tensors from multiple processes. The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. @@ -802,7 +800,7 @@ def init_tensor(self) -> AbstractContextManager: the right data type depending on the precision setting in Fabric.""" return self._strategy.tensor_init_context() - def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def init_module(self, empty_init: bool | None = None) -> AbstractContextManager: """Instantiate the model and its parameters under this context manager to reduce peak memory usage. The parameters get created on the device and with the right data type right away without wasting memory being @@ -819,9 +817,9 @@ def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManag def save( self, - path: Union[str, Path], - state: dict[str, Union[nn.Module, Optimizer, Any]], - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + path: str | Path, + state: dict[str, nn.Module | Optimizer | Any], + filter: dict[str, Callable[[str, Any], bool]] | None = None, ) -> None: r"""Save checkpoint contents to a file. @@ -868,8 +866,8 @@ def param_filter(name, param): def load( self, - path: Union[str, Path], - state: Optional[dict[str, Union[nn.Module, Optimizer, Any]]] = None, + path: str | Path, + state: dict[str, nn.Module | Optimizer | Any] | None = None, strict: bool = True, ) -> dict[str, Any]: """Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.) @@ -911,7 +909,7 @@ def load( state[k] = unwrapped_state[k] return remainder - def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], strict: bool = True) -> None: + def load_raw(self, path: str | Path, obj: nn.Module | Optimizer, strict: bool = True) -> None: """Load the state of a module or optimizer from a single state-dict file. Use this for loading a raw PyTorch model checkpoint created without Fabric. @@ -1024,7 +1022,7 @@ def on_train_epoch_end(self, results): # method(self, *args, y=1) # method(self, *args, **kwargs) - def log(self, name: str, value: Any, step: Optional[int] = None) -> None: + def log(self, name: str, value: Any, step: int | None = None) -> None: """Log a scalar to all loggers that were added to Fabric. Args: @@ -1037,7 +1035,7 @@ def log(self, name: str, value: Any, step: Optional[int] = None) -> None: """ self.log_dict(metrics={name: value}, step=step) - def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> None: + def log_dict(self, metrics: Mapping[str, Any], step: int | None = None) -> None: """Log multiple scalars at once to all loggers that were added to Fabric. Args: @@ -1052,7 +1050,7 @@ def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> No logger.log_metrics(metrics=metrics, step=step) @staticmethod - def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None, verbose: bool = True) -> int: + def seed_everything(seed: int | None = None, workers: bool | None = None, verbose: bool = True) -> int: r"""Helper function to seed everything without explicitly importing Lightning. See :func:`~lightning.fabric.utilities.seed.seed_everything` for more details. @@ -1204,7 +1202,7 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") @staticmethod - def _configure_callbacks(callbacks: Optional[Union[list[Any], Any]]) -> list[Any]: + def _configure_callbacks(callbacks: list[Any] | Any | None) -> list[Any]: callbacks = callbacks if callbacks is not None else [] callbacks = callbacks if isinstance(callbacks, list) else [callbacks] callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory")) diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index dd7dfc63671f0..0e217bceef9a4 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -16,7 +16,7 @@ import logging import os from argparse import Namespace -from typing import Any, Optional, Union +from typing import Any from torch import Tensor from typing_extensions import override @@ -61,8 +61,8 @@ class CSVLogger(Logger): def __init__( self, root_dir: _PATH, - name: Optional[str] = "lightning_logs", - version: Optional[Union[int, str]] = None, + name: str | None = "lightning_logs", + version: int | str | None = None, prefix: str = "", flush_logs_every_n_steps: int = 100, ): @@ -73,7 +73,7 @@ def __init__( self._version = version self._prefix = prefix self._fs = get_filesystem(root_dir) - self._experiment: Optional[_ExperimentWriter] = None + self._experiment: _ExperimentWriter | None = None self._flush_logs_every_n_steps = flush_logs_every_n_steps @property @@ -89,7 +89,7 @@ def name(self) -> str: @property @override - def version(self) -> Union[int, str]: + def version(self) -> int | str: """Gets the version of the experiment. Returns: @@ -138,13 +138,13 @@ def experiment(self) -> "_ExperimentWriter": @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: raise NotImplementedError("The `CSVLogger` does not yet support logging hyperparameters.") @override @rank_zero_only def log_metrics( # type: ignore[override] - self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None + self, metrics: dict[str, Tensor | float], step: int | None = None ) -> None: metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) if step is None: @@ -210,10 +210,10 @@ def __init__(self, log_dir: str) -> None: self._check_log_dir_exists() self._fs.makedirs(self.log_dir, exist_ok=True) - def log_metrics(self, metrics_dict: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics_dict: dict[str, float], step: int | None = None) -> None: """Record metrics.""" - def _handle_value(value: Union[Tensor, Any]) -> Any: + def _handle_value(value: Tensor | Any) -> Any: if isinstance(value, Tensor): return value.item() return value diff --git a/src/lightning/fabric/loggers/logger.py b/src/lightning/fabric/loggers/logger.py index 39a9fa06a08d0..0a79b484fb3da 100644 --- a/src/lightning/fabric/loggers/logger.py +++ b/src/lightning/fabric/loggers/logger.py @@ -15,8 +15,9 @@ from abc import ABC, abstractmethod from argparse import Namespace +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any from torch import Tensor from torch.nn import Module @@ -29,22 +30,22 @@ class Logger(ABC): @property @abstractmethod - def name(self) -> Optional[str]: + def name(self) -> str | None: """Return the experiment name.""" @property @abstractmethod - def version(self) -> Optional[Union[int, str]]: + def version(self) -> int | str | None: """Return the experiment version.""" @property - def root_dir(self) -> Optional[str]: + def root_dir(self) -> str | None: """Return the root directory where all versions of an experiment get saved, or `None` if the logger does not save data locally.""" return None @property - def log_dir(self) -> Optional[str]: + def log_dir(self) -> str | None: """Return directory the current version of the experiment gets saved, or `None` if the logger does not save data locally.""" return None @@ -55,7 +56,7 @@ def group_separator(self) -> str: return "/" @abstractmethod - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: """Records metrics. This method logs metrics as soon as it received them. Args: @@ -66,7 +67,7 @@ def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> pass @abstractmethod - def log_hyperparams(self, params: Union[dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace, *args: Any, **kwargs: Any) -> None: """Record hyperparameters. Args: @@ -76,7 +77,7 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace], *args: Any, """ - def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: + def log_graph(self, model: Module, input_array: Tensor | None = None) -> None: """Record model graph. Args: @@ -103,7 +104,7 @@ def rank_zero_experiment(fn: Callable) -> Callable: """Returns the real experiment on rank 0 and otherwise the _DummyExperiment.""" @wraps(fn) - def experiment(self: Logger) -> Union[Any, _DummyExperiment]: + def experiment(self: Logger) -> Any | _DummyExperiment: """ Note: ``self`` is a custom logger instance. The loggers typically wrap an ``experiment`` method diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 208244dc38cd3..9ae31645282a0 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -15,7 +15,7 @@ import os from argparse import Namespace from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -82,11 +82,11 @@ class TensorBoardLogger(Logger): def __init__( self, root_dir: _PATH, - name: Optional[str] = "lightning_logs", - version: Optional[Union[int, str]] = None, + name: str | None = "lightning_logs", + version: int | str | None = None, default_hp_metric: bool = True, prefix: str = "", - sub_dir: Optional[_PATH] = None, + sub_dir: _PATH | None = None, **kwargs: Any, ): if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE: @@ -105,7 +105,7 @@ def __init__( self._prefix = prefix self._fs = get_filesystem(root_dir) - self._experiment: Optional[SummaryWriter] = None + self._experiment: SummaryWriter | None = None self._kwargs = kwargs @property @@ -121,7 +121,7 @@ def name(self) -> str: @property @override - def version(self) -> Union[int, str]: + def version(self) -> int | str: """Get the experiment version. Returns: @@ -161,7 +161,7 @@ def log_dir(self) -> str: return log_dir @property - def sub_dir(self) -> Optional[str]: + def sub_dir(self) -> str | None: """Gets the sub directory where the TensorBoard experiments are saved. Returns: @@ -197,7 +197,7 @@ def experiment(self) -> "SummaryWriter": @override @rank_zero_only - def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) @@ -221,9 +221,9 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) @rank_zero_only def log_hyperparams( self, - params: Union[dict[str, Any], Namespace], - metrics: Optional[dict[str, Any]] = None, - step: Optional[int] = None, + params: dict[str, Any] | Namespace, + metrics: dict[str, Any] | None = None, + step: int | None = None, ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to @@ -263,7 +263,7 @@ def log_hyperparams( @override @rank_zero_only - def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: + def log_graph(self, model: Module, input_array: Tensor | None = None) -> None: model_example_input = getattr(model, "example_input_array", None) input_array = model_example_input if input_array is None else input_array model = _unwrap_objects(model) diff --git a/src/lightning/fabric/plugins/collectives/collective.py b/src/lightning/fabric/plugins/collectives/collective.py index 9408fd87da400..f816f06200743 100644 --- a/src/lightning/fabric/plugins/collectives/collective.py +++ b/src/lightning/fabric/plugins/collectives/collective.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from torch import Tensor from typing_extensions import Self @@ -17,7 +17,7 @@ class Collective(ABC): """ def __init__(self) -> None: - self._group: Optional[CollectibleGroup] = None + self._group: CollectibleGroup | None = None @property @abstractmethod @@ -65,10 +65,10 @@ def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[T def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ... @abstractmethod - def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: ... + def recv(self, tensor: Tensor, src: int | None = None, tag: int = 0) -> Tensor: ... @abstractmethod - def barrier(self, device_ids: Optional[list[int]] = None) -> None: ... + def barrier(self, device_ids: list[int] | None = None) -> None: ... @classmethod @abstractmethod diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 182e75f4583ef..7c2bca6dbc8d5 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -1,6 +1,6 @@ import datetime import os -from typing import Any, Optional, Union +from typing import Any import torch import torch.distributed as dist @@ -56,13 +56,13 @@ def broadcast(self, tensor: Tensor, src: int) -> Tensor: return tensor @override - def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor: + def all_reduce(self, tensor: Tensor, op: str | ReduceOp | RedOpType = "sum") -> Tensor: op = self._convert_to_native_op(op) dist.all_reduce(tensor, op=op, group=self.group) return tensor @override - def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor: + def reduce(self, tensor: Tensor, dst: int, op: str | ReduceOp | RedOpType = "sum") -> Tensor: op = self._convert_to_native_op(op) dist.reduce(tensor, dst, op=op, group=self.group) # type: ignore[arg-type] return tensor @@ -84,7 +84,7 @@ def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> T @override def reduce_scatter( - self, output: Tensor, input_list: list[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" + self, output: Tensor, input_list: list[Tensor], op: str | ReduceOp | RedOpType = "sum" ) -> Tensor: op = self._convert_to_native_op(op) dist.reduce_scatter(output, input_list, op=op, group=self.group) @@ -100,7 +100,7 @@ def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: dist.send(tensor, dst, tag=tag, group=self.group) # type: ignore[arg-type] @override - def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: + def recv(self, tensor: Tensor, src: int | None = None, tag: int = 0) -> Tensor: dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type] return tensor @@ -108,9 +108,7 @@ def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]: dist.all_gather_object(object_list, obj, group=self.group) return object_list - def broadcast_object_list( - self, object_list: list[Any], src: int, device: Optional[torch.device] = None - ) -> list[Any]: + def broadcast_object_list(self, object_list: list[Any], src: int, device: torch.device | None = None) -> list[Any]: dist.broadcast_object_list(object_list, src, group=self.group, device=device) # type: ignore[arg-type] return object_list @@ -125,16 +123,16 @@ def scatter_object_list( return scatter_object_output_list @override - def barrier(self, device_ids: Optional[list[int]] = None) -> None: + def barrier(self, device_ids: list[int] | None = None) -> None: if self.group == dist.GroupMember.NON_GROUP_MEMBER: return dist.barrier(group=self.group, device_ids=device_ids) # type: ignore[arg-type] - def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None: + def monitored_barrier(self, timeout: datetime.timedelta | None = None, wait_all_ranks: bool = False) -> None: dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) # type: ignore[arg-type] @override - def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any) -> Self: + def setup(self, main_address: str | None = None, main_port: str | None = None, **kwargs: Any) -> Self: if self.is_initialized(): return self # maybe set addr @@ -203,7 +201,7 @@ def destroy_group(cls, group: CollectibleGroup) -> None: @classmethod @override - def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]: + def _convert_to_native_op(cls, op: str | ReduceOp | RedOpType) -> ReduceOp | RedOpType: # `ReduceOp` is an empty shell for `RedOpType`, the latter being the actually returned class. # For example, `ReduceOp.SUM` returns a `RedOpType.SUM`. the only exception is `RedOpType.PREMUL_SUM` where # `ReduceOp` is still the desired class, but it's created via a special `_make_nccl_premul_sum` function diff --git a/src/lightning/fabric/plugins/environments/mpi.py b/src/lightning/fabric/plugins/environments/mpi.py index dd4897663d187..e4d092d724319 100644 --- a/src/lightning/fabric/plugins/environments/mpi.py +++ b/src/lightning/fabric/plugins/environments/mpi.py @@ -15,7 +15,6 @@ import logging import socket from functools import lru_cache -from typing import Optional from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -42,10 +41,10 @@ def __init__(self) -> None: from mpi4py import MPI self._comm_world = MPI.COMM_WORLD - self._comm_local: Optional[MPI.Comm] = None - self._node_rank: Optional[int] = None - self._main_address: Optional[str] = None - self._main_port: Optional[int] = None + self._comm_local: MPI.Comm | None = None + self._node_rank: int | None = None + self._main_address: str | None = None + self._main_port: int | None = None @property @override diff --git a/src/lightning/fabric/plugins/environments/slurm.py b/src/lightning/fabric/plugins/environments/slurm.py index 4d98b7ed6a8eb..78e61d7fa98e5 100644 --- a/src/lightning/fabric/plugins/environments/slurm.py +++ b/src/lightning/fabric/plugins/environments/slurm.py @@ -18,7 +18,6 @@ import shutil import signal import sys -from typing import Optional from typing_extensions import override @@ -44,7 +43,7 @@ class SLURMEnvironment(ClusterEnvironment): """ - def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Signals] = None) -> None: + def __init__(self, auto_requeue: bool = True, requeue_signal: signal.Signals | None = None) -> None: super().__init__() self.auto_requeue = auto_requeue if requeue_signal is None and not _IS_WINDOWS: @@ -113,11 +112,11 @@ def detect() -> bool: return _is_srun_used() @staticmethod - def job_name() -> Optional[str]: + def job_name() -> str | None: return os.environ.get("SLURM_JOB_NAME") @staticmethod - def job_id() -> Optional[int]: + def job_id() -> int | None: # in interactive mode, don't make logs use the same job id if _is_slurm_interactive_mode(): return None diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index db7578d9ca8c0..eb1c148725cdf 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from lightning.fabric.utilities.types import _PATH @@ -36,7 +36,7 @@ class CheckpointIO(ABC): """ @abstractmethod - def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Any | None = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -48,7 +48,7 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio @abstractmethod def load_checkpoint( - self, path: _PATH, map_location: Optional[Any] = None, weights_only: Optional[bool] = None + self, path: _PATH, map_location: Any | None = None, weights_only: bool | None = None ) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index c52ad6913e1e2..179154f951817 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -13,7 +13,8 @@ # limitations under the License. import logging import os -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from typing_extensions import override @@ -34,7 +35,7 @@ class TorchCheckpointIO(CheckpointIO): """ @override - def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Any | None = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -61,8 +62,8 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio def load_checkpoint( self, path: _PATH, - map_location: Optional[Callable] = lambda storage, loc: storage, - weights_only: Optional[bool] = None, + map_location: Callable | None = lambda storage, loc: storage, + weights_only: bool | None = None, ) -> dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py index 146fa2f33b510..40a51faffef36 100644 --- a/src/lightning/fabric/plugins/io/xla.py +++ b/src/lightning/fabric/plugins/io/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Optional +from typing import Any import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -41,7 +41,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @override - def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Any | None = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index d5fc1f0c1cc2a..bdce089d3cfe3 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -72,7 +72,7 @@ def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) @override - def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: + def backward(self, tensor: Tensor, model: Module | None, *args: Any, **kwargs: Any) -> None: if self.scaler is not None: tensor = self.scaler.scale(tensor) super().backward(tensor, model, *args, **kwargs) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 4c648f2b97181..466f21d4946ef 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -17,10 +17,11 @@ import os import warnings from collections import OrderedDict +from collections.abc import Callable from contextlib import AbstractContextManager, ExitStack from functools import partial from types import ModuleType -from typing import Any, Callable, Literal, Optional, cast +from typing import Any, Literal, cast import torch from lightning_utilities import apply_to_collection @@ -70,8 +71,8 @@ class BitsandbytesPrecision(Precision): def __init__( self, mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"], - dtype: Optional[torch.dtype] = None, - ignore_modules: Optional[set[str]] = None, + dtype: torch.dtype | None = None, + ignore_modules: set[str] | None = None, ) -> None: _import_bitsandbytes() @@ -176,7 +177,7 @@ def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _In def _replace_param( - param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None + param: torch.nn.Parameter, data: torch.Tensor, quant_state: tuple | None = None ) -> torch.nn.Parameter: bnb = _import_bitsandbytes() @@ -223,10 +224,10 @@ class _Linear8bitLt(bnb.nn.Linear8bitLt): """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and re-quantizaton when loading the state dict.""" - def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: float = 6.0, **kwargs: Any) -> None: + def __init__(self, *args: Any, device: _DEVICE | None = None, threshold: float = 6.0, **kwargs: Any) -> None: super().__init__(*args, device=device, threshold=threshold, **kwargs) self.weight = cast(bnb.nn.Int8Params, self.weight) # type: ignore[has-type] - self.bias: Optional[torch.nn.Parameter] = self.bias + self.bias: torch.nn.Parameter | None = self.bias # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up # filling the device memory with float32 weights which could lead to OOM if torch.tensor(0, device=device).device.type == "cuda": @@ -234,7 +235,7 @@ def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: floa self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_)) self.register_load_state_dict_post_hook(_ignore_missing_weights_hook) - def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None: + def quantize_(self, weight: torch.Tensor | None = None, device: torch.device | None = None) -> None: """Inplace quantize.""" if weight is None: weight = self.weight.data @@ -246,7 +247,7 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc @staticmethod def quantize( - int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device] + int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: torch.device | None ) -> bnb.nn.Int8Params: device = device or torch.device("cuda") if device.type != "cuda": @@ -310,10 +311,10 @@ class _Linear4bit(bnb.nn.Linear4bit): """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the state dict, meta-device initialization, and materialization.""" - def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) -> None: + def __init__(self, *args: Any, device: _DEVICE | None = None, **kwargs: Any) -> None: super().__init__(*args, device=device, **kwargs) self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type] - self.bias: Optional[torch.nn.Parameter] = self.bias + self.bias: torch.nn.Parameter | None = self.bias # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up # filling the device memory with float32 weights which could lead to OOM if torch.tensor(0, device=device).device.type == "cuda": @@ -321,7 +322,7 @@ def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_)) self.register_load_state_dict_post_hook(_ignore_missing_weights_hook) - def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None: + def quantize_(self, weight: torch.Tensor | None = None, device: torch.device | None = None) -> None: """Inplace quantize.""" if weight is None: weight = self.weight.data @@ -334,7 +335,7 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc @staticmethod def quantize( - params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device] + params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: torch.device | None ) -> bnb.nn.Params4bit: device = device or torch.device("cuda") if device.type != "cuda": diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 526095008f376..b258e882e3afc 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, get_args import torch from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module -from typing_extensions import get_args, override +from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 270a67e3a2338..c3af1ba1e952c 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, get_args import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from typing_extensions import get_args, override +from typing_extensions import override from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.plugins.precision.precision import Precision @@ -129,7 +129,7 @@ def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) @override - def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: + def backward(self, tensor: Tensor, model: Module | None, *args: Any, **kwargs: Any) -> None: if self.scaler is not None: tensor = self.scaler.scale(tensor) super().backward(tensor, model, *args, **kwargs) diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index 1dfab2a7bc649..6bb2b233d650a 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import AbstractContextManager, nullcontext -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from torch import Tensor from torch.nn import Module @@ -33,7 +33,7 @@ "32-true", "64-true", ] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS] +_PRECISION_INPUT = _PRECISION_INPUT_INT | _PRECISION_INPUT_STR | _PRECISION_INPUT_STR_ALIAS class Precision: @@ -87,7 +87,7 @@ def convert_output(self, data: Any) -> Any: """ return data - def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any: + def pre_backward(self, tensor: Tensor, module: Module | None) -> Any: """Runs before precision plugin executes backward. Args: @@ -96,7 +96,7 @@ def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any: """ - def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: + def backward(self, tensor: Tensor, model: Module | None, *args: Any, **kwargs: Any) -> None: """Performs the actual backpropagation. Args: @@ -106,7 +106,7 @@ def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs """ tensor.backward(*args, **kwargs) - def post_backward(self, tensor: Tensor, module: Optional[Module]) -> Any: + def post_backward(self, tensor: Tensor, module: Module | None) -> Any: """Runs after precision plugin executes backward. Args: diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index bf1e51ea6b2b0..6acefe3eeea7d 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -14,7 +14,7 @@ import logging from collections.abc import Mapping from contextlib import AbstractContextManager, ExitStack -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Union import torch from lightning_utilities import apply_to_collection @@ -68,9 +68,9 @@ def __init__( self, *, weights_dtype: torch.dtype, - recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None, - replace_layers: Optional[bool] = None, - fallback_compute_dtype: Optional[torch.dtype] = None, + recipe: Union[Mapping[str, Any], "DelayedScaling"] | None = None, + replace_layers: bool | None = None, + fallback_compute_dtype: torch.dtype | None = None, ) -> None: if not _TRANSFORMER_ENGINE_AVAILABLE: raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE)) diff --git a/src/lightning/fabric/plugins/precision/utils.py b/src/lightning/fabric/plugins/precision/utils.py index 8362384cb1042..a511d82a5c3a9 100644 --- a/src/lightning/fabric/plugins/precision/utils.py +++ b/src/lightning/fabric/plugins/precision/utils.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Mapping -from typing import Any, Union +from typing import Any import torch from torch import Tensor -def _convert_fp_tensor(tensor: Tensor, dst_type: Union[str, torch.dtype]) -> Tensor: +def _convert_fp_tensor(tensor: Tensor, dst_type: str | torch.dtype) -> Tensor: return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor diff --git a/src/lightning/fabric/plugins/precision/xla.py b/src/lightning/fabric/plugins/precision/xla.py index fdb30032b3cdd..0e1d7a82287e0 100644 --- a/src/lightning/fabric/plugins/precision/xla.py +++ b/src/lightning/fabric/plugins/precision/xla.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Literal +from typing import Any, Literal, get_args import torch -from typing_extensions import get_args, override +from typing_extensions import override from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins.precision.precision import Precision diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index af182ad7f422f..ca8b7b8865de7 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import AbstractContextManager, nullcontext from datetime import timedelta -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import torch import torch.distributed @@ -55,13 +55,13 @@ class DDPStrategy(ParallelStrategy): def __init__( self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, + accelerator: Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + checkpoint_io: CheckpointIO | None = None, + precision: Precision | None = None, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, start_method: Literal["popen", "spawn", "fork", "forkserver"] = "popen", **kwargs: Any, ) -> None: @@ -73,8 +73,8 @@ def __init__( precision=precision, ) self._num_nodes = 1 - self._process_group_backend: Optional[str] = process_group_backend - self._timeout: Optional[timedelta] = timeout + self._process_group_backend: str | None = process_group_backend + self._timeout: timedelta | None = timeout self._start_method = start_method self._backward_sync_control = _DDPBackwardSyncControl() self._ddp_kwargs = kwargs @@ -104,7 +104,7 @@ def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property - def process_group_backend(self) -> Optional[str]: + def process_group_backend(self) -> str | None: return self._process_group_backend @override @@ -134,9 +134,7 @@ def module_to_device(self, module: Module) -> None: module.to(self.root_device) @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: + def all_reduce(self, tensor: Tensor, group: Any | None = None, reduce_op: ReduceOp | str | None = "mean") -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -182,15 +180,13 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj[0] @override - def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Any | Tensor]: if isinstance(module, DistributedDataParallel): module = module.module return super().get_module_state_dict(module) @override - def load_module_state_dict( - self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True - ) -> None: + def load_module_state_dict(self, module: Module, state_dict: dict[str, Any | Tensor], strict: bool = True) -> None: if isinstance(module, DistributedDataParallel): module = module.module super().load_module_state_dict(module=module, state_dict=state_dict, strict=strict) @@ -239,7 +235,7 @@ def _set_world_ranks(self) -> None: # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank - def _determine_ddp_device_ids(self) -> Optional[list[int]]: + def _determine_ddp_device_ids(self) -> list[int] | None: return None if self.root_device.type == "cpu" else [self.root_device.index] diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 883546fea1f2d..35efa5ec6109d 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -16,12 +16,12 @@ import logging import os import platform -from collections.abc import Mapping +from collections.abc import Callable, Mapping from contextlib import AbstractContextManager, ExitStack from datetime import timedelta from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import torch from lightning_utilities.core.imports import RequirementCache @@ -57,10 +57,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): def __init__( self, - accelerator: Optional[Accelerator] = None, + accelerator: Accelerator | None = None, zero_optimization: bool = True, stage: int = 2, - remote_device: Optional[str] = None, + remote_device: str | None = None, offload_optimizer: bool = False, offload_parameters: bool = False, offload_params_device: str = "cpu", @@ -84,11 +84,11 @@ def __init__( allgather_bucket_size: int = 200_000_000, reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, - logging_batch_size_per_gpu: Optional[int] = None, - config: Optional[Union[_PATH, dict[str, Any]]] = None, + logging_batch_size_per_gpu: int | None = None, + config: _PATH | dict[str, Any] | None = None, logging_level: int = logging.WARN, - parallel_devices: Optional[list[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, @@ -99,9 +99,9 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, - precision: Optional[Precision] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, + precision: Precision | None = None, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, exclude_frozen_parameters: bool = False, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large @@ -262,7 +262,7 @@ def __init__( process_group_backend=process_group_backend, ) self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally - self._timeout: Optional[timedelta] = timeout + self._timeout: timedelta | None = timeout self.config = self._load_config(config) if self.config is None: @@ -316,7 +316,7 @@ def __init__( self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale - self._deepspeed_engine: Optional[DeepSpeedEngine] = None + self._deepspeed_engine: DeepSpeedEngine | None = None @property def zero_stage_3(self) -> bool: @@ -374,7 +374,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: raise NotImplementedError(self._err_msg_joint_setup_required()) @override - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: bool | None = None) -> AbstractContextManager: if self.zero_stage_3 and empty_init is False: raise NotImplementedError( f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." @@ -404,9 +404,9 @@ def module_sharded_context(self) -> AbstractContextManager: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + state: dict[str, Module | Optimizer | Any], + storage_options: Any | None = None, + filter: dict[str, Callable[[str, Any], bool]] | None = None, ) -> None: """Save model, optimizer, and other state in a checkpoint directory. @@ -471,9 +471,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Module | Optimizer | dict[str, Module | Optimizer | Any] | None = None, strict: bool = True, - weights_only: Optional[bool] = None, + weights_only: bool | None = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -554,8 +554,8 @@ def clip_gradients_norm( self, module: "DeepSpeedEngine", optimizer: Optimizer, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, + max_norm: float | int, + norm_type: float | int = 2.0, error_if_nonfinite: bool = True, ) -> torch.Tensor: raise NotImplementedError( @@ -564,9 +564,7 @@ def clip_gradients_norm( ) @override - def clip_gradients_value( - self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int] - ) -> None: + def clip_gradients_value(self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: float | int) -> None: raise NotImplementedError( "DeepSpeed handles gradient clipping automatically within the optimizer. " "Make sure to set the `gradient_clipping` value in your Config." @@ -614,7 +612,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: ) def _initialize_engine( - self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None + self, model: Module, optimizer: Optimizer | None = None, scheduler: Optional["_LRScheduler"] = None ) -> tuple["DeepSpeedEngine", Optimizer, Any]: """Initialize one model and one optimizer with an optional learning rate scheduler. @@ -716,7 +714,7 @@ def _create_default_config( self, zero_optimization: bool, zero_allow_untested_optimizer: bool, - logging_batch_size_per_gpu: Optional[int], + logging_batch_size_per_gpu: int | None, partition_activations: bool, cpu_checkpointing: bool, contiguous_memory_optimization: bool, @@ -825,7 +823,7 @@ def load(module: torch.nn.Module, prefix: str = "") -> None: load(module, prefix="") - def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: + def _load_config(self, config: _PATH | dict[str, Any] | None) -> dict[str, Any] | None: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] diff --git a/src/lightning/fabric/strategies/dp.py b/src/lightning/fabric/strategies/dp.py index f407040649c54..e8f8d4e656781 100644 --- a/src/lightning/fabric/strategies/dp.py +++ b/src/lightning/fabric/strategies/dp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any import torch from torch import Tensor @@ -34,10 +34,10 @@ class DataParallelStrategy(ParallelStrategy): def __init__( self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, + accelerator: Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + checkpoint_io: CheckpointIO | None = None, + precision: Precision | None = None, ): super().__init__( accelerator=accelerator, @@ -68,13 +68,13 @@ def module_to_device(self, module: Module) -> None: module.to(self.root_device) @override - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: + def batch_to_device(self, batch: Any, device: torch.device | None = None) -> Any: # DataParallel handles the transfer of batch to the device return batch @override def all_reduce( - self, collection: TReduce, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + self, collection: TReduce, group: Any | None = None, reduce_op: ReduceOp | str | None = "mean" ) -> TReduce: def mean(t: Tensor) -> Tensor: original_dtype = t.dtype @@ -95,15 +95,13 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: return decision @override - def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Any | Tensor]: if isinstance(module, DataParallel): module = module.module return super().get_module_state_dict(module) @override - def load_module_state_dict( - self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True - ) -> None: + def load_module_state_dict(self, module: Module, state_dict: dict[str, Any | Tensor], strict: bool = True) -> None: if isinstance(module, DataParallel): module = module.module super().load_module_state_dict(module=module, state_dict=state_dict, strict=strict) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index f42ade7484395..f8c5d158c5d3b 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import shutil import warnings -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import AbstractContextManager, ExitStack, nullcontext from datetime import timedelta from functools import partial @@ -21,9 +21,9 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Literal, Optional, + TypeGuard, Union, ) @@ -33,7 +33,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from typing_extensions import TypeGuard, override +from typing_extensions import override from lightning.fabric.accelerators import Accelerator from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, Precision @@ -73,8 +73,8 @@ from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.optim.lr_scheduler import _LRScheduler - _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] - _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] + _POLICY = set[type[Module]] | Callable[[Module, bool, int], bool] | ModuleWrapPolicy + _SHARDING_STRATEGY = ShardingStrategy | Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"] _FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload") @@ -137,20 +137,20 @@ class FSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - precision: Optional[Precision] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, + accelerator: Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + precision: Precision | None = None, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, cpu_offload: Union[bool, "CPUOffload", None] = None, mixed_precision: Optional["MixedPrecision"] = None, auto_wrap_policy: Optional["_POLICY"] = None, - activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None, + activation_checkpointing: type[Module] | list[type[Module]] | None = None, activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "sharded", - device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, + device_mesh: Union[tuple[int], "DeviceMesh"] | None = None, **kwargs: Any, ) -> None: super().__init__( @@ -160,8 +160,8 @@ def __init__( precision=precision, ) self._num_nodes = 1 - self._process_group_backend: Optional[str] = process_group_backend - self._timeout: Optional[timedelta] = timeout + self._process_group_backend: str | None = process_group_backend + self._timeout: timedelta | None = timeout self._backward_sync_control = _FSDPBackwardSyncControl() self._fsdp_kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs) @@ -215,7 +215,7 @@ def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property - def process_group_backend(self) -> Optional[str]: + def process_group_backend(self) -> str | None: return self._process_group_backend @property @@ -238,7 +238,7 @@ def precision(self) -> FSDPPrecision: @precision.setter @override - def precision(self, precision: Optional[Precision]) -> None: + def precision(self, precision: Precision | None) -> None: if precision is not None and not isinstance(precision, FSDPPrecision): raise TypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision}") self._precision = precision @@ -335,7 +335,7 @@ def module_to_device(self, module: Module) -> None: pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: bool | None = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -363,9 +363,7 @@ def module_sharded_context(self) -> AbstractContextManager: ) @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: + def all_reduce(self, tensor: Tensor, group: Any | None = None, reduce_op: ReduceOp | str | None = "mean") -> Tensor: if isinstance(tensor, Tensor): return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor @@ -393,8 +391,8 @@ def clip_gradients_norm( self, module: Module, optimizer: Optimizer, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, + max_norm: float | int, + norm_type: float | int = 2.0, error_if_nonfinite: bool = True, ) -> Tensor: """Clip gradients by norm.""" @@ -414,9 +412,9 @@ def clip_gradients_norm( def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + state: dict[str, Module | Optimizer | Any], + storage_options: Any | None = None, + filter: dict[str, Callable[[str, Any], bool]] | None = None, ) -> None: """Save model, optimizer, and other state to a checkpoint on disk. @@ -514,9 +512,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Module | Optimizer | dict[str, Module | Optimizer | Any] | None = None, strict: bool = True, - weights_only: Optional[bool] = None, + weights_only: bool | None = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: @@ -682,7 +680,7 @@ def _set_world_ranks(self) -> None: def _activation_checkpointing_kwargs( - activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]], + activation_checkpointing: type[Module] | list[type[Module]] | None, activation_checkpointing_policy: Optional["_POLICY"], ) -> dict: if activation_checkpointing is None and activation_checkpointing_policy is None: @@ -761,7 +759,7 @@ def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextMana return module.no_sync() -def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload": +def _init_cpu_offload(cpu_offload: Union[bool, "CPUOffload"] | None) -> "CPUOffload": from torch.distributed.fsdp import CPUOffload return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload)) diff --git a/src/lightning/fabric/strategies/launchers/launcher.py b/src/lightning/fabric/strategies/launchers/launcher.py index c22a14633eb76..0b266a55cb9b8 100644 --- a/src/lightning/fabric/strategies/launchers/launcher.py +++ b/src/lightning/fabric/strategies/launchers/launcher.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable +from collections.abc import Callable +from typing import Any class _Launcher(ABC): diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 3b3e180e63f41..a29e6bd3113e5 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -13,10 +13,11 @@ # limitations under the License. import itertools import os +from collections.abc import Callable from dataclasses import dataclass from multiprocessing.queues import SimpleQueue from textwrap import dedent -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import torch import torch.backends.cudnn diff --git a/src/lightning/fabric/strategies/launchers/subprocess_script.py b/src/lightning/fabric/strategies/launchers/subprocess_script.py index 8a78eb3c7dfbf..31121fd471807 100644 --- a/src/lightning/fabric/strategies/launchers/subprocess_script.py +++ b/src/lightning/fabric/strategies/launchers/subprocess_script.py @@ -18,8 +18,8 @@ import sys import threading import time -from collections.abc import Sequence -from typing import Any, Callable, Optional +from collections.abc import Callable, Sequence +from typing import Any from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -131,7 +131,7 @@ def _call_children_scripts(self) -> None: # start process # if hydra is available and initialized, make sure to set the cwd correctly hydra_in_use = False - cwd: Optional[str] = None + cwd: str | None = None if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig diff --git a/src/lightning/fabric/strategies/launchers/xla.py b/src/lightning/fabric/strategies/launchers/xla.py index 639de55805646..e6d0c6994cc27 100644 --- a/src/lightning/fabric/strategies/launchers/xla.py +++ b/src/lightning/fabric/strategies/launchers/xla.py @@ -13,9 +13,11 @@ # limitations under the License. import queue import time -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Union import torch.multiprocessing as mp +from torch.multiprocessing.queue import SimpleQueue from typing_extensions import override from lightning.fabric.accelerators.xla import _XLA_AVAILABLE @@ -68,7 +70,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: **kwargs: Optional keyword arguments to be passed to the given function. """ - return_queue: Union[queue.Queue, mp.SimpleQueue] + return_queue: queue.Queue | SimpleQueue return_queue = mp.Manager().Queue() import torch_xla.distributed.xla_multiprocessing as xmp @@ -96,8 +98,8 @@ def _wrapping_function( function: Callable, args: Any, kwargs: Any, - return_queue: Union[mp.SimpleQueue, queue.Queue], - global_states: Optional[_GlobalStateSnapshot] = None, + return_queue: queue.Queue | SimpleQueue, + global_states: _GlobalStateSnapshot | None = None, ) -> None: import torch_xla.core.xla_model as xm diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 677584668975e..2c275843bae9d 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -13,18 +13,18 @@ # limitations under the License. import itertools import shutil -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import AbstractContextManager, ExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Literal, TypeGuard, TypeVar import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from typing_extensions import TypeGuard, override +from typing_extensions import override from lightning.fabric.plugins import CheckpointIO from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout @@ -89,11 +89,11 @@ class ModelParallelStrategy(ParallelStrategy): def __init__( self, parallelize_fn: Callable[[TModel, "DeviceMesh"], TModel], - data_parallel_size: Union[Literal["auto"], int] = "auto", - tensor_parallel_size: Union[Literal["auto"], int] = "auto", + data_parallel_size: Literal["auto"] | int = "auto", + tensor_parallel_size: Literal["auto"] | int = "auto", save_distributed_checkpoint: bool = True, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, ) -> None: super().__init__() if not _TORCH_GREATER_EQUAL_2_4: @@ -103,11 +103,11 @@ def __init__( self._tensor_parallel_size = tensor_parallel_size self._num_nodes = 1 self._save_distributed_checkpoint = save_distributed_checkpoint - self._process_group_backend: Optional[str] = process_group_backend - self._timeout: Optional[timedelta] = timeout + self._process_group_backend: str | None = process_group_backend + self._timeout: timedelta | None = timeout self._backward_sync_control = _ParallelBackwardSyncControl() - self._device_mesh: Optional[DeviceMesh] = None + self._device_mesh: DeviceMesh | None = None @property def device_mesh(self) -> "DeviceMesh": @@ -151,7 +151,7 @@ def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} @property - def process_group_backend(self) -> Optional[str]: + def process_group_backend(self) -> str | None: return self._process_group_backend @override @@ -195,7 +195,7 @@ def module_to_device(self, module: Module) -> None: pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: bool | None = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() stack = ExitStack() if empty_init: @@ -206,9 +206,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont return stack @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: + def all_reduce(self, tensor: Tensor, group: Any | None = None, reduce_op: ReduceOp | str | None = "mean") -> Tensor: if isinstance(tensor, Tensor): return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor @@ -235,9 +233,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + state: dict[str, Module | Optimizer | Any], + storage_options: Any | None = None, + filter: dict[str, Callable[[str, Any], bool]] | None = None, ) -> None: """Save model, optimizer, and other state to a checkpoint on disk. @@ -273,9 +271,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Module | Optimizer | dict[str, Module | Optimizer | Any] | None = None, strict: bool = True, - weights_only: Optional[bool] = None, + weights_only: bool | None = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: @@ -348,10 +346,10 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def _save_checkpoint( path: Path, - state: dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Module | Optimizer | Any], full_state_dict: bool, rank: int, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + filter: dict[str, Callable[[str, Any], bool]] | None = None, ) -> None: if path.is_dir() and full_state_dict and not _is_sharded_checkpoint(path): raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") @@ -409,10 +407,10 @@ def _save_checkpoint( def _load_checkpoint( path: Path, - state: dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Module | Optimizer | Any], strict: bool = True, optimizer_states_from_list: bool = False, - weights_only: Optional[bool] = None, + weights_only: bool | None = None, ) -> dict[str, Any]: from torch.distributed.checkpoint.state_dict import ( StateDictOptions, diff --git a/src/lightning/fabric/strategies/parallel.py b/src/lightning/fabric/strategies/parallel.py index 327cfc016d4ef..4ad0aecf56f79 100644 --- a/src/lightning/fabric/strategies/parallel.py +++ b/src/lightning/fabric/strategies/parallel.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any, Optional +from typing import Any import torch from torch import Tensor @@ -32,15 +32,15 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, + accelerator: Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + checkpoint_io: CheckpointIO | None = None, + precision: Precision | None = None, ): super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) self.parallel_devices = parallel_devices - self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment + self.cluster_environment: ClusterEnvironment | None = cluster_environment @property def global_rank(self) -> int: @@ -64,15 +64,15 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[list[torch.device]]: + def parallel_devices(self) -> list[torch.device] | None: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[list[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: list[torch.device] | None) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: + def distributed_sampler_kwargs(self) -> dict[str, Any] | None: """Arguments for the ``DistributedSampler``. If this method is not defined, or it returns ``None``, then the ``DistributedSampler`` will not be used. @@ -81,7 +81,7 @@ def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: return {"num_replicas": self.world_size, "rank": self.global_rank} @override - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Perform a all_gather on all processes.""" return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) diff --git a/src/lightning/fabric/strategies/registry.py b/src/lightning/fabric/strategies/registry.py index d2376463c3111..afdfe078528a7 100644 --- a/src/lightning/fabric/strategies/registry.py +++ b/src/lightning/fabric/strategies/registry.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from typing_extensions import override @@ -44,8 +45,8 @@ def __init__(self, a, b): def register( self, name: str, - strategy: Optional[Callable] = None, - description: Optional[str] = None, + strategy: Callable | None = None, + description: str | None = None, override: bool = False, **init_params: Any, ) -> Callable: @@ -82,7 +83,7 @@ def do_register(strategy: Callable) -> Callable: return do_register @override - def get(self, name: str, default: Optional[Any] = None) -> Any: + def get(self, name: str, default: Any | None = None) -> Any: """Calls the registered strategy with the required parameters and returns the strategy object. Args: diff --git a/src/lightning/fabric/strategies/single_xla.py b/src/lightning/fabric/strategies/single_xla.py index ba2fce91f1146..317de0068bd64 100644 --- a/src/lightning/fabric/strategies/single_xla.py +++ b/src/lightning/fabric/strategies/single_xla.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import torch from typing_extensions import override @@ -31,9 +30,9 @@ class SingleDeviceXLAStrategy(SingleDeviceStrategy): def __init__( self, device: _DEVICE, - accelerator: Optional[Accelerator] = None, - checkpoint_io: Optional[XLACheckpointIO] = None, - precision: Optional[XLAPrecision] = None, + accelerator: Accelerator | None = None, + checkpoint_io: XLACheckpointIO | None = None, + precision: XLAPrecision | None = None, ): if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) @@ -61,7 +60,7 @@ def checkpoint_io(self) -> XLACheckpointIO: @checkpoint_io.setter @override - def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + def checkpoint_io(self, io: CheckpointIO | None) -> None: if io is not None and not isinstance(io, XLACheckpointIO): raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}") self._checkpoint_io = io @@ -77,7 +76,7 @@ def precision(self) -> XLAPrecision: @precision.setter @override - def precision(self, precision: Optional[Precision]) -> None: + def precision(self, precision: Precision | None) -> None: if precision is not None and not isinstance(precision, XLAPrecision): raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}") self._precision = precision diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index b368f626c3b11..faf9bb6575d0b 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -13,9 +13,9 @@ # limitations under the License. import logging from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Callable, Iterable from contextlib import AbstractContextManager, ExitStack -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar import torch from torch import Tensor @@ -47,17 +47,17 @@ class Strategy(ABC): def __init__( self, - accelerator: Optional[Accelerator] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, + accelerator: Accelerator | None = None, + checkpoint_io: CheckpointIO | None = None, + precision: Precision | None = None, ) -> None: - self._accelerator: Optional[Accelerator] = accelerator - self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io - self._precision: Optional[Precision] = None + self._accelerator: Accelerator | None = accelerator + self._checkpoint_io: CheckpointIO | None = checkpoint_io + self._precision: Precision | None = None # Call the precision setter for input validation self.precision = precision - self._launcher: Optional[_Launcher] = None - self._backward_sync_control: Optional[_BackwardSyncControl] = None + self._launcher: _Launcher | None = None + self._backward_sync_control: _BackwardSyncControl | None = None @property @abstractmethod @@ -70,11 +70,11 @@ def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" @property - def launcher(self) -> Optional[_Launcher]: + def launcher(self) -> _Launcher | None: return self._launcher @property - def accelerator(self) -> Optional[Accelerator]: + def accelerator(self) -> Accelerator | None: return self._accelerator @accelerator.setter @@ -96,7 +96,7 @@ def precision(self) -> Precision: return self._precision if self._precision is not None else Precision() @precision.setter - def precision(self, precision: Optional[Precision]) -> None: + def precision(self, precision: Precision | None) -> None: self._precision = precision def _configure_launcher(self) -> None: @@ -129,7 +129,7 @@ def tensor_init_context(self) -> AbstractContextManager: stack.enter_context(precision_init_ctx) return stack - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: bool | None = None) -> AbstractContextManager: """A context manager wrapping the model instantiation. Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other @@ -172,7 +172,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: def module_to_device(self, module: Module) -> None: """Moves the model to the correct device.""" - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: + def batch_to_device(self, batch: Any, device: torch.device | None = None) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just @@ -186,7 +186,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> device = device or self.root_device return move_data_to_device(batch, device) - def backward(self, tensor: Tensor, module: Optional[Module], *args: Any, **kwargs: Any) -> None: + def backward(self, tensor: Tensor, module: Module | None, *args: Any, **kwargs: Any) -> None: r"""Forwards backward-calls to the precision plugin.""" self.precision.pre_backward(tensor, module) self.precision.backward(tensor, module, *args, **kwargs) @@ -207,7 +207,7 @@ def optimizer_step( return self.precision.optimizer_step(optimizer, **kwargs) @abstractmethod - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Perform an all_gather on all processes. Args: @@ -220,10 +220,10 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo @abstractmethod def all_reduce( self, - tensor: Union[Tensor, Any], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ) -> Union[Tensor, Any]: + tensor: Tensor | Any, + group: Any | None = None, + reduce_op: ReduceOp | str | None = "mean", + ) -> Tensor | Any: """Reduces the given tensor (e.g. across GPUs/processes). Args: @@ -235,7 +235,7 @@ def all_reduce( """ @abstractmethod - def barrier(self, name: Optional[str] = None) -> None: + def barrier(self, name: str | None = None) -> None: """Synchronizes all processes which blocks processes until the whole group enters this function. Args: @@ -260,9 +260,9 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + state: dict[str, Module | Optimizer | Any], + storage_options: Any | None = None, + filter: dict[str, Callable[[str, Any], bool]] | None = None, ) -> None: """Save model, optimizer, and other state as a checkpoint file. @@ -280,13 +280,11 @@ def save_checkpoint( if self.is_global_zero: self.checkpoint_io.save_checkpoint(checkpoint=state, path=path, storage_options=storage_options) - def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Any | Tensor]: """Returns model state.""" return module.state_dict() - def load_module_state_dict( - self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True - ) -> None: + def load_module_state_dict(self, module: Module, state_dict: dict[str, Any | Tensor], strict: bool = True) -> None: """Loads the given state into the model.""" module.load_state_dict(state_dict, strict=strict) @@ -308,9 +306,9 @@ def get_optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Module | Optimizer | dict[str, Module | Optimizer | Any] | None = None, strict: bool = True, - weights_only: Optional[bool] = None, + weights_only: bool | None = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -371,8 +369,8 @@ def clip_gradients_norm( self, module: torch.nn.Module, optimizer: Optimizer, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, + max_norm: float | int, + norm_type: float | int = 2.0, error_if_nonfinite: bool = True, ) -> torch.Tensor: """Clip gradients by norm.""" @@ -382,7 +380,7 @@ def clip_gradients_norm( parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite ) - def clip_gradients_value(self, module: torch.nn.Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None: + def clip_gradients_value(self, module: torch.nn.Module, optimizer: Optimizer, clip_val: float | int) -> None: """Clip gradients by value.""" self.precision.unscale_gradients(optimizer) parameters = self.precision.main_params(optimizer) @@ -399,7 +397,7 @@ def _err_msg_joint_setup_required(self) -> str: ) def _convert_stateful_objects_in_state( - self, state: dict[str, Union[Module, Optimizer, Any]], filter: dict[str, Callable[[str, Any], bool]] + self, state: dict[str, Module | Optimizer | Any], filter: dict[str, Callable[[str, Any], bool]] ) -> dict[str, Any]: converted_state: dict[str, Any] = {} for key, obj in state.items(): diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 3a571fef37f00..3737416ec878c 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import torch from torch import Tensor @@ -42,10 +43,10 @@ class XLAStrategy(ParallelStrategy): def __init__( self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, - checkpoint_io: Optional[XLACheckpointIO] = None, - precision: Optional[XLAPrecision] = None, + accelerator: Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + checkpoint_io: XLACheckpointIO | None = None, + precision: XLAPrecision | None = None, sync_module_states: bool = True, ) -> None: super().__init__( @@ -83,7 +84,7 @@ def checkpoint_io(self) -> XLACheckpointIO: @checkpoint_io.setter @override - def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + def checkpoint_io(self, io: CheckpointIO | None) -> None: if io is not None and not isinstance(io, XLACheckpointIO): raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}") self._checkpoint_io = io @@ -99,7 +100,7 @@ def precision(self) -> XLAPrecision: @precision.setter @override - def precision(self, precision: Optional[Precision]) -> None: + def precision(self, precision: Precision | None) -> None: if precision is not None and not isinstance(precision, XLAPrecision): raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}") self._precision = precision @@ -174,7 +175,7 @@ def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": return dataloader @override - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. Args: @@ -205,7 +206,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo @override def all_reduce( - self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + self, output: Tensor | Any, group: Any | None = None, reduce_op: ReduceOp | str | None = None ) -> Tensor: if not isinstance(output, Tensor): output = torch.tensor(output, device=self.root_device) @@ -227,7 +228,7 @@ def all_reduce( return output @override - def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + def barrier(self, name: str | None = None, *args: Any, **kwargs: Any) -> None: if not self._launched: return import torch_xla.core.xla_model as xm @@ -276,9 +277,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + state: dict[str, Module | Optimizer | Any], + storage_options: Any | None = None, + filter: dict[str, Callable[[str, Any], bool]] | None = None, ) -> None: """Save model, optimizer, and other state as a checkpoint file. diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 51b528eff26ff..e6c57987a7542 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +from collections.abc import Callable from contextlib import AbstractContextManager, ExitStack, nullcontext from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional import torch from torch import Tensor @@ -48,7 +49,7 @@ from torch_xla.distributed.parallel_loader import MpDeviceLoader _POLICY_SET = set[type[Module]] -_POLICY = Union[_POLICY_SET, Callable[[Module, bool, int], bool]] +_POLICY = _POLICY_SET | Callable[[Module, bool, int], bool] class XLAFSDPStrategy(ParallelStrategy, _Sharded): @@ -83,12 +84,12 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, - checkpoint_io: Optional[XLACheckpointIO] = None, - precision: Optional[XLAPrecision] = None, - auto_wrap_policy: Optional[_POLICY] = None, - activation_checkpointing_policy: Optional[_POLICY_SET] = None, + accelerator: Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + checkpoint_io: XLACheckpointIO | None = None, + precision: XLAPrecision | None = None, + auto_wrap_policy: _POLICY | None = None, + activation_checkpointing_policy: _POLICY_SET | None = None, state_dict_type: Literal["full", "sharded"] = "sharded", sequential_save: bool = False, **kwargs: Any, @@ -135,7 +136,7 @@ def checkpoint_io(self) -> XLACheckpointIO: @checkpoint_io.setter @override - def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + def checkpoint_io(self, io: CheckpointIO | None) -> None: if io is not None and not isinstance(io, XLACheckpointIO): raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}") self._checkpoint_io = io @@ -151,7 +152,7 @@ def precision(self) -> XLAPrecision: @precision.setter @override - def precision(self, precision: Optional[Precision]) -> None: + def precision(self, precision: Precision | None) -> None: if precision is not None and not isinstance(precision, XLAPrecision): raise TypeError(f"The XLA FSDP strategy can only work with the `XLAPrecision` plugin, found {precision}") self._precision = precision @@ -226,7 +227,7 @@ def setup_module(self, module: Module) -> Module: def module_to_device(self, module: Module) -> None: pass - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: bool | None = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -290,8 +291,8 @@ def clip_gradients_norm( self, module: Module, optimizer: Optimizer, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, + max_norm: float | int, + norm_type: float | int = 2.0, error_if_nonfinite: bool = True, ) -> Tensor: """Clip gradients by norm.""" @@ -300,7 +301,7 @@ def clip_gradients_norm( return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) @override - def clip_gradients_value(self, module: Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None: + def clip_gradients_value(self, module: Module, optimizer: Optimizer, clip_val: float | int) -> None: """Clip gradients by value.""" raise NotImplementedError( "XLA's FSDP strategy does not support to clip gradients by value." @@ -308,7 +309,7 @@ def clip_gradients_value(self, module: Module, optimizer: Optimizer, clip_val: U ) @override - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. Args: @@ -339,7 +340,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo @override def all_reduce( - self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + self, output: Tensor | Any, group: Any | None = None, reduce_op: ReduceOp | str | None = None ) -> Tensor: if not isinstance(output, Tensor): output = torch.tensor(output, device=self.root_device) @@ -361,7 +362,7 @@ def all_reduce( return output @override - def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + def barrier(self, name: str | None = None, *args: Any, **kwargs: Any) -> None: if not self._launched: return import torch_xla.core.xla_model as xm @@ -410,9 +411,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + state: dict[str, Module | Optimizer | Any], + storage_options: Any | None = None, + filter: dict[str, Callable[[str, Any], bool]] | None = None, ) -> None: """Save model, optimizer, and other state in the provided checkpoint directory. @@ -485,9 +486,9 @@ def save_checkpoint( def _save_checkpoint_shard( self, path: Path, - state: dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any], - filter: Optional[dict[str, Callable[[str, Any], bool]]], + state: dict[str, Module | Optimizer | Any], + storage_options: Any | None, + filter: dict[str, Callable[[str, Any], bool]] | None, ) -> None: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP @@ -514,9 +515,9 @@ def _save_checkpoint_shard( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Module | Optimizer | dict[str, Module | Optimizer | Any] | None = None, strict: bool = True, - weights_only: Optional[bool] = None, + weights_only: bool | None = None, ) -> dict[str, Any]: """Given a folder, load the contents from a checkpoint and restore the state of the given objects. @@ -652,7 +653,7 @@ def _activation_checkpointing_auto_wrapper(policy: _POLICY_SET, module: Module, return XLAFSDP(module, *args, **kwargs) -def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict) -> dict: +def _activation_checkpointing_kwargs(policy: _POLICY_SET | None, kwargs: dict) -> dict: if not policy: return kwargs if "auto_wrapper_callable" in kwargs: diff --git a/src/lightning/fabric/utilities/apply_func.py b/src/lightning/fabric/utilities/apply_func.py index 35693a5fcf1fb..04128c5fe1cda 100644 --- a/src/lightning/fabric/utilities/apply_func.py +++ b/src/lightning/fabric/utilities/apply_func.py @@ -14,8 +14,9 @@ """Utilities used for collections.""" from abc import ABC +from collections.abc import Callable from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import TYPE_CHECKING, Any import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -68,7 +69,7 @@ class _TransferableDataType(ABC): """ @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + def __subclasshook__(cls, subclass: Any) -> bool | Any: if cls is _TransferableDataType: to = getattr(subclass, "to", None) return callable(to) @@ -126,7 +127,7 @@ def convert_tensors_to_scalars(data: Any) -> Any: """ - def to_item(value: Tensor) -> Union[int, float, bool]: + def to_item(value: Tensor) -> int | float | bool: if value.numel() != 1: raise ValueError( f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar." diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 54b18fb6ce3b0..0ad0b7b78dcde 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -17,7 +17,7 @@ import io import logging from pathlib import Path -from typing import IO, Any, Optional, Union +from typing import IO, Any import fsspec import fsspec.utils @@ -32,9 +32,9 @@ def _load( - path_or_url: Union[IO, _PATH], + path_or_url: IO | _PATH, map_location: _MAP_LOCATION_TYPE = None, - weights_only: Optional[bool] = None, + weights_only: bool | None = None, ) -> Any: """Loads a checkpoint. @@ -131,7 +131,7 @@ def _is_object_storage(fs: AbstractFileSystem) -> bool: return False -def _is_dir(fs: AbstractFileSystem, path: Union[str, Path], strict: bool = False) -> bool: +def _is_dir(fs: AbstractFileSystem, path: str | Path, strict: bool = False) -> bool: """Check if a path is directory-like. This function determines if a given path is considered directory-like, taking into account the behavior diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index ea35d8c3da4a9..7ddd62dba49bf 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -16,14 +16,13 @@ import inspect import os from collections import OrderedDict -from collections.abc import Generator, Iterable, Sized +from collections.abc import Callable, Generator, Iterable, Sized from contextlib import contextmanager from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, TypeGuard from lightning_utilities.core.inheritance import get_all_subclasses from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler -from typing_extensions import TypeGuard from lightning.fabric.utilities.enums import LightningEnum from lightning.fabric.utilities.exceptions import MisconfigurationException @@ -36,7 +35,7 @@ class _WrapAttrTag(LightningEnum): DEL = "del" def __call__(self, *args: Any) -> None: - fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]] + fn: Callable[[object, str], None] | Callable[[object, str, Any], None] fn = setattr if self == self.SET else delattr return fn(*args) @@ -45,7 +44,7 @@ def has_iterable_dataset(dataloader: object) -> bool: return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) -def sized_len(dataloader: object) -> Optional[int]: +def sized_len(dataloader: object) -> int | None: """Try to get the length of an object, return ``None`` otherwise.""" try: # try getting the length @@ -72,14 +71,14 @@ def has_len(dataloader: object) -> TypeGuard[Sized]: return length is not None -def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]) -> DataLoader: +def _update_dataloader(dataloader: DataLoader, sampler: Sampler | Iterable) -> DataLoader: dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler) return _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, - sampler: Union[Sampler, Iterable], + sampler: Sampler | Iterable, ) -> tuple[tuple[Any], dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -172,7 +171,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, - sampler: Union[Sampler, Iterable], + sampler: Sampler | Iterable, ) -> dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation.""" @@ -250,7 +249,7 @@ def _auto_add_worker_init_fn(dataloader: object, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) -def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[type] = None, **kwargs: Any) -> Any: +def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: type | None = None, **kwargs: Any) -> Any: constructor = type(orig_object) if explicit_cls is None else explicit_cls try: @@ -281,7 +280,7 @@ def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optio return result -def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable: +def _wrap_init_method(init: Callable, store_explicit_arg: str | None = None) -> Callable: """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @@ -356,7 +355,7 @@ def wrapper(obj: Any, *args: Any) -> None: @contextmanager -def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: +def _replace_dunder_methods(base_cls: type, store_explicit_arg: str | None = None) -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py index 527ed90203e46..cd110d25132a1 100644 --- a/src/lightning/fabric/utilities/device_dtype_mixin.py +++ b/src/lightning/fabric/utilities/device_dtype_mixin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any import torch from torch.nn import Module @@ -26,17 +26,17 @@ class _DeviceDtypeModuleMixin(Module): def __init__(self) -> None: super().__init__() - self._dtype: Union[str, torch.dtype] = torch.get_default_dtype() + self._dtype: str | torch.dtype = torch.get_default_dtype() # Workarounds from the original pytorch issue: # https://github.com/pytorch/pytorch/issues/115333#issuecomment-1848449687 self._device = torch.get_default_device() if _TORCH_GREATER_EQUAL_2_3 else torch.empty(0).device @property - def dtype(self) -> Union[str, torch.dtype]: + def dtype(self) -> str | torch.dtype: return self._dtype @dtype.setter - def dtype(self, new_dtype: Union[str, torch.dtype]) -> None: + def dtype(self, new_dtype: str | torch.dtype) -> None: # necessary to avoid infinite recursion raise RuntimeError("Cannot set the dtype explicitly. Please use module.to(new_dtype).") @@ -59,7 +59,7 @@ def to(self, *args: Any, **kwargs: Any) -> Self: return super().to(*args, **kwargs) @override - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: + def cuda(self, device: torch.device | int | None = None) -> Self: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -86,7 +86,7 @@ def cpu(self) -> Self: return super().cpu() @override - def type(self, dst_type: Union[str, torch.dtype]) -> Self: + def type(self, dst_type: str | torch.dtype) -> Self: """See :meth:`torch.nn.Module.type`.""" _update_properties(self, dtype=dst_type) return super().type(dst_type=dst_type) @@ -111,7 +111,7 @@ def half(self) -> Self: def _update_properties( - root: torch.nn.Module, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None + root: torch.nn.Module, device: torch.device | None = None, dtype: str | torch.dtype | None = None ) -> None: for module in root.modules(): if not isinstance(module, _DeviceDtypeModuleMixin): diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 8bdacc0f523f5..86c820a3142fa 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import MutableSequence -from typing import Optional, Union import torch @@ -20,7 +19,7 @@ from lightning.fabric.utilities.types import _DEVICE -def _determine_root_gpu_device(gpus: list[_DEVICE]) -> Optional[_DEVICE]: +def _determine_root_gpu_device(gpus: list[_DEVICE]) -> _DEVICE | None: """ Args: gpus: Non-empty list of ints representing which GPUs to use @@ -47,10 +46,10 @@ def _determine_root_gpu_device(gpus: list[_DEVICE]) -> Optional[_DEVICE]: def _parse_gpu_ids( - gpus: Optional[Union[int, str, list[int]]], + gpus: int | str | list[int] | None, include_cuda: bool = False, include_mps: bool = False, -) -> Optional[list[int]]: +) -> list[int] | None: """Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`. Args: @@ -103,7 +102,7 @@ def _parse_gpu_ids( return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) -def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[int, list[int]]: +def _normalize_parse_gpu_string_input(s: int | str | list[int]) -> int | list[int]: if not isinstance(s, str): return s if s == "-1": @@ -140,8 +139,8 @@ def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, list[int], tuple[int, ...]], include_cuda: bool, include_mps: bool -) -> Optional[list[int]]: + gpus: int | list[int] | tuple[int, ...], include_cuda: bool, include_mps: bool +) -> list[int] | None: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): return list(gpus) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index a5f9f7457862e..8e9ec56d7753b 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -8,13 +8,13 @@ from contextlib import nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, TypeGuard import torch import torch.nn.functional as F from torch import Tensor from torch.utils.data import Dataset, DistributedSampler, Sampler -from typing_extensions import Self, TypeGuard, override +from typing_extensions import Self, override from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.data import _num_cpus_available @@ -40,7 +40,7 @@ class group: # type: ignore log = logging.getLogger(__name__) -def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, timeout: int = 3) -> bool: +def is_shared_filesystem(strategy: "Strategy", path: _PATH | None = None, timeout: int = 3) -> bool: """Checks whether the filesystem under the given path is shared across all processes. This function should only be used in a context where distributed is initialized. @@ -99,7 +99,7 @@ def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, tim return all_found -def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> list[Tensor]: +def _gather_all_tensors(result: Tensor, group: Any | None = None) -> list[Tensor]: """Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case @@ -159,9 +159,7 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> l return gathered_result -def _sync_ddp_if_available( - result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None -) -> Tensor: +def _sync_ddp_if_available(result: Tensor, group: Any | None = None, reduce_op: ReduceOp | str | None = None) -> Tensor: """Function to reduce a tensor across worker processes during distributed training. Args: @@ -179,7 +177,7 @@ def _sync_ddp_if_available( return result -def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor: +def _sync_ddp(result: Tensor, group: Any | None = None, reduce_op: ReduceOp | str | None = None) -> Tensor: """Reduces a tensor across several distributed processes. This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes. @@ -197,7 +195,7 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U divide_by_world_size = False group = torch.distributed.group.WORLD if group is None else group - op: Optional[ReduceOp] + op: ReduceOp | None if isinstance(reduce_op, str): reduce_op = "avg" if reduce_op == "mean" else reduce_op if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo": @@ -251,8 +249,8 @@ def _all_gather_ddp_if_available( def _init_dist_connection( cluster_environment: "ClusterEnvironment", torch_distributed_backend: str, - global_rank: Optional[int] = None, - world_size: Optional[int] = None, + global_rank: int | None = None, + world_size: int | None = None, **kwargs: Any, ) -> None: """Utility function to initialize distributed connection by setting env variables and initializing the distributed @@ -314,7 +312,7 @@ def _get_default_process_group_backend_for_device(device: torch.device) -> str: class _DatasetSamplerWrapper(Dataset): """Dataset to create indexes from `Sampler` or `Iterable`""" - def __init__(self, sampler: Union[Sampler, Iterable]) -> None: + def __init__(self, sampler: Sampler | Iterable) -> None: if not isinstance(sampler, Sized): raise TypeError( "You seem to have configured a sampler in your DataLoader which" @@ -335,7 +333,7 @@ def __init__(self, sampler: Union[Sampler, Iterable]) -> None: ) self._sampler = sampler # defer materializing an iterator until it is necessary - self._sampler_list: Optional[list[Any]] = None + self._sampler_list: list[Any] | None = None @override def __getitem__(self, index: int) -> Any: @@ -364,7 +362,7 @@ class DistributedSamplerWrapper(DistributedSampler): """ - def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None: + def __init__(self, sampler: Sampler | Iterable, *args: Any, **kwargs: Any) -> None: super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs) @override diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 4f8519eec9610..3d9c6651099b1 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any import torch from torch.nn import Module, Parameter @@ -47,7 +47,7 @@ def __torch_function__( func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[dict] = None, + kwargs: dict | None = None, ) -> Any: kwargs = kwargs or {} if not self.enabled: @@ -105,7 +105,7 @@ def _materialize_distributed_module(module: Module, device: torch.device) -> Non ) -def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool: +def _has_meta_device_parameters_or_buffers(obj: Module | Optimizer, recurse: bool = True) -> bool: if isinstance(obj, Optimizer): return any( t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 9e158c1677b6d..d7f901c5ea74a 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -14,11 +14,11 @@ import pickle import warnings from collections import OrderedDict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial from io import BytesIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Optional, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -102,7 +102,7 @@ def rebuild_tensor_v2( stride: tuple, requires_grad: bool, backward_hooks: OrderedDict, - metadata: Optional[Any] = None, + metadata: Any | None = None, *, archiveinfo: "_LazyLoadingUnpickler", ) -> "_NotYetLoadedTensor": @@ -136,7 +136,7 @@ def __torch_function__( func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[dict] = None, + kwargs: dict | None = None, ) -> Any: kwargs = kwargs or {} loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args] @@ -221,7 +221,7 @@ def _load_tensor(t: _NotYetLoadedTensor) -> Tensor: def _move_state_into( - source: dict[str, Any], destination: dict[str, Union[Any, _Stateful]], keys: Optional[set[str]] = None + source: dict[str, Any], destination: dict[str, Any | _Stateful], keys: set[str] | None = None ) -> None: """Takes the state from the source destination and moves it into the destination dictionary. diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index 04b9069dd0788..fb082bc59a5eb 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -17,14 +17,14 @@ from argparse import Namespace from collections.abc import Mapping, MutableMapping from dataclasses import asdict, is_dataclass -from typing import Any, Optional, Union +from typing import Any from torch import Tensor from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE -def _convert_params(params: Optional[Union[dict[str, Any], Namespace]]) -> dict[str, Any]: +def _convert_params(params: dict[str, Any] | Namespace | None) -> dict[str, Any]: """Ensure parameters are a dict or convert to dict if necessary. Args: @@ -164,9 +164,7 @@ def _is_json_serializable(value: Any) -> bool: return False -def _add_prefix( - metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str -) -> Mapping[str, Union[Tensor, float]]: +def _add_prefix(metrics: Mapping[str, Tensor | float], prefix: str, separator: str) -> Mapping[str, Tensor | float]: """Insert prefix before each key in a dict, separated by the separator. Args: diff --git a/src/lightning/fabric/utilities/rank_zero.py b/src/lightning/fabric/utilities/rank_zero.py index d34e19430b107..c600acd19bc1c 100644 --- a/src/lightning/fabric/utilities/rank_zero.py +++ b/src/lightning/fabric/utilities/rank_zero.py @@ -15,7 +15,6 @@ import logging import os -from typing import Optional import lightning_utilities.core.rank_zero as rank_zero_module @@ -32,7 +31,7 @@ rank_zero_module.log = logging.getLogger(__name__) -def _get_rank() -> Optional[int]: +def _get_rank() -> int | None: # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 7d8f6ca17712e..7d8896e1b24e5 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -15,7 +15,7 @@ from importlib.metadata import entry_points from inspect import getmembers, isclass from types import ModuleType -from typing import Any, Union +from typing import Any from lightning_utilities import is_overridden @@ -43,7 +43,7 @@ def _load_external_callbacks(group: str) -> list[Any]: external_callbacks: list[Any] = [] for factory in factories: callback_factory = factory.load() - callbacks_list: Union[list[Any], Any] = callback_factory() + callbacks_list: list[Any] | Any = callback_factory() callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list if callbacks_list: _log.info( diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index 841fa195696a2..1f53b00720f4d 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -3,7 +3,7 @@ import random from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Optional +from typing import Any import torch @@ -17,7 +17,7 @@ min_seed_value = 0 -def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: bool = True) -> int: +def seed_everything(seed: int | None = None, workers: bool = False, verbose: bool = True) -> int: r"""Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python's random module. In addition, sets the following environment variables: @@ -82,7 +82,7 @@ def reset_seed() -> None: seed_everything(int(seed), workers=bool(int(workers)), verbose=False) -def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover +def pl_worker_init_function(worker_id: int, rank: int | None = None) -> None: # pragma: no cover r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with ``seed_everything(seed, workers=True)``. diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index 9c1b0a2a00572..c3d4525f4a106 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -1,7 +1,7 @@ import json import os import warnings -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal import torch @@ -40,9 +40,9 @@ def __init__( mode: Literal["min", "max"] = "min", window: int = 10, warmup: int = 1, - atol: Optional[float] = None, - rtol: Optional[float] = 2.0, - exclude_batches_path: Optional[_PATH] = None, + atol: float | None = None, + rtol: float | None = 2.0, + exclude_batches_path: _PATH | None = None, finite_only: bool = True, ): if _TORCHMETRICS_GREATER_EQUAL_1_0_0: @@ -52,7 +52,7 @@ def __init__( raise RuntimeError("SpikeDetection requires `torchmetrics>=1.0.0` Please upgrade your version.") super().__init__() - self.last_val: Union[torch.Tensor, float] = 0.0 + self.last_val: torch.Tensor | float = 0.0 # spike detection happens individually on each machine self.running_mean = Running(MeanMetric(dist_sync_on_step=False, sync_on_compute=False), window=window) # workaround for https://github.com/Lightning-AI/torchmetrics/issues/1899 @@ -125,10 +125,10 @@ def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None: raise TrainingSpikeException(batch_idx=batch_idx) - def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool: + def _check_atol(self, val_a: float | torch.Tensor, val_b: float | torch.Tensor) -> bool: return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol)) - def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool: + def _check_rtol(self, val_a: float | torch.Tensor, val_b: float | torch.Tensor) -> bool: return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b)) def _is_better(self, diff_val: torch.Tensor) -> bool: diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index d085e4138d742..ca0c3e367b983 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -14,7 +14,6 @@ import operator import os import sys -from typing import Optional import torch from lightning_utilities.core.imports import compare_version @@ -30,12 +29,12 @@ def _runif_reasons( *, min_cuda_gpus: int = 0, - min_torch: Optional[str] = None, - max_torch: Optional[str] = None, - min_python: Optional[str] = None, + min_torch: str | None = None, + max_torch: str | None = None, + min_python: str | None = None, bf16_cuda: bool = False, tpu: bool = False, - mps: Optional[bool] = None, + mps: bool | None = None, skip_windows: bool = False, standalone: bool = False, deepspeed: bool = False, diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index 6bc329fa1c3be..6e2ffa96bf2d7 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -13,7 +13,8 @@ # limitations under the License. # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820/composer/callbacks/speed_monitor.py from collections import deque -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, TypeVar import torch from typing_extensions import override @@ -24,7 +25,7 @@ from lightning.fabric import Fabric from lightning.fabric.plugins import Precision -_THROUGHPUT_METRICS = dict[str, Union[int, float]] +_THROUGHPUT_METRICS = dict[str, int | float] # The API design of this class follows `torchmetrics.Metric` but it doesn't need to be an actual Metric because there's @@ -92,7 +93,7 @@ class Throughput: """ def __init__( - self, available_flops: Optional[float] = None, world_size: int = 1, window_size: int = 100, separator: str = "/" + self, available_flops: float | None = None, world_size: int = 1, window_size: int = 100, separator: str = "/" ) -> None: self.available_flops = available_flops self.separator = separator @@ -116,8 +117,8 @@ def update( time: float, batches: int, samples: int, - lengths: Optional[int] = None, - flops: Optional[int] = None, + lengths: int | None = None, + flops: int | None = None, ) -> None: """Update throughput metrics. @@ -249,7 +250,7 @@ def __init__(self, fabric: "Fabric", **kwargs: Any) -> None: self.compute_and_log = rank_zero_only(self.compute_and_log, default={}) # type: ignore[method-assign] self.reset = rank_zero_only(self.reset) # type: ignore[method-assign] - def compute_and_log(self, step: Optional[int] = None, **kwargs: Any) -> _THROUGHPUT_METRICS: + def compute_and_log(self, step: int | None = None, **kwargs: Any) -> _THROUGHPUT_METRICS: r"""See :meth:`Throughput.compute` Args: @@ -266,7 +267,7 @@ def compute_and_log(self, step: Optional[int] = None, **kwargs: Any) -> _THROUGH def measure_flops( model: torch.nn.Module, forward_fn: Callable[[], torch.Tensor], - loss_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + loss_fn: Callable[[torch.Tensor], torch.Tensor] | None = None, ) -> int: """Utility to compute the total number of FLOPs used by a module during training or during inference. @@ -302,7 +303,7 @@ def measure_flops( return flop_counter.get_total_flops() -_CUDA_FLOPS: dict[str, dict[Union[str, torch.dtype], float]] = { +_CUDA_FLOPS: dict[str, dict[str | torch.dtype, float]] = { # Hopper # source: https://nvdam.widen.net/s/nb5zzzsjdf/hpc-datasheet-sc23-h200-datasheet-3002446 "h200 sxm1": { @@ -543,7 +544,7 @@ def measure_flops( } -def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) -> Optional[int]: +def get_available_flops(device: torch.device, dtype: torch.dtype | str) -> int | None: """Returns the available theoretical FLOPs. This is an optimistic upper limit that could only be achievable if only thick matmuls were run in a benchmark @@ -678,7 +679,7 @@ def __init__(self, maxlen: int) -> None: self.maxlen = maxlen @property - def last(self) -> Optional[T]: + def last(self) -> T | None: if len(self) > 0: return self[-1] return None diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py index 365ecbcefcb42..a0eed9adb54b3 100644 --- a/src/lightning/fabric/utilities/types.py +++ b/src/lightning/fabric/utilities/types.py @@ -12,29 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Callable, Iterator from pathlib import Path from typing import ( Any, - Callable, - Optional, Protocol, + TypeAlias, TypeVar, - Union, runtime_checkable, ) import torch from torch import Tensor -from typing_extensions import TypeAlias, overload +from typing_extensions import overload UntypedStorage: TypeAlias = torch.UntypedStorage -_PATH = Union[str, Path] -_DEVICE = Union[torch.device, str, int] -_MAP_LOCATION_TYPE = Optional[ - Union[_DEVICE, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[_DEVICE, _DEVICE]] -] +_PATH = str | Path +_DEVICE = torch.device | str | int +_MAP_LOCATION_TYPE = _DEVICE | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[_DEVICE, _DEVICE] | None _PARAMETERS = Iterator[torch.nn.Parameter] if torch.distributed.is_available(): @@ -74,7 +70,7 @@ def step(self, closure: None = ...) -> None: ... @overload def step(self, closure: Callable[[], float]) -> float: ... - def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: ... + def step(self, closure: Callable[[], float] | None = ...) -> float | None: ... @runtime_checkable diff --git a/src/lightning/fabric/utilities/warnings.py b/src/lightning/fabric/utilities/warnings.py index b62bece384e32..56595f786bc14 100644 --- a/src/lightning/fabric/utilities/warnings.py +++ b/src/lightning/fabric/utilities/warnings.py @@ -15,7 +15,6 @@ import warnings from pathlib import Path -from typing import Optional, Union from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning @@ -38,7 +37,7 @@ def disable_possible_user_warnings(module: str = "") -> None: def _custom_format_warning( - message: Union[Warning, str], category: type[Warning], filename: str, lineno: int, line: Optional[str] = None + message: Warning | str, category: type[Warning], filename: str, lineno: int, line: str | None = None ) -> str: """Custom formatting that avoids an extra line in case warnings are emitted from the `rank_zero`-functions.""" if _is_path_in_lightning(Path(filename)): diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index b593c9f22ed23..88d8f167a15e6 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from collections.abc import Generator, Iterator, Mapping +from collections.abc import Callable, Generator, Iterator, Mapping from copy import deepcopy from functools import partial, wraps from types import MethodType from typing import ( Any, - Callable, - Optional, TypeVar, - Union, overload, ) @@ -50,7 +47,7 @@ class _FabricOptimizer: - def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[list[Callable]] = None) -> None: + def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: list[Callable] | None = None) -> None: """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer step calls to the strategy. @@ -77,7 +74,7 @@ def state_dict(self) -> dict[str, Tensor]: def load_state_dict(self, state_dict: dict[str, Tensor]) -> None: self.optimizer.load_state_dict(state_dict) - def step(self, closure: Optional[Callable] = None) -> Any: + def step(self, closure: Callable | None = None) -> Any: kwargs = {"closure": closure} if closure is not None else {} if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Optimizable): # only DeepSpeed defines this @@ -99,9 +96,7 @@ def __getattr__(self, item: Any) -> Any: class _FabricModule(_DeviceDtypeModuleMixin): - def __init__( - self, forward_module: nn.Module, strategy: Strategy, original_module: Optional[nn.Module] = None - ) -> None: + def __init__(self, forward_module: nn.Module, strategy: Strategy, original_module: nn.Module | None = None) -> None: """The FabricModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast automatically for the forward pass. @@ -148,8 +143,8 @@ def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, A @override def state_dict( - self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False - ) -> Optional[dict[str, Any]]: + self, destination: T_destination | None = None, prefix: str = "", keep_vars: bool = False + ) -> dict[str, Any] | None: return self._original_module.state_dict( destination=destination, # type: ignore[type-var] prefix=prefix, @@ -162,7 +157,7 @@ def load_state_dict( # type: ignore[override] ) -> _IncompatibleKeys: return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs) - def mark_forward_method(self, method: Union[MethodType, str]) -> None: + def mark_forward_method(self, method: MethodType | str) -> None: """Mark a method as a 'forward' method to prevent it bypassing the strategy wrapper (e.g., DDP).""" if not isinstance(method, (MethodType, str)): raise TypeError(f"Expected a method or a string, but got: {type(method).__name__}") @@ -291,7 +286,7 @@ def __setattr__(self, name: str, value: Any) -> None: class _FabricDataLoader: - def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: + def __init__(self, dataloader: DataLoader, device: torch.device | None = None) -> None: """The FabricDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the device automatically if the device is specified. @@ -307,13 +302,13 @@ def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None self._num_iter_calls = 0 @property - def device(self) -> Optional[torch.device]: + def device(self) -> torch.device | None: return self._device def __len__(self) -> int: return len(self._dataloader) - def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: + def __iter__(self) -> Iterator[Any] | Generator[Any, None, None]: # Without setting the epoch, the distributed sampler would return the same indices every time, even when # shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler. # In Fabric, we take care of this boilerplate code. @@ -329,8 +324,8 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: def _unwrap_objects(collection: Any) -> Any: def _unwrap( - obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader], - ) -> Union[nn.Module, Optimizer, DataLoader]: + obj: _FabricModule | _FabricOptimizer | _FabricDataLoader, + ) -> nn.Module | Optimizer | DataLoader: if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule): return _unwrap_compiled(unwrapped._forward_module)[0] if isinstance(obj, _FabricOptimizer): @@ -345,7 +340,7 @@ def _unwrap( return apply_to_collection(collection, dtype=tuple(types), function=_unwrap) -def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> tuple[Union[Any, nn.Module], Optional[dict[str, Any]]]: +def _unwrap_compiled(obj: Any | OptimizedModule) -> tuple[Any | nn.Module, dict[str, Any] | None]: """Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped. Use this function before instance checks against e.g. :class:`_FabricModule`. diff --git a/src/lightning/pytorch/_graveyard/_torchmetrics.py b/src/lightning/pytorch/_graveyard/_torchmetrics.py index 82e3ad2dcf549..b2cd99c614659 100644 --- a/src/lightning/pytorch/_graveyard/_torchmetrics.py +++ b/src/lightning/pytorch/_graveyard/_torchmetrics.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torchmetrics from lightning_utilities.core.imports import compare_version as _compare_version diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index e1338cdd73698..970bcdbf01955 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Any import torch from lightning_utilities.core.imports import RequirementCache @@ -48,13 +48,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str]) -> int: + def parse_devices(devices: int | str) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: + def get_parallel_devices(devices: int | str) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 63a0d8adba8ea..44fdc0a49409c 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -15,7 +15,7 @@ import os import shutil import subprocess -from typing import Any, Optional, Union +from typing import Any import torch from typing_extensions import override @@ -83,7 +83,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + def parse_devices(devices: int | str | list[int]) -> list[int] | None: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_cuda=True) diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index a999411f5ed8b..e29e65bc129cd 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any import torch from typing_extensions import override @@ -53,13 +53,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + def parse_devices(devices: int | str | list[int]) -> list[int] | None: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_mps=True) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]: + def get_parallel_devices(devices: int | str | list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None diff --git a/src/lightning/pytorch/callbacks/batch_size_finder.py b/src/lightning/pytorch/callbacks/batch_size_finder.py index 7b3c6c65f674e..5d64e0265915e 100644 --- a/src/lightning/pytorch/callbacks/batch_size_finder.py +++ b/src/lightning/pytorch/callbacks/batch_size_finder.py @@ -18,8 +18,6 @@ Finds optimal batch size """ -from typing import Optional - from typing_extensions import override import lightning.pytorch as pl @@ -133,7 +131,7 @@ def __init__( assert 0.0 <= margin < 1.0, f"`margin` should be between 0 and 1. Found {margin=}" - self.optimal_batch_size: Optional[int] = init_val + self.optimal_batch_size: int | None = init_val self._mode = mode self._steps_per_trial = steps_per_trial self._init_val = init_val @@ -144,7 +142,7 @@ def __init__( self._early_exit = False @override - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str | None = None) -> None: if trainer._accelerator_connector.is_distributed: raise MisconfigurationException("The Batch size finder is not supported with distributed strategies.") # TODO: check if this can be enabled (#4040) diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py index 873c4c05f5aed..02cc15574ffc2 100644 --- a/src/lightning/pytorch/callbacks/device_stats_monitor.py +++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py @@ -19,7 +19,7 @@ """ -from typing import Any, Optional +from typing import Any from typing_extensions import override @@ -115,7 +115,7 @@ class DeviceStatsMonitor(Callback): """ - def __init__(self, cpu_stats: Optional[bool] = None) -> None: + def __init__(self, cpu_stats: bool | None = None) -> None: self._cpu_stats = cpu_stats @override diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index 7569705b9d4ea..221b36809f828 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -20,8 +20,9 @@ """ import logging +from collections.abc import Callable from enum import Enum -from typing import Any, Callable, Optional +from typing import Any import torch from torch import Tensor @@ -120,9 +121,9 @@ def __init__( mode: str = "min", strict: bool = True, check_finite: bool = True, - stopping_threshold: Optional[float] = None, - divergence_threshold: Optional[float] = None, - check_on_train_epoch_end: Optional[bool] = None, + stopping_threshold: float | None = None, + divergence_threshold: float | None = None, + check_on_train_epoch_end: bool | None = None, log_rank_zero_only: bool = False, ): super().__init__() @@ -138,7 +139,7 @@ def __init__( self.wait_count = 0 self.stopped_epoch = 0 self.stopping_reason = EarlyStoppingReason.NOT_STOPPED - self.stopping_reason_message: Optional[str] = None + self.stopping_reason_message: str | None = None self._check_on_train_epoch_end = check_on_train_epoch_end self.log_rank_zero_only = log_rank_zero_only @@ -243,7 +244,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: if reason and self.verbose: self._log_info(trainer, reason, self.log_rank_zero_only) - def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[str]]: + def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, str | None]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index bb5145ce25b85..82e3efe0de17b 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -19,8 +19,8 @@ """ import logging -from collections.abc import Generator, Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Generator, Iterable +from typing import Any import torch from torch.nn import Module, ModuleDict @@ -124,7 +124,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self._restarting = False @staticmethod - def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> list[Module]: + def flatten_modules(modules: Module | Iterable[Module | Iterable]) -> list[Module]: """This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves. @@ -152,7 +152,7 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - @staticmethod def filter_params( - modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True + modules: Module | Iterable[Module | Iterable], train_bn: bool = True, requires_grad: bool = True ) -> Generator: """Yields the `requires_grad` parameters of a given module or list of modules. @@ -174,7 +174,7 @@ def filter_params( yield param @staticmethod - def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None: + def make_trainable(modules: Module | Iterable[Module | Iterable]) -> None: """Unfreezes the parameters of the provided modules. Args: @@ -204,7 +204,7 @@ def freeze_module(module: Module) -> None: param.requires_grad = False @staticmethod - def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None: + def freeze(modules: Module | Iterable[Module | Iterable], train_bn: bool = True) -> None: """Freezes the parameters of the provided modules. Args: @@ -253,9 +253,9 @@ def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> list: @staticmethod def unfreeze_and_add_param_group( - modules: Union[Module, Iterable[Union[Module, Iterable]]], + modules: Module | Iterable[Module | Iterable], optimizer: Optimizer, - lr: Optional[float] = None, + lr: float | None = None, initial_denom_lr: float = 10.0, train_bn: bool = True, ) -> None: @@ -410,7 +410,7 @@ def __init__( unfreeze_backbone_at_epoch: int = 10, lambda_func: Callable = multiplicative, backbone_initial_ratio_lr: float = 10e-2, - backbone_initial_lr: Optional[float] = None, + backbone_initial_lr: float | None = None, should_align: bool = True, initial_denom_lr: float = 10.0, train_bn: bool = True, @@ -422,13 +422,13 @@ def __init__( self.unfreeze_backbone_at_epoch: int = unfreeze_backbone_at_epoch self.lambda_func: Callable = lambda_func self.backbone_initial_ratio_lr: float = backbone_initial_ratio_lr - self.backbone_initial_lr: Optional[float] = backbone_initial_lr + self.backbone_initial_lr: float | None = backbone_initial_lr self.should_align: bool = should_align self.initial_denom_lr: float = initial_denom_lr self.train_bn: bool = train_bn self.verbose: bool = verbose self.rounding: int = rounding - self.previous_backbone_lr: Optional[float] = None + self.previous_backbone_lr: float | None = None @override def state_dict(self) -> dict[str, Any]: diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py index f04b2d777deb3..2e50f9a18d48d 100644 --- a/src/lightning/pytorch/callbacks/lambda_function.py +++ b/src/lightning/pytorch/callbacks/lambda_function.py @@ -19,7 +19,7 @@ """ -from typing import Callable, Optional +from collections.abc import Callable from lightning.pytorch.callbacks.callback import Callback @@ -40,43 +40,43 @@ class LambdaCallback(Callback): def __init__( self, - setup: Optional[Callable] = None, - teardown: Optional[Callable] = None, - on_fit_start: Optional[Callable] = None, - on_fit_end: Optional[Callable] = None, - on_sanity_check_start: Optional[Callable] = None, - on_sanity_check_end: Optional[Callable] = None, - on_train_batch_start: Optional[Callable] = None, - on_train_batch_end: Optional[Callable] = None, - on_train_epoch_start: Optional[Callable] = None, - on_train_epoch_end: Optional[Callable] = None, - on_validation_epoch_start: Optional[Callable] = None, - on_validation_epoch_end: Optional[Callable] = None, - on_test_epoch_start: Optional[Callable] = None, - on_test_epoch_end: Optional[Callable] = None, - on_validation_batch_start: Optional[Callable] = None, - on_validation_batch_end: Optional[Callable] = None, - on_test_batch_start: Optional[Callable] = None, - on_test_batch_end: Optional[Callable] = None, - on_train_start: Optional[Callable] = None, - on_train_end: Optional[Callable] = None, - on_validation_start: Optional[Callable] = None, - on_validation_end: Optional[Callable] = None, - on_test_start: Optional[Callable] = None, - on_test_end: Optional[Callable] = None, - on_exception: Optional[Callable] = None, - on_save_checkpoint: Optional[Callable] = None, - on_load_checkpoint: Optional[Callable] = None, - on_before_backward: Optional[Callable] = None, - on_after_backward: Optional[Callable] = None, - on_before_optimizer_step: Optional[Callable] = None, - on_before_zero_grad: Optional[Callable] = None, - on_predict_start: Optional[Callable] = None, - on_predict_end: Optional[Callable] = None, - on_predict_batch_start: Optional[Callable] = None, - on_predict_batch_end: Optional[Callable] = None, - on_predict_epoch_start: Optional[Callable] = None, - on_predict_epoch_end: Optional[Callable] = None, + setup: Callable | None = None, + teardown: Callable | None = None, + on_fit_start: Callable | None = None, + on_fit_end: Callable | None = None, + on_sanity_check_start: Callable | None = None, + on_sanity_check_end: Callable | None = None, + on_train_batch_start: Callable | None = None, + on_train_batch_end: Callable | None = None, + on_train_epoch_start: Callable | None = None, + on_train_epoch_end: Callable | None = None, + on_validation_epoch_start: Callable | None = None, + on_validation_epoch_end: Callable | None = None, + on_test_epoch_start: Callable | None = None, + on_test_epoch_end: Callable | None = None, + on_validation_batch_start: Callable | None = None, + on_validation_batch_end: Callable | None = None, + on_test_batch_start: Callable | None = None, + on_test_batch_end: Callable | None = None, + on_train_start: Callable | None = None, + on_train_end: Callable | None = None, + on_validation_start: Callable | None = None, + on_validation_end: Callable | None = None, + on_test_start: Callable | None = None, + on_test_end: Callable | None = None, + on_exception: Callable | None = None, + on_save_checkpoint: Callable | None = None, + on_load_checkpoint: Callable | None = None, + on_before_backward: Callable | None = None, + on_after_backward: Callable | None = None, + on_before_optimizer_step: Callable | None = None, + on_before_zero_grad: Callable | None = None, + on_predict_start: Callable | None = None, + on_predict_end: Callable | None = None, + on_predict_batch_start: Callable | None = None, + on_predict_batch_end: Callable | None = None, + on_predict_epoch_start: Callable | None = None, + on_predict_epoch_end: Callable | None = None, ): for k, v in locals().items(): if k == "self": diff --git a/src/lightning/pytorch/callbacks/lr_finder.py b/src/lightning/pytorch/callbacks/lr_finder.py index aaadc3c38ed5e..c2a4d6255821e 100644 --- a/src/lightning/pytorch/callbacks/lr_finder.py +++ b/src/lightning/pytorch/callbacks/lr_finder.py @@ -18,8 +18,6 @@ Finds optimal learning rate """ -from typing import Optional - from typing_extensions import override import lightning.pytorch as pl @@ -89,7 +87,7 @@ def __init__( max_lr: float = 1, num_training_steps: int = 100, mode: str = "exponential", - early_stop_threshold: Optional[float] = 4.0, + early_stop_threshold: float | None = 4.0, update_attr: bool = True, attr_name: str = "", ) -> None: @@ -106,7 +104,7 @@ def __init__( self._attr_name = attr_name self._early_exit = False - self.optimal_lr: Optional[_LRFinder] = None + self.optimal_lr: _LRFinder | None = None def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: with isolate_rng(): diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index 37ecb620bf6d4..002209a75634f 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -22,7 +22,7 @@ import itertools from collections import defaultdict -from typing import Any, Literal, Optional +from typing import Any, Literal import torch from torch.optim.optimizer import Optimizer @@ -93,7 +93,7 @@ def configure_optimizer(self): def __init__( self, - logging_interval: Optional[Literal["step", "epoch"]] = None, + logging_interval: Literal["step", "epoch"] | None = None, log_momentum: bool = False, log_weight_decay: bool = False, ) -> None: @@ -105,8 +105,8 @@ def __init__( self.log_weight_decay = log_weight_decay self.lrs: dict[str, list[float]] = {} - self.last_momentum_values: dict[str, Optional[list[float]]] = {} - self.last_weight_decay_values: dict[str, Optional[list[float]]] = {} + self.last_momentum_values: dict[str, list[float] | None] = {} + self.last_weight_decay_values: dict[str, list[float] | None] = {} @override def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: @@ -344,7 +344,7 @@ def _check_duplicates_and_update_name( name: str, seen_optimizers: list[Optimizer], seen_optimizer_types: defaultdict[type[Optimizer], int], - lr_scheduler_config: Optional[LRSchedulerConfig], + lr_scheduler_config: LRSchedulerConfig | None, ) -> list[str]: seen_optimizers.append(optimizer) optimizer_cls = type(optimizer) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 8a5d9dcdf786f..0326cc933e433 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,7 +27,7 @@ from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from weakref import proxy import torch @@ -227,20 +227,20 @@ class ModelCheckpoint(Checkpoint): def __init__( self, - dirpath: Optional[_PATH] = None, - filename: Optional[str] = None, - monitor: Optional[str] = None, + dirpath: _PATH | None = None, + filename: str | None = None, + monitor: str | None = None, verbose: bool = False, - save_last: Optional[Union[bool, Literal["link"]]] = None, + save_last: bool | Literal["link"] | None = None, save_top_k: int = 1, save_on_exception: bool = False, save_weights_only: bool = False, mode: str = "min", auto_insert_metric_name: bool = True, - every_n_train_steps: Optional[int] = None, - train_time_interval: Optional[timedelta] = None, - every_n_epochs: Optional[int] = None, - save_on_train_epoch_end: Optional[bool] = None, + every_n_train_steps: int | None = None, + train_time_interval: timedelta | None = None, + every_n_epochs: int | None = None, + save_on_train_epoch_end: bool | None = None, enable_version_counter: bool = True, ): super().__init__() @@ -254,11 +254,11 @@ def __init__( self._save_on_train_epoch_end = save_on_train_epoch_end self._enable_version_counter = enable_version_counter self._last_global_step_saved = 0 # no need to save when no steps were taken - self._last_time_checked: Optional[float] = None - self.current_score: Optional[Tensor] = None + self._last_time_checked: float | None = None + self.current_score: Tensor | None = None self.best_k_models: dict[str, Tensor] = {} self.kth_best_model_path = "" - self.best_model_score: Optional[Tensor] = None + self.best_model_score: Tensor | None = None self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" @@ -267,7 +267,7 @@ def __init__( self._defer_save_until_validation: bool = False self.kth_value: Tensor - self.dirpath: Optional[_PATH] + self.dirpath: _PATH | None self.__init_monitor_mode(mode) self.__init_ckpt_dir(dirpath, filename) self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) @@ -574,7 +574,7 @@ def __validate_init_configuration(self) -> None: " configuration. No quantity for top_k to track." ) - def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: + def __init_ckpt_dir(self, dirpath: _PATH | None, filename: str | None) -> None: self._fs = get_filesystem(dirpath if dirpath else "") if dirpath and _is_local_file_protocol(dirpath if dirpath else ""): @@ -594,9 +594,9 @@ def __init_monitor_mode(self, mode: str) -> None: def __init_triggers( self, - every_n_train_steps: Optional[int], - every_n_epochs: Optional[int], - train_time_interval: Optional[timedelta], + every_n_train_steps: int | None, + every_n_epochs: int | None, + train_time_interval: timedelta | None, ) -> None: # Default to running once after each validation epoch if neither # every_n_train_steps nor every_n_epochs is set @@ -608,15 +608,15 @@ def __init_triggers( every_n_epochs = every_n_epochs or 0 every_n_train_steps = every_n_train_steps or 0 - self._train_time_interval: Optional[timedelta] = train_time_interval + self._train_time_interval: timedelta | None = train_time_interval self._every_n_epochs: int = every_n_epochs self._every_n_train_steps: int = every_n_train_steps @property - def every_n_epochs(self) -> Optional[int]: + def every_n_epochs(self) -> int | None: return self._every_n_epochs - def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = None) -> bool: + def check_monitor_top_k(self, trainer: "pl.Trainer", current: Tensor | None = None) -> bool: if current is None: return False @@ -637,9 +637,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = def _format_checkpoint_name( self, - filename: Optional[str], + filename: str | None, metrics: dict[str, Tensor], - prefix: Optional[str] = None, + prefix: str | None = None, auto_insert_metric_name: bool = True, ) -> str: if not filename: @@ -674,9 +674,9 @@ def _format_checkpoint_name( def format_checkpoint_name( self, metrics: dict[str, Tensor], - filename: Optional[str] = None, - ver: Optional[int] = None, - prefix: Optional[str] = None, + filename: str | None = None, + ver: int | None = None, + prefix: str | None = None, ) -> str: """Generate a filename according to the defined template. @@ -765,7 +765,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") def _get_metric_interpolated_filepath_name( - self, monitor_candidates: dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None + self, monitor_candidates: dict[str, Tensor], trainer: "pl.Trainer", del_filepath: str | None = None ) -> str: filepath = self.format_checkpoint_name(monitor_candidates) @@ -870,7 +870,7 @@ def _update_best_and_save( if del_filepath and self._should_remove_checkpoint(trainer, del_filepath, filepath): self._remove_checkpoint(trainer, del_filepath) - def to_yaml(self, filepath: Optional[_PATH] = None) -> None: + def to_yaml(self, filepath: _PATH | None = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML file.""" best_k = {k: v.item() for k, v in self.best_k_models.items()} diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py index ee9ff2f3bd902..74f5fd506a06b 100644 --- a/src/lightning/pytorch/callbacks/model_summary.py +++ b/src/lightning/pytorch/callbacks/model_summary.py @@ -23,7 +23,7 @@ """ import logging -from typing import Any, Union +from typing import Any from typing_extensions import override @@ -82,7 +82,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - **self._summarize_kwargs, ) - def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Union[DeepSpeedSummary, Summary]: + def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> DeepSpeedSummary | Summary: from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy if isinstance(trainer.strategy, DeepSpeedStrategy) and trainer.strategy.zero_stage_3: diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index ce6342c7aa88d..98882767bbb44 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -19,7 +19,7 @@ """ from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from typing_extensions import override @@ -122,7 +122,7 @@ def write_on_batch_end( trainer: "pl.Trainer", pl_module: "pl.LightningModule", prediction: Any, - batch_indices: Optional[Sequence[int]], + batch_indices: Sequence[int] | None, batch: Any, batch_idx: int, dataloader_idx: int, diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py index 4c965038cb294..4b0e9035fb1ea 100644 --- a/src/lightning/pytorch/callbacks/progress/progress_bar.py +++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any from typing_extensions import override @@ -48,8 +48,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """ def __init__(self) -> None: - self._trainer: Optional[pl.Trainer] = None - self._current_eval_dataloader_idx: Optional[int] = None + self._trainer: pl.Trainer | None = None + self._current_eval_dataloader_idx: int | None = None @property def trainer(self) -> "pl.Trainer": @@ -78,7 +78,7 @@ def predict_description(self) -> str: return "Predicting" @property - def total_train_batches(self) -> Union[int, float]: + def total_train_batches(self) -> int | float: """The total number of training batches, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training @@ -91,7 +91,7 @@ def total_train_batches(self) -> Union[int, float]: return self.trainer.num_training_batches @property - def total_val_batches_current_dataloader(self) -> Union[int, float]: + def total_val_batches_current_dataloader(self) -> int | float: """The total number of validation batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation @@ -105,7 +105,7 @@ def total_val_batches_current_dataloader(self) -> Union[int, float]: return batches @property - def total_test_batches_current_dataloader(self) -> Union[int, float]: + def total_test_batches_current_dataloader(self) -> int | float: """The total number of testing batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is @@ -119,7 +119,7 @@ def total_test_batches_current_dataloader(self) -> Union[int, float]: return batches @property - def total_predict_batches_current_dataloader(self) -> Union[int, float]: + def total_predict_batches_current_dataloader(self) -> int | float: """The total number of prediction batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader @@ -130,7 +130,7 @@ def total_predict_batches_current_dataloader(self) -> Union[int, float]: return self.trainer.num_predict_batches[self._current_eval_dataloader_idx] @property - def total_val_batches(self) -> Union[int, float]: + def total_val_batches(self) -> int | float: """The total number of validation batches, which may change from epoch to epoch for all val dataloaders. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader @@ -179,7 +179,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s def get_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" - ) -> dict[str, Union[int, str, float, dict[str, float]]]: + ) -> dict[str, int | str | float | dict[str, float]]: r"""Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. @@ -210,7 +210,7 @@ def get_metrics(self, trainer, model): return {**standard_metrics, **pbar_metrics} -def get_standard_metrics(trainer: "pl.Trainer") -> dict[str, Union[int, str]]: +def get_standard_metrics(trainer: "pl.Trainer") -> dict[str, int | str]: r"""Returns the standard metrics displayed in the progress bar. Currently, it only includes the version of the experiment when using a logger. @@ -222,7 +222,7 @@ def get_standard_metrics(trainer: "pl.Trainer") -> dict[str, Union[int, str]]: Dictionary with the standard metrics to be displayed in the progress bar. """ - items_dict: dict[str, Union[int, str]] = {} + items_dict: dict[str, int | str] = {} if trainer.loggers: from lightning.pytorch.loggers.utilities import _version diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index d4c3c916c7ed0..6e652b5b5f177 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -63,7 +63,7 @@ class CustomInfiniteTask(Task): """ @property - def time_remaining(self) -> Optional[float]: + def time_remaining(self) -> float | None: return None class CustomProgress(Progress): @@ -73,7 +73,7 @@ def add_task( self, description: str, start: bool = True, - total: Optional[float] = 100.0, + total: float | None = 100.0, completed: int = 0, visible: bool = True, **fields: Any, @@ -107,7 +107,7 @@ class CustomTimeColumn(ProgressColumn): # Only refresh twice a second to prevent jitter max_refresh = 0.5 - def __init__(self, style: Union[str, Style]) -> None: + def __init__(self, style: str | Style) -> None: self.style = style super().__init__() @@ -119,7 +119,7 @@ def render(self, task: "Task") -> Text: return Text(f"{elapsed_delta} • {remaining_delta}", style=self.style) class BatchesProcessedColumn(ProgressColumn): - def __init__(self, style: Union[str, Style]): + def __init__(self, style: str | Style): self.style = style super().__init__() @@ -128,7 +128,7 @@ def render(self, task: "Task") -> RenderableType: return Text(f"{int(task.completed)}/{total}", style=self.style) class ProcessingSpeedColumn(ProgressColumn): - def __init__(self, style: Union[str, Style]): + def __init__(self, style: str | Style): self.style = style super().__init__() @@ -147,9 +147,9 @@ def __init__( metrics_format: str, ): self._trainer = trainer - self._tasks: dict[Union[int, TaskID], Any] = {} + self._tasks: dict[int | TaskID, Any] = {} self._current_task_id = 0 - self._metrics: dict[Union[str, Style], Any] = {} + self._metrics: dict[str | Style, Any] = {} self._style = style self._text_delimiter = text_delimiter self._metrics_format = metrics_format @@ -261,7 +261,7 @@ def __init__( refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), - console_kwargs: Optional[dict[str, Any]] = None, + console_kwargs: dict[str, Any] | None = None, ) -> None: if not _RICH_AVAILABLE: raise ModuleNotFoundError( @@ -271,17 +271,17 @@ def __init__( super().__init__() self._refresh_rate: int = refresh_rate self._leave: bool = leave - self._console: Optional[Console] = None + self._console: Console | None = None self._console_kwargs = console_kwargs or {} self._enabled: bool = True - self.progress: Optional[CustomProgress] = None - self.train_progress_bar_id: Optional[TaskID] - self.val_sanity_progress_bar_id: Optional[TaskID] = None - self.val_progress_bar_id: Optional[TaskID] - self.test_progress_bar_id: Optional[TaskID] - self.predict_progress_bar_id: Optional[TaskID] + self.progress: CustomProgress | None = None + self.train_progress_bar_id: TaskID | None + self.val_sanity_progress_bar_id: TaskID | None = None + self.val_progress_bar_id: TaskID | None + self.test_progress_bar_id: TaskID | None + self.predict_progress_bar_id: TaskID | None self._reset_progress_bar_ids() - self._metric_component: Optional[MetricsTextColumn] = None + self._metric_component: MetricsTextColumn | None = None self._progress_stopped: bool = False self.theme = theme @@ -453,7 +453,7 @@ def on_validation_batch_start( self.refresh() - def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID": + def _add_task(self, total_batches: int | float, description: str, visible: bool = True) -> "TaskID": assert self.progress is not None return self.progress.add_task( f"[{self.theme.description}]{description}" if self.theme.description else description, @@ -474,7 +474,7 @@ def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bo return self.progress.update(progress_bar_id, completed=current, visible=visible) - def _should_update(self, current: int, total: Union[int, float]) -> bool: + def _should_update(self, current: int, total: int | float) -> bool: return current % self.refresh_rate == 0 or current == total @override @@ -634,7 +634,7 @@ def _reset_progress_bar_ids(self) -> None: @override def get_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" - ) -> dict[str, Union[int, str, float, dict[str, float]]]: + ) -> dict[str, int | str | float | dict[str, float]]: items = super().get_metrics(trainer, pl_module) # convert all metrics to float before sending to rich return apply_to_collection(items, torch.Tensor, lambda x: x.item()) @@ -643,7 +643,7 @@ def _update_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", - current: Optional[int] = None, + current: int | None = None, total_batches: bool = False, ) -> None: if not self.is_enabled or self._metric_component is None: diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 74abb8ecd850c..d8d6f8f0e3112 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -15,7 +15,7 @@ import math import os import sys -from typing import Any, Optional, Union +from typing import Any from typing_extensions import override @@ -44,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @staticmethod - def format_num(n: Union[int, float, str]) -> str: + def format_num(n: int | float | str) -> str: """Add additional padding to the formatted numbers.""" should_be_padded = isinstance(n, (float, str)) if not isinstance(n, str): @@ -109,10 +109,10 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool self._refresh_rate = self._resolve_refresh_rate(refresh_rate) self._process_position = process_position self._enabled = True - self._train_progress_bar: Optional[_tqdm] = None - self._val_progress_bar: Optional[_tqdm] = None - self._test_progress_bar: Optional[_tqdm] = None - self._predict_progress_bar: Optional[_tqdm] = None + self._train_progress_bar: _tqdm | None = None + self._val_progress_bar: _tqdm | None = None + self._test_progress_bar: _tqdm | None = None + self._predict_progress_bar: _tqdm | None = None self._leave = leave def __getstate__(self) -> dict: @@ -450,7 +450,7 @@ def _resolve_refresh_rate(refresh_rate: int) -> int: return refresh_rate -def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: +def convert_inf(x: int | float | None) -> int | float | None: """The tqdm doesn't support inf/nan values. We have to convert it to None. diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index f0e1bcbe49f99..4f48fc9820207 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -18,10 +18,10 @@ import inspect import logging -from collections.abc import Sequence +from collections.abc import Callable, Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any import torch.nn.utils.prune as pytorch_prune from lightning_utilities.core.apply_func import apply_to_collection @@ -65,17 +65,17 @@ class ModelPruning(Callback): def __init__( self, - pruning_fn: Union[Callable, str], + pruning_fn: Callable | str, parameters_to_prune: _PARAM_LIST = (), - parameter_names: Optional[list[str]] = None, + parameter_names: list[str] | None = None, use_global_unstructured: bool = True, - amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, - apply_pruning: Union[bool, Callable[[int], bool]] = True, + amount: int | float | Callable[[int], int | float] = 0.5, + apply_pruning: bool | Callable[[int], bool] = True, make_pruning_permanent: bool = True, - use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True, + use_lottery_ticket_hypothesis: bool | Callable[[int], bool] = True, resample_parameters: bool = False, - pruning_dim: Optional[int] = None, - pruning_norm: Optional[int] = None, + pruning_dim: int | None = None, + pruning_norm: int | None = None, verbose: int = 0, prune_on_train_epoch_end: bool = True, ) -> None: @@ -167,8 +167,8 @@ def __init__( self._prune_on_train_epoch_end = prune_on_train_epoch_end self._parameter_names = parameter_names or self.PARAMETER_NAMES self._global_kwargs: dict[str, Any] = {} - self._original_layers: Optional[dict[int, _LayerRef]] = None - self._pruning_method_name: Optional[str] = None + self._original_layers: dict[int, _LayerRef] | None = None + self._pruning_method_name: str | None = None for name in self._parameter_names: if name not in self.PARAMETER_NAMES: @@ -236,7 +236,7 @@ def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _ """This function can be overridden to control which module to prune.""" return parameters_to_prune - def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]: + def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Callable | pytorch_prune.BasePruningMethod: """This function takes `pruning_fn`, a function name. IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE, @@ -331,7 +331,7 @@ def _get_pruned_stats(module: nn.Module, name: str) -> tuple[int, int]: mask = getattr(module, attr) return (mask == 0).sum().item(), mask.numel() - def apply_pruning(self, amount: Union[int, float]) -> None: + def apply_pruning(self, amount: int | float) -> None: """Applies pruning to ``parameters_to_prune``.""" if self._verbose: prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] @@ -347,7 +347,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None: @rank_zero_only def _log_sparsity_stats( - self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0 + self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: int | float = 0 ) -> None: total_params = sum(total for _, total in curr) prev_total_zeros = sum(zeros for zeros, _ in prev) diff --git a/src/lightning/pytorch/callbacks/spike.py b/src/lightning/pytorch/callbacks/spike.py index b006acd44dcdb..5c3f8e9f5897a 100644 --- a/src/lightning/pytorch/callbacks/spike.py +++ b/src/lightning/pytorch/callbacks/spike.py @@ -1,6 +1,6 @@ import os from collections.abc import Mapping -from typing import Any, Union +from typing import Any import torch @@ -15,7 +15,7 @@ def on_train_batch_end( # type: ignore self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", - outputs: Union[torch.Tensor, Mapping[str, torch.Tensor]], + outputs: torch.Tensor | Mapping[str, torch.Tensor], batch: Any, batch_idx: int, ) -> None: diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 79c5423c54084..f13862fe32a4f 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -16,8 +16,9 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Literal, Optional, Union, cast +from typing import Any, Literal, cast import torch from torch import Tensor, nn @@ -39,12 +40,12 @@ class StochasticWeightAveraging(Callback): def __init__( self, - swa_lrs: Union[float, list[float]], - swa_epoch_start: Union[int, float] = 0.8, + swa_lrs: float | list[float], + swa_epoch_start: int | float = 0.8, annealing_epochs: int = 10, annealing_strategy: Literal["cos", "linear"] = "cos", - avg_fn: Optional[_AVG_FN] = None, - device: Optional[Union[torch.device, str]] = torch.device("cpu"), + avg_fn: _AVG_FN | None = None, + device: torch.device | str | None = torch.device("cpu"), ): r"""Implements the Stochastic Weight Averaging (SWA) Callback to average a model. @@ -115,21 +116,21 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") - self.n_averaged: Optional[Tensor] = None + self.n_averaged: Tensor | None = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device - self._model_contains_batch_norm: Optional[bool] = None - self._average_model: Optional[pl.LightningModule] = None + self._model_contains_batch_norm: bool | None = None + self._average_model: pl.LightningModule | None = None self._initialized = False - self._swa_scheduler: Optional[LRScheduler] = None - self._scheduler_state: Optional[dict] = None + self._swa_scheduler: LRScheduler | None = None + self._scheduler_state: dict | None = None self._init_n_averaged = 0 self._latest_update_epoch = -1 - self.momenta: dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} + self.momenta: dict[nn.modules.batchnorm._BatchNorm, float | None] = {} self._max_epochs: int @property diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index d38928d33de75..db1b9ebe52f95 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import torch from typing_extensions import override @@ -77,13 +78,13 @@ def sample_forward(): """ def __init__( - self, batch_size_fn: Callable[[Any], int], length_fn: Optional[Callable[[Any], int]] = None, **kwargs: Any + self, batch_size_fn: Callable[[Any], int], length_fn: Callable[[Any], int] | None = None, **kwargs: Any ) -> None: super().__init__() self.kwargs = kwargs self.batch_size_fn = batch_size_fn self.length_fn = length_fn - self.available_flops: Optional[int] = None + self.available_flops: int | None = None self._throughputs: dict[RunningStage, Throughput] = {} self._t0s: dict[RunningStage, float] = {} self._lengths: dict[RunningStage, int] = {} @@ -154,7 +155,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, flops=flops_per_batch, # type: ignore[arg-type] ) - def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None: + def _compute(self, trainer: "Trainer", iter_num: int | None = None) -> None: if not trainer._logger_connector.should_update_logs: return stage = trainer.state.stage @@ -246,7 +247,7 @@ def on_predict_batch_end( self._compute(trainer, iter_num) -def _plugin_to_compute_dtype(plugin: Union[FabricPrecision, Precision]) -> torch.dtype: +def _plugin_to_compute_dtype(plugin: FabricPrecision | Precision) -> torch.dtype: # TODO: integrate this into the precision plugins if not isinstance(plugin, Precision): return fabric_plugin_to_compute_dtype(plugin) diff --git a/src/lightning/pytorch/callbacks/timer.py b/src/lightning/pytorch/callbacks/timer.py index 91f5fd0e75d9b..9766d251bbd05 100644 --- a/src/lightning/pytorch/callbacks/timer.py +++ b/src/lightning/pytorch/callbacks/timer.py @@ -20,7 +20,7 @@ import re import time from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any from typing_extensions import override @@ -83,7 +83,7 @@ class Timer(Callback): def __init__( self, - duration: Optional[Union[str, timedelta, dict[str, int]]] = None, + duration: str | timedelta | dict[str, int] | None = None, interval: str = Interval.step, verbose: bool = True, ) -> None: @@ -111,16 +111,16 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: dict[RunningStage, Optional[float]] = dict.fromkeys(RunningStage) - self._end_time: dict[RunningStage, Optional[float]] = dict.fromkeys(RunningStage) + self._start_time: dict[RunningStage, float | None] = dict.fromkeys(RunningStage) + self._end_time: dict[RunningStage, float | None] = dict.fromkeys(RunningStage) self._offset = 0 - def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + def start_time(self, stage: str = RunningStage.TRAINING) -> float | None: """Return the start time of a particular stage (in seconds)""" stage = RunningStage(stage) return self._start_time[stage] - def end_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + def end_time(self, stage: str = RunningStage.TRAINING) -> float | None: """Return the end time of a particular stage (in seconds)""" stage = RunningStage(stage) return self._end_time[stage] @@ -136,7 +136,7 @@ def time_elapsed(self, stage: str = RunningStage.TRAINING) -> float: return time.monotonic() - start + offset return end - start + offset - def time_remaining(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + def time_remaining(self, stage: str = RunningStage.TRAINING) -> float | None: """Return the time remaining for a particular stage (in seconds)""" if self._duration is not None: return self._duration - self.time_elapsed(stage) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index f9b8d64eae6a5..0eadcd6cfe424 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -18,7 +18,7 @@ import itertools from copy import deepcopy -from typing import Any, Optional, Union +from typing import Any import torch from torch.optim.swa_utils import AveragedModel @@ -89,19 +89,19 @@ def should_update(self, step_idx=None, epoch_idx=None): def __init__( self, - device: Optional[Union[torch.device, str, int]] = None, + device: torch.device | str | int | None = None, use_buffers: bool = True, **kwargs: Any, ) -> None: # The default value is a string so that jsonargparse knows how to serialize it. if isinstance(device, str): - self._device: Optional[Union[torch.device, int]] = torch.device(device) + self._device: torch.device | int | None = torch.device(device) else: self._device = device self._use_buffers = use_buffers self._kwargs = kwargs - self._average_model: Optional[AveragedModel] = None + self._average_model: AveragedModel | None = None # Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures # that self.should_update() will be first called after the first optimizer step, which takes place after N @@ -112,7 +112,7 @@ def __init__( # epoch. self._latest_update_epoch = -1 - def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: + def should_update(self, step_idx: int | None = None, epoch_idx: int | None = None) -> bool: """Called after every optimizer step and after every training epoch to check whether the average model should be updated. diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2bcb1d8f4b1fd..43bf0f869d638 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -14,11 +14,11 @@ import inspect import os import sys -from collections.abc import Iterable +from collections.abc import Callable, Iterable from functools import partial, update_wrapper from pathlib import Path from types import MethodType -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar import torch import yaml @@ -80,14 +80,14 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any # LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch: LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]] +LRSchedulerTypeUnion = LRScheduler | ReduceLROnPlateau +LRSchedulerType = type[LRScheduler] | type[ReduceLROnPlateau] # Type aliases intended for convenience of CLI developers -ArgsType = Optional[Union[list[str], dict[str, Any], Namespace]] +ArgsType = list[str] | dict[str, Any] | Namespace | None OptimizerCallable = Callable[[Iterable], Optimizer] -LRSchedulerCallable = Callable[[Optimizer], Union[LRScheduler, ReduceLROnPlateau]] +LRSchedulerCallable = Callable[[Optimizer], LRScheduler | ReduceLROnPlateau] class LightningArgumentParser(ArgumentParser): @@ -117,18 +117,16 @@ def __init__( super().__init__(*args, description=description, env_prefix=env_prefix, default_env=default_env, **kwargs) self.callback_keys: list[str] = [] # separate optimizers and lr schedulers to know which were added - self._optimizers: dict[str, tuple[Union[type, tuple[type, ...]], str]] = {} - self._lr_schedulers: dict[str, tuple[Union[type, tuple[type, ...]], str]] = {} + self._optimizers: dict[str, tuple[type | tuple[type, ...], str]] = {} + self._lr_schedulers: dict[str, tuple[type | tuple[type, ...], str]] = {} def add_lightning_class_args( self, - lightning_class: Union[ - Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], - type[Trainer], - type[LightningModule], - type[LightningDataModule], - type[Callback], - ], + lightning_class: Callable[..., Trainer | LightningModule | LightningDataModule | Callback] + | type[Trainer] + | type[LightningModule] + | type[LightningDataModule] + | type[Callback], nested_key: str, subclass_mode: bool = False, required: bool = True, @@ -169,7 +167,7 @@ def add_lightning_class_args( def add_optimizer_args( self, - optimizer_class: Union[type[Optimizer], tuple[type[Optimizer], ...]] = (Optimizer,), + optimizer_class: type[Optimizer] | tuple[type[Optimizer], ...] = (Optimizer,), nested_key: str = "optimizer", link_to: str = "AUTOMATIC", ) -> None: @@ -194,7 +192,7 @@ def add_optimizer_args( def add_lr_scheduler_args( self, - lr_scheduler_class: Union[LRSchedulerType, tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, + lr_scheduler_class: LRSchedulerType | tuple[LRSchedulerType, ...] = LRSchedulerTypeTuple, nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: @@ -321,14 +319,14 @@ class LightningCLI: def __init__( self, - model_class: Optional[Union[type[LightningModule], Callable[..., LightningModule]]] = None, - datamodule_class: Optional[Union[type[LightningDataModule], Callable[..., LightningDataModule]]] = None, - save_config_callback: Optional[type[SaveConfigCallback]] = SaveConfigCallback, - save_config_kwargs: Optional[dict[str, Any]] = None, - trainer_class: Union[type[Trainer], Callable[..., Trainer]] = Trainer, - trainer_defaults: Optional[dict[str, Any]] = None, - seed_everything_default: Union[bool, int] = True, - parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None, + model_class: type[LightningModule] | Callable[..., LightningModule] | None = None, + datamodule_class: type[LightningDataModule] | Callable[..., LightningDataModule] | None = None, + save_config_callback: type[SaveConfigCallback] | None = SaveConfigCallback, + save_config_kwargs: dict[str, Any] | None = None, + trainer_class: type[Trainer] | Callable[..., Trainer] = Trainer, + trainer_defaults: dict[str, Any] | None = None, + seed_everything_default: bool | int = True, + parser_kwargs: dict[str, Any] | dict[str, dict[str, Any]] | None = None, parser_class: type[LightningArgumentParser] = LightningArgumentParser, subclass_mode_model: bool = False, subclass_mode_data: bool = False, @@ -450,7 +448,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No """Adds default arguments to the parser.""" parser.add_argument( "--seed_everything", - type=Union[bool, int], + type=bool | int, default=self.seed_everything_default, help=( "Set to an int to run seed_everything with this value before classes instantiation." @@ -528,7 +526,7 @@ def _prepare_subcommand_parser(self, klass: type, subcommand: str, **kwargs: Any parser = self.init_parser(**kwargs) self._add_arguments(parser) # subcommand arguments - skip: set[Union[str, int]] = set(self.subcommands()[subcommand]) + skip: set[str | int] = set(self.subcommands()[subcommand]) added = parser.add_method_arguments(klass, subcommand, skip=skip) # need to save which arguments were added to pass them to the method later self._subcommand_method_arguments[subcommand] = added @@ -654,7 +652,7 @@ def _instantiate_trainer(self, config: dict[str, Any], callbacks: list[Callback] ) return self.trainer_class(**config) - def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: + def _parser(self, subcommand: str | None) -> LightningArgumentParser: if subcommand is None: return self.parser # return the subcommand parser for the subcommand passed @@ -662,7 +660,7 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: @staticmethod def configure_optimizers( - lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None + lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: LRSchedulerTypeUnion | None = None ) -> Any: """Override to customize the :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` method. @@ -681,7 +679,7 @@ def configure_optimizers( } return [optimizer], [lr_scheduler] - def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: + def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None: """Overrides the model's :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` method if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" if not self.auto_configure_optimizers: @@ -690,7 +688,7 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) - parser = self._parser(subcommand) def get_automatic( - class_type: Union[type, tuple[type, ...]], register: dict[str, tuple[Union[type, tuple[type, ...]], str]] + class_type: type | tuple[type, ...], register: dict[str, tuple[type | tuple[type, ...], str]] ) -> list[str]: automatic = [] for key, (base_class, link_to) in register.items(): @@ -743,7 +741,7 @@ def get_automatic( # override the existing method self.model.configure_optimizers = MethodType(fn, self.model) - def _get(self, config: Namespace, key: str, default: Optional[Any] = None) -> Any: + def _get(self, config: Namespace, key: str, default: Any | None = None) -> Any: """Utility to get a config value which might be inside a subcommand.""" return config.get(str(self.subcommand), config).get(key, default) @@ -792,9 +790,7 @@ def _class_path_from_class(class_type: type) -> str: return class_type.__module__ + "." + class_type.__name__ -def _global_add_class_path( - class_type: type, init_args: Optional[Union[Namespace, dict[str, Any]]] = None -) -> dict[str, Any]: +def _global_add_class_path(class_type: type, init_args: Namespace | dict[str, Any] | None = None) -> dict[str, Any]: if isinstance(init_args, Namespace): init_args = init_args.as_dict() return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} @@ -807,7 +803,7 @@ def add_class_path(init_args: Namespace) -> dict[str, Any]: return add_class_path -def instantiate_class(args: Union[Any, tuple[Any, ...]], init: dict[str, Any]) -> Any: +def instantiate_class(args: Any | tuple[Any, ...], init: dict[str, Any]) -> Any: """Instantiates a class with the given args and init. Args: @@ -827,7 +823,7 @@ def instantiate_class(args: Union[Any, tuple[Any, ...]], init: dict[str, Any]) - return args_class(*args, **kwargs) -def _get_short_description(component: object) -> Optional[str]: +def _get_short_description(component: object) -> str | None: if component.__doc__ is None: return None try: @@ -837,7 +833,7 @@ def _get_short_description(component: object) -> Optional[str]: rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") -def _get_module_type(value: Union[Callable, type]) -> type: +def _get_module_type(value: Callable | type) -> type: if callable(value) and not isinstance(value, type): return inspect.signature(value).return_annotation return value diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 07ec02ef87bd8..e137766743fd2 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -16,7 +16,7 @@ import inspect import os from collections.abc import Iterable, Sized -from typing import IO, Any, Optional, Union, cast +from typing import IO, Any, cast from lightning_utilities import apply_to_collection from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -75,7 +75,7 @@ def teardown(self): """ - name: Optional[str] = None + name: str | None = None CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters" CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name" CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type" @@ -83,15 +83,15 @@ def teardown(self): def __init__(self) -> None: super().__init__() # Pointer to the trainer object - self.trainer: Optional[pl.Trainer] = None + self.trainer: pl.Trainer | None = None @classmethod def from_datasets( cls, - train_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None, - val_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None, - test_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None, - predict_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None, + train_dataset: Dataset | Iterable[Dataset] | None = None, + val_dataset: Dataset | Iterable[Dataset] | None = None, + test_dataset: Dataset | Iterable[Dataset] | None = None, + predict_dataset: Dataset | Iterable[Dataset] | None = None, batch_size: int = 1, num_workers: int = 0, **datamodule_kwargs: Any, @@ -174,10 +174,10 @@ def on_exception(self, exception: BaseException) -> None: @_restricted_classmethod def load_from_checkpoint( cls, - checkpoint_path: Union[_PATH, IO], + checkpoint_path: _PATH | IO, map_location: _MAP_LOCATION_TYPE = None, - hparams_file: Optional[_PATH] = None, - weights_only: Optional[bool] = None, + hparams_file: _PATH | None = None, + weights_only: bool | None = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the @@ -274,14 +274,14 @@ def retrieve_dataset_info(loader: DataLoader) -> dataset_info: return dataset_info(True, size) def loader_info( - loader: Union[DataLoader, Iterable[DataLoader]], - ) -> Union[dataset_info, Iterable[dataset_info]]: + loader: DataLoader | Iterable[DataLoader], + ) -> dataset_info | Iterable[dataset_info]: """Helper function to compute dataset information.""" return apply_to_collection(loader, DataLoader, retrieve_dataset_info) def extract_loader_info(methods: list[tuple[str, str]]) -> dict: """Helper function to extract information for each dataloader method.""" - info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {} + info: dict[str, dataset_info | Iterable[dataset_info]] = {} for loader_name, func_name in methods: loader_method = getattr(self, func_name, None) @@ -293,7 +293,7 @@ def extract_loader_info(methods: list[tuple[str, str]]) -> dict: return info - def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str: + def format_loader_info(info: dict[str, dataset_info | Iterable[dataset_info]]) -> str: """Helper function to format loader information.""" output = [] for loader_name, loader_info in info.items(): diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14244e38..362d34fd86f64 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. """Various hooks to be used in the Lightning code.""" -from typing import Any, Optional +from typing import Any import torch from torch import Tensor @@ -65,7 +65,7 @@ def on_predict_start(self) -> None: def on_predict_end(self) -> None: """Called at the end of predicting.""" - def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]: + def on_train_batch_start(self, batch: Any, batch_idx: int) -> int | None: """Called in the training loop before anything happens for that batch. If you return -1 here, you will skip training for the rest of the current epoch. @@ -144,7 +144,7 @@ def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int """ - def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def on_predict_batch_end(self, outputs: Any | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Called in the predict loop after the batch. Args: diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 3a01cd2fe9a7c..3e895624cd0c1 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -18,7 +18,7 @@ from collections.abc import Iterator, MutableMapping, Sequence from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Optional, Union +from typing import Any from lightning.fabric.utilities.data import AttributeDict from lightning.pytorch.utilities.parsing import save_hyperparameters @@ -51,8 +51,8 @@ def __init__(self) -> None: def save_hyperparameters( self, *args: Any, - ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None, + ignore: Sequence[str] | str | None = None, + frame: types.FrameType | None = None, logger: bool = True, ) -> None: """Save arguments to ``hparams`` attribute. @@ -130,7 +130,7 @@ class ``__init__`` to be ignored frame = current_frame.f_back save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams) - def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: + def _set_hparams(self, hp: MutableMapping | Namespace | str) -> None: hp = self._to_hparams_dict(hp) if isinstance(hp, dict) and isinstance(self.hparams, dict): @@ -139,7 +139,7 @@ def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: self._hparams = hp @staticmethod - def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]: + def _to_hparams_dict(hp: MutableMapping | Namespace | str) -> MutableMapping | AttributeDict: if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): @@ -151,7 +151,7 @@ def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[Mutable return hp @property - def hparams(self) -> Union[AttributeDict, MutableMapping]: + def hparams(self) -> AttributeDict | MutableMapping: """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 37b07f025f8e9..4dcfa03c2789e 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -17,7 +17,7 @@ import logging import numbers import weakref -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Callable, Generator, Mapping, Sequence from contextlib import contextmanager, nullcontext from io import BytesIO from pathlib import Path @@ -25,10 +25,8 @@ IO, TYPE_CHECKING, Any, - Callable, Literal, Optional, - Union, cast, overload, ) @@ -90,9 +88,14 @@ warning_cache = WarningCache() log = logging.getLogger(__name__) -MODULE_OPTIMIZERS = Union[ - Optimizer, LightningOptimizer, _FabricOptimizer, list[Optimizer], list[LightningOptimizer], list[_FabricOptimizer] -] +MODULE_OPTIMIZERS = ( + Optimizer + | LightningOptimizer + | _FabricOptimizer + | list[Optimizer] + | list[LightningOptimizer] + | list[_FabricOptimizer] +) class LightningModule( @@ -134,33 +137,31 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # pointer to the trainer object - self._trainer: Optional[pl.Trainer] = None + self._trainer: pl.Trainer | None = None # attributes that can be set by user - self._example_input_array: Optional[Union[Tensor, tuple, dict]] = None + self._example_input_array: Tensor | tuple | dict | None = None self._automatic_optimization: bool = True - self._strict_loading: Optional[bool] = None + self._strict_loading: bool | None = None # attributes used internally - self._current_fx_name: Optional[str] = None + self._current_fx_name: str | None = None self._param_requires_grad_state: dict[str, bool] = {} - self._metric_attributes: Optional[dict[int, str]] = None - self._compiler_ctx: Optional[dict[str, Any]] = None + self._metric_attributes: dict[int, str] | None = None + self._compiler_ctx: dict[str, Any] | None = None # attributes only used when using fabric - self._fabric: Optional[lf.Fabric] = None + self._fabric: lf.Fabric | None = None self._fabric_optimizers: list[_FabricOptimizer] = [] # access to device mesh in `conigure_model()` hook - self._device_mesh: Optional[DeviceMesh] = None + self._device_mesh: DeviceMesh | None = None @overload - def optimizers( - self, use_pl_optimizer: Literal[True] = True - ) -> Union[LightningOptimizer, list[LightningOptimizer]]: ... + def optimizers(self, use_pl_optimizer: Literal[True] = True) -> LightningOptimizer | list[LightningOptimizer]: ... @overload - def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, list[Optimizer]]: ... + def optimizers(self, use_pl_optimizer: Literal[False]) -> Optimizer | list[Optimizer]: ... @overload def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ... @@ -195,7 +196,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS: # multiple opts return opts - def lr_schedulers(self) -> Union[None, list[LRSchedulerPLType], LRSchedulerPLType]: + def lr_schedulers(self) -> None | list[LRSchedulerPLType] | LRSchedulerPLType: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. Returns: @@ -245,7 +246,7 @@ def fabric(self, fabric: Optional["lf.Fabric"]) -> None: self._fabric = fabric @property - def example_input_array(self) -> Optional[Union[Tensor, tuple, dict]]: + def example_input_array(self) -> Tensor | tuple | dict | None: """The example input array is a specification of what the module can consume in the :meth:`forward` method. The return type is interpreted as follows: @@ -260,7 +261,7 @@ def example_input_array(self) -> Optional[Union[Tensor, tuple, dict]]: return self._example_input_array @example_input_array.setter - def example_input_array(self, example: Optional[Union[Tensor, tuple, dict]]) -> None: + def example_input_array(self, example: Tensor | tuple | dict | None) -> None: self._example_input_array = example @property @@ -316,14 +317,14 @@ def strict_loading(self, strict_loading: bool) -> None: self._strict_loading = strict_loading @property - def logger(self) -> Optional[Union[Logger, FabricLogger]]: + def logger(self) -> Logger | FabricLogger | None: """Reference to the logger object in the Trainer.""" if self._fabric is not None: return self._fabric.logger return self._trainer.logger if self._trainer is not None else None @property - def loggers(self) -> Union[list[Logger], list[FabricLogger]]: + def loggers(self) -> list[Logger] | list[FabricLogger]: """Reference to the list of loggers in the Trainer.""" if self._fabric is not None: return self._fabric.loggers @@ -356,7 +357,7 @@ def _on_before_batch_transfer(self, batch: Any, dataloader_idx: int = 0) -> Any: return self._call_batch_hook("on_before_batch_transfer", batch, dataloader_idx) def _apply_batch_transfer_handler( - self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0 + self, batch: Any, device: torch.device | None = None, dataloader_idx: int = 0 ) -> Any: device = device or self.device batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx) @@ -388,16 +389,16 @@ def log( name: str, value: _METRIC, prog_bar: bool = False, - logger: Optional[bool] = None, - on_step: Optional[bool] = None, - on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable[[Any], Any]] = "mean", + logger: bool | None = None, + on_step: bool | None = None, + on_epoch: bool | None = None, + reduce_fx: str | Callable[[Any], Any] = "mean", enable_graph: bool = False, sync_dist: bool = False, - sync_dist_group: Optional[Any] = None, + sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, - batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, + batch_size: int | None = None, + metric_attribute: str | None = None, rank_zero_only: bool = False, ) -> None: """Log a key, value pair. @@ -551,17 +552,17 @@ def log( def log_dict( self, - dictionary: Union[Mapping[str, _METRIC], MetricCollection], + dictionary: Mapping[str, _METRIC] | MetricCollection, prog_bar: bool = False, - logger: Optional[bool] = None, - on_step: Optional[bool] = None, - on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable[[Any], Any]] = "mean", + logger: bool | None = None, + on_step: bool | None = None, + on_epoch: bool | None = None, + reduce_fx: str | Callable[[Any], Any] = "mean", enable_graph: bool = False, sync_dist: bool = False, - sync_dist_group: Optional[Any] = None, + sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, - batch_size: Optional[int] = None, + batch_size: int | None = None, rank_zero_only: bool = False, ) -> None: """Log a dictionary of values at once. @@ -630,7 +631,7 @@ def log_dict( return None def _log_dict_through_fabric( - self, dictionary: Union[Mapping[str, _METRIC], MetricCollection], logger: Optional[bool] = None + self, dictionary: Mapping[str, _METRIC] | MetricCollection, logger: bool | None = None ) -> None: if logger is False: # Passing `logger=False` with Fabric does not make much sense because there is no other destination to @@ -655,7 +656,7 @@ def __check_not_nested(value: dict, name: str) -> None: def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") - def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: + def __to_tensor(self, value: Tensor | numbers.Number, name: str) -> Tensor: value = ( value.clone().detach() if isinstance(value, Tensor) @@ -670,8 +671,8 @@ def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor return value def all_gather( - self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, dict, list, tuple]: + self, data: Tensor | dict | list | tuple, group: Any | None = None, sync_grads: bool = False + ) -> Tensor | dict | list | tuple: r"""Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes and the tensors need to have the same shape across all @@ -967,7 +968,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): batch = kwargs.get("batch", args[0]) return self(batch) - def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: + def configure_callbacks(self) -> Sequence[Callback] | Callback: """Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already @@ -1136,7 +1137,7 @@ def backward(self, loss): else: loss.backward(*args, **kwargs) - def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None: + def toggle_optimizer(self, optimizer: Optimizer | LightningOptimizer) -> None: """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup. @@ -1165,7 +1166,7 @@ def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> N param.requires_grad = param_requires_grad_state[param] self._param_requires_grad_state = param_requires_grad_state - def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None: + def untoggle_optimizer(self, optimizer: Optimizer | LightningOptimizer) -> None: """Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. Args: @@ -1182,7 +1183,7 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> self._param_requires_grad_state = {} @contextmanager - def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> Generator: + def toggled_optimizer(self, optimizer: Optimizer | LightningOptimizer) -> Generator: """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and :meth:`untoggle_optimizer` into context manager. @@ -1210,8 +1211,8 @@ def training_step(...): def clip_gradients( self, optimizer: Optimizer, - gradient_clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[str] = None, + gradient_clip_val: int | float | None = None, + gradient_clip_algorithm: str | None = None, ) -> None: """Handles gradient clipping internally. @@ -1278,8 +1279,8 @@ def clip_gradients( def configure_gradient_clipping( self, optimizer: Optimizer, - gradient_clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[str] = None, + gradient_clip_val: int | float | None = None, + gradient_clip_algorithm: str | None = None, ) -> None: """Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`. @@ -1306,7 +1307,7 @@ def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_cli optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm ) - def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Optional[Any]) -> None: + def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Any | None) -> None: r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls each scheduler. By default, Lightning calls ``step()`` and as shown in the example for each scheduler based on its ``interval``. @@ -1338,8 +1339,8 @@ def optimizer_step( self, epoch: int, batch_idx: int, - optimizer: Union[Optimizer, LightningOptimizer], - optimizer_closure: Optional[Callable[[], Any]] = None, + optimizer: Optimizer | LightningOptimizer, + optimizer_closure: Callable[[], Any] | None = None, ) -> None: r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls the optimizer. @@ -1428,8 +1429,8 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None: @torch.no_grad() def to_onnx( self, - file_path: Union[str, Path, BytesIO, None] = None, - input_sample: Optional[Any] = None, + file_path: str | Path | BytesIO | None = None, + input_sample: Any | None = None, **kwargs: Any, ) -> Optional["ONNXProgram"]: """Saves the model in ONNX format. @@ -1487,11 +1488,11 @@ def forward(self, x): @torch.no_grad() def to_torchscript( self, - file_path: Optional[Union[str, Path]] = None, - method: Optional[str] = "script", - example_inputs: Optional[Any] = None, + file_path: str | Path | None = None, + method: str | None = "script", + example_inputs: Any | None = None, **kwargs: Any, - ) -> Union[ScriptModule, dict[str, ScriptModule]]: + ) -> ScriptModule | dict[str, ScriptModule]: """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing, please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that are @@ -1575,14 +1576,14 @@ def forward(self, x): @torch.no_grad() def to_tensorrt( self, - file_path: Optional[Union[str, Path, BytesIO]] = None, - input_sample: Optional[Any] = None, + file_path: str | Path | BytesIO | None = None, + input_sample: Any | None = None, ir: Literal["default", "dynamo", "ts"] = "default", output_format: Literal["exported_program", "torchscript"] = "exported_program", retrace: bool = False, - default_device: Union[str, torch.device] = "cuda", + default_device: str | torch.device = "cuda", **compile_kwargs: Any, - ) -> Union[ScriptModule, torch.fx.GraphModule]: + ) -> ScriptModule | torch.fx.GraphModule: """Export the model to ScriptModule or GraphModule using TensorRT compile backend. Args: @@ -1686,11 +1687,11 @@ def forward(self, x): @_restricted_classmethod def load_from_checkpoint( cls, - checkpoint_path: Union[_PATH, IO], + checkpoint_path: _PATH | IO, map_location: _MAP_LOCATION_TYPE = None, - hparams_file: Optional[_PATH] = None, - strict: Optional[bool] = None, - weights_only: Optional[bool] = None, + hparams_file: _PATH | None = None, + strict: bool | None = None, + weights_only: bool | None = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index b85e9b2c10e5a..dc18c8d62a8de 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from dataclasses import fields -from typing import Any, Callable, Optional, Union, overload +from typing import Any, Union, overload from weakref import proxy import torch @@ -48,7 +48,7 @@ class LightningOptimizer: def __init__(self, optimizer: Optimizer): self._optimizer = optimizer - self._strategy: Optional[pl.strategies.Strategy] = None + self._strategy: pl.strategies.Strategy | None = None # to inject logic around the optimizer step, particularly useful with manual optimization self._on_before_step = do_nothing_closure self._on_after_step = do_nothing_closure @@ -82,7 +82,7 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: yield lightning_module.untoggle_optimizer(self) - def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any: + def step(self, closure: Callable[[], Any] | None = None, **kwargs: Any) -> Any: """Performs a single optimization step (parameter update). Args: @@ -198,8 +198,8 @@ def _init_optimizers_and_lr_schedulers( def _configure_optimizers( - optim_conf: Union[dict[str, Any], list, Optimizer, tuple], -) -> tuple[list, list, Optional[str]]: + optim_conf: dict[str, Any] | list | Optimizer | tuple, +) -> tuple[list, list, str | None]: optimizers, lr_schedulers = [], [] monitor = None @@ -247,7 +247,7 @@ def _configure_optimizers( return optimizers, lr_schedulers, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> list[LRSchedulerConfig]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: str | None) -> list[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization.""" lr_scheduler_configs = [] for scheduler in schedulers: @@ -406,12 +406,12 @@ def step(self, closure: None = ...) -> None: ... def step(self, closure: Callable[[], float]) -> float: ... @override - def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + def step(self, closure: Callable[[], float] | None = None) -> float | None: if closure is not None: return closure() @override - def zero_grad(self, set_to_none: Optional[bool] = True) -> None: + def zero_grad(self, set_to_none: bool | None = True) -> None: pass # Do Nothing @override diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 391e9dd5d0f25..54c0681d0a3be 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -19,10 +19,11 @@ import logging import os from argparse import Namespace +from collections.abc import Callable from copy import deepcopy from enum import Enum from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Optional, Union from warnings import warn import torch @@ -51,12 +52,12 @@ def _load_from_checkpoint( - cls: Union[type["pl.LightningModule"], type["pl.LightningDataModule"]], - checkpoint_path: Union[_PATH, IO], + cls: type["pl.LightningModule"] | type["pl.LightningDataModule"], + checkpoint_path: _PATH | IO, map_location: _MAP_LOCATION_TYPE = None, - hparams_file: Optional[_PATH] = None, - strict: Optional[bool] = None, - weights_only: Optional[bool] = None, + hparams_file: _PATH | None = None, + strict: bool | None = None, + weights_only: bool | None = None, **kwargs: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: map_location = map_location or _default_map_location @@ -117,9 +118,9 @@ def _default_map_location(storage: "UntypedStorage", location: str) -> Optional[ def _load_state( - cls: Union[type["pl.LightningModule"], type["pl.LightningDataModule"]], + cls: type["pl.LightningModule"] | type["pl.LightningDataModule"], checkpoint: dict[str, Any], - strict: Optional[bool] = None, + strict: bool | None = None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: cls_spec = inspect.getfullargspec(cls.__init__) @@ -201,9 +202,7 @@ def _load_state( return obj -def _convert_loaded_hparams( - model_args: dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None -) -> dict[str, Any]: +def _convert_loaded_hparams(model_args: dict[str, Any], hparams_type: Callable | str | None = None) -> dict[str, Any]: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define if not hparams_type: @@ -267,7 +266,7 @@ def load_hparams_from_tags_csv(tags_csv: _PATH) -> dict[str, Any]: return {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} -def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) -> None: +def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: dict | Namespace) -> None: fs = get_filesystem(tags_csv) if not _is_dir(fs, os.path.dirname(tags_csv)): raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.") @@ -317,7 +316,7 @@ def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> di return hparams -def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None: +def save_hparams_to_yaml(config_yaml: _PATH, hparams: dict | Namespace, use_omegaconf: bool = True) -> None: """ Args: config_yaml: path to new YAML file @@ -372,7 +371,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us yaml.dump(hparams_allowed, fp) -def convert(val: str) -> Union[int, float, bool, str]: +def convert(val: str) -> int | float | bool | str: try: return ast.literal_eval(val) except (ValueError, SyntaxError) as err: diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index 3855f31898b81..6762c4b11efc8 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable, Iterator -from typing import Any, Optional +from typing import Any import torch import torch.nn as nn @@ -117,7 +117,7 @@ def __init__(self) -> None: def forward(self, x: Tensor) -> Tensor: return self.layer(x) - def loss(self, preds: Tensor, labels: Optional[Tensor] = None) -> Tensor: + def loss(self, preds: Tensor, labels: Tensor | None = None) -> Tensor: if labels is None: labels = torch.ones_like(preds) # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls diff --git a/src/lightning/pytorch/demos/lstm.py b/src/lightning/pytorch/demos/lstm.py index 9432dd9acd1f8..f774e2f880775 100644 --- a/src/lightning/pytorch/demos/lstm.py +++ b/src/lightning/pytorch/demos/lstm.py @@ -6,7 +6,6 @@ """ from collections.abc import Iterator, Sized -from typing import Optional import torch import torch.nn as nn @@ -73,7 +72,7 @@ class LightningLSTM(LightningModule): def __init__(self, vocab_size: int = 33278): super().__init__() self.model = SimpleLSTM(vocab_size=vocab_size) - self.hidden: Optional[tuple[Tensor, Tensor]] = None + self.hidden: tuple[Tensor, Tensor] | None = None def on_train_epoch_end(self) -> None: self.hidden = None diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index 9ecc5411ae974..1203dc8297667 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -16,8 +16,8 @@ import random import time import urllib -from collections.abc import Sized -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sized +from typing import Any from urllib.error import HTTPError from warnings import warn @@ -121,7 +121,7 @@ def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> tuple[Ten return res @staticmethod - def normalize_tensor(tensor: Tensor, mean: Union[int, float] = 0.0, std: Union[int, float] = 1.0) -> Tensor: + def normalize_tensor(tensor: Tensor, mean: int | float = 0.0, std: int | float = 1.0) -> Tensor: mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) return tensor.sub(mean).div(std) @@ -242,7 +242,7 @@ def test_dataloader(self) -> DataLoader: ) @property - def default_transforms(self) -> Optional[Callable]: + def default_transforms(self) -> Callable | None: if not _TORCHVISION_AVAILABLE: return None diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 13b5e05adc680..c7baa1c3e3447 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -8,7 +8,6 @@ import math import os from pathlib import Path -from typing import Optional import torch import torch.nn as nn @@ -54,7 +53,7 @@ def __init__( self.ninp = ninp self.vocab_size = vocab_size - self.src_mask: Optional[Tensor] = None + self.src_mask: Tensor | None = None def generate_square_subsequent_mask(self, size: int) -> Tensor: """Generate a square mask for the sequence to prevent future tokens from being seen.""" @@ -62,7 +61,7 @@ def generate_square_subsequent_mask(self, size: int) -> Tensor: mask = mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) return mask - def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, inputs: Tensor, target: Tensor, mask: Tensor | None = None) -> Tensor: _, t = inputs.shape # Generate source mask to prevent future token leakage @@ -88,7 +87,7 @@ def __init__(self, dim: int, dropout: float = 0.1, max_len: int = 5000) -> None: self.dropout = nn.Dropout(p=dropout) self.dim = dim self.max_len = max_len - self.pe: Optional[Tensor] = None + self.pe: Tensor | None = None def forward(self, x: Tensor) -> Tensor: if self.pe is None: diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index b544212e755e2..a289cc327498f 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -20,7 +20,7 @@ import os from argparse import Namespace from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -199,13 +199,13 @@ def __init__(self, *args, **kwarg): def __init__( self, *, - api_key: Optional[str] = None, - workspace: Optional[str] = None, - project: Optional[str] = None, - experiment_key: Optional[str] = None, - mode: Optional[Literal["get_or_create", "get", "create"]] = None, - online: Optional[bool] = None, - prefix: Optional[str] = None, + api_key: str | None = None, + workspace: str | None = None, + project: str | None = None, + experiment_key: str | None = None, + mode: Literal["get_or_create", "get", "create"] | None = None, + online: bool | None = None, + prefix: str | None = None, **kwargs: Any, ): if not _COMET_AVAILABLE: @@ -254,14 +254,14 @@ def __init__( ) ################################################## - self._api_key: Optional[str] = api_key - self._experiment: Optional[comet_experiment] = None - self._workspace: Optional[str] = workspace - self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode - self._online: Optional[bool] = online - self._project_name: Optional[str] = project - self._experiment_key: Optional[str] = experiment_key - self._prefix: Optional[str] = prefix + self._api_key: str | None = api_key + self._experiment: comet_experiment | None = None + self._workspace: str | None = workspace + self._mode: Literal["get_or_create", "get", "create"] | None = mode + self._online: bool | None = online + self._project_name: str | None = project + self._experiment_key: str | None = experiment_key + self._prefix: str | None = prefix self._kwargs: dict[str, Any] = kwargs # needs to be set before the first `comet_ml` import @@ -322,7 +322,7 @@ def experiment(self) -> comet_experiment: @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: params = _convert_params(params) self.experiment.__internal_api__log_parameters__( parameters=params, @@ -333,7 +333,7 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: @override @rank_zero_only - def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, Tensor | float], step: int | None = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" # Comet.com expects metrics to be a dictionary of detached tensors on CPU metrics_without_epoch = metrics.copy() @@ -365,7 +365,7 @@ def finalize(self, status: str) -> None: @property @override - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Gets the save directory. Returns: @@ -376,7 +376,7 @@ def save_dir(self) -> Optional[str]: @property @override - def name(self) -> Optional[str]: + def name(self) -> str | None: """Gets the project name. Returns: @@ -387,7 +387,7 @@ def name(self) -> Optional[str]: @property @override - def version(self) -> Optional[str]: + def version(self) -> str | None: """Gets the version. Returns: @@ -413,7 +413,7 @@ def __getstate__(self) -> dict[str, Any]: return state @override - def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: + def log_graph(self, model: Module, input_array: Tensor | None = None) -> None: if self._experiment is not None: self._experiment.__internal_api__set_model_graph__( graph=model, diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index 5ad7353310af4..adaacbe1c0b19 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -21,7 +21,7 @@ import os from argparse import Namespace -from typing import Any, Optional, Union +from typing import Any from typing_extensions import override @@ -88,8 +88,8 @@ class CSVLogger(Logger, FabricCSVLogger): def __init__( self, save_dir: _PATH, - name: Optional[str] = "lightning_logs", - version: Optional[Union[int, str]] = None, + name: str | None = "lightning_logs", + version: int | str | None = None, prefix: str = "", flush_logs_every_n_steps: int = 100, ): @@ -139,7 +139,7 @@ def save_dir(self) -> str: @override @rank_zero_only - def log_hyperparams(self, params: Optional[Union[dict[str, Any], Namespace]] = None) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace | None = None) -> None: params = _convert_params(params) self.experiment.log_hparams(params) diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index 668fe39cb67d2..79bf84e9fb425 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -18,8 +18,8 @@ import statistics from abc import ABC from collections import defaultdict -from collections.abc import Mapping, Sequence -from typing import Any, Callable, Optional +from collections.abc import Callable, Mapping, Sequence +from typing import Any from typing_extensions import override @@ -42,7 +42,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: pass @property - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Return the root directory where experiment logs get saved, or `None` if the logger does not save data locally.""" return None @@ -100,7 +100,7 @@ def method(*args: Any, **kwargs: Any) -> None: # TODO: this should have been deprecated def merge_dicts( # pragma: no cover dicts: Sequence[Mapping], - agg_key_funcs: Optional[Mapping] = None, + agg_key_funcs: Mapping | None = None, default_func: Callable[[Sequence[float]], float] = statistics.mean, ) -> dict: """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ff9b2b0d7e542..80cc41ce2226b 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -21,10 +21,10 @@ import re import tempfile from argparse import Namespace -from collections.abc import Mapping +from collections.abc import Callable, Mapping from pathlib import Path from time import time -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal import yaml from lightning_utilities.core.imports import RequirementCache @@ -116,15 +116,15 @@ def any_lightning_module_function_or_hook(self): def __init__( self, experiment_name: str = "lightning_logs", - run_name: Optional[str] = None, - tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"), - tags: Optional[dict[str, Any]] = None, - save_dir: Optional[str] = "./mlruns", + run_name: str | None = None, + tracking_uri: str | None = os.getenv("MLFLOW_TRACKING_URI"), + tags: dict[str, Any] | None = None, + save_dir: str | None = "./mlruns", log_model: Literal[True, False, "all"] = False, prefix: str = "", - artifact_location: Optional[str] = None, - run_id: Optional[str] = None, - synchronous: Optional[bool] = None, + artifact_location: str | None = None, + run_id: str | None = None, + synchronous: bool | None = None, ): if not _MLFLOW_AVAILABLE: raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE)) @@ -135,14 +135,14 @@ def __init__( tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" self._experiment_name = experiment_name - self._experiment_id: Optional[str] = None + self._experiment_id: str | None = None self._tracking_uri = tracking_uri self._run_name = run_name self._run_id = run_id self.tags = tags self._log_model = log_model self._logged_model_time: dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._checkpoint_callback: ModelCheckpoint | None = None self._prefix = prefix self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} @@ -205,7 +205,7 @@ def experiment(self) -> "MlflowClient": return self._mlflow_client @property - def run_id(self) -> Optional[str]: + def run_id(self) -> str | None: """Create the experiment if it does not exist to get the run id. Returns: @@ -216,7 +216,7 @@ def run_id(self) -> Optional[str]: return self._run_id @property - def experiment_id(self) -> Optional[str]: + def experiment_id(self) -> str | None: """Create the experiment if it does not exist to get the experiment id. Returns: @@ -228,7 +228,7 @@ def experiment_id(self) -> Optional[str]: @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: params = _convert_params(params) params = _flatten_dict(params) @@ -244,7 +244,7 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: @override @rank_zero_only - def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" from mlflow.entities import Metric @@ -291,7 +291,7 @@ def finalize(self, status: str = "success") -> None: @property @override - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """The root file directory in which MLflow experiments are saved. Return: @@ -305,7 +305,7 @@ def save_dir(self) -> Optional[str]: @property @override - def name(self) -> Optional[str]: + def name(self) -> str | None: """Get the experiment id. Returns: @@ -316,7 +316,7 @@ def name(self) -> Optional[str]: @property @override - def version(self) -> Optional[str]: + def version(self) -> str | None: """Get the run id. Returns: diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index bf9669c824784..46870598a6225 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -20,9 +20,9 @@ import logging import os from argparse import Namespace -from collections.abc import Generator +from collections.abc import Callable, Generator from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -234,11 +234,11 @@ def any_lightning_module_function_or_hook(self): def __init__( self, *, # force users to call `NeptuneLogger` initializer with `kwargs` - api_key: Optional[str] = None, - project: Optional[str] = None, - name: Optional[str] = None, - run: Optional[Union["Run", "Handler"]] = None, - log_model_checkpoints: Optional[bool] = True, + api_key: str | None = None, + project: str | None = None, + name: str | None = None, + run: Union["Run", "Handler"] | None = None, + log_model_checkpoints: bool | None = True, prefix: str = "training", **neptune_run_kwargs: Any, ): @@ -255,7 +255,7 @@ def __init__( self._api_key = api_key self._run_instance = run self._neptune_run_kwargs = neptune_run_kwargs - self._run_short_id: Optional[str] = None + self._run_short_id: str | None = None if self._run_instance is not None: self._retrieve_run_data() @@ -317,10 +317,10 @@ def _construct_path_with_prefix(self, *keys: str) -> str: @staticmethod def _verify_input_arguments( - api_key: Optional[str], - project: Optional[str], - name: Optional[str], - run: Optional[Union["Run", "Handler"]], + api_key: str | None, + project: str | None, + name: str | None, + run: Union["Run", "Handler"] | None, neptune_run_kwargs: dict, ) -> None: from neptune import Run @@ -396,7 +396,7 @@ def run(self) -> "Run": @override @rank_zero_only @_catch_inactive - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: r"""Log hyperparameters to the run. Hyperparameters will be logged under the "/hyperparams" namespace. @@ -444,7 +444,7 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: @override @rank_zero_only @_catch_inactive - def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, Tensor | float], step: int | None = None) -> None: """Log metrics (numeric values) in Neptune runs. Args: @@ -475,7 +475,7 @@ def finalize(self, status: str) -> None: @property @override - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save locally. @@ -569,7 +569,7 @@ def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], return set(cls._dict_paths(uploaded_models_dict)) @classmethod - def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator: + def _dict_paths(cls, d: dict[str, Any], path_in_build: str | None = None) -> Generator: for k, v in d.items(): path = f"{path_in_build}/{k}" if path_in_build is not None else k if not isinstance(v, dict): @@ -579,13 +579,13 @@ def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> @property @override - def name(self) -> Optional[str]: + def name(self) -> str | None: """Return the experiment name or 'offline-name' when exp is run in offline mode.""" return self._run_name @property @override - def version(self) -> Optional[str]: + def version(self) -> str | None: """Return the experiment version. It's Neptune Run's short_id diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index f9cc41c67045c..0110f71c37f6d 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -18,7 +18,7 @@ import os from argparse import Namespace -from typing import Any, Optional, Union +from typing import Any from torch import Tensor from typing_extensions import override @@ -85,12 +85,12 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): def __init__( self, save_dir: _PATH, - name: Optional[str] = "lightning_logs", - version: Optional[Union[int, str]] = None, + name: str | None = "lightning_logs", + version: int | str | None = None, log_graph: bool = False, default_hp_metric: bool = True, prefix: str = "", - sub_dir: Optional[_PATH] = None, + sub_dir: _PATH | None = None, **kwargs: Any, ): super().__init__( @@ -108,7 +108,7 @@ def __init__( f"{str(_TENSORBOARD_AVAILABLE)}" ) self._log_graph = log_graph and _TENSORBOARD_AVAILABLE - self.hparams: Union[dict[str, Any], Namespace] = {} + self.hparams: dict[str, Any] | Namespace = {} @property @override @@ -154,9 +154,9 @@ def save_dir(self) -> str: @rank_zero_only def log_hyperparams( self, - params: Union[dict[str, Any], Namespace], - metrics: Optional[dict[str, Any]] = None, - step: Optional[int] = None, + params: dict[str, Any] | Namespace, + metrics: dict[str, Any] | None = None, + step: int | None = None, ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to @@ -184,7 +184,7 @@ def log_hyperparams( @override @rank_zero_only def log_graph( # type: ignore[override] - self, model: "pl.LightningModule", input_array: Optional[Tensor] = None + self, model: "pl.LightningModule", input_array: Tensor | None = None ) -> None: if not self._log_graph: return diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index ced8a6f1f2bd3..e89aa0b502137 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -14,7 +14,7 @@ """Utilities for loggers.""" from pathlib import Path -from typing import Any, Union +from typing import Any from torch import Tensor @@ -22,7 +22,7 @@ from lightning.pytorch.callbacks import Checkpoint -def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]: +def _version(loggers: list[Any], separator: str = "_") -> int | str: if len(loggers) == 1: return loggers[0].version # Concatenate versions together, removing duplicates and preserving order diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 37ca362fa40c1..cb225e2cec9fc 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -20,7 +20,7 @@ from argparse import Namespace from collections.abc import Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Union import torch.nn as nn from lightning_utilities.core.imports import RequirementCache @@ -293,18 +293,18 @@ def any_lightning_module_function_or_hook(self): def __init__( self, - name: Optional[str] = None, + name: str | None = None, save_dir: _PATH = ".", - version: Optional[str] = None, + version: str | None = None, offline: bool = False, - dir: Optional[_PATH] = None, - id: Optional[str] = None, - anonymous: Optional[bool] = None, - project: Optional[str] = None, - log_model: Union[Literal["all"], bool] = False, + dir: _PATH | None = None, + id: str | None = None, + anonymous: bool | None = None, + project: str | None = None, + log_model: Literal["all"] | bool = False, experiment: Union["Run", "RunDisabled", None] = None, prefix: str = "", - checkpoint_name: Optional[str] = None, + checkpoint_name: str | None = None, add_file_policy: Literal["mutable", "immutable"] = "mutable", **kwargs: Any, ) -> None: @@ -422,13 +422,13 @@ def experiment(self) -> Union["Run", "RunDisabled"]: return self._experiment def watch( - self, model: nn.Module, log: Optional[str] = "gradients", log_freq: int = 100, log_graph: bool = True + self, model: nn.Module, log: str | None = "gradients", log_freq: int = 100, log_graph: bool = True ) -> None: self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: params = _convert_params(params) params = _sanitize_callable_params(params) params = _convert_json_serializable(params) @@ -436,7 +436,7 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: @override @rank_zero_only - def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) @@ -449,10 +449,10 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) def log_table( self, key: str, - columns: Optional[list[str]] = None, - data: Optional[list[list[Any]]] = None, + columns: list[str] | None = None, + data: list[list[Any]] | None = None, dataframe: Any = None, - step: Optional[int] = None, + step: int | None = None, ) -> None: """Log a Table containing any object type (text, image, audio, video, molecule, html, etc). @@ -468,10 +468,10 @@ def log_table( def log_text( self, key: str, - columns: Optional[list[str]] = None, - data: Optional[list[list[str]]] = None, + columns: list[str] | None = None, + data: list[list[str]] | None = None, dataframe: Any = None, - step: Optional[int] = None, + step: int | None = None, ) -> None: """Log text as a Table. @@ -482,7 +482,7 @@ def log_text( self.log_table(key, columns, data, dataframe, step) @rank_zero_only - def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_image(self, key: str, images: list[Any], step: int | None = None, **kwargs: Any) -> None: """Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). @@ -502,7 +502,7 @@ def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only - def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_audio(self, key: str, audios: list[Any], step: int | None = None, **kwargs: Any) -> None: r"""Log audios (numpy arrays, or file paths). Args: @@ -528,7 +528,7 @@ def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only - def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_video(self, key: str, videos: list[Any], step: int | None = None, **kwargs: Any) -> None: """Log videos (numpy arrays, or file paths). Args: @@ -555,7 +555,7 @@ def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **k @property @override - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Gets the save directory. Returns: @@ -566,7 +566,7 @@ def save_dir(self) -> Optional[str]: @property @override - def name(self) -> Optional[str]: + def name(self) -> str | None: """The project name of this experiment. Returns: @@ -578,7 +578,7 @@ def name(self) -> Optional[str]: @property @override - def version(self) -> Optional[str]: + def version(self) -> str | None: """Gets the id of the experiment. Returns: @@ -600,9 +600,9 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: @rank_zero_only def download_artifact( artifact: str, - save_dir: Optional[_PATH] = None, - artifact_type: Optional[str] = None, - use_artifact: Optional[bool] = True, + save_dir: _PATH | None = None, + artifact_type: str | None = None, + use_artifact: bool | None = True, ) -> str: """Downloads an artifact from the wandb server. @@ -627,7 +627,7 @@ def download_artifact( save_dir = None if save_dir is None else os.fspath(save_dir) return artifact.download(root=save_dir) - def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "Artifact": + def use_artifact(self, artifact: str, artifact_type: str | None = None) -> "Artifact": """Logs to the wandb dashboard that the mentioned artifact is used by the run. Args: diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 6036e57cf59ae..e01173a529bcf 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -19,7 +19,7 @@ from collections import ChainMap, OrderedDict, defaultdict from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor @@ -70,7 +70,7 @@ def __init__( self.verbose = verbose self.inference_mode = inference_mode self.batch_progress = _BatchProgress() # across dataloaders - self._max_batches: list[Union[int, float]] = [] + self._max_batches: list[int | float] = [] self._results = _ResultCollection(training=False) self._logged_outputs: list[_OUT_DICT] = [] @@ -78,8 +78,8 @@ def __init__( self._trainer_fn = trainer_fn self._stage = stage self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader") - self._combined_loader: Optional[CombinedLoader] = None - self._data_fetcher: Optional[_DataFetcher] = None + self._combined_loader: CombinedLoader | None = None + self._data_fetcher: _DataFetcher | None = None self._seen_batches_per_dataloader: defaultdict[int, int] = defaultdict(int) self._last_val_dl_reload_epoch = float("-inf") self._module_mode = _ModuleMode() @@ -93,7 +93,7 @@ def num_dataloaders(self) -> int: return len(combined_loader.flattened) @property - def max_batches(self) -> list[Union[int, float]]: + def max_batches(self) -> list[int | float]: """The max number of batches to run per dataloader.""" max_batches = self._max_batches if not self.trainer.sanity_checking: @@ -394,7 +394,7 @@ def _on_after_fetch(self) -> None: self.trainer.profiler.stop(f"[{type(self).__name__}].{self._stage.dataloader_prefix}_next") def _evaluation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int, dataloader_iter: Optional[Iterator] + self, batch: Any, batch_idx: int, dataloader_idx: int, dataloader_iter: Iterator | None ) -> None: """Runs the actual evaluation step together with all the necessary bookkeeping and the hooks tied to it. @@ -470,7 +470,7 @@ def _evaluation_step( if not self.batch_progress.is_last_batch and trainer.received_sigterm: raise SIGTERMException - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> OrderedDict: + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int | None) -> OrderedDict: """Helper method to build the arguments for the current step. Args: @@ -526,7 +526,7 @@ def _get_keys(data: dict) -> Iterable[tuple[str, ...]]: yield (k,) @staticmethod - def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]: + def _find_value(data: dict, target: Iterable[str]) -> Any | None: target_start, *rest = target if target_start not in data: return None diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 92ec95a9e2f58..ef2c42555f083 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Iterator -from typing import Any, Optional +from typing import Any from typing_extensions import override @@ -28,11 +28,11 @@ def _profile_nothing() -> None: class _DataFetcher(Iterator): def __init__(self) -> None: - self._combined_loader: Optional[CombinedLoader] = None - self.iterator: Optional[Iterator] = None + self._combined_loader: CombinedLoader | None = None + self.iterator: Iterator | None = None self.fetched: int = 0 self.done: bool = False - self.length: Optional[int] = None + self.length: int | None = None self._start_profiler = _profile_nothing self._stop_profiler = _profile_nothing @@ -197,7 +197,7 @@ def fetched(self) -> int: return self.data_fetcher.fetched @property - def length(self) -> Optional[int]: + def length(self) -> int | None: return self.data_fetcher.length @override diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 8bb123939dc20..ff9f4f6f41d4c 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging import time from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any import torch from typing_extensions import override @@ -87,8 +87,8 @@ class _FitLoop(_Loop): def __init__( self, trainer: "pl.Trainer", - min_epochs: Optional[int] = 0, - max_epochs: Optional[int] = None, + min_epochs: int | None = 0, + max_epochs: int | None = None, ) -> None: super().__init__(trainer) if isinstance(max_epochs, int) and max_epochs < -1: @@ -101,12 +101,12 @@ def __init__( self.min_epochs = min_epochs self.epoch_loop = _TrainingEpochLoop(trainer) self.epoch_progress = _Progress() - self.max_batches: Union[int, float] = float("inf") + self.max_batches: int | float = float("inf") self._data_source = _DataLoaderSource(None, "train_dataloader") - self._combined_loader: Optional[CombinedLoader] = None + self._combined_loader: CombinedLoader | None = None self._combined_loader_states_to_load: list[dict[str, Any]] = [] - self._data_fetcher: Optional[_DataFetcher] = None + self._data_fetcher: _DataFetcher | None = None self._last_train_dl_reload_epoch = float("-inf") self._restart_stage = RestartStage.NONE @@ -121,7 +121,7 @@ def batch_idx(self) -> int: return self.epoch_loop.batch_idx @property - def min_steps(self) -> Optional[int]: + def min_steps(self) -> int | None: """Returns the minimum number of steps to run.""" return self.epoch_loop.min_steps diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index f4324c003f7a9..cbc6e3d046e3f 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import lightning.pytorch as pl from lightning.pytorch.loops.progress import _BaseProgress @@ -59,7 +58,7 @@ def on_save_checkpoint(self) -> dict: def on_load_checkpoint(self, state_dict: dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" - def state_dict(self, destination: Optional[dict] = None, prefix: str = "") -> dict: + def state_dict(self, destination: dict | None = None, prefix: str = "") -> dict: """The state dict is determined by the state and progress of this loop and all its children. Args: diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index e19b5761c4d4b..9782e25570e40 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from collections.abc import Mapping +from collections.abc import Callable, Mapping from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Optional +from typing import Any import torch from torch import Tensor @@ -46,8 +46,8 @@ class ClosureResult(OutputResult): """ - closure_loss: Optional[Tensor] - loss: Optional[Tensor] = field(init=False, default=None) + closure_loss: Tensor | None + loss: Tensor | None = field(init=False, default=None) extra: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: @@ -117,8 +117,8 @@ class Closure(AbstractClosure[ClosureResult]): def __init__( self, step_fn: Callable[[], ClosureResult], - backward_fn: Optional[Callable[[Tensor], None]] = None, - zero_grad_fn: Optional[Callable[[], None]] = None, + backward_fn: Callable[[Tensor], None] | None = None, + zero_grad_fn: Callable[[], None] | None = None, ): super().__init__() self._step_fn = step_fn @@ -142,7 +142,7 @@ def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: return step_output @override - def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: + def __call__(self, *args: Any, **kwargs: Any) -> Tensor | None: self._result = self.closure(*args, **kwargs) return self._result.loss @@ -208,7 +208,7 @@ def _make_step_fn(self, kwargs: OrderedDict) -> Callable[[], ClosureResult]: """Build the step function that runs the `training_step` and processes its output.""" return partial(self._training_step, kwargs) - def _make_zero_grad_fn(self, batch_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: + def _make_zero_grad_fn(self, batch_idx: int, optimizer: Optimizer) -> Callable[[], None] | None: """Build a `zero_grad` function that zeroes the gradients before back-propagation. Returns ``None`` in the case backward needs to be skipped. @@ -227,7 +227,7 @@ def zero_grad_fn() -> None: return zero_grad_fn - def _make_backward_fn(self, optimizer: Optimizer) -> Optional[Callable[[Tensor], None]]: + def _make_backward_fn(self, optimizer: Optimizer) -> Callable[[Tensor], None] | None: """Build a `backward` function that handles back-propagation through the output produced by the `training_step` function. @@ -245,7 +245,7 @@ def backward_fn(loss: Tensor) -> None: def _optimizer_step( self, batch_idx: int, - train_step_and_backward_closure: Callable[[], Optional[Tensor]], + train_step_and_backward_closure: Callable[[], Tensor | None], ) -> None: """Performs the optimizer step and some sanity checking. diff --git a/src/lightning/pytorch/loops/optimization/closure.py b/src/lightning/pytorch/loops/optimization/closure.py index e45262a067f52..265b3220f477e 100644 --- a/src/lightning/pytorch/loops/optimization/closure.py +++ b/src/lightning/pytorch/loops/optimization/closure.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -40,7 +40,7 @@ class AbstractClosure(ABC, Generic[T]): def __init__(self) -> None: super().__init__() - self._result: Optional[T] = None + self._result: T | None = None def consume_result(self) -> T: """The cached result from the last time the closure was called. diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index dcfd873a28b4b..baf85c0f68f84 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import OrderedDict from collections.abc import Iterator -from typing import Any, Optional, Union +from typing import Any import torch from lightning_utilities import WarningCache @@ -54,12 +54,12 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: self.epoch_batch_indices: list[list[list[int]]] = [] self.current_batch_indices: list[int] = [] # used by PredictionWriter self.batch_progress = _Progress() # across dataloaders - self.max_batches: list[Union[int, float]] = [] + self.max_batches: list[int | float] = [] self._warning_cache = WarningCache() self._data_source = _DataLoaderSource(None, "predict_dataloader") - self._combined_loader: Optional[CombinedLoader] = None - self._data_fetcher: Optional[_DataFetcher] = None + self._combined_loader: CombinedLoader | None = None + self._data_fetcher: _DataFetcher | None = None self._results = None # for `trainer._results` access self._predictions: list[list[Any]] = [] # dataloaders x batches self._return_predictions = False @@ -71,7 +71,7 @@ def return_predictions(self) -> bool: return self._return_predictions @return_predictions.setter - def return_predictions(self, return_predictions: Optional[bool] = None) -> None: + def return_predictions(self, return_predictions: bool | None = None) -> None: # Strategies that spawn or fork don't support returning predictions return_supported = not isinstance(self.trainer.strategy.launcher, _MultiProcessingLauncher) if return_predictions and not return_supported: @@ -101,7 +101,7 @@ def skip(self) -> bool: return sum(self.max_batches) == 0 @_no_grad_context - def run(self) -> Optional[_PREDICT_OUTPUT]: + def run(self) -> _PREDICT_OUTPUT | None: self.setup_data() if self.skip: return None @@ -198,7 +198,7 @@ def on_run_start(self) -> None: self._on_predict_start() self._on_predict_epoch_start() - def on_run_end(self) -> Optional[_PREDICT_OUTPUT]: + def on_run_end(self) -> _PREDICT_OUTPUT | None: """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" results = self._on_predict_epoch_end() self._on_predict_end() @@ -210,9 +210,7 @@ def teardown(self) -> None: self._data_fetcher.teardown() self._data_fetcher = None - def _predict_step( - self, batch: Any, batch_idx: int, dataloader_idx: int, dataloader_iter: Optional[Iterator] - ) -> None: + def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int, dataloader_iter: Iterator | None) -> None: """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to it. Args: @@ -273,7 +271,7 @@ def _predict_step( if self._return_predictions or any_on_epoch: self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu"))) - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> OrderedDict: + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int | None) -> OrderedDict: """Assembles the keyword arguments for the ``predict_step`` Args: @@ -358,7 +356,7 @@ def _on_predict_epoch_start(self) -> None: call._call_callback_hooks(trainer, "on_predict_epoch_start") call._call_lightning_module_hook(trainer, "on_predict_epoch_start") - def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: + def _on_predict_epoch_end(self) -> _PREDICT_OUTPUT | None: """Calls ``on_predict_epoch_end`` hook. Returns: diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 3d01780b705fe..c9d72324422ad 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -16,7 +16,7 @@ import time from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any import torch from typing_extensions import override @@ -38,8 +38,6 @@ from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature -_BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] - @dataclass class RestartStage: @@ -70,7 +68,7 @@ class _TrainingEpochLoop(loops._Loop): """ - def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_steps: int = -1) -> None: + def __init__(self, trainer: "pl.Trainer", min_steps: int | None = None, max_steps: int = -1) -> None: super().__init__(trainer) if max_steps < -1: raise MisconfigurationException( @@ -324,7 +322,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: self.batch_progress.increment_ready() trainer._logger_connector.on_batch_start(batch) - batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy + batch_output: _OPTIMIZER_LOOP_OUTPUTS_TYPE | _MANUAL_LOOP_OUTPUTS_TYPE | None = None # for mypy if batch is None and not using_dataloader_iter: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") else: @@ -509,7 +507,7 @@ def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool) ) self.scheduler_progress.increment_completed() - def _get_monitor_value(self, key: str) -> Optional[Any]: + def _get_monitor_value(self, key: str) -> Any | None: # this is a separate method to aid in testing return self.trainer.callback_metrics.get(key) diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 8e20f485828f0..04da7dc8da800 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import AbstractContextManager, contextmanager -from typing import Any, Callable, Optional +from typing import Any import torch import torch.distributed as dist @@ -36,7 +36,7 @@ from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature -def check_finite_loss(loss: Optional[Tensor]) -> None: +def check_finite_loss(loss: Tensor | None) -> None: """Checks for finite loss value. Args: @@ -48,10 +48,10 @@ def check_finite_loss(loss: Optional[Tensor]) -> None: def _parse_loop_limits( - min_steps: Optional[int], + min_steps: int | None, max_steps: int, - min_epochs: Optional[int], - max_epochs: Optional[int], + min_epochs: int | None, + max_epochs: int | None, trainer: "pl.Trainer", ) -> tuple[int, int]: """This utility computes the default values for the minimum and maximum number of steps and epochs given the values diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 92d444338ff0f..0e6155e9cc6f6 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from collections.abc import Iterable, Iterator, Sized -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable, Iterator, Sized +from typing import Any, cast import torch from torch import Tensor @@ -27,8 +27,8 @@ def _find_tensors( - obj: Union[Tensor, list, tuple, dict, Any], -) -> Union[list[Tensor], itertools.chain]: # pragma: no-cover + obj: Tensor | list | tuple | dict | Any, +) -> list[Tensor] | itertools.chain: # pragma: no-cover """Recursively find all tensors contained in the specified object.""" if isinstance(obj, Tensor): return [obj] @@ -61,9 +61,9 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None: def _register_ddp_comm_hook( model: DistributedDataParallel, - ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[Callable] = None, - ddp_comm_wrapper: Optional[Callable] = None, + ddp_comm_state: object | None = None, + ddp_comm_hook: Callable | None = None, + ddp_comm_wrapper: Callable | None = None, ) -> None: """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. @@ -223,7 +223,7 @@ def __iter__(self) -> Iterator[list[int]]: class UnrepeatedDistributedSamplerWrapper(UnrepeatedDistributedSampler): """Equivalent class to ``DistributedSamplerWrapper`` but for the ``UnrepeatedDistributedSampler``.""" - def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None: + def __init__(self, sampler: Sampler | Iterable, *args: Any, **kwargs: Any) -> None: super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs) @override @@ -245,7 +245,7 @@ def __init__(self, batch_sampler: _SizedIterable) -> None: if k not in ("__next__", "__iter__", "__len__", "__getstate__") } self._batch_sampler = batch_sampler - self._iterator: Optional[Iterator[list[int]]] = None + self._iterator: Iterator[list[int]] | None = None def __next__(self) -> list[int]: assert self._iterator is not None diff --git a/src/lightning/pytorch/plugins/__init__.py b/src/lightning/pytorch/plugins/__init__.py index d4fd63807c78d..b1280d073ca97 100644 --- a/src/lightning/pytorch/plugins/__init__.py +++ b/src/lightning/pytorch/plugins/__init__.py @@ -1,5 +1,3 @@ -from typing import Union - from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, TorchCheckpointIO, XLACheckpointIO from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm @@ -13,7 +11,7 @@ from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision from lightning.pytorch.plugins.precision.xla import XLAPrecision -_PLUGIN_INPUT = Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync] +_PLUGIN_INPUT = Precision | ClusterEnvironment | CheckpointIO | LayerSync __all__ = [ "AsyncCheckpointIO", diff --git a/src/lightning/pytorch/plugins/io/async_plugin.py b/src/lightning/pytorch/plugins/io/async_plugin.py index 5cff35074992e..baa854d9427f7 100644 --- a/src/lightning/pytorch/plugins/io/async_plugin.py +++ b/src/lightning/pytorch/plugins/io/async_plugin.py @@ -35,8 +35,8 @@ class AsyncCheckpointIO(_WrappingCheckpointIO): """ - _executor: Optional[ThreadPoolExecutor] - _error: Optional[BaseException] + _executor: ThreadPoolExecutor | None + _error: BaseException | None def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: super().__init__(checkpoint_io) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 5ea62233e1f69..a922fe5dee8e8 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -9,9 +9,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Literal, Optional import torch from torch import Tensor @@ -101,7 +101,7 @@ def optimizer_step( # type: ignore[override] def clip_gradients( self, optimizer: Optimizer, - clip_val: Union[int, float] = 0.0, + clip_val: int | float = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: if clip_val > 0 and _optimizer_handles_unscaling(optimizer): diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 9225e3bb9e7be..08b7ea0e162be 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,15 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, get_args import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module from torch.optim import LBFGS, Optimizer -from typing_extensions import get_args, override +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT @@ -94,7 +95,7 @@ def backward( # type: ignore[override] self, tensor: Tensor, model: "pl.LightningModule", - optimizer: Optional[Steppable], + optimizer: Steppable | None, *args: Any, **kwargs: Any, ) -> None: @@ -142,7 +143,7 @@ def optimizer_step( # type: ignore[override] def clip_gradients( self, optimizer: Optimizer, - clip_val: Union[int, float] = 0.0, + clip_val: int | float = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """DeepSpeed handles gradient clipping internally.""" diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index f3bab3e915e91..e4b859b85656b 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Optional, get_args import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module -from typing_extensions import get_args, override +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 327fb2d4f5a27..3c1ec9577ccdf 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from collections.abc import Generator +from collections.abc import Callable, Generator from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Union import torch from torch import Tensor @@ -55,7 +55,7 @@ def backward( # type: ignore[override] self, tensor: Tensor, model: "pl.LightningModule", - optimizer: Optional[Steppable], + optimizer: Steppable | None, *args: Any, **kwargs: Any, ) -> None: @@ -126,8 +126,8 @@ def _clip_gradients( self, model: Union["pl.LightningModule", Module], optimizer: Steppable, - clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None, + clip_val: int | float | None = None, + gradient_clip_algorithm: GradClipAlgorithmType | None = None, ) -> None: if not isinstance(model, pl.LightningModule) or not model.automatic_optimization: # the configuration validator disallows clipping on manual @@ -144,7 +144,7 @@ def _clip_gradients( def clip_gradients( self, optimizer: Optimizer, - clip_val: Union[int, float] = 0.0, + clip_val: int | float = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """Clips the gradients.""" @@ -155,12 +155,12 @@ def clip_gradients( elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: self.clip_grad_by_norm(optimizer, clip_val) - def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_value(self, optimizer: Optimizer, clip_val: int | float) -> None: """Clip gradients by value.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: int | float) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py index 6890cc4c1d825..09622d7c805fb 100644 --- a/src/lightning/pytorch/plugins/precision/xla.py +++ b/src/lightning/pytorch/plugins/precision/xla.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Callable from functools import partial -from typing import Any, Callable +from typing import Any, get_args import torch -from typing_extensions import get_args, override +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.accelerators.xla import _XLA_AVAILABLE diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index c0b4b9953cc33..43a1740a84c4a 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -21,7 +21,6 @@ import tempfile from collections import defaultdict from pathlib import Path -from typing import Optional, Union from typing_extensions import override @@ -42,8 +41,8 @@ class AdvancedProfiler(Profiler): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + dirpath: str | Path | None = None, + filename: str | None = None, line_count_restriction: float = 1.0, dump_stats: bool = False, ) -> None: @@ -114,7 +113,7 @@ def summary(self) -> str: return self._stats_to_str(recorded_stats) @override - def teardown(self, stage: Optional[str]) -> None: + def teardown(self, stage: str | None) -> None: super().teardown(stage=stage) self.profiled_actions.clear() diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index e8a8c60881062..6fb9c8637f626 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -16,10 +16,10 @@ import logging import os from abc import ABC, abstractmethod -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Optional, TextIO, Union +from typing import Any, TextIO from lightning.fabric.utilities.cloud_io import get_filesystem @@ -31,16 +31,16 @@ class Profiler(ABC): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + dirpath: str | Path | None = None, + filename: str | None = None, ) -> None: self.dirpath = dirpath self.filename = filename - self._output_file: Optional[TextIO] = None - self._write_stream: Optional[Callable] = None - self._local_rank: Optional[int] = None - self._stage: Optional[str] = None + self._output_file: TextIO | None = None + self._write_stream: Callable | None = None + self._local_rank: int | None = None + self._stage: str | None = None @abstractmethod def start(self, action_name: str) -> None: @@ -78,7 +78,7 @@ def _rank_zero_info(self, *args: Any, **kwargs: Any) -> None: def _prepare_filename( self, - action_name: Optional[str] = None, + action_name: str | None = None, extension: str = ".txt", split_token: str = "-", # noqa: S107 ) -> str: @@ -130,13 +130,13 @@ def _stats_to_str(self, stats: dict[str, str]) -> str: output.append(value) return os.linesep.join(output) - def setup(self, stage: str, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None: + def setup(self, stage: str, local_rank: int | None = None, log_dir: str | None = None) -> None: """Execute arbitrary pre-profiling set-up steps.""" self._stage = stage self._local_rank = local_rank self.dirpath = self.dirpath or log_dir - def teardown(self, stage: Optional[str]) -> None: + def teardown(self, stage: str | None) -> None: """Execute arbitrary post-profiling tear-down steps. Closes the currently open file and stream. diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index 7e17dcd40398a..8216a84e8bc0a 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -16,10 +16,11 @@ import inspect import logging import os +from collections.abc import Callable from contextlib import AbstractContextManager from functools import lru_cache, partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any import torch from torch import Tensor, nn @@ -41,7 +42,7 @@ log = logging.getLogger(__name__) warning_cache = WarningCache() -_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx] +_PROFILER = torch.profiler.profile | torch.autograd.profiler.profile | torch.autograd.profiler.emit_nvtx _KINETO_AVAILABLE = torch.profiler.kineto_available() @@ -122,9 +123,9 @@ def reset(self) -> None: self._test_step_reached_end = False self._predict_step_reached_end = False # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. - self._current_action: Optional[str] = None - self._prev_schedule_action: Optional[ProfilerAction] = None - self._start_action_name: Optional[str] = None + self._current_action: str | None = None + self._prev_schedule_action: ProfilerAction | None = None + self._start_action_name: str | None = None def setup(self, start_action_name: str) -> None: self._start_action_name = start_action_name @@ -232,15 +233,15 @@ class PyTorchProfiler(Profiler): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + dirpath: str | Path | None = None, + filename: str | None = None, group_by_input_shapes: bool = False, emit_nvtx: bool = False, export_to_chrome: bool = True, row_limit: int = 20, - sort_by_key: Optional[str] = None, + sort_by_key: str | None = None, record_module_names: bool = True, - table_kwargs: Optional[dict[str, Any]] = None, + table_kwargs: dict[str, Any] | None = None, **profiler_kwargs: Any, ) -> None: r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of @@ -303,14 +304,14 @@ def __init__( self._profiler_kwargs = profiler_kwargs self._table_kwargs = table_kwargs if table_kwargs is not None else {} - self.profiler: Optional[_PROFILER] = None - self.function_events: Optional[EventList] = None - self._lightning_module: Optional[LightningModule] = None # set by ProfilerConnector - self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[AbstractContextManager] = None + self.profiler: _PROFILER | None = None + self.function_events: EventList | None = None + self._lightning_module: LightningModule | None = None # set by ProfilerConnector + self._register: RegisterRecordFunction | None = None + self._parent_profiler: AbstractContextManager | None = None self._recording_map: dict[str, record_function] = {} - self._start_action_name: Optional[str] = None - self._schedule: Optional[ScheduleWrapper] = None + self._start_action_name: str | None = None + self._schedule: ScheduleWrapper | None = None if _KINETO_AVAILABLE: self._init_kineto(profiler_kwargs) @@ -359,7 +360,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None: self._profiler_kwargs["with_stack"] = with_stack @property - def _total_steps(self) -> Union[int, float]: + def _total_steps(self) -> int | float: assert self._schedule is not None assert self._lightning_module is not None trainer = self._lightning_module.trainer @@ -396,7 +397,7 @@ def _should_override_schedule(self) -> bool: @staticmethod @lru_cache(1) - def _default_schedule() -> Optional[Callable]: + def _default_schedule() -> Callable | None: if _KINETO_AVAILABLE: # Those schedule defaults allow the profiling overhead to be negligible over training time. return torch.profiler.schedule(wait=1, warmup=1, active=3) @@ -566,7 +567,7 @@ def _delete_profilers(self) -> None: self._register = None @override - def teardown(self, stage: Optional[str]) -> None: + def teardown(self, stage: str | None) -> None: self._delete_profilers() for k in list(self._recording_map): diff --git a/src/lightning/pytorch/profilers/simple.py b/src/lightning/pytorch/profilers/simple.py index 8a53965e3f487..7697bbb515e53 100644 --- a/src/lightning/pytorch/profilers/simple.py +++ b/src/lightning/pytorch/profilers/simple.py @@ -18,7 +18,6 @@ import time from collections import defaultdict from pathlib import Path -from typing import Optional, Union import torch from typing_extensions import override @@ -39,8 +38,8 @@ class SimpleProfiler(Profiler): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + dirpath: str | Path | None = None, + filename: str | None = None, extended: bool = True, ) -> None: """ diff --git a/src/lightning/pytorch/serve/servable_module.py b/src/lightning/pytorch/serve/servable_module.py index ed7a8a987898b..f08f980253218 100644 --- a/src/lightning/pytorch/serve/servable_module.py +++ b/src/lightning/pytorch/serve/servable_module.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch from torch import Tensor diff --git a/src/lightning/pytorch/serve/servable_module_validator.py b/src/lightning/pytorch/serve/servable_module_validator.py index 4c0e6192abdba..6885f5999df4c 100644 --- a/src/lightning/pytorch/serve/servable_module_validator.py +++ b/src/lightning/pytorch/serve/servable_module_validator.py @@ -2,7 +2,7 @@ import logging import time from multiprocessing import Process -from typing import Any, Literal, Optional +from typing import Any, Literal import requests import torch @@ -42,7 +42,7 @@ class ServableModuleValidator(Callback): def __init__( self, - optimization: Optional[Literal["trace", "script", "onnx", "tensorrt"]] = None, + optimization: Literal["trace", "script", "onnx", "tensorrt"] | None = None, server: Literal["fastapi", "ml_server", "torchserve", "sagemaker"] = "fastapi", host: str = "127.0.0.1", port: int = 8080, @@ -70,7 +70,7 @@ def __init__( self.server = server self.timeout = timeout self.exit_on_failure = exit_on_failure - self.resp: Optional[requests.Response] = None + self.resp: requests.Response | None = None @override @rank_zero_only @@ -131,7 +131,7 @@ def on_train_start(self, trainer: "pl.Trainer", servable_module: "pl.LightningMo _logger.info(f"Your model is servable and the received payload was {self.resp.json()}.") @property - def successful(self) -> Optional[bool]: + def successful(self) -> bool | None: """Returns whether the model was successfully served.""" return self.resp.status_code == 200 if self.resp else None diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 4eca6159ddced..1dfcb6a5d8b61 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections.abc import Callable from contextlib import nullcontext from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch import torch.distributed @@ -71,16 +72,16 @@ class DDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[list[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[Precision] = None, - ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[Callable] = None, - ddp_comm_wrapper: Optional[Callable] = None, - model_averaging_period: Optional[int] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: Precision | None = None, + ddp_comm_state: object | None = None, + ddp_comm_hook: Callable | None = None, + ddp_comm_wrapper: Callable | None = None, + model_averaging_period: int | None = None, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, start_method: Literal["popen", "spawn", "fork", "forkserver"] = "popen", **kwargs: Any, ) -> None: @@ -99,9 +100,9 @@ def __init__( self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper self._model_averaging_period = model_averaging_period - self._model_averager: Optional[ModelAverager] = None - self._process_group_backend: Optional[str] = process_group_backend - self._timeout: Optional[timedelta] = timeout + self._model_averager: ModelAverager | None = None + self._process_group_backend: str | None = process_group_backend + self._timeout: timedelta | None = timeout self._start_method = start_method self._pl_static_graph_delay_done = False @@ -138,7 +139,7 @@ def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property - def process_group_backend(self) -> Optional[str]: + def process_group_backend(self) -> str | None: return self._process_group_backend @override @@ -259,7 +260,7 @@ def optimizer_step( self, optimizer: Optimizer, closure: Callable[[], Any], - model: Optional[Union["pl.LightningModule", Module]] = None, + model: Union["pl.LightningModule", Module] | None = None, **kwargs: Any, ) -> Any: """Performs the actual optimizer step. @@ -287,7 +288,7 @@ def configure_ddp(self) -> None: self.model = self._setup_model(self.model) self._register_ddp_hooks() - def determine_ddp_device_ids(self) -> Optional[list[int]]: + def determine_ddp_device_ids(self) -> list[int] | None: if self.root_device.type == "cpu": return None return [self.root_device.index] @@ -348,9 +349,7 @@ def model_to_device(self) -> None: self.model.to(self.root_device) @override - def reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: + def reduce(self, tensor: Tensor, group: Any | None = None, reduce_op: ReduceOp | str | None = "mean") -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 1fce7b06887cd..a1ea12b19abdf 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -21,7 +21,7 @@ from contextlib import contextmanager from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import torch from torch.nn import Module @@ -82,7 +82,7 @@ def __init__( accelerator: Optional["pl.accelerators.Accelerator"] = None, zero_optimization: bool = True, stage: int = 2, - remote_device: Optional[str] = None, + remote_device: str | None = None, offload_optimizer: bool = False, offload_parameters: bool = False, offload_params_device: str = "cpu", @@ -106,11 +106,11 @@ def __init__( allgather_bucket_size: int = 200_000_000, reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, - logging_batch_size_per_gpu: Union[str, int] = "auto", - config: Optional[Union[_PATH, dict[str, Any]]] = None, + logging_batch_size_per_gpu: str | int = "auto", + config: _PATH | dict[str, Any] | None = None, logging_level: int = logging.WARN, - parallel_devices: Optional[list[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, @@ -121,9 +121,9 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, - precision_plugin: Optional[Precision] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, + precision_plugin: Precision | None = None, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, exclude_frozen_parameters: bool = False, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large @@ -284,7 +284,7 @@ def __init__( precision_plugin=precision_plugin, process_group_backend=process_group_backend, ) - self._timeout: Optional[timedelta] = timeout + self._timeout: timedelta | None = timeout self.config = self._load_config(config) if self.config is None: @@ -433,8 +433,8 @@ def _setup_model_and_optimizers( def _setup_model_and_optimizer( self, model: Module, - optimizer: Optional[Optimizer], - lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, + optimizer: Optimizer | None, + lr_scheduler: LRScheduler | ReduceLROnPlateau | None = None, ) -> tuple["deepspeed.DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. @@ -476,7 +476,7 @@ def init_deepspeed(self) -> None: else: self._initialize_deepspeed_inference(self.model) - def _init_optimizers(self) -> tuple[Optimizer, Optional[LRSchedulerConfig]]: + def _init_optimizers(self) -> tuple[Optimizer, LRSchedulerConfig | None]: assert self.lightning_module is not None optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(lr_schedulers) > 1: @@ -527,7 +527,7 @@ def _initialize_deepspeed_train(self, model: Module) -> None: @contextmanager @override - def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: + def tensor_init_context(self, empty_init: bool | None = None) -> Generator[None, None, None]: if self.zero_stage_3: if empty_init is False: raise NotImplementedError( @@ -632,7 +632,7 @@ def _multi_device(self) -> bool: return self.num_processes > 1 or self.num_nodes > 1 @override - def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Any | None = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -674,7 +674,7 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op ) @override - def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool | None = None) -> dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing @@ -809,7 +809,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: + def _load_config(self, config: _PATH | dict[str, Any] | None) -> dict[str, Any] | None: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] @@ -849,7 +849,7 @@ def _create_default_config( self, zero_optimization: bool, zero_allow_untested_optimizer: bool, - logging_batch_size_per_gpu: Union[str, int], + logging_batch_size_per_gpu: str | int, partition_activations: bool, cpu_checkpointing: bool, contiguous_memory_optimization: bool, diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 9706c8a64e61b..b00e62116aa0e 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -13,18 +13,11 @@ # limitations under the License. import logging import shutil -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, - Optional, - Union, -) +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -82,8 +75,8 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy - _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] - _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] + _POLICY = set[type[Module]] | Callable[[Module, bool, int], bool] | ModuleWrapPolicy + _SHARDING_STRATEGY = ShardingStrategy | Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"] log = logging.getLogger(__name__) @@ -147,20 +140,20 @@ class FSDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[list[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[Precision] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, - cpu_offload: Union[bool, "CPUOffload", None] = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: Precision | None = None, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, + cpu_offload: Union[bool, "CPUOffload"] | None = None, mixed_precision: Optional["MixedPrecision"] = None, auto_wrap_policy: Optional["_POLICY"] = None, - activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None, + activation_checkpointing: type[Module] | list[type[Module]] | None = None, activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "full", - device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, + device_mesh: Union[tuple[int], "DeviceMesh"] | None = None, **kwargs: Any, ) -> None: super().__init__( @@ -172,7 +165,7 @@ def __init__( ) self.num_nodes = 1 self._process_group_backend = process_group_backend - self._timeout: Optional[timedelta] = timeout + self._timeout: timedelta | None = timeout self.cpu_offload = _init_cpu_offload(cpu_offload) self.mixed_precision = mixed_precision self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs) @@ -204,7 +197,7 @@ def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property - def process_group_backend(self) -> Optional[str]: + def process_group_backend(self) -> str | None: return self._process_group_backend @property @@ -227,7 +220,7 @@ def precision_plugin(self) -> FSDPPrecision: @precision_plugin.setter @override - def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: + def precision_plugin(self, precision_plugin: Precision | None) -> None: if precision_plugin is not None and not isinstance(precision_plugin, FSDPPrecision): raise TypeError( f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision_plugin}" @@ -386,7 +379,7 @@ def model_to_device(self) -> None: @contextmanager @override - def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: + def tensor_init_context(self, empty_init: bool | None = None) -> Generator[None, None, None]: # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: # 1) materialize module 2) call `reset_parameters()` 3) shard the module. # These operations are applied to each submodule 'bottom up' in the module hierarchy. @@ -412,7 +405,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: yield @override - def barrier(self, name: Optional[str] = None) -> None: + def barrier(self, name: str | None = None) -> None: if not _distributed_is_initialized(): return if torch.distributed.get_backend() == "nccl": @@ -432,9 +425,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: @override def reduce( self, - tensor: Union[Tensor, Any], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", + tensor: Tensor | Any, + group: Any | None = None, + reduce_op: ReduceOp | str | None = "mean", ) -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. @@ -547,9 +540,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: pass @override - def save_checkpoint( - self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Any | None = None) -> None: if storage_options is not None: raise TypeError( "`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because" @@ -583,7 +574,7 @@ def save_checkpoint( raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") @override - def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool | None = None) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index 3589460574c39..c3a8fae01b025 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -16,15 +16,17 @@ import os import queue import tempfile +from collections.abc import Callable from contextlib import suppress from dataclasses import dataclass -from typing import Any, Callable, Literal, NamedTuple, Optional, Union +from typing import Any, Literal, NamedTuple, Optional import torch import torch.backends.cudnn import torch.multiprocessing as mp from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor +from torch.multiprocessing.queue import SimpleQueue from typing_extensions import override import lightning.pytorch as pl @@ -159,7 +161,7 @@ def _wrapping_function( function: Callable, args: Any, kwargs: Any, - return_queue: Union[mp.SimpleQueue, queue.Queue], + return_queue: SimpleQueue | queue.Queue, global_states: Optional["_GlobalStateSnapshot"] = None, ) -> None: if global_states: @@ -272,8 +274,8 @@ def __getstate__(self) -> dict: class _WorkerOutput(NamedTuple): - best_model_path: Optional[_PATH] - weights_path: Optional[_PATH] + best_model_path: _PATH | None + weights_path: _PATH | None trainer_state: TrainerState trainer_results: Any extra: dict[str, Any] diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py index b7ec294c148d5..5b42a03b2c7f4 100644 --- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py +++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py @@ -14,7 +14,8 @@ import logging import os import subprocess -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -134,7 +135,7 @@ def _call_children_scripts(self) -> None: del env_copy["PL_GLOBAL_SEED"] hydra_in_use = False - cwd: Optional[str] = None + cwd: str | None = None if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 066fecc79f208..6642feb5b2375 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -13,9 +13,11 @@ # limitations under the License. import os import queue -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional import torch.multiprocessing as mp +from torch.multiprocessing.queue import SimpleQueue from typing_extensions import override from lightning.fabric.accelerators.xla import _XLA_AVAILABLE @@ -126,8 +128,8 @@ def _wrapping_function( function: Callable, args: Any, kwargs: Any, - return_queue: Union[mp.SimpleQueue, queue.Queue], - global_states: Optional[_GlobalStateSnapshot] = None, + return_queue: SimpleQueue | queue.Queue, + global_states: _GlobalStateSnapshot | None = None, ) -> None: import torch_xla.core.xla_model as xm diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index f3165a08e6bdd..adf709751376b 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -16,7 +16,7 @@ from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -80,11 +80,11 @@ class ModelParallelStrategy(ParallelStrategy): def __init__( self, - data_parallel_size: Union[Literal["auto"], int] = "auto", - tensor_parallel_size: Union[Literal["auto"], int] = "auto", + data_parallel_size: Literal["auto"] | int = "auto", + tensor_parallel_size: Literal["auto"] | int = "auto", save_distributed_checkpoint: bool = True, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, ) -> None: super().__init__() if not _TORCH_GREATER_EQUAL_2_4: @@ -92,9 +92,9 @@ def __init__( self._data_parallel_size = data_parallel_size self._tensor_parallel_size = tensor_parallel_size self._save_distributed_checkpoint = save_distributed_checkpoint - self._process_group_backend: Optional[str] = process_group_backend - self._timeout: Optional[timedelta] = timeout - self._device_mesh: Optional[DeviceMesh] = None + self._process_group_backend: str | None = process_group_backend + self._timeout: timedelta | None = timeout + self._device_mesh: DeviceMesh | None = None self.num_nodes = 1 @property @@ -121,7 +121,7 @@ def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} @property - def process_group_backend(self) -> Optional[str]: + def process_group_backend(self) -> str | None: return self._process_group_backend @property @@ -203,14 +203,14 @@ def model_to_device(self) -> None: @contextmanager @override - def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: + def tensor_init_context(self, empty_init: bool | None = None) -> Generator[None, None, None]: # Materializaton happens in `setup()` empty_init_context = torch.device("meta") if empty_init else nullcontext() with empty_init_context, self.precision_plugin.tensor_init_context(): yield @override - def barrier(self, name: Optional[str] = None) -> None: + def barrier(self, name: str | None = None) -> None: if not _distributed_is_initialized(): return if torch.distributed.get_backend() == "nccl": @@ -230,9 +230,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: @override def reduce( self, - tensor: Union[Tensor, Any], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", + tensor: Tensor | Any, + group: Any | None = None, + reduce_op: ReduceOp | str | None = "mean", ) -> Tensor: if isinstance(tensor, Tensor): return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) @@ -296,9 +296,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: pass @override - def save_checkpoint( - self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Any | None = None) -> None: if storage_options is not None: raise TypeError( f"`{type(self).__name__}.save_checkpoint(..., storage_options=...)` is not supported because" @@ -329,7 +327,7 @@ def save_checkpoint( return super().save_checkpoint(checkpoint=checkpoint, filepath=path) @override - def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool | None = None) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) state = { diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py index dbd8e2962b230..6e75c3e2a8484 100644 --- a/src/lightning/pytorch/strategies/parallel.py +++ b/src/lightning/pytorch/strategies/parallel.py @@ -34,15 +34,15 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[list[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[Precision] = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: Precision | None = None, ): super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.parallel_devices = parallel_devices - self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment - self._layer_sync: Optional[LayerSync] = None + self.cluster_environment: ClusterEnvironment | None = cluster_environment + self._layer_sync: LayerSync | None = None @property @abstractmethod @@ -72,11 +72,11 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[list[torch.device]]: + def parallel_devices(self) -> list[torch.device] | None: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[list[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: list[torch.device] | None) -> None: self._parallel_devices = parallel_devices @property @@ -87,7 +87,7 @@ def distributed_sampler_kwargs(self) -> dict[str, Any]: } @override - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Perform a all_gather on all processes.""" return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) diff --git a/src/lightning/pytorch/strategies/single_xla.py b/src/lightning/pytorch/strategies/single_xla.py index 2a5e2f3a85b96..ca9b3b6919529 100644 --- a/src/lightning/pytorch/strategies/single_xla.py +++ b/src/lightning/pytorch/strategies/single_xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Optional, Union +from typing import Optional import torch from typing_extensions import override @@ -37,8 +37,8 @@ def __init__( self, device: _DEVICE, accelerator: Optional["pl.accelerators.Accelerator"] = None, - checkpoint_io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]] = None, - precision_plugin: Optional[XLAPrecision] = None, + checkpoint_io: XLACheckpointIO | _WrappingCheckpointIO | None = None, + precision_plugin: XLAPrecision | None = None, debug: bool = False, ): if not _XLA_AVAILABLE: @@ -58,7 +58,7 @@ def __init__( @property @override - def checkpoint_io(self) -> Union[XLACheckpointIO, _WrappingCheckpointIO]: + def checkpoint_io(self) -> XLACheckpointIO | _WrappingCheckpointIO: plugin = self._checkpoint_io if plugin is not None: assert isinstance(plugin, (XLACheckpointIO, _WrappingCheckpointIO)) @@ -67,7 +67,7 @@ def checkpoint_io(self) -> Union[XLACheckpointIO, _WrappingCheckpointIO]: @checkpoint_io.setter @override - def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + def checkpoint_io(self, io: CheckpointIO | None) -> None: if io is not None and not isinstance(io, (XLACheckpointIO, _WrappingCheckpointIO)): raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}") self._checkpoint_io = io @@ -83,7 +83,7 @@ def precision_plugin(self) -> XLAPrecision: @precision_plugin.setter @override - def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: + def precision_plugin(self, precision_plugin: Precision | None) -> None: if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecision): raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision_plugin}") self._precision_plugin = precision_plugin diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 0a00cb28af15e..983d2126c3527 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -13,9 +13,9 @@ # limitations under the License. import logging from abc import ABC, abstractmethod -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import torch from torch import Tensor @@ -50,24 +50,24 @@ class Strategy(ABC): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[Precision] = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: Precision | None = None, ) -> None: - self._accelerator: Optional[pl.accelerators.Accelerator] = accelerator - self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io - self._precision_plugin: Optional[Precision] = None + self._accelerator: pl.accelerators.Accelerator | None = accelerator + self._checkpoint_io: CheckpointIO | None = checkpoint_io + self._precision_plugin: Precision | None = None # Call the precision setter for input validation self.precision_plugin = precision_plugin - self._lightning_module: Optional[pl.LightningModule] = None - self._model: Optional[Module] = None - self._launcher: Optional[_Launcher] = None + self._lightning_module: pl.LightningModule | None = None + self._model: Module | None = None + self._launcher: _Launcher | None = None self._forward_redirection: _ForwardRedirection = _ForwardRedirection() self._optimizers: list[Optimizer] = [] self._lightning_optimizers: list[LightningOptimizer] = [] self.lr_scheduler_configs: list[LRSchedulerConfig] = [] @property - def launcher(self) -> Optional[_Launcher]: + def launcher(self) -> _Launcher | None: return self._launcher @property @@ -96,7 +96,7 @@ def precision_plugin(self) -> Precision: return self._precision_plugin if self._precision_plugin is not None else Precision() @precision_plugin.setter - def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: + def precision_plugin(self, precision_plugin: Precision | None) -> None: self._precision_plugin = precision_plugin @property @@ -192,7 +192,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: def backward( self, closure_loss: Tensor, - optimizer: Optional[Optimizer], + optimizer: Optimizer | None, *args: Any, **kwargs: Any, ) -> Tensor: @@ -221,7 +221,7 @@ def optimizer_step( self, optimizer: Optimizer, closure: Callable[[], Any], - model: Optional[Union["pl.LightningModule", Module]] = None, + model: Union["pl.LightningModule", Module] | None = None, **kwargs: Any, ) -> Any: r"""Performs the actual optimizer step. @@ -260,7 +260,7 @@ def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: # TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324 return optimizer - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + def batch_to_device(self, batch: Any, device: torch.device | None = None, dataloader_idx: int = 0) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just @@ -295,10 +295,10 @@ def is_global_zero(self) -> bool: @abstractmethod def reduce( self, - tensor: Union[Tensor, Any], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ) -> Union[Tensor, Any]: + tensor: Tensor | Any, + group: Any | None = None, + reduce_op: ReduceOp | str | None = "mean", + ) -> Tensor | Any: """Reduces the given tensor (e.g. across GPUs/processes). Args: @@ -310,7 +310,7 @@ def reduce( """ @abstractmethod - def barrier(self, name: Optional[str] = None) -> None: + def barrier(self, name: str | None = None) -> None: """Synchronizes all processes which blocks processes until the whole group enters this function. Args: @@ -329,7 +329,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: """ @abstractmethod - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Perform an all_gather on all processes. Args: @@ -350,12 +350,12 @@ def post_backward(self, closure_loss: Tensor) -> None: """Run after precision plugin executes backward.""" @property - def model(self) -> Optional[Module]: + def model(self) -> Module | None: """Returns the potentially wrapped LightningModule.""" return self._model if self._model is not None else self._lightning_module @model.setter - def model(self, new_model: Optional[Module]) -> None: + def model(self, new_model: Module | None) -> None: self._model = new_model @property @@ -363,7 +363,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool | None = None) -> dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only) @@ -476,9 +476,7 @@ def lightning_module_state_dict(self) -> dict[str, Any]: assert self.lightning_module is not None return self.lightning_module.state_dict() - def save_checkpoint( - self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Any | None = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -501,7 +499,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None: self.checkpoint_io.remove_checkpoint(filepath) @contextmanager - def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: + def tensor_init_context(self, empty_init: bool | None = None) -> Generator[None, None, None]: """Controls how tensors get created (device, dtype). Args: diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index cbdc890a1ca32..66f1f008a40ba 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import torch from torch import Tensor @@ -49,9 +49,9 @@ class XLAStrategy(DDPStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[list[torch.device]] = None, - checkpoint_io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]] = None, - precision_plugin: Optional[XLAPrecision] = None, + parallel_devices: list[torch.device] | None = None, + checkpoint_io: XLACheckpointIO | _WrappingCheckpointIO | None = None, + precision_plugin: XLAPrecision | None = None, debug: bool = False, sync_module_states: bool = True, **_: Any, @@ -72,7 +72,7 @@ def __init__( @property @override - def checkpoint_io(self) -> Union[XLACheckpointIO, _WrappingCheckpointIO]: + def checkpoint_io(self) -> XLACheckpointIO | _WrappingCheckpointIO: plugin = self._checkpoint_io if plugin is not None: assert isinstance(plugin, (XLACheckpointIO, _WrappingCheckpointIO)) @@ -81,7 +81,7 @@ def checkpoint_io(self) -> Union[XLACheckpointIO, _WrappingCheckpointIO]: @checkpoint_io.setter @override - def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + def checkpoint_io(self, io: CheckpointIO | None) -> None: if io is not None and not isinstance(io, (XLACheckpointIO, _WrappingCheckpointIO)): raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}") self._checkpoint_io = io @@ -97,7 +97,7 @@ def precision_plugin(self) -> XLAPrecision: @precision_plugin.setter @override - def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: + def precision_plugin(self, precision_plugin: Precision | None) -> None: if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecision): raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision_plugin}") self._precision_plugin = precision_plugin @@ -199,7 +199,7 @@ def model_to_device(self) -> None: self.model = self.model.to(self.root_device) @override - def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + def barrier(self, name: str | None = None, *args: Any, **kwargs: Any) -> None: if not self._launched: return @@ -248,9 +248,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: @override def reduce( self, - output: Union[Tensor, Any], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", + output: Tensor | Any, + group: Any | None = None, + reduce_op: ReduceOp | str | None = "mean", ) -> Tensor: if not isinstance(output, Tensor): output = torch.tensor(output, device=self.root_device) @@ -297,9 +297,7 @@ def set_world_ranks(self) -> None: pass @override - def save_checkpoint( - self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Any | None = None) -> None: import torch_xla.core.xla_model as xm # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs @@ -319,7 +317,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None: self.checkpoint_io.remove_checkpoint(filepath) @override - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. Args: diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 77536cdc16b33..88f77e7ebe672 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -14,8 +14,9 @@ import logging import signal import sys +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Optional, Union +from typing import Any, Optional from packaging.version import Version @@ -204,7 +205,7 @@ def _call_callback_hooks( trainer: "pl.Trainer", hook_name: str, *args: Any, - monitoring_callbacks: Optional[bool] = None, + monitoring_callbacks: bool | None = None, **kwargs: Any, ) -> None: log.debug(f"{trainer.__class__.__name__}: calling callback hook: {hook_name}") @@ -271,7 +272,7 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[s prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = "on_load_checkpoint" - callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") + callback_states: dict[type | str, dict] | None = checkpoint.get("callbacks") if callback_states is None: return @@ -297,7 +298,7 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[s def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" - callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") + callback_states: dict[type | str, dict] | None = checkpoint.get("callbacks") if callback_states is None: return diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index d51b11ea6fb12..43ed1397f0a56 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -16,7 +16,7 @@ import os from collections import Counter from collections.abc import Iterable -from typing import Literal, Optional, Union +from typing import Literal import torch @@ -74,16 +74,16 @@ class _AcceleratorConnector: def __init__( self, - devices: Union[list[int], str, int] = "auto", + devices: list[int] | str | int = "auto", num_nodes: int = 1, - accelerator: Union[str, Accelerator] = "auto", - strategy: Union[str, Strategy] = "auto", - plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None, - precision: Optional[_PRECISION_INPUT] = None, + accelerator: str | Accelerator = "auto", + strategy: str | Strategy = "auto", + plugins: _PLUGIN_INPUT | Iterable[_PLUGIN_INPUT] | None = None, + precision: _PRECISION_INPUT | None = None, sync_batchnorm: bool = False, - benchmark: Optional[bool] = None, + benchmark: bool | None = None, use_distributed_sampler: bool = True, - deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, + deterministic: bool | _LITERAL_WARN | None = None, ) -> None: """The AcceleratorConnector parses several Trainer arguments and instantiates the Strategy including other components such as the Accelerator and Precision plugins. @@ -117,14 +117,14 @@ def __init__( # Raise an exception if there are conflicts between flags # Set each valid flag to `self._x_flag` after validation - self._strategy_flag: Union[Strategy, str] = "auto" - self._accelerator_flag: Union[Accelerator, str] = "auto" + self._strategy_flag: Strategy | str = "auto" + self._accelerator_flag: Accelerator | str = "auto" self._precision_flag: _PRECISION_INPUT_STR = "32-true" - self._precision_plugin_flag: Optional[Precision] = None - self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: list[Union[int, torch.device, str]] = [] - self._layer_sync: Optional[LayerSync] = TorchSyncBatchNorm() if sync_batchnorm else None - self.checkpoint_io: Optional[CheckpointIO] = None + self._precision_plugin_flag: Precision | None = None + self._cluster_environment_flag: ClusterEnvironment | str | None = None + self._parallel_devices: list[int | torch.device | str] = [] + self._layer_sync: LayerSync | None = TorchSyncBatchNorm() if sync_batchnorm else None + self.checkpoint_io: CheckpointIO | None = None self._check_config_and_set_final_flags( strategy=strategy, @@ -162,10 +162,10 @@ def __init__( def _check_config_and_set_final_flags( self, - strategy: Union[str, Strategy], - accelerator: Union[str, Accelerator], - precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]], + strategy: str | Strategy, + accelerator: str | Accelerator, + precision: _PRECISION_INPUT | None, + plugins: _PLUGIN_INPUT | Iterable[_PLUGIN_INPUT] | None, sync_batchnorm: bool, ) -> None: """This method checks: @@ -309,7 +309,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: + def _check_device_config_and_set_final_flags(self, devices: list[int] | str | int, num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") @@ -397,7 +397,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: return env_type() return LightningEnvironment() - def _choose_strategy(self) -> Union[Strategy, str]: + def _choose_strategy(self) -> Strategy | str: if self._accelerator_flag == "hpu": raise MisconfigurationException("HPU is currently not supported. Please contact developer@lightning.ai") @@ -568,9 +568,7 @@ def is_distributed(self) -> bool: return False -def _set_torch_flags( - *, deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, benchmark: Optional[bool] = None -) -> None: +def _set_torch_flags(*, deterministic: bool | _LITERAL_WARN | None = None, benchmark: bool | None = None) -> None: if deterministic: if benchmark is None: # Set benchmark to False to ensure determinism diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 62dd49c26cc71..526daf1ae2c43 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -16,7 +16,6 @@ import os from collections.abc import Sequence from datetime import timedelta -from typing import Optional, Union from lightning_utilities.core.imports import RequirementCache @@ -50,12 +49,12 @@ def __init__(self, trainer: "pl.Trainer"): def on_trainer_init( self, - callbacks: Optional[Union[list[Callback], Callback]], + callbacks: list[Callback] | Callback | None, enable_checkpointing: bool, enable_progress_bar: bool, - default_root_dir: Optional[str], + default_root_dir: str | None, enable_model_summary: bool, - max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, + max_time: str | timedelta | dict[str, int] | None = None, ) -> None: # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() @@ -155,7 +154,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar() self.trainer.callbacks.append(progress_bar_callback) - def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None: + def _configure_timer_callback(self, max_time: str | timedelta | dict[str, int] | None = None) -> None: if max_time is None: return if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index ae5038b2022d2..23011fc79417c 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -14,7 +14,7 @@ import logging import os import re -from typing import Any, Optional +from typing import Any import torch from fsspec.core import url_to_fs @@ -46,13 +46,13 @@ class _CheckpointConnector: def __init__(self, trainer: "pl.Trainer") -> None: self.trainer = trainer - self._ckpt_path: Optional[_PATH] = None + self._ckpt_path: _PATH | None = None # flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter` self._user_managed: bool = False self._loaded_checkpoint: dict[str, Any] = {} @property - def _hpc_resume_path(self) -> Optional[str]: + def _hpc_resume_path(self) -> str | None: dir_path_hpc = str(self.trainer.default_root_dir) fs, path = url_to_fs(dir_path_hpc) if not _is_dir(fs, path): @@ -64,7 +64,7 @@ def _hpc_resume_path(self) -> Optional[str]: return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt" return None - def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None) -> None: + def resume_start(self, checkpoint_path: _PATH | None = None, weights_only: bool | None = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`. @@ -84,8 +84,8 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: Op self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) def _select_ckpt_path( - self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool - ) -> Optional[_PATH]: + self, state_fn: TrainerFn, ckpt_path: _PATH | None, model_provided: bool, model_connected: bool + ) -> _PATH | None: """Called by the ``Trainer`` to select the checkpoint path source.""" if self._user_managed: if ckpt_path: @@ -114,8 +114,8 @@ def _select_ckpt_path( return ckpt_path def _parse_ckpt_path( - self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool - ) -> Optional[_PATH]: + self, state_fn: TrainerFn, ckpt_path: _PATH | None, model_provided: bool, model_connected: bool + ) -> _PATH | None: """Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer configuration.""" if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None: @@ -230,7 +230,7 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.strategy.barrier("_CheckpointConnector.resume_end") - def restore(self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None) -> None: + def restore(self, checkpoint_path: _PATH | None = None, weights_only: bool | None = None) -> None: """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore, in this priority: @@ -404,7 +404,7 @@ def restore_lr_schedulers(self) -> None: config.scheduler.load_state_dict(lrs_state) def _restore_modules_and_callbacks( - self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None + self, checkpoint_path: _PATH | None = None, weights_only: bool | None = None ) -> None: # restore modules after setup self.resume_start(checkpoint_path, weights_only=weights_only) @@ -412,7 +412,7 @@ def _restore_modules_and_callbacks( self.restore_datamodule() self.restore_callbacks() - def dump_checkpoint(self, weights_only: Optional[bool] = None) -> dict: + def dump_checkpoint(self, weights_only: bool | None = None) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: @@ -522,7 +522,7 @@ def _get_loops_state_dict(self) -> dict[str, Any]: } @staticmethod - def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]: + def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> int | None: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. Args: diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 240dae6296c1f..9a2056b6cf2ce 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -47,13 +47,13 @@ class _DataConnector: def __init__(self, trainer: "pl.Trainer"): self.trainer = trainer - self._datahook_selector: Optional[_DataHookSelector] = None + self._datahook_selector: _DataHookSelector | None = None def on_trainer_init( self, - val_check_interval: Optional[Union[int, float, str, timedelta, dict]], + val_check_interval: int | float | str | timedelta | dict | None, reload_dataloaders_every_n_epochs: int, - check_val_every_n_epoch: Optional[int], + check_val_every_n_epoch: int | None, ) -> None: self.trainer.datamodule = None @@ -104,10 +104,10 @@ def prepare_data(self) -> None: def attach_data( self, model: "pl.LightningModule", - train_dataloaders: Optional[TRAIN_DATALOADERS] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - test_dataloaders: Optional[EVAL_DATALOADERS] = None, - predict_dataloaders: Optional[EVAL_DATALOADERS] = None, + train_dataloaders: TRAIN_DATALOADERS | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + test_dataloaders: EVAL_DATALOADERS | None = None, + predict_dataloaders: EVAL_DATALOADERS | None = None, datamodule: Optional["pl.LightningDataModule"] = None, ) -> None: # set up the passed in dataloaders (if needed) @@ -126,10 +126,10 @@ def attach_data( def attach_dataloaders( self, model: "pl.LightningModule", - train_dataloaders: Optional[TRAIN_DATALOADERS] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - test_dataloaders: Optional[EVAL_DATALOADERS] = None, - predict_dataloaders: Optional[EVAL_DATALOADERS] = None, + train_dataloaders: TRAIN_DATALOADERS | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + test_dataloaders: EVAL_DATALOADERS | None = None, + predict_dataloaders: EVAL_DATALOADERS | None = None, ) -> None: trainer = self.trainer @@ -194,8 +194,8 @@ def _prepare_dataloader(self, dataloader: object, shuffle: bool, mode: RunningSt return dataloader def _resolve_sampler( - self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None - ) -> Union[Sampler, Iterable]: + self, dataloader: DataLoader, shuffle: bool, mode: RunningStage | None = None + ) -> Sampler | Iterable: if self._requires_distributed_sampler(dataloader): distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs assert distributed_sampler_kwargs is not None @@ -230,8 +230,8 @@ def _resolve_sampler( def _get_distributed_sampler( dataloader: DataLoader, shuffle: bool, - overfit_batches: Union[int, float], - mode: Optional[RunningStage] = None, + overfit_batches: int | float, + mode: RunningStage | None = None, **kwargs: Any, ) -> DistributedSampler: """This function is used to created the distributed sampler injected within the user DataLoader.""" @@ -286,10 +286,10 @@ class _DataLoaderSource: """ - instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] + instance: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"] | None name: str - def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: + def dataloader(self) -> TRAIN_DATALOADERS | EVAL_DATALOADERS: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. @@ -320,7 +320,7 @@ def is_module(self) -> bool: return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule)) -def _request_dataloader(data_source: _DataLoaderSource) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: +def _request_dataloader(data_source: _DataLoaderSource) -> TRAIN_DATALOADERS | EVAL_DATALOADERS: """Requests a dataloader by calling dataloader hooks corresponding to the given stage. Returns: @@ -446,9 +446,7 @@ def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: ) -def _parse_num_batches( - stage: RunningStage, length: Union[int, float], limit_batches: Union[int, float] -) -> Union[int, float]: +def _parse_num_batches(stage: RunningStage, length: int | float, limit_batches: int | float) -> int | float: if length == 0: return int(length) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index c1ee0013bfa19..faf71038c94cf 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union from typing_extensions import TypedDict @@ -20,8 +19,8 @@ class _FxValidator: class _LogOptions(TypedDict): - allowed_on_step: Union[tuple[bool], tuple[bool, bool]] - allowed_on_epoch: Union[tuple[bool], tuple[bool, bool]] + allowed_on_step: tuple[bool] | tuple[bool, bool] + allowed_on_epoch: tuple[bool] | tuple[bool, bool] default_on_step: bool default_on_epoch: bool @@ -164,9 +163,7 @@ def check_logging(cls, fx_name: str) -> None: ) @classmethod - def get_default_logging_levels( - cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> tuple[bool, bool]: + def get_default_logging_levels(cls, fx_name: str, on_step: bool | None, on_epoch: bool | None) -> tuple[bool, bool]: """Return default logging levels for given hook.""" fx_config = cls.functions[fx_name] assert fx_config is not None @@ -190,7 +187,7 @@ def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> No @classmethod def check_logging_and_get_default_levels( - cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] + cls, fx_name: str, on_step: bool | None, on_epoch: bool | None ) -> tuple[bool, bool]: """Check if the given hook name is allowed to log and return logging levels.""" cls.check_logging(fx_name) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index f6e8885ee050a..98a90c08f7d83 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor @@ -35,13 +35,13 @@ def __init__(self, trainer: "pl.Trainer") -> None: self._progress_bar_metrics: _PBAR_DICT = {} self._logged_metrics: _OUT_DICT = {} self._callback_metrics: _OUT_DICT = {} - self._current_fx: Optional[str] = None + self._current_fx: str | None = None # None: hasn't started, True: first loop iteration, False: subsequent iterations - self._first_loop_iter: Optional[bool] = None + self._first_loop_iter: bool | None = None def on_trainer_init( self, - logger: Union[bool, Logger, Iterable[Logger]], + logger: bool | Logger | Iterable[Logger], log_every_n_steps: int, ) -> None: self.configure_logger(logger) @@ -64,7 +64,7 @@ def should_update_logs(self) -> bool: should_log = step % trainer.log_every_n_steps == 0 return should_log or trainer.should_stop - def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> None: + def configure_logger(self, logger: bool | Logger | Iterable[Logger]) -> None: if not logger: # logger is None or logger is False self.trainer.loggers = [] @@ -87,7 +87,7 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non else: self.trainer.loggers = [logger] - def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: + def log_metrics(self, metrics: _OUT_DICT, step: int | None = None) -> None: """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step. @@ -177,7 +177,7 @@ def update_train_epoch_metrics(self) -> None: Utilities and properties """ - def on_batch_start(self, batch: Any, dataloader_idx: Optional[int] = None) -> None: + def on_batch_start(self, batch: Any, dataloader_idx: int | None = None) -> None: if self._first_loop_iter is None: self._first_loop_iter = True elif self._first_loop_iter is True: diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 7c9364b5ddfe1..0ea5953ed40f1 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator +from collections.abc import Callable, Generator from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, Optional, Union, cast +from typing import Any, cast import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -32,7 +32,7 @@ from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.warnings import PossibleUserWarning -_VALUE = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors +_VALUE = Metric | Tensor # Do not include scalars as they were converted to tensors _OUT_DICT = dict[str, Tensor] _PBAR_DICT = dict[str, float] @@ -48,11 +48,11 @@ class _METRICS(TypedDict): @dataclass class _Sync: - fn: Optional[Callable] = None + fn: Callable | None = None _should: bool = False rank_zero_only: bool = False - _op: Optional[str] = None - _group: Optional[Any] = None + _op: str | None = None + _group: Any | None = None def __post_init__(self) -> None: self._generate_sync_fn() @@ -68,21 +68,21 @@ def should(self, should: bool) -> None: self._generate_sync_fn() @property - def op(self) -> Optional[str]: + def op(self) -> str | None: return self._op @op.setter - def op(self, op: Optional[str]) -> None: + def op(self, op: str | None) -> None: self._op = op # `self._fn` needs to be re-generated. self._generate_sync_fn() @property - def group(self) -> Optional[Any]: + def group(self) -> Any | None: return self._group @group.setter - def group(self, group: Optional[Any]) -> None: + def group(self, group: Any | None) -> None: self._group = group # `self._fn` needs to be re-generated. self._generate_sync_fn() @@ -114,9 +114,9 @@ class _Metadata: reduce_fx: Callable = torch.mean enable_graph: bool = False add_dataloader_idx: bool = True - dataloader_idx: Optional[int] = None - metric_attribute: Optional[str] = None - _sync: Optional[_Sync] = None + dataloader_idx: int | None = None + metric_attribute: str | None = None + _sync: _Sync | None = None def __post_init__(self) -> None: if not self.on_step and not self.on_epoch: @@ -201,7 +201,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: self.cumulated_batch_size: Tensor self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) # this is defined here only because upstream is missing the type annotation - self._forward_cache: Optional[Any] = None + self._forward_cache: Any | None = None @override def update(self, value: _VALUE, batch_size: int) -> None: @@ -273,7 +273,7 @@ def forward(self, value: _VALUE, batch_size: int) -> None: def _wrap_compute(self, compute: Any) -> Any: # Override to avoid syncing - we handle it ourselves. @wraps(compute) - def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: + def wrapped_func(*args: Any, **kwargs: Any) -> Any | None: update_called = self.update_called if _TORCHMETRICS_GREATER_EQUAL_1_0_0 else self._update_called if not update_called: rank_zero_warn( @@ -328,15 +328,15 @@ class _ResultCollection(dict): def __init__(self, training: bool) -> None: super().__init__() self.training = training - self.batch: Optional[Any] = None - self.batch_size: Optional[int] = None - self.dataloader_idx: Optional[int] = None + self.batch: Any | None = None + self.batch_size: int | None = None + self.dataloader_idx: int | None = None @property def result_metrics(self) -> list[_ResultMetric]: return list(self.values()) - def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], meta: _Metadata) -> int: + def _extract_batch_size(self, value: _ResultMetric, batch_size: int | None, meta: _Metadata) -> int: # check if we have extracted the batch size already if batch_size is None: batch_size = self.batch_size @@ -366,10 +366,10 @@ def log( enable_graph: bool = False, sync_dist: bool = False, sync_dist_fn: Callable = _Sync.no_op, - sync_dist_group: Optional[Any] = None, + sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, - batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, + batch_size: int | None = None, + metric_attribute: str | None = None, rank_zero_only: bool = False, ) -> None: """See :meth:`~lightning.pytorch.core.LightningModule.log`""" @@ -422,7 +422,7 @@ def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: result_metric.has_reset = False @staticmethod - def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: + def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Tensor | None: cache = None if on_step and result_metric.meta.on_step: cache = result_metric._forward_cache @@ -493,7 +493,7 @@ def metrics(self, on_step: bool) -> _METRICS: return metrics - def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: + def reset(self, metrics: bool | None = None, fx: str | None = None) -> None: """Reset the result collection. Args: diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index ece7e902c5f5f..8d7dc88cd7cce 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -3,9 +3,10 @@ import re import signal import threading +from collections.abc import Callable from subprocess import call from types import FrameType -from typing import Any, Callable, Union +from typing import Any import torch import torch.distributed as dist @@ -16,14 +17,14 @@ from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_info # copied from signal.pyi -_SIGNUM = Union[int, signal.Signals] -_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None] +_SIGNUM = int | signal.Signals +_HANDLER = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None log = logging.getLogger(__name__) class _HandlersCompose: - def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None: + def __init__(self, signal_handlers: _HANDLER | list[_HANDLER]) -> None: if not isinstance(signal_handlers, list): signal_handlers = [signal_handlers] self.signal_handlers = signal_handlers diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 4f522c7c008bc..6613f04181b15 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -14,7 +14,6 @@ """Houses the methods used to set up the Trainer.""" from datetime import timedelta -from typing import Optional, Union import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -34,13 +33,13 @@ def _init_debugging_flags( trainer: "pl.Trainer", - limit_train_batches: Optional[Union[int, float]], - limit_val_batches: Optional[Union[int, float]], - limit_test_batches: Optional[Union[int, float]], - limit_predict_batches: Optional[Union[int, float]], - fast_dev_run: Union[int, bool], - overfit_batches: Union[int, float], - val_check_interval: Optional[Union[int, float, str, timedelta, dict]], + limit_train_batches: int | float | None, + limit_val_batches: int | float | None, + limit_test_batches: int | float | None, + limit_predict_batches: int | float | None, + fast_dev_run: int | bool, + overfit_batches: int | float, + val_check_interval: int | float | str | timedelta | dict | None, num_sanity_val_steps: int, ) -> None: # init debugging flags @@ -97,7 +96,7 @@ def _init_debugging_flags( trainer.limit_val_batches = overfit_batches -def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: +def _determine_batch_limits(batches: int | float | None, name: str) -> int | float: if batches is None: # batches is optional to know if the user passed a value so that we can show the above info messages only to the # users that set a value explicitly @@ -130,7 +129,7 @@ def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> ) -def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str]]) -> None: +def _init_profiler(trainer: "pl.Trainer", profiler: Profiler | str | None) -> None: if isinstance(profiler, str): PROFILERS = { "simple": SimpleProfiler, @@ -181,7 +180,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None: rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.") -def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float: +def _parse_time_interval_seconds(value: str | timedelta | dict) -> float: """Convert a time interval into seconds. This helper parses different representations of a time interval and diff --git a/src/lightning/pytorch/trainer/states.py b/src/lightning/pytorch/trainer/states.py index 36b1c4099b17c..d1c5aa0c36880 100644 --- a/src/lightning/pytorch/trainer/states.py +++ b/src/lightning/pytorch/trainer/states.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional from lightning.pytorch.utilities.enums import LightningEnum @@ -65,7 +64,7 @@ def evaluating(self) -> bool: return self in (self.VALIDATING, self.TESTING, self.SANITY_CHECKING) @property - def dataloader_prefix(self) -> Optional[str]: + def dataloader_prefix(self) -> str | None: if self in (self.VALIDATING, self.SANITY_CHECKING): return "val" return self.value @@ -76,8 +75,8 @@ class TrainerState: """Dataclass to encapsulate the current :class:`~lightning.pytorch.trainer.trainer.Trainer` state.""" status: TrainerStatus = TrainerStatus.INITIALIZING - fn: Optional[TrainerFn] = None - stage: Optional[RunningStage] = None + fn: TrainerFn | None = None + stage: RunningStage | None = None @property def finished(self) -> bool: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index f2f59e396ab23..e0b825856b0f5 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -26,7 +26,7 @@ from collections.abc import Generator, Iterable from contextlib import contextmanager from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Optional from weakref import proxy import torch @@ -91,47 +91,47 @@ class Trainer: def __init__( self, *, - accelerator: Union[str, Accelerator] = "auto", - strategy: Union[str, Strategy] = "auto", - devices: Union[list[int], str, int] = "auto", + accelerator: str | Accelerator = "auto", + strategy: str | Strategy = "auto", + devices: list[int] | str | int = "auto", num_nodes: int = 1, - precision: Optional[_PRECISION_INPUT] = None, - logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, - callbacks: Optional[Union[list[Callback], Callback]] = None, - fast_dev_run: Union[int, bool] = False, - max_epochs: Optional[int] = None, - min_epochs: Optional[int] = None, + precision: _PRECISION_INPUT | None = None, + logger: Logger | Iterable[Logger] | bool | None = None, + callbacks: list[Callback] | Callback | None = None, + fast_dev_run: int | bool = False, + max_epochs: int | None = None, + min_epochs: int | None = None, max_steps: int = -1, - min_steps: Optional[int] = None, - max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, - limit_train_batches: Optional[Union[int, float]] = None, - limit_val_batches: Optional[Union[int, float]] = None, - limit_test_batches: Optional[Union[int, float]] = None, - limit_predict_batches: Optional[Union[int, float]] = None, - overfit_batches: Union[int, float] = 0.0, - val_check_interval: Optional[Union[int, float, str, timedelta, dict[str, int]]] = None, - check_val_every_n_epoch: Optional[int] = 1, - num_sanity_val_steps: Optional[int] = None, - log_every_n_steps: Optional[int] = None, - enable_checkpointing: Optional[bool] = None, - enable_progress_bar: Optional[bool] = None, - enable_model_summary: Optional[bool] = None, + min_steps: int | None = None, + max_time: str | timedelta | dict[str, int] | None = None, + limit_train_batches: int | float | None = None, + limit_val_batches: int | float | None = None, + limit_test_batches: int | float | None = None, + limit_predict_batches: int | float | None = None, + overfit_batches: int | float = 0.0, + val_check_interval: int | float | str | timedelta | dict[str, int] | None = None, + check_val_every_n_epoch: int | None = 1, + num_sanity_val_steps: int | None = None, + log_every_n_steps: int | None = None, + enable_checkpointing: bool | None = None, + enable_progress_bar: bool | None = None, + enable_model_summary: bool | None = None, accumulate_grad_batches: int = 1, - gradient_clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[str] = None, - deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, - benchmark: Optional[bool] = None, + gradient_clip_val: int | float | None = None, + gradient_clip_algorithm: str | None = None, + deterministic: bool | _LITERAL_WARN | None = None, + benchmark: bool | None = None, inference_mode: bool = True, use_distributed_sampler: bool = True, - profiler: Optional[Union[Profiler, str]] = None, + profiler: Profiler | str | None = None, detect_anomaly: bool = False, barebones: bool = False, - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, + plugins: _PLUGIN_INPUT | list[_PLUGIN_INPUT] | None = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, - default_root_dir: Optional[_PATH] = None, + default_root_dir: _PATH | None = None, enable_autolog_hparams: bool = True, - model_registry: Optional[str] = None, + model_registry: str | None = None, ) -> None: r"""Customize every aspect of training via flags. @@ -454,7 +454,7 @@ def __init__( ) # init data flags - self.check_val_every_n_epoch: Optional[int] + self.check_val_every_n_epoch: int | None self._data_connector.on_trainer_init( val_check_interval, reload_dataloaders_every_n_epochs, @@ -473,8 +473,8 @@ def __init__( f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}." ) - self.gradient_clip_val: Optional[Union[int, float]] = gradient_clip_val - self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = ( + self.gradient_clip_val: int | float | None = gradient_clip_val + self.gradient_clip_algorithm: GradClipAlgorithmType | None = ( GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None ) @@ -498,13 +498,13 @@ def __init__( self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags - self.val_check_batch: Optional[Union[int, float]] = None - self.val_check_interval: Union[int, float] - self.num_sanity_val_steps: Union[int, float] - self.limit_train_batches: Union[int, float] - self.limit_val_batches: Union[int, float] - self.limit_test_batches: Union[int, float] - self.limit_predict_batches: Union[int, float] + self.val_check_batch: int | float | None = None + self.val_check_interval: int | float + self.num_sanity_val_steps: int | float + self.limit_train_batches: int | float + self.limit_val_batches: int | float + self.limit_test_batches: int | float + self.limit_predict_batches: int | float setup._init_debugging_flags( self, limit_train_batches, @@ -522,11 +522,11 @@ def __init__( def fit( self, model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, + train_dataloaders: TRAIN_DATALOADERS | LightningDataModule | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: LightningDataModule | None = None, + ckpt_path: _PATH | None = None, + weights_only: bool | None = None, ) -> None: r"""Runs the full optimization routine. @@ -593,11 +593,11 @@ def fit( def _fit_impl( self, model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, + train_dataloaders: TRAIN_DATALOADERS | LightningDataModule | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: LightningDataModule | None = None, + ckpt_path: _PATH | None = None, + weights_only: bool | None = None, ) -> None: log.debug(f"{self.__class__.__name__}: trainer fit stage") @@ -634,11 +634,11 @@ def _fit_impl( def validate( self, model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[_PATH] = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + ckpt_path: _PATH | None = None, verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - weights_only: Optional[bool] = None, + datamodule: LightningDataModule | None = None, + weights_only: bool | None = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the validation set. @@ -704,12 +704,12 @@ def validate( def _validate_impl( self, model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[_PATH] = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + ckpt_path: _PATH | None = None, verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - weights_only: Optional[bool] = None, - ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: + datamodule: LightningDataModule | None = None, + weights_only: bool | None = None, + ) -> _PREDICT_OUTPUT | _EVALUATE_OUTPUT | None: # -------------------- # SETUP HOOK # -------------------- @@ -752,11 +752,11 @@ def _validate_impl( def test( self, model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[_PATH] = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + ckpt_path: _PATH | None = None, verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - weights_only: Optional[bool] = None, + datamodule: LightningDataModule | None = None, + weights_only: bool | None = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. @@ -823,12 +823,12 @@ def test( def _test_impl( self, model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[_PATH] = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + ckpt_path: _PATH | None = None, verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - weights_only: Optional[bool] = None, - ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: + datamodule: LightningDataModule | None = None, + weights_only: bool | None = None, + ) -> _PREDICT_OUTPUT | _EVALUATE_OUTPUT | None: # -------------------- # SETUP HOOK # -------------------- @@ -871,12 +871,12 @@ def _test_impl( def predict( self, model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - datamodule: Optional[LightningDataModule] = None, - return_predictions: Optional[bool] = None, - ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, - ) -> Optional[_PREDICT_OUTPUT]: + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + datamodule: LightningDataModule | None = None, + return_predictions: bool | None = None, + ckpt_path: _PATH | None = None, + weights_only: bool | None = None, + ) -> _PREDICT_OUTPUT | None: r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. @@ -950,12 +950,12 @@ def predict( def _predict_impl( self, model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - datamodule: Optional[LightningDataModule] = None, - return_predictions: Optional[bool] = None, - ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, - ) -> Optional[_PREDICT_OUTPUT]: + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + datamodule: LightningDataModule | None = None, + return_predictions: bool | None = None, + ckpt_path: _PATH | None = None, + weights_only: bool | None = None, + ) -> _PREDICT_OUTPUT | None: # -------------------- # SETUP HOOK # -------------------- @@ -995,9 +995,9 @@ def _predict_impl( def _run( self, model: "pl.LightningModule", - ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, - ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + ckpt_path: _PATH | None = None, + weights_only: bool | None = None, + ) -> _EVALUATE_OUTPUT | _PREDICT_OUTPUT | None: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self @@ -1105,7 +1105,7 @@ def _teardown(self) -> None: self._logger_connector.teardown() self._signal_connector.teardown() - def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: + def _run_stage(self) -> _PREDICT_OUTPUT | _EVALUATE_OUTPUT | None: # wait for all to join if on distributed self.strategy.barrier("run-stage") self.lightning_module.zero_grad() @@ -1167,7 +1167,7 @@ def __setup_profiler(self) -> None: self.profiler.setup(stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir) @contextmanager - def init_module(self, empty_init: Optional[bool] = None) -> Generator: + def init_module(self, empty_init: bool | None = None) -> Generator: """Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in the Trainer. @@ -1288,11 +1288,11 @@ def precision(self) -> _PRECISION_INPUT_STR: return self.strategy.precision_plugin.precision @property - def scaler(self) -> Optional[Any]: + def scaler(self) -> Any | None: return getattr(self.precision_plugin, "scaler", None) @property - def model(self) -> Optional[torch.nn.Module]: + def model(self) -> torch.nn.Module | None: """The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. To access the pure LightningModule, use @@ -1306,7 +1306,7 @@ def model(self) -> Optional[torch.nn.Module]: """ @property - def log_dir(self) -> Optional[str]: + def log_dir(self) -> str | None: """The directory for the current experiment. Use this to save images to, etc... .. note:: You must call this on all processes. Failing to do so will cause your program to stall forever. @@ -1343,7 +1343,7 @@ def training_step(self, batch, batch_idx): return self.strategy.is_global_zero @property - def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: + def distributed_sampler_kwargs(self) -> dict[str, Any] | None: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs return None @@ -1369,7 +1369,7 @@ def default_root_dir(self) -> str: return self._default_root_dir @property - def early_stopping_callback(self) -> Optional[EarlyStopping]: + def early_stopping_callback(self) -> EarlyStopping | None: """The first :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.early_stopping_callbacks @@ -1382,7 +1382,7 @@ def early_stopping_callbacks(self) -> list[EarlyStopping]: return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @property - def checkpoint_callback(self) -> Optional[Checkpoint]: + def checkpoint_callback(self) -> Checkpoint | None: """The first :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.checkpoint_callbacks @@ -1395,7 +1395,7 @@ def checkpoint_callbacks(self) -> list[Checkpoint]: return [c for c in self.callbacks if isinstance(c, Checkpoint)] @property - def progress_bar_callback(self) -> Optional[ProgressBar]: + def progress_bar_callback(self) -> ProgressBar | None: """An instance of :class:`~lightning.pytorch.callbacks.progress.progress_bar.ProgressBar` found in the Trainer.callbacks list, or ``None`` if one doesn't exist.""" for c in self.callbacks: @@ -1404,7 +1404,7 @@ def progress_bar_callback(self) -> Optional[ProgressBar]: return None @property - def ckpt_path(self) -> Optional[_PATH]: + def ckpt_path(self) -> _PATH | None: """Set to the path/URL of a checkpoint loaded via :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`, :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`, :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`, or @@ -1416,7 +1416,7 @@ def ckpt_path(self) -> Optional[_PATH]: return self._checkpoint_connector._ckpt_path @ckpt_path.setter - def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: + def ckpt_path(self, ckpt_path: _PATH | None) -> None: """Allows you to manage which checkpoint is loaded statefully. .. code-block:: python @@ -1435,7 +1435,7 @@ def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: self._checkpoint_connector._user_managed = bool(ckpt_path) def save_checkpoint( - self, filepath: _PATH, weights_only: Optional[bool] = None, storage_options: Optional[Any] = None + self, filepath: _PATH, weights_only: bool | None = None, storage_options: Any | None = None ) -> None: r"""Runs routine to create a checkpoint. @@ -1562,11 +1562,11 @@ def current_epoch(self) -> int: return self.fit_loop.epoch_progress.current.completed @property - def max_epochs(self) -> Optional[int]: + def max_epochs(self) -> int | None: return self.fit_loop.max_epochs @property - def min_epochs(self) -> Optional[int]: + def min_epochs(self) -> int | None: return self.fit_loop.min_epochs @property @@ -1574,7 +1574,7 @@ def max_steps(self) -> int: return self.fit_loop.max_steps @property - def min_steps(self) -> Optional[int]: + def min_steps(self) -> int | None: return self.fit_loop.min_steps @property @@ -1583,14 +1583,14 @@ def is_last_batch(self) -> bool: return self.fit_loop.epoch_loop.batch_progress.is_last_batch @property - def train_dataloader(self) -> Optional[TRAIN_DATALOADERS]: + def train_dataloader(self) -> TRAIN_DATALOADERS | None: """The training dataloader(s) used during ``trainer.fit()``.""" if (combined_loader := self.fit_loop._combined_loader) is not None: return combined_loader.iterables return None @property - def val_dataloaders(self) -> Optional[EVAL_DATALOADERS]: + def val_dataloaders(self) -> EVAL_DATALOADERS | None: """The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``.""" if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None or ( combined_loader := self.validate_loop._combined_loader @@ -1599,33 +1599,33 @@ def val_dataloaders(self) -> Optional[EVAL_DATALOADERS]: return None @property - def test_dataloaders(self) -> Optional[EVAL_DATALOADERS]: + def test_dataloaders(self) -> EVAL_DATALOADERS | None: """The test dataloader(s) used during ``trainer.test()``.""" if (combined_loader := self.test_loop._combined_loader) is not None: return combined_loader.iterables return None @property - def predict_dataloaders(self) -> Optional[EVAL_DATALOADERS]: + def predict_dataloaders(self) -> EVAL_DATALOADERS | None: """The prediction dataloader(s) used during ``trainer.predict()``.""" if (combined_loader := self.predict_loop._combined_loader) is not None: return combined_loader.iterables return None @property - def num_training_batches(self) -> Union[int, float]: + def num_training_batches(self) -> int | float: """The number of training batches that will be used during ``trainer.fit()``.""" return self.fit_loop.max_batches @property - def num_sanity_val_batches(self) -> list[Union[int, float]]: + def num_sanity_val_batches(self) -> list[int | float]: """The number of validation batches that will be used during the sanity-checking part of ``trainer.fit()``.""" max_batches = self.fit_loop.epoch_loop.val_loop.max_batches # re-compute the `min` in case this is called outside the sanity-checking stage return [min(self.num_sanity_val_steps, batches) for batches in max_batches] @property - def num_val_batches(self) -> list[Union[int, float]]: + def num_val_batches(self) -> list[int | float]: """The number of validation batches that will be used during ``trainer.fit()`` or ``trainer.validate()``.""" if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop.max_batches @@ -1634,12 +1634,12 @@ def num_val_batches(self) -> list[Union[int, float]]: return self.fit_loop.epoch_loop.val_loop._max_batches @property - def num_test_batches(self) -> list[Union[int, float]]: + def num_test_batches(self) -> list[int | float]: """The number of test batches that will be used during ``trainer.test()``.""" return self.test_loop.max_batches @property - def num_predict_batches(self) -> list[Union[int, float]]: + def num_predict_batches(self) -> list[int | float]: """The number of prediction batches that will be used during ``trainer.predict()``.""" return self.predict_loop.max_batches @@ -1654,7 +1654,7 @@ def _evaluation_loop(self) -> _EvaluationLoop: raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope") @property - def _active_loop(self) -> Optional[Union[_FitLoop, _EvaluationLoop, _PredictionLoop]]: + def _active_loop(self) -> _FitLoop | _EvaluationLoop | _PredictionLoop | None: if self.training: return self.fit_loop if self.sanity_checking or self.evaluating: @@ -1668,12 +1668,12 @@ def _active_loop(self) -> Optional[Union[_FitLoop, _EvaluationLoop, _PredictionL """ @property - def logger(self) -> Optional[Logger]: + def logger(self) -> Logger | None: """The first :class:`~lightning.pytorch.loggers.logger.Logger` being used.""" return self.loggers[0] if len(self.loggers) > 0 else None @logger.setter - def logger(self, logger: Optional[Logger]) -> None: + def logger(self, logger: Logger | None) -> None: if not logger: self.loggers = [] else: @@ -1692,7 +1692,7 @@ def loggers(self) -> list[Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[list[Logger]]) -> None: + def loggers(self, loggers: list[Logger] | None) -> None: self._loggers = loggers if loggers else [] @property @@ -1732,7 +1732,7 @@ def progress_bar_metrics(self) -> _PBAR_DICT: return self._logger_connector.progress_bar_metrics @property - def _results(self) -> Optional[_ResultCollection]: + def _results(self) -> _ResultCollection | None: active_loop = self._active_loop if active_loop is not None: return active_loop._results @@ -1743,7 +1743,7 @@ def _results(self) -> Optional[_ResultCollection]: """ @property - def estimated_stepping_batches(self) -> Union[int, float]: + def estimated_stepping_batches(self) -> int | float: r"""The estimated number of batches that will ``optimizer.step()`` during training. This accounts for gradient accumulation and the current trainer configuration. This might be used when setting diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 4795dc4a67ff5..5fa5c17f5e08c 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -15,7 +15,7 @@ import os import uuid from copy import deepcopy -from typing import Any, Optional +from typing import Any import lightning.pytorch as pl from lightning.pytorch.utilities.memory import garbage_collection_cuda, is_oom_error @@ -34,7 +34,7 @@ def _scale_batch_size( batch_arg_name: str = "batch_size", margin: float = 0.05, max_val: int = 8192, -) -> Optional[int]: +) -> int | None: """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -329,8 +329,8 @@ def _adjust_batch_size( trainer: "pl.Trainer", batch_arg_name: str = "batch_size", factor: float = 1.0, - value: Optional[int] = None, - desc: Optional[str] = None, + value: int | None = None, + desc: str | None = None, max_val: int = 8192, ) -> tuple[int, bool]: """Helper function for adjusting the batch size. diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index 5ef35dcd6d992..5c647f5ff5e02 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -118,7 +118,7 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: def plot( self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None - ) -> Optional[Union["plt.Figure", "plt.SubFigure"]]: + ) -> Union["plt.Figure", "plt.SubFigure"] | None: """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point @@ -136,7 +136,7 @@ def plot( lrs = self.results["lr"] losses = self.results["loss"] - fig: Optional[Union[plt.Figure, plt.SubFigure]] + fig: plt.Figure | plt.SubFigure | None if ax is None: fig, ax = plt.subplots() else: @@ -159,7 +159,7 @@ def plot( return fig - def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]: + def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> float | None: """This will propose a suggestion for an initial learning rate based on the point with the steepest negative gradient. @@ -203,10 +203,10 @@ def _lr_find( max_lr: float = 1, num_training: int = 100, mode: str = "exponential", - early_stop_threshold: Optional[float] = 4.0, + early_stop_threshold: float | None = 4.0, update_attr: bool = False, attr_name: str = "", -) -> Optional[_LRFinder]: +) -> _LRFinder | None: """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. @@ -320,7 +320,7 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> dict[str, Any]: } -def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: Optional[float]) -> None: +def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float | None) -> None: from lightning.pytorch.loggers.logger import DummyLogger trainer.strategy.lr_scheduler_configs = [] @@ -367,7 +367,7 @@ class _LRCallback(Callback): def __init__( self, num_training: int, - early_stop_threshold: Optional[float] = 4.0, + early_stop_threshold: float | None = 4.0, progress_bar_refresh_rate: int = 0, beta: float = 0.98, ): @@ -473,7 +473,7 @@ def get_lr(self) -> list[float]: return val @property - def lr(self) -> Union[float, list[float]]: + def lr(self) -> float | list[float]: return self._lr @@ -510,7 +510,7 @@ def get_lr(self) -> list[float]: return val @property - def lr(self) -> Union[float, list[float]]: + def lr(self) -> float | list[float]: return self._lr diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py index f34eb365480b3..8fbbd47c04bb3 100644 --- a/src/lightning/pytorch/tuner/tuning.py +++ b/src/lightning/pytorch/tuner/tuning.py @@ -31,9 +31,9 @@ def __init__(self, trainer: "pl.Trainer") -> None: def scale_batch_size( self, model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - dataloaders: Optional[EVAL_DATALOADERS] = None, + train_dataloaders: Union[TRAIN_DATALOADERS, "pl.LightningDataModule"] | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + dataloaders: EVAL_DATALOADERS | None = None, datamodule: Optional["pl.LightningDataModule"] = None, method: Literal["fit", "validate", "test", "predict"] = "fit", mode: str = "power", @@ -43,7 +43,7 @@ def scale_batch_size( batch_arg_name: str = "batch_size", margin: float = 0.05, max_val: int = 8192, - ) -> Optional[int]: + ) -> int | None: """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -119,16 +119,16 @@ def scale_batch_size( def lr_find( self, model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - dataloaders: Optional[EVAL_DATALOADERS] = None, + train_dataloaders: Union[TRAIN_DATALOADERS, "pl.LightningDataModule"] | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + dataloaders: EVAL_DATALOADERS | None = None, datamodule: Optional["pl.LightningDataModule"] = None, method: Literal["fit", "validate", "test", "predict"] = "fit", min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = "exponential", - early_stop_threshold: Optional[float] = 4.0, + early_stop_threshold: float | None = 4.0, update_attr: bool = True, attr_name: str = "", ) -> Optional["_LRFinder"]: @@ -196,9 +196,9 @@ def lr_find( def _check_tuner_configuration( - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - dataloaders: Optional[EVAL_DATALOADERS] = None, + train_dataloaders: Union[TRAIN_DATALOADERS, "pl.LightningDataModule"] | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + dataloaders: EVAL_DATALOADERS | None = None, method: Literal["fit", "validate", "test", "predict"] = "fit", ) -> None: supported_methods = ("fit", "validate", "test", "predict") diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index 1e01297248ffa..d85f695b0c3f2 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -17,9 +17,10 @@ import os from argparse import Namespace from ast import literal_eval +from collections.abc import Callable from contextlib import suppress from functools import wraps -from typing import Any, Callable, TypeVar, cast +from typing import Any, TypeVar, cast _T = TypeVar("_T", bound=Callable[..., Any]) diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 9c89c998aa913..d2bedfa4b4195 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from collections.abc import Iterable, Iterator -from typing import Any, Callable, Literal, Optional, Union +from collections.abc import Callable, Iterable, Iterator +from typing import Any, Literal from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict, override @@ -26,7 +26,7 @@ class _ModeIterator(Iterator[_ITERATOR_RETURN]): - def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: list[int | float] | None = None) -> None: if limits is not None and len(limits) != len(iterables): raise ValueError(f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(iterables)})") self.iterables = iterables @@ -65,7 +65,7 @@ def __getstate__(self) -> dict[str, Any]: class _MaxSizeCycle(_ModeIterator): - def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: list[int | float] | None = None) -> None: super().__init__(iterables, limits) self._consumed: list[bool] = [] @@ -121,7 +121,7 @@ def __len__(self) -> int: class _Sequential(_ModeIterator): - def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: list[int | float] | None = None) -> None: super().__init__(iterables, limits) self._iterator_idx = 0 # what would be dataloader_idx @@ -287,8 +287,8 @@ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") self._iterables = iterables self._flattened, self._spec = _tree_flatten(iterables) self._mode = mode - self._iterator: Optional[_ModeIterator] = None - self._limits: Optional[list[Union[int, float]]] = None + self._iterator: _ModeIterator | None = None + self._limits: list[int | float] | None = None @property def iterables(self) -> Any: @@ -322,12 +322,12 @@ def flattened(self, flattened: list[Any]) -> None: self._flattened = flattened @property - def limits(self) -> Optional[list[Union[int, float]]]: + def limits(self) -> list[int | float] | None: """Optional limits per iterator.""" return self._limits @limits.setter - def limits(self, limits: Optional[Union[int, float, list[Union[int, float]]]]) -> None: + def limits(self, limits: int | float | list[int | float] | None) -> None: if isinstance(limits, (int, float)): limits = [limits] * len(self.flattened) elif isinstance(limits, list) and len(limits) != len(self.flattened): @@ -401,5 +401,5 @@ def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: dataloader._iterator = None -def _get_iterables_lengths(iterables: list[Iterable]) -> list[Union[int, float]]: +def _get_iterables_lengths(iterables: list[Iterable]) -> list[int | float]: return [(float("inf") if (length := sized_len(iterable)) is None else length) for iterable in iterables] diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index b04bc0dfdc2da..745e8ec93b916 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -14,13 +14,12 @@ import inspect from collections.abc import Generator, Iterable, Mapping, Sized from dataclasses import fields -from typing import Any, Optional, Union +from typing import Any, TypeGuard import torch from lightning_utilities.core.apply_func import is_dataclass_instance from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler -from typing_extensions import TypeGuard import lightning.pytorch as pl from lightning.fabric.utilities.data import ( @@ -35,12 +34,12 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn -BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] +BType = Tensor | str | Mapping[Any, "BType"] | Iterable["BType"] warning_cache = WarningCache() -def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]: +def _extract_batch_size(batch: BType) -> Generator[int | None, None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: yield 1 @@ -130,7 +129,7 @@ def has_len_all_ranks( def _update_dataloader( - dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None + dataloader: DataLoader, sampler: Sampler | Iterable, mode: RunningStage | None = None ) -> DataLoader: dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) return _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) @@ -138,8 +137,8 @@ def _update_dataloader( def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, - sampler: Union[Sampler, Iterable], - mode: Optional[RunningStage] = None, + sampler: Sampler | Iterable, + mode: RunningStage | None = None, ) -> tuple[tuple[Any], dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -232,8 +231,8 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, - sampler: Union[Sampler, Iterable], - mode: Optional[RunningStage] = None, + sampler: Sampler | Iterable, + mode: RunningStage | None = None, ) -> dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation. diff --git a/src/lightning/pytorch/utilities/grads.py b/src/lightning/pytorch/utilities/grads.py index 08a0230f759cf..2e37621e358de 100644 --- a/src/lightning/pytorch/utilities/grads.py +++ b/src/lightning/pytorch/utilities/grads.py @@ -13,13 +13,11 @@ # limitations under the License. """Utilities to describe gradients.""" -from typing import Union - import torch from torch.nn import Module -def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> dict[str, float]: +def grad_norm(module: Module, norm_type: float | int | str, group_separator: str = "/") -> dict[str, float]: """Compute each parameter's gradient's norm and their overall norm. The overall norm is computed over all gradients together, as if they diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 5db942b29183f..5b392fd0daa00 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -31,7 +31,8 @@ """ import re -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.callbacks.early_stopping import EarlyStopping diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 42074a3735c87..628f311bce6eb 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -18,7 +18,7 @@ import threading import warnings from types import ModuleType, TracebackType -from typing import Any, Optional +from typing import Any from packaging.version import Version from typing_extensions import override @@ -37,7 +37,7 @@ def migrate_checkpoint( - checkpoint: _CHECKPOINT, target_version: Optional[str] = None + checkpoint: _CHECKPOINT, target_version: str | None = None ) -> tuple[_CHECKPOINT, dict[str, list[str]]]: """Applies Lightning version migrations to a checkpoint dictionary. @@ -121,9 +121,9 @@ class _FaultTolerantMode(LightningEnum): def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - exc_traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, ) -> None: if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") @@ -134,7 +134,7 @@ def __exit__( _lock.release() -def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: +def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: _PATH | None = None) -> _CHECKPOINT: """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user. This function is used by the Lightning Trainer when resuming from a checkpoint. @@ -174,7 +174,7 @@ def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None: checkpoint.setdefault("legacy_pytorch-lightning_version", version) -def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: Optional[str] = None) -> bool: +def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: str | None = None) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" target_version = Version(target) is_lte_max_version = max_version is None or target_version <= Version(max_version) diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index 15a0d96e383c0..3a6a4158e6030 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -15,18 +15,19 @@ import inspect import logging import os -from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Concatenate, Generic, TypeVar from lightning_utilities.core.imports import RequirementCache from torch import nn -from typing_extensions import Concatenate, ParamSpec, override +from typing_extensions import ParamSpec, override import lightning.pytorch as pl _log = logging.getLogger(__name__) -def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[type[object]] = None) -> bool: +def is_overridden(method_name: str, instance: object | None = None, parent: type[object] | None = None) -> bool: if instance is None: # if `self.lightning_module` was passed as instance, it can be `None` return False @@ -115,7 +116,7 @@ def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None: self.method = method @override - def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: # type: ignore[override] + def __get__(self, instance: _T, cls: type[_T] | None = None) -> Callable[_P, _R_co]: # type: ignore[override] # The wrapper ensures that the method can be inspected, but not called on an instance @functools.wraps(self.method) def wrapper(*args: Any, **kwargs: Any) -> _R_co: diff --git a/src/lightning/pytorch/utilities/model_registry.py b/src/lightning/pytorch/utilities/model_registry.py index 104da2514f5c2..fce46624eecd2 100644 --- a/src/lightning/pytorch/utilities/model_registry.py +++ b/src/lightning/pytorch/utilities/model_registry.py @@ -13,7 +13,6 @@ # limitations under the License. import os import re -from typing import Optional from lightning_utilities import module_available @@ -26,7 +25,7 @@ __doctest_skip__ = ["_determine_model_folder"] -def _is_registry(text: Optional[_PATH]) -> bool: +def _is_registry(text: _PATH | None) -> bool: """Check if a string equals 'registry' or starts with 'registry:'. Args: @@ -50,7 +49,7 @@ def _is_registry(text: Optional[_PATH]) -> bool: return bool(re.match(pattern, text.lower())) -def _parse_registry_model_version(ckpt_path: Optional[_PATH]) -> tuple[str, str]: +def _parse_registry_model_version(ckpt_path: _PATH | None) -> tuple[str, str]: """Parse the model version from a registry path. Args: @@ -86,7 +85,7 @@ def _parse_registry_model_version(ckpt_path: Optional[_PATH]) -> tuple[str, str] return model_name, version -def _determine_model_name(ckpt_path: Optional[_PATH], default_model_registry: Optional[str]) -> str: +def _determine_model_name(ckpt_path: _PATH | None, default_model_registry: str | None) -> str: """Determine the model name from the checkpoint path. Args: @@ -142,7 +141,7 @@ def _determine_model_folder(model_name: str, default_root_dir: str) -> str: def find_model_local_ckpt_path( - ckpt_path: Optional[_PATH], default_model_registry: Optional[str], default_root_dir: str + ckpt_path: _PATH | None, default_model_registry: str | None, default_root_dir: str ) -> str: """Find the local checkpoint path for a model.""" model_registry = _determine_model_name(ckpt_path, default_model_registry) @@ -156,7 +155,7 @@ def find_model_local_ckpt_path( return os.path.join(local_model_dir, folder_files[0]) -def download_model_from_registry(ckpt_path: Optional[_PATH], trainer: "pl.Trainer") -> None: +def download_model_from_registry(ckpt_path: _PATH | None, trainer: "pl.Trainer") -> None: """Download a model from the Lightning Model Registry.""" if trainer.local_rank == 0: if not module_available("litmodels"): diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 01b692abdc05f..034906ec1e4e7 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -17,7 +17,7 @@ import logging import math from collections import OrderedDict -from typing import Any, Optional, Union +from typing import Any import torch import torch.nn as nn @@ -76,13 +76,13 @@ def __init__(self, module: nn.Module) -> None: super().__init__() self._module = module self._hook_handle = self._register_hook() - self._in_size: Optional[Union[str, list]] = None - self._out_size: Optional[Union[str, list]] = None + self._in_size: str | list | None = None + self._out_size: str | list | None = None def __del__(self) -> None: self.detach_hook() - def _register_hook(self) -> Optional[RemovableHandle]: + def _register_hook(self) -> RemovableHandle | None: """Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the hook is called, it will remove itself from the from the module, meaning that recursive models will only record their input- and output shapes once. Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. @@ -124,11 +124,11 @@ def detach_hook(self) -> None: self._hook_handle.remove() @property - def in_size(self) -> Union[str, list]: + def in_size(self) -> str | list: return self._in_size or UNKNOWN_SIZE @property - def out_size(self) -> Union[str, list]: + def out_size(self) -> str | list: return self._out_size or UNKNOWN_SIZE @property @@ -431,7 +431,7 @@ def __repr__(self) -> str: return str(self) -def parse_batch_shape(batch: Any) -> Union[str, list]: +def parse_batch_shape(batch: Any) -> str | list: if hasattr(batch, "shape"): return list(batch.shape) diff --git a/src/lightning/pytorch/utilities/parameter_tying.py b/src/lightning/pytorch/utilities/parameter_tying.py index da0309b0626bb..36ba28facace8 100644 --- a/src/lightning/pytorch/utilities/parameter_tying.py +++ b/src/lightning/pytorch/utilities/parameter_tying.py @@ -18,8 +18,6 @@ """ -from typing import Optional - from torch import nn @@ -28,7 +26,7 @@ def find_shared_parameters(module: nn.Module) -> list[str]: return _find_shared_parameters(module) -def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[dict] = None, prefix: str = "") -> list[str]: +def _find_shared_parameters(module: nn.Module, tied_parameters: dict | None = None, prefix: str = "") -> list[str]: if tied_parameters is None: tied_parameters = {} for name, param in module._parameters.items(): diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 64aa0209819ab..e3950f7fdede8 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -19,7 +19,7 @@ import types from collections.abc import MutableMapping, Sequence from dataclasses import fields, is_dataclass -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from torch import nn @@ -49,7 +49,7 @@ def clean_namespace(hparams: MutableMapping) -> None: del hparams[k] -def parse_class_init_keys(cls: type) -> tuple[str, Optional[str], Optional[str]]: +def parse_class_init_keys(cls: type) -> tuple[str, str | None, str | None]: """Parse key words for standard ``self``, ``*args`` and ``**kwargs``. Examples: @@ -71,7 +71,7 @@ def parse_class_init_keys(cls: type) -> tuple[str, Optional[str], Optional[str]] def _get_first_if_any( params: list[inspect.Parameter], param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD], - ) -> Optional[str]: + ) -> str | None: for p in params: if p.kind == param_type: return p.name @@ -89,7 +89,7 @@ def get_init_args(frame: types.FrameType) -> dict[str, Any]: # pragma: no-cover return local_args -def _get_init_args(frame: types.FrameType) -> tuple[Optional[Any], dict[str, Any]]: +def _get_init_args(frame: types.FrameType) -> tuple[Any | None, dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars or frame.f_code.co_name != "__init__": return None, {} @@ -146,9 +146,9 @@ def collect_init_args( def save_hyperparameters( obj: Any, *args: Any, - ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None, - given_hparams: Optional[dict[str, Any]] = None, + ignore: Sequence[str] | str | None = None, + frame: types.FrameType | None = None, + given_hparams: dict[str, Any] | None = None, ) -> None: """See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`""" @@ -263,7 +263,7 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) return holders -def _lightning_get_first_attr_holder(model: "pl.LightningModule", attribute: str) -> Optional[Any]: +def _lightning_get_first_attr_holder(model: "pl.LightningModule", attribute: str) -> Any | None: """Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams @@ -286,7 +286,7 @@ def lightning_hasattr(model: "pl.LightningModule", attribute: str) -> bool: return _lightning_get_first_attr_holder(model, attribute) is not None -def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[Any]: +def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Any | None: """Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. diff --git a/src/lightning/pytorch/utilities/signature_utils.py b/src/lightning/pytorch/utilities/signature_utils.py index 0f41c5948fb46..e474a3b1b221b 100644 --- a/src/lightning/pytorch/utilities/signature_utils.py +++ b/src/lightning/pytorch/utilities/signature_utils.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Callable, Optional +from collections.abc import Callable def is_param_in_hook_signature( - hook_fx: Callable, param: str, explicit: bool = False, min_args: Optional[int] = None + hook_fx: Callable, param: str, explicit: bool = False, min_args: int | None = None ) -> bool: """ Args: diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 4c5b3bb6b4712..5b8dc20a61be8 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from lightning_utilities.core.imports import RequirementCache @@ -26,12 +25,12 @@ def _runif_reasons( *, min_cuda_gpus: int = 0, - min_torch: Optional[str] = None, - max_torch: Optional[str] = None, - min_python: Optional[str] = None, + min_torch: str | None = None, + max_torch: str | None = None, + min_python: str | None = None, bf16_cuda: bool = False, tpu: bool = False, - mps: Optional[bool] = None, + mps: bool | None = None, skip_windows: bool = False, standalone: bool = False, deepspeed: bool = False, diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index 6abd8aa952296..2b734b0e19290 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -22,10 +22,8 @@ from dataclasses import dataclass from typing import ( Any, - Optional, Protocol, TypedDict, - Union, runtime_checkable, ) @@ -38,11 +36,11 @@ from lightning.fabric.utilities.types import ProcessGroup -_NUMBER = Union[int, float] -_METRIC = Union[Metric, Tensor, _NUMBER] -STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]] +_NUMBER = int | float +_METRIC = Metric | Tensor | _NUMBER +STEP_OUTPUT = Tensor | Mapping[str, Any] | None _EVALUATE_OUTPUT = list[Mapping[str, float]] # 1 dict per DataLoader -_PREDICT_OUTPUT = Union[list[Any], list[list[Any]]] +_PREDICT_OUTPUT = list[Any] | list[list[Any]] TRAIN_DATALOADERS = Any # any iterable or collection of iterables EVAL_DATALOADERS = Any # any iterable or collection of iterables @@ -54,11 +52,11 @@ class DistributedDataParallel(Protocol): def __init__( self, module: torch.nn.Module, - device_ids: Optional[list[Union[int, torch.device]]] = None, - output_device: Optional[Union[int, torch.device]] = None, + device_ids: list[int | torch.device] | None = None, + output_device: int | torch.device | None = None, dim: int = 0, broadcast_buffers: bool = True, - process_group: Optional[ProcessGroup] = None, + process_group: ProcessGroup | None = None, bucket_cap_mb: int = 25, find_unused_parameters: bool = False, check_reduction: bool = False, @@ -72,16 +70,16 @@ def no_sync(self) -> Generator: ... # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]] -LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau] +LRSchedulerTypeUnion = LRScheduler | ReduceLROnPlateau +LRSchedulerType = type[LRScheduler] | type[ReduceLROnPlateau] +LRSchedulerPLType = LRScheduler | ReduceLROnPlateau @dataclass class LRSchedulerConfig: - scheduler: Union[LRScheduler, ReduceLROnPlateau] + scheduler: LRScheduler | ReduceLROnPlateau # no custom name - name: Optional[str] = None + name: str | None = None # after epoch is over interval: str = "epoch" # every epoch/batch @@ -89,18 +87,18 @@ class LRSchedulerConfig: # most often not ReduceLROnPlateau scheduler reduce_on_plateau: bool = False # value to monitor for ReduceLROnPlateau - monitor: Optional[str] = None + monitor: str | None = None # enforce that the monitor exists for ReduceLROnPlateau strict: bool = True class LRSchedulerConfigType(TypedDict, total=False): scheduler: Required[LRSchedulerTypeUnion] - name: Optional[str] + name: str | None interval: str frequency: int reduce_on_plateau: bool - monitor: Optional[str] + monitor: str | None strict: bool @@ -110,21 +108,20 @@ class OptimizerConfig(TypedDict): class OptimizerLRSchedulerConfig(TypedDict): optimizer: Optimizer - lr_scheduler: Union[LRSchedulerTypeUnion, LRSchedulerConfigType] + lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType monitor: NotRequired[str] -OptimizerLRScheduler = Optional[ - Union[ - Optimizer, - Sequence[Optimizer], - tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], - OptimizerConfig, - OptimizerLRSchedulerConfig, - Sequence[OptimizerConfig], - Sequence[OptimizerLRSchedulerConfig], - ] -] +OptimizerLRScheduler = ( + Optimizer + | Sequence[Optimizer] + | tuple[Sequence[Optimizer], Sequence[LRSchedulerTypeUnion | LRSchedulerConfig]] + | OptimizerConfig + | OptimizerLRSchedulerConfig + | Sequence[OptimizerConfig] + | Sequence[OptimizerLRSchedulerConfig] + | None +) class _SizedIterable(Protocol): diff --git a/tests/parity_fabric/models.py b/tests/parity_fabric/models.py index f65a20460e2f7..15c617e16ffcf 100644 --- a/tests/parity_fabric/models.py +++ b/tests/parity_fabric/models.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Callable +from collections.abc import Callable import torch import torch.nn as nn diff --git a/tests/parity_fabric/test_parity_simple.py b/tests/parity_fabric/test_parity_simple.py index 54c0de7297ac5..4bd5e72bf62bf 100644 --- a/tests/parity_fabric/test_parity_simple.py +++ b/tests/parity_fabric/test_parity_simple.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from collections.abc import Callable from copy import deepcopy -from typing import Callable import pytest import torch diff --git a/tests/parity_pytorch/measure.py b/tests/parity_pytorch/measure.py index b986861ef10df..fd5db2e6023a5 100644 --- a/tests/parity_pytorch/measure.py +++ b/tests/parity_pytorch/measure.py @@ -1,6 +1,6 @@ import gc import time -from typing import Callable +from collections.abc import Callable import torch from tqdm import tqdm diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index 420f711678808..0999d753a94fe 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Union +from typing import Any from unittest.mock import Mock import pytest @@ -53,7 +53,7 @@ def setup(self, trainer: "pl.Trainer") -> None: def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup - def load_checkpoint(self, checkpoint_path: Union[str, Path], weights_only: bool) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: str | Path, weights_only: bool) -> dict[str, Any]: assert self.setup_called == restore_after_pre_setup return super().load_checkpoint(checkpoint_path, weights_only) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 0bd29b998c598..56b1db12036e8 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -16,7 +16,6 @@ import pickle import sys from collections import defaultdict -from typing import Union from unittest import mock from unittest.mock import ANY, Mock, PropertyMock, call, patch @@ -457,7 +456,7 @@ def training_step(self, batch, batch_idx): ("abc", "abc"), ], ) -def test_tqdm_format_num(input_num: Union[str, int, float], expected: str): +def test_tqdm_format_num(input_num: str | int | float, expected: str): """Check that the specialized tqdm.format_num appends 0 to floats and strings.""" assert Tqdm.format_num(input_num) == expected diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index 290a0921cb06d..49ddc1bca0f58 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -14,7 +14,6 @@ import csv import os import re -from typing import Optional from unittest import mock from unittest.mock import Mock @@ -40,7 +39,7 @@ def test_device_stats_gpu_from_torch(tmp_path): class DebugLogger(CSVLogger): @rank_zero_only - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: fields = [ "allocated_bytes.all.freed", "inactive_split.all.peak", @@ -74,7 +73,7 @@ def test_device_stats_cpu(cpu_stats_mock, tmp_path, cpu_stats): CPU_METRIC_KEYS = (_CPU_VM_PERCENT, _CPU_SWAP_PERCENT, _CPU_PERCENT) class DebugLogger(CSVLogger): - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: enabled = cpu_stats is not False for f in CPU_METRIC_KEYS: has_cpu_metrics = any(f in h for h in metrics) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index ff65e08c8c01a..3a2f9daed5dea 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -16,7 +16,6 @@ import math import os import pickle -from typing import Optional from unittest import mock from unittest.mock import Mock @@ -334,7 +333,7 @@ def test_early_stopping_mode_options(): class EarlyStoppingModel(BoringModel): - def __init__(self, expected_end_epoch: int, early_stop_on_train: bool, dist_diverge_epoch: Optional[int] = None): + def __init__(self, expected_end_epoch: int, early_stop_on_train: bool, dist_diverge_epoch: int | None = None): super().__init__() self.expected_end_epoch = expected_end_epoch self.early_stop_on_train = early_stop_on_train @@ -414,7 +413,7 @@ def test_multiple_early_stopping_callbacks( check_on_train_epoch_end: bool, strategy: str, devices: int, - dist_diverge_epoch: Optional[int], + dist_diverge_epoch: int | None, ): """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index 1a23efd919171..2b55aa0fd3002 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -14,7 +14,6 @@ import re from collections import OrderedDict from logging import INFO -from typing import Union import pytest import torch @@ -151,7 +150,7 @@ def test_pruning_callback( tmp_path, use_global_unstructured: bool, parameters_to_prune: bool, - pruning_fn: Union[str, pytorch_prune.BasePruningMethod], + pruning_fn: str | pytorch_prune.BasePruningMethod, use_lottery_ticket_hypothesis: bool, ): train_with_pruning_callback( diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index c63dd4e5c2ac9..86dc32ade8f89 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -15,7 +15,6 @@ import os from contextlib import AbstractContextManager from pathlib import Path -from typing import Optional from unittest import mock import pytest @@ -90,7 +89,7 @@ class SwaTestCallback(StochasticWeightAveraging): update_parameters_calls: int = 0 transfer_weights_calls: int = 0 # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0 - first_epoch: Optional[int] = None + first_epoch: int | None = None def update_parameters(self, *args, **kwargs): self.update_parameters_calls += 1 diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index ec230b2fd6c97..4a3e61d782079 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from pathlib import Path -from typing import Any, Optional +from typing import Any import pytest import torch @@ -84,7 +84,7 @@ def __init__(self, devices: int = 1, **kwargs: Any) -> None: self.swap_calls = 0 self.copy_calls = 0 # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0. - self.first_epoch: Optional[int] = None + self.first_epoch: int | None = None def _swap_models(self, *args: Any, **kwargs: Any): self.swap_calls += 1 @@ -130,9 +130,9 @@ def __init__(self, **kwargs: Any) -> None: self.swap_calls = 0 self.copy_calls = 0 # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0. - self.first_epoch: Optional[int] = None + self.first_epoch: int | None = None - def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: + def should_update(self, step_idx: int | None = None, epoch_idx: int | None = None) -> bool: return epoch_idx in (3, 5, 7) def _swap_models(self, *args: Any, **kwargs: Any): @@ -289,7 +289,7 @@ def _train( strategy: str = "auto", accelerator: str = "cpu", devices: int = 1, - checkpoint_path: Optional[str] = None, + checkpoint_path: str | None = None, will_crash: bool = False, ) -> None: deterministic = accelerator == "cpu" diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 449484da970a8..244b5be76a4f8 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -20,7 +20,6 @@ from datetime import timedelta from inspect import signature from pathlib import Path -from typing import Union from unittest import mock from unittest.mock import Mock, call, patch @@ -335,7 +334,7 @@ def test_model_checkpoint_to_yaml(tmp_path, save_top_k: int): @pytest.mark.parametrize( ("logger_version", "expected"), [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")] ) -def test_model_checkpoint_path(tmp_path, logger_version: Union[None, int, str], expected: str): +def test_model_checkpoint_path(tmp_path, logger_version: None | int | str, expected: str): """Test that "version_" prefix is only added when logger's version is an integer.""" model = LogInTwoMethods() logger = TensorBoardLogger(tmp_path, version=logger_version) diff --git a/tests/tests_pytorch/deprecated_api/__init__.py b/tests/tests_pytorch/deprecated_api/__init__.py index cae45411063c9..7eafdc892bc7a 100644 --- a/tests/tests_pytorch/deprecated_api/__init__.py +++ b/tests/tests_pytorch/deprecated_api/__init__.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Optional from lightning_utilities.test.warning import no_warning_call @contextmanager -def no_deprecated_call(match: Optional[str] = None): +def no_deprecated_call(match: str | None = None): with no_warning_call(expected_warning=DeprecationWarning, match=match): yield diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index 638d3a2946a74..18a05c52d2085 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -17,7 +17,6 @@ import time import urllib.request from collections.abc import Sequence -from typing import Optional import torch from torch import Tensor @@ -136,7 +135,7 @@ class TrialMNIST(MNIST): """ - def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): + def __init__(self, root: str, num_samples: int = 100, digits: Sequence | None = (0, 1, 2), **kwargs): # number of examples per class self.num_samples = num_samples # take just a subset of MNIST dataset diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index 124a9120a9197..931eb858d0596 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -14,7 +14,7 @@ import pickle from argparse import Namespace from copy import deepcopy -from typing import Any, Optional +from typing import Any from unittest.mock import patch import numpy as np @@ -58,7 +58,7 @@ def finalize(self, status): self.finalized_status = status @property - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Return the root directory where experiment logs get saved, or `None` if the logger does not save data locally.""" return None diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 5e52a2a3b5a82..1160b150d5083 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -19,7 +19,6 @@ from argparse import Namespace from dataclasses import dataclass, field from enum import Enum -from typing import Optional from unittest import mock import cloudpickle @@ -97,7 +96,7 @@ def __init__(self, hparams, *my_args, **my_kwargs): # STANDARD TESTS # ------------------------- def _run_standard_hparams_test( - tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only: Optional[bool] = None + tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only: bool | None = None ): """Tests for the existence of an arg 'test_arg=14'.""" obj = datamodule if issubclass(cls, LightningDataModule) else model diff --git a/tests/tests_pytorch/plugins/test_async_checkpoint.py b/tests/tests_pytorch/plugins/test_async_checkpoint.py index 0718dab78d75f..89423c80d4d5a 100644 --- a/tests/tests_pytorch/plugins/test_async_checkpoint.py +++ b/tests/tests_pytorch/plugins/test_async_checkpoint.py @@ -1,5 +1,5 @@ import time -from typing import Any, Optional +from typing import Any import pytest import torch @@ -10,15 +10,15 @@ class _CaptureCheckpointIO(CheckpointIO): def __init__(self) -> None: - self.saved: Optional[dict[str, Any]] = None + self.saved: dict[str, Any] | None = None - def save_checkpoint(self, checkpoint: dict[str, Any], path: str, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: str, storage_options: Any | None = None) -> None: # Simulate some delay to increase race window time.sleep(0.05) # Store the received checkpoint object (not a deep copy) to inspect tensor values self.saved = checkpoint - def load_checkpoint(self, path: str, map_location: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint(self, path: str, map_location: Any | None = None) -> dict[str, Any]: raise NotImplementedError def remove_checkpoint(self, path: str) -> None: diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 0f66f215f6864..9f2bbfede36c1 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import os from pathlib import Path -from typing import Any, Optional +from typing import Any from unittest.mock import MagicMock, Mock import pytest @@ -29,11 +29,11 @@ class CustomCheckpointIO(CheckpointIO): - def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Any | None = None) -> None: torch.save(checkpoint, path) def load_checkpoint( - self, path: _PATH, storage_options: Optional[Any] = None, weights_only: bool = True + self, path: _PATH, storage_options: Any | None = None, weights_only: bool = True ) -> dict[str, Any]: return torch.load(path, weights_only=True) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index f7c15b5930be8..f6aa86f3e59be 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -5,7 +5,6 @@ from functools import partial from pathlib import Path from re import escape -from typing import Optional from unittest import mock from unittest.mock import ANY, MagicMock, Mock @@ -34,7 +33,7 @@ class TestFSDPModel(BoringModel): def __init__(self): super().__init__() - self.layer: Optional[nn.Module] = None + self.layer: nn.Module | None = None def _init_model(self) -> None: self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) @@ -162,7 +161,7 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype -def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): +def _run_multiple_stages(trainer, model, model_path: str | None = None): trainer.fit(model) trainer.test(model) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 6ff4bee264a7b..a2bc103fe555c 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -17,10 +17,10 @@ import operator import os import sys +from collections.abc import Callable from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout from io import StringIO from pathlib import Path -from typing import Callable, Optional, Union from unittest import mock from unittest.mock import ANY @@ -118,7 +118,7 @@ def _model_builder(model_param: int) -> Model: def _trainer_builder( - limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[list[Callback], Callback]] = None + limit_train_batches: int, fast_dev_run: bool = False, callbacks: list[Callback] | Callback | None = None ) -> Trainer: return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks) @@ -591,7 +591,7 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE)) def test_lightning_cli_torch_modules(cleandir): class TestModule(BoringModel): - def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[torch.nn.Module]] = None): + def __init__(self, activation: torch.nn.Module = None, transform: list[torch.nn.Module] | None = None): super().__init__() self.activation = activation self.transform = transform @@ -684,7 +684,7 @@ def add_arguments_to_parser(self, parser): class CustomAdam(torch.optim.Adam): - def __init__(self, params, num_classes: Optional[int] = None, **kwargs): + def __init__(self, params, num_classes: int | None = None, **kwargs): super().__init__(params, **kwargs) diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index 90f9a3e697535..70c4b24a9728c 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Optional, Union +from typing import Any from unittest.mock import Mock import lightning.pytorch as pl @@ -37,10 +37,10 @@ def __init__(self): def experiment(self) -> Any: return self.exp - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None): + def log_metrics(self, metrics: dict[str, float], step: int | None = None): self.logs.update(metrics) - def version(self) -> Union[int, str]: + def version(self) -> int | str: return 1 def name(self) -> str: @@ -143,7 +143,7 @@ def __init__(self): self.buffer = {} self.logs = {} - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: self.buffer.update(metrics) def finalize(self, status: str) -> None: @@ -155,7 +155,7 @@ def experiment(self) -> Any: return None @property - def version(self) -> Union[int, str]: + def version(self) -> int | str: return 1 @property