Skip to content

Commit 38b7936

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 927135d commit 38b7936

File tree

195 files changed

+1580
-1662
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

195 files changed

+1580
-1662
lines changed

.actions/assistant.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from itertools import chain
2323
from os.path import dirname, isfile
2424
from pathlib import Path
25-
from typing import Any, Optional
25+
from typing import Any
2626

2727
from packaging.requirements import Requirement
2828
from packaging.version import Version
@@ -48,7 +48,7 @@
4848
class _RequirementWithComment(Requirement):
4949
strict_string = "# strict"
5050

51-
def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
51+
def __init__(self, *args: Any, comment: str = "", pip_argument: str | None = None, **kwargs: Any) -> None:
5252
super().__init__(*args, **kwargs)
5353
self.comment = comment
5454
assert pip_argument is None or pip_argument # sanity check that it's not an empty str
@@ -284,7 +284,7 @@ def copy_replace_imports(
284284
source_dir: str,
285285
source_imports: Sequence[str],
286286
target_imports: Sequence[str],
287-
target_dir: Optional[str] = None,
287+
target_dir: str | None = None,
288288
lightning_by: str = "",
289289
) -> None:
290290
"""Copy package content with import adjustments."""
@@ -346,7 +346,7 @@ def copy_replace_imports(
346346
source_dir: str,
347347
source_import: str,
348348
target_import: str,
349-
target_dir: Optional[str] = None,
349+
target_dir: str | None = None,
350350
lightning_by: str = "",
351351
) -> None:
352352
"""Copy package content with import adjustments."""
@@ -362,7 +362,7 @@ def pull_docs_files(
362362
target_dir: str = "docs/source-pytorch/XXX",
363363
checkout: str = "refs/tags/1.0.0",
364364
source_dir: str = "docs/source",
365-
single_page: Optional[str] = None,
365+
single_page: str | None = None,
366366
as_orphan: bool = False,
367367
) -> None:
368368
"""Pull docs pages from external source and append to local docs.

examples/fabric/build_your_own_trainer/trainer.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from collections.abc import Iterable, Mapping
33
from functools import partial
4-
from typing import Any, Literal, Optional, Union, cast
4+
from typing import Any, Literal, cast
55

66
import torch
77
from lightning_utilities import apply_to_collection
@@ -18,18 +18,18 @@
1818
class MyCustomTrainer:
1919
def __init__(
2020
self,
21-
accelerator: Union[str, Accelerator] = "auto",
22-
strategy: Union[str, Strategy] = "auto",
23-
devices: Union[list[int], str, int] = "auto",
24-
precision: Union[str, int] = "32-true",
25-
plugins: Optional[Union[str, Any]] = None,
26-
callbacks: Optional[Union[list[Any], Any]] = None,
27-
loggers: Optional[Union[Logger, list[Logger]]] = None,
28-
max_epochs: Optional[int] = 1000,
29-
max_steps: Optional[int] = None,
21+
accelerator: str | Accelerator = "auto",
22+
strategy: str | Strategy = "auto",
23+
devices: list[int] | str | int = "auto",
24+
precision: str | int = "32-true",
25+
plugins: str | Any | None = None,
26+
callbacks: list[Any] | Any | None = None,
27+
loggers: Logger | list[Logger] | None = None,
28+
max_epochs: int | None = 1000,
29+
max_steps: int | None = None,
3030
grad_accum_steps: int = 1,
31-
limit_train_batches: Union[int, float] = float("inf"),
32-
limit_val_batches: Union[int, float] = float("inf"),
31+
limit_train_batches: int | float = float("inf"),
32+
limit_val_batches: int | float = float("inf"),
3333
validation_frequency: int = 1,
3434
use_distributed_sampler: bool = True,
3535
checkpoint_dir: str = "./checkpoints",
@@ -115,8 +115,8 @@ def __init__(
115115
self.limit_val_batches = limit_val_batches
116116
self.validation_frequency = validation_frequency
117117
self.use_distributed_sampler = use_distributed_sampler
118-
self._current_train_return: Union[torch.Tensor, Mapping[str, Any]] = {}
119-
self._current_val_return: Optional[Union[torch.Tensor, Mapping[str, Any]]] = {}
118+
self._current_train_return: torch.Tensor | Mapping[str, Any] = {}
119+
self._current_val_return: torch.Tensor | Mapping[str, Any] | None = {}
120120

121121
self.checkpoint_dir = checkpoint_dir
122122
self.checkpoint_frequency = checkpoint_frequency
@@ -126,7 +126,7 @@ def fit(
126126
model: L.LightningModule,
127127
train_loader: torch.utils.data.DataLoader,
128128
val_loader: torch.utils.data.DataLoader,
129-
ckpt_path: Optional[str] = None,
129+
ckpt_path: str | None = None,
130130
):
131131
"""The main entrypoint of the trainer, triggering the actual training.
132132
@@ -196,8 +196,8 @@ def train_loop(
196196
model: L.LightningModule,
197197
optimizer: torch.optim.Optimizer,
198198
train_loader: torch.utils.data.DataLoader,
199-
limit_batches: Union[int, float] = float("inf"),
200-
scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
199+
limit_batches: int | float = float("inf"),
200+
scheduler_cfg: Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None = None,
201201
):
202202
"""The training loop running a single training epoch.
203203
@@ -262,8 +262,8 @@ def train_loop(
262262
def val_loop(
263263
self,
264264
model: L.LightningModule,
265-
val_loader: Optional[torch.utils.data.DataLoader],
266-
limit_batches: Union[int, float] = float("inf"),
265+
val_loader: torch.utils.data.DataLoader | None,
266+
limit_batches: int | float = float("inf"),
267267
):
268268
"""The validation loop running a single validation epoch.
269269
@@ -331,7 +331,7 @@ def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) ->
331331
batch_idx: index of the current batch w.r.t the current epoch
332332
333333
"""
334-
outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx)
334+
outputs: torch.Tensor | Mapping[str, Any] = model.training_step(batch, batch_idx=batch_idx)
335335

336336
loss = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"]
337337

@@ -347,7 +347,7 @@ def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) ->
347347
def step_scheduler(
348348
self,
349349
model: L.LightningModule,
350-
scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
350+
scheduler_cfg: Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None,
351351
level: Literal["step", "epoch"],
352352
current_value: int,
353353
) -> None:
@@ -387,7 +387,7 @@ def step_scheduler(
387387
possible_monitor_vals.update({"val_" + k: v for k, v in self._current_val_return.items()})
388388

389389
try:
390-
monitor = possible_monitor_vals[cast(Optional[str], scheduler_cfg["monitor"])]
390+
monitor = possible_monitor_vals[cast(str | None, scheduler_cfg["monitor"])]
391391
except KeyError as ex:
392392
possible_keys = list(possible_monitor_vals.keys())
393393
raise KeyError(
@@ -414,7 +414,7 @@ def progbar_wrapper(self, iterable: Iterable, total: int, **kwargs: Any):
414414
return tqdm(iterable, total=total, **kwargs)
415415
return iterable
416416

417-
def load(self, state: Optional[Mapping], path: str) -> None:
417+
def load(self, state: Mapping | None, path: str) -> None:
418418
"""Loads a checkpoint from a given file into state.
419419
420420
Args:
@@ -432,7 +432,7 @@ def load(self, state: Optional[Mapping], path: str) -> None:
432432
if remainder:
433433
raise RuntimeError(f"Unused Checkpoint Values: {remainder}")
434434

