11import os
22from collections .abc import Iterable , Mapping
33from functools import partial
4- from typing import Any , Literal , Optional , Union , cast
4+ from typing import Any , Literal , cast
55
66import torch
77from lightning_utilities import apply_to_collection
1818class MyCustomTrainer :
1919 def __init__ (
2020 self ,
21- accelerator : Union [ str , Accelerator ] = "auto" ,
22- strategy : Union [ str , Strategy ] = "auto" ,
23- devices : Union [ list [int ], str , int ] = "auto" ,
24- precision : Union [ str , int ] = "32-true" ,
25- plugins : Optional [ Union [ str , Any ]] = None ,
26- callbacks : Optional [ Union [ list [Any ], Any ]] = None ,
27- loggers : Optional [ Union [ Logger , list [Logger ]]] = None ,
28- max_epochs : Optional [ int ] = 1000 ,
29- max_steps : Optional [ int ] = None ,
21+ accelerator : str | Accelerator = "auto" ,
22+ strategy : str | Strategy = "auto" ,
23+ devices : list [int ] | str | int = "auto" ,
24+ precision : str | int = "32-true" ,
25+ plugins : str | Any | None = None ,
26+ callbacks : list [Any ] | Any | None = None ,
27+ loggers : Logger | list [Logger ] | None = None ,
28+ max_epochs : int | None = 1000 ,
29+ max_steps : int | None = None ,
3030 grad_accum_steps : int = 1 ,
31- limit_train_batches : Union [ int , float ] = float ("inf" ),
32- limit_val_batches : Union [ int , float ] = float ("inf" ),
31+ limit_train_batches : int | float = float ("inf" ),
32+ limit_val_batches : int | float = float ("inf" ),
3333 validation_frequency : int = 1 ,
3434 use_distributed_sampler : bool = True ,
3535 checkpoint_dir : str = "./checkpoints" ,
@@ -115,8 +115,8 @@ def __init__(
115115 self .limit_val_batches = limit_val_batches
116116 self .validation_frequency = validation_frequency
117117 self .use_distributed_sampler = use_distributed_sampler
118- self ._current_train_return : Union [ torch .Tensor , Mapping [str , Any ] ] = {}
119- self ._current_val_return : Optional [ Union [ torch .Tensor , Mapping [str , Any ]]] = {}
118+ self ._current_train_return : torch .Tensor | Mapping [str , Any ] = {}
119+ self ._current_val_return : torch .Tensor | Mapping [str , Any ] | None = {}
120120
121121 self .checkpoint_dir = checkpoint_dir
122122 self .checkpoint_frequency = checkpoint_frequency
@@ -126,7 +126,7 @@ def fit(
126126 model : L .LightningModule ,
127127 train_loader : torch .utils .data .DataLoader ,
128128 val_loader : torch .utils .data .DataLoader ,
129- ckpt_path : Optional [ str ] = None ,
129+ ckpt_path : str | None = None ,
130130 ):
131131 """The main entrypoint of the trainer, triggering the actual training.
132132
@@ -196,8 +196,8 @@ def train_loop(
196196 model : L .LightningModule ,
197197 optimizer : torch .optim .Optimizer ,
198198 train_loader : torch .utils .data .DataLoader ,
199- limit_batches : Union [ int , float ] = float ("inf" ),
200- scheduler_cfg : Optional [ Mapping [str , Union [ L .fabric .utilities .types .LRScheduler , bool , str , int ]]] = None ,
199+ limit_batches : int | float = float ("inf" ),
200+ scheduler_cfg : Mapping [str , L .fabric .utilities .types .LRScheduler | bool | str | int ] | None = None ,
201201 ):
202202 """The training loop running a single training epoch.
203203
@@ -262,8 +262,8 @@ def train_loop(
262262 def val_loop (
263263 self ,
264264 model : L .LightningModule ,
265- val_loader : Optional [ torch .utils .data .DataLoader ] ,
266- limit_batches : Union [ int , float ] = float ("inf" ),
265+ val_loader : torch .utils .data .DataLoader | None ,
266+ limit_batches : int | float = float ("inf" ),
267267 ):
268268 """The validation loop running a single validation epoch.
269269
@@ -331,7 +331,7 @@ def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) ->
331331 batch_idx: index of the current batch w.r.t the current epoch
332332
333333 """
334- outputs : Union [ torch .Tensor , Mapping [str , Any ] ] = model .training_step (batch , batch_idx = batch_idx )
334+ outputs : torch .Tensor | Mapping [str , Any ] = model .training_step (batch , batch_idx = batch_idx )
335335
336336 loss = outputs if isinstance (outputs , torch .Tensor ) else outputs ["loss" ]
337337
@@ -347,7 +347,7 @@ def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) ->
347347 def step_scheduler (
348348 self ,
349349 model : L .LightningModule ,
350- scheduler_cfg : Optional [ Mapping [str , Union [ L .fabric .utilities .types .LRScheduler , bool , str , int ]]] ,
350+ scheduler_cfg : Mapping [str , L .fabric .utilities .types .LRScheduler | bool | str | int ] | None ,
351351 level : Literal ["step" , "epoch" ],
352352 current_value : int ,
353353 ) -> None :
@@ -387,7 +387,7 @@ def step_scheduler(
387387 possible_monitor_vals .update ({"val_" + k : v for k , v in self ._current_val_return .items ()})
388388
389389 try :
390- monitor = possible_monitor_vals [cast (Optional [ str ] , scheduler_cfg ["monitor" ])]
390+ monitor = possible_monitor_vals [cast (str | None , scheduler_cfg ["monitor" ])]
391391 except KeyError as ex :
392392 possible_keys = list (possible_monitor_vals .keys ())
393393 raise KeyError (
@@ -414,7 +414,7 @@ def progbar_wrapper(self, iterable: Iterable, total: int, **kwargs: Any):
414414 return tqdm (iterable , total = total , ** kwargs )
415415 return iterable
416416
417- def load (self , state : Optional [ Mapping ] , path : str ) -> None :
417+ def load (self , state : Mapping | None , path : str ) -> None :
418418 """Loads a checkpoint from a given file into state.
419419
420420 Args:
@@ -432,7 +432,7 @@ def load(self, state: Optional[Mapping], path: str) -> None:
432432 if remainder :
433433 raise RuntimeError (f"Unused Checkpoint Values: { remainder } " )
434434
435- def save (self , state : Optional [ Mapping ] ) -> None :
435+ def save (self , state : Mapping | None ) -> None :
436436 """Saves a checkpoint to the ``checkpoint_dir``
437437
438438 Args:
@@ -447,7 +447,7 @@ def save(self, state: Optional[Mapping]) -> None:
447447 self .fabric .save (os .path .join (self .checkpoint_dir , f"epoch-{ self .current_epoch :04d} .ckpt" ), state )
448448
449449 @staticmethod
450- def get_latest_checkpoint (checkpoint_dir : str ) -> Optional [ str ] :
450+ def get_latest_checkpoint (checkpoint_dir : str ) -> str | None :
451451 """Returns the latest checkpoint from the ``checkpoint_dir``
452452
453453 Args:
@@ -467,8 +467,8 @@ def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
467467 def _parse_optimizers_schedulers (
468468 self , configure_optim_output
469469 ) -> tuple [
470- Optional [ L .fabric .utilities .types .Optimizable ] ,
471- Optional [ Mapping [str , Union [ L .fabric .utilities .types .LRScheduler , bool , str , int ]]] ,
470+ L .fabric .utilities .types .Optimizable | None ,
471+ Mapping [str , L .fabric .utilities .types .LRScheduler | bool | str | int ] | None ,
472472 ]:
473473 """Recursively parses the output of :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
474474
@@ -521,7 +521,7 @@ def _parse_optimizers_schedulers(
521521
522522 @staticmethod
523523 def _format_iterable (
524- prog_bar , candidates : Optional [ Union [ torch .Tensor , Mapping [str , Union [ torch .Tensor , float , int ]]]] , prefix : str
524+ prog_bar , candidates : torch .Tensor | Mapping [str , torch .Tensor | float | int ] | None , prefix : str
525525 ):
526526 """Adds values as postfix string to progressbar.
527527
0 commit comments