435-
def save(self, state: Optional[Mapping]) -> None:
435+
def save(self, state: Mapping | None) -> None:
436436
"""Saves a checkpoint to the ``checkpoint_dir``
437437
438438
Args:
@@ -447,7 +447,7 @@ def save(self, state: Optional[Mapping]) -> None:
447447
self.fabric.save(os.path.join(self.checkpoint_dir, f"epoch-{self.current_epoch:04d}.ckpt"), state)
448448

449449
@staticmethod
450-
def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
450+
def get_latest_checkpoint(checkpoint_dir: str) -> str | None:
451451
"""Returns the latest checkpoint from the ``checkpoint_dir``
452452
453453
Args:
@@ -467,8 +467,8 @@ def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
467467
def _parse_optimizers_schedulers(
468468
self, configure_optim_output
469469
) -> tuple[
470-
Optional[L.fabric.utilities.types.Optimizable],
471-
Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
470+
L.fabric.utilities.types.Optimizable | None,
471+
Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None,
472472
]:
473473
"""Recursively parses the output of :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
474474
@@ -521,7 +521,7 @@ def _parse_optimizers_schedulers(
521521

522522
@staticmethod
523523
def _format_iterable(
524-
prog_bar, candidates: Optional[Union[torch.Tensor, Mapping[str, Union[torch.Tensor, float, int]]]], prefix: str
524+
prog_bar, candidates: torch.Tensor | Mapping[str, torch.Tensor | float | int] | None, prefix: str
525525
):
526526
"""Adds values as postfix string to progressbar.
527527

examples/fabric/reinforcement_learning/rl/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
import math
33
import os
4-
from typing import TYPE_CHECKING, Optional, Union
4+
from typing import TYPE_CHECKING, Union
55

66
import gymnasium as gym
77
import torch
@@ -160,7 +160,7 @@ def linear_annealing(optimizer: torch.optim.Optimizer, update: int, num_updates:
160160
pg["lr"] = lrnow
161161

162162

163-
def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: Optional[str] = None, prefix: str = ""):
163+
def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: str | None = None, prefix: str = ""):
164164
def thunk():
165165
env = gym.make(env_id, render_mode="rgb_array")
166166
env = gym.wrappers.RecordEpisodeStatistics(env)

examples/fabric/tensor_parallel/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010

1111
from dataclasses import dataclass
12-
from typing import Optional
1312

1413
import torch
1514
import torch.nn.functional as F
@@ -21,10 +20,10 @@ class ModelArgs:
2120
dim: int = 4096
2221
n_layers: int = 32
2322
n_heads: int = 32
24-
n_kv_heads: Optional[int] = None
23+
n_kv_heads: int | None = None
2524
vocab_size: int = -1 # defined later by tokenizer
2625
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27-
ffn_dim_multiplier: Optional[float] = None
26+
ffn_dim_multiplier: float | None = None
2827
norm_eps: float = 1e-5
2928
rope_theta: float = 10000
3029

@@ -248,7 +247,7 @@ def __init__(
248247
dim: int,
249248
hidden_dim: int,
250249
multiple_of: int,
251-
ffn_dim_multiplier: Optional[float],
250+
ffn_dim_multiplier: float | None,
252251
):
253252
super().__init__()
254253
hidden_dim = int(2 * hidden_dim / 3)

examples/pytorch/basics/autoencoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
"""
1919

2020
from os import path
21-
from typing import Optional
2221

2322
import torch
2423
import torch.nn.functional as F
@@ -46,7 +45,7 @@ def __init__(
4645
nrow: int = 8,
4746
padding: int = 2,
4847
normalize: bool = True,
49-
value_range: Optional[tuple[int, int]] = None,
48+
value_range: tuple[int, int] | None = None,
5049
scale_each: bool = False,
5150
pad_value: int = 0,
5251
) -> None:

examples/pytorch/basics/backbone_image_classifier.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
"""
1919

2020
from os import path
21-
from typing import Optional
2221

2322
import torch
2423
from torch.nn import functional as F
@@ -63,7 +62,7 @@ class LitClassifier(LightningModule):
6362
)
6463
"""
6564

66-
def __init__(self, backbone: Optional[Backbone] = None, learning_rate: float = 0.0001):
65+
def __init__(self, backbone: Backbone | None = None, learning_rate: float = 0.0001):
6766
super().__init__()
6867
self.save_hyperparameters(ignore=["backbone"])
6968
if backbone is None:

examples/pytorch/domain_templates/computer_vision_fine_tuning.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
import logging
4444
from pathlib import Path
45-
from typing import Union
4645

4746
import torch
4847
import torch.nn.functional as F
@@ -91,7 +90,7 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O
9190

9291

9392
class CatDogImageDataModule(LightningDataModule):
94-
def __init__(self, dl_path: Union[str, Path] = "data", num_workers: int = 0, batch_size: int = 8):
93+
def __init__(self, dl_path: str | Path = "data", num_workers: int = 0, batch_size: int = 8):
9594
"""CatDogImageDataModule.
9695
9796
Args:

examples/pytorch/domain_templates/imagenet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
"""
3333

3434
import os
35-
from typing import Optional
3635

3736
import torch
3837
import torch.nn.functional as F
@@ -65,7 +64,7 @@ def __init__(
6564
self,
6665
data_path: str,
6766
arch: str = "resnet18",
68-
weights: Optional[str] = None,
67+
weights: str | None = None,
6968
lr: float = 0.1,
7069
momentum: float = 0.9,
7170
weight_decay: float = 1e-4,
@@ -82,8 +81,8 @@ def __init__(
8281
self.batch_size = batch_size
8382
self.workers = workers
8483
self.model = get_torchvision_model(self.arch, weights=self.weights)
85-
self.train_dataset: Optional[Dataset] = None
86-
self.eval_dataset: Optional[Dataset] = None
84+
self.train_dataset: Dataset | None = None
85+
self.eval_dataset: Dataset | None = None
8786
# ToDo: this number of classes hall be parsed when the dataset is loaded from folder
8887
self.train_acc1 = Accuracy(task="multiclass", num_classes=1000, top_k=1)
8988
self.train_acc5 = Accuracy(task="multiclass", num_classes=1000, top_k=5)

examples/pytorch/domain_templates/reinforce_learn_ppo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
"""
3131

3232
import argparse
33-
from collections.abc import Iterator
34-
from typing import Callable
33+
from collections.abc import Callable, Iterator
3534

3635
import gym
3736
import torch

examples/pytorch/servable_module/production.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from dataclasses import dataclass
33
from io import BytesIO
44
from os import path
5-
from typing import Optional
65

76
import numpy as np
87
import torch
@@ -56,8 +55,8 @@ def val_dataloader(self, *args, **kwargs):
5655

5756
@dataclass(unsafe_hash=True)
5857
class Image:
59-
height: Optional[int] = None
60-
width: Optional[int] = None
58+
height: int | None = None
59+
width: int | None = None
6160
extension: str = "JPEG"
6261
mode: str = "RGB"
6362
channel_first: bool = False

0 commit comments

Comments
 (0)