Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .actions/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from itertools import chain
from os.path import dirname, isfile
from pathlib import Path
from typing import Any, Optional
from typing import Any

from packaging.requirements import Requirement
from packaging.version import Version
Expand All @@ -48,7 +48,7 @@
class _RequirementWithComment(Requirement):
strict_cmd = "strict"

def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
def __init__(self, *args: Any, comment: str = "", pip_argument: str | None = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.comment = comment
assert pip_argument is None or pip_argument # sanity check that it's not an empty str
Expand Down Expand Up @@ -285,7 +285,7 @@ def copy_replace_imports(
source_dir: str,
source_imports: Sequence[str],
target_imports: Sequence[str],
target_dir: Optional[str] = None,
target_dir: str | None = None,
lightning_by: str = "",
) -> None:
"""Copy package content with import adjustments."""
Expand Down Expand Up @@ -347,7 +347,7 @@ def copy_replace_imports(
source_dir: str,
source_import: str,
target_import: str,
target_dir: Optional[str] = None,
target_dir: str | None = None,
lightning_by: str = "",
) -> None:
"""Copy package content with import adjustments."""
Expand All @@ -363,7 +363,7 @@ def pull_docs_files(
target_dir: str = "docs/source-pytorch/XXX",
checkout: str = "refs/tags/1.0.0",
source_dir: str = "docs/source",
single_page: Optional[str] = None,
single_page: str | None = None,
as_orphan: bool = False,
) -> None:
"""Pull docs pages from external source and append to local docs.
Expand Down
56 changes: 28 additions & 28 deletions examples/fabric/build_your_own_trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from collections.abc import Iterable, Mapping
from functools import partial
from typing import Any, Literal, Optional, Union, cast
from typing import Any, Literal, cast

import torch
from lightning_utilities import apply_to_collection
Expand All @@ -18,18 +18,18 @@
class MyCustomTrainer:
def __init__(
self,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
devices: Union[list[int], str, int] = "auto",
precision: Union[str, int] = "32-true",
plugins: Optional[Union[str, Any]] = None,
callbacks: Optional[Union[list[Any], Any]] = None,
loggers: Optional[Union[Logger, list[Logger]]] = None,
max_epochs: Optional[int] = 1000,
max_steps: Optional[int] = None,
accelerator: str | Accelerator = "auto",
strategy: str | Strategy = "auto",
devices: list[int] | str | int = "auto",
precision: str | int = "32-true",
plugins: str | Any | None = None,
callbacks: list[Any] | Any | None = None,
loggers: Logger | list[Logger] | None = None,
max_epochs: int | None = 1000,
max_steps: int | None = None,
grad_accum_steps: int = 1,
limit_train_batches: Union[int, float] = float("inf"),
limit_val_batches: Union[int, float] = float("inf"),
limit_train_batches: int | float = float("inf"),
limit_val_batches: int | float = float("inf"),
validation_frequency: int = 1,
use_distributed_sampler: bool = True,
checkpoint_dir: str = "./checkpoints",
Expand Down Expand Up @@ -115,8 +115,8 @@ def __init__(
self.limit_val_batches = limit_val_batches
self.validation_frequency = validation_frequency
self.use_distributed_sampler = use_distributed_sampler
self._current_train_return: Union[torch.Tensor, Mapping[str, Any]] = {}
self._current_val_return: Optional[Union[torch.Tensor, Mapping[str, Any]]] = {}
self._current_train_return: torch.Tensor | Mapping[str, Any] = {}
self._current_val_return: torch.Tensor | Mapping[str, Any] | None = {}

self.checkpoint_dir = checkpoint_dir
self.checkpoint_frequency = checkpoint_frequency
Expand All @@ -126,7 +126,7 @@ def fit(
model: L.LightningModule,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
ckpt_path: Optional[str] = None,
ckpt_path: str | None = None,
):
"""The main entrypoint of the trainer, triggering the actual training.

Expand Down Expand Up @@ -196,8 +196,8 @@ def train_loop(
model: L.LightningModule,
optimizer: torch.optim.Optimizer,
train_loader: torch.utils.data.DataLoader,
limit_batches: Union[int, float] = float("inf"),
scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
limit_batches: int | float = float("inf"),
scheduler_cfg: Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None = None,
):
"""The training loop running a single training epoch.

Expand Down Expand Up @@ -262,8 +262,8 @@ def train_loop(
def val_loop(
self,
model: L.LightningModule,
val_loader: Optional[torch.utils.data.DataLoader],
limit_batches: Union[int, float] = float("inf"),
val_loader: torch.utils.data.DataLoader | None,
limit_batches: int | float = float("inf"),
):
"""The validation loop running a single validation epoch.

Expand Down Expand Up @@ -331,7 +331,7 @@ def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) ->
batch_idx: index of the current batch w.r.t the current epoch

"""
outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx)
outputs: torch.Tensor | Mapping[str, Any] = model.training_step(batch, batch_idx=batch_idx)

loss = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"]

Expand All @@ -347,7 +347,7 @@ def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) ->
def step_scheduler(
self,
model: L.LightningModule,
scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
scheduler_cfg: Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None,
level: Literal["step", "epoch"],
current_value: int,
) -> None:
Expand Down Expand Up @@ -387,7 +387,7 @@ def step_scheduler(
possible_monitor_vals.update({"val_" + k: v for k, v in self._current_val_return.items()})

try:
monitor = possible_monitor_vals[cast(Optional[str], scheduler_cfg["monitor"])]
monitor = possible_monitor_vals[cast(str | None, scheduler_cfg["monitor"])]
except KeyError as ex:
possible_keys = list(possible_monitor_vals.keys())
raise KeyError(
Expand All @@ -414,7 +414,7 @@ def progbar_wrapper(self, iterable: Iterable, total: int, **kwargs: Any):
return tqdm(iterable, total=total, **kwargs)
return iterable

def load(self, state: Optional[Mapping], path: str) -> None:
def load(self, state: Mapping | None, path: str) -> None:
"""Loads a checkpoint from a given file into state.

Args:
Expand All @@ -432,7 +432,7 @@ def load(self, state: Optional[Mapping], path: str) -> None:
if remainder:
raise RuntimeError(f"Unused Checkpoint Values: {remainder}")

def save(self, state: Optional[Mapping]) -> None:
def save(self, state: Mapping | None) -> None:
"""Saves a checkpoint to the ``checkpoint_dir``

Args:
Expand All @@ -447,7 +447,7 @@ def save(self, state: Optional[Mapping]) -> None:
self.fabric.save(os.path.join(self.checkpoint_dir, f"epoch-{self.current_epoch:04d}.ckpt"), state)

@staticmethod
def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
def get_latest_checkpoint(checkpoint_dir: str) -> str | None:
"""Returns the latest checkpoint from the ``checkpoint_dir``

Args:
Expand All @@ -467,8 +467,8 @@ def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
def _parse_optimizers_schedulers(
self, configure_optim_output
) -> tuple[
Optional[L.fabric.utilities.types.Optimizable],
Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
L.fabric.utilities.types.Optimizable | None,
Mapping[str, L.fabric.utilities.types.LRScheduler | bool | str | int] | None,
]:
"""Recursively parses the output of :meth:`lightning.pytorch.LightningModule.configure_optimizers`.

Expand Down Expand Up @@ -521,7 +521,7 @@ def _parse_optimizers_schedulers(

@staticmethod
def _format_iterable(
prog_bar, candidates: Optional[Union[torch.Tensor, Mapping[str, Union[torch.Tensor, float, int]]]], prefix: str
prog_bar, candidates: torch.Tensor | Mapping[str, torch.Tensor | float | int] | None, prefix: str
):
"""Adds values as postfix string to progressbar.

Expand Down
4 changes: 2 additions & 2 deletions examples/fabric/reinforcement_learning/rl/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import math
import os
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Union

import gymnasium as gym
import torch
Expand Down Expand Up @@ -160,7 +160,7 @@ def linear_annealing(optimizer: torch.optim.Optimizer, update: int, num_updates:
pg["lr"] = lrnow


def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: Optional[str] = None, prefix: str = ""):
def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: str | None = None, prefix: str = ""):
def thunk():
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordEpisodeStatistics(env)
Expand Down
7 changes: 3 additions & 4 deletions examples/fabric/tensor_parallel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn.functional as F
Expand All @@ -21,10 +20,10 @@ class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
n_kv_heads: int | None = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float | None = None
norm_eps: float = 1e-5
rope_theta: float = 10000

Expand Down Expand Up @@ -248,7 +247,7 @@ def __init__(
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
ffn_dim_multiplier: float | None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
Expand Down
3 changes: 1 addition & 2 deletions examples/pytorch/basics/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""

from os import path
from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -46,7 +45,7 @@ def __init__(
nrow: int = 8,
padding: int = 2,
normalize: bool = True,
value_range: Optional[tuple[int, int]] = None,
value_range: tuple[int, int] | None = None,
scale_each: bool = False,
pad_value: int = 0,
) -> None:
Expand Down
3 changes: 1 addition & 2 deletions examples/pytorch/basics/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""

from os import path
from typing import Optional

import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -63,7 +62,7 @@ class LitClassifier(LightningModule):
)
"""

def __init__(self, backbone: Optional[Backbone] = None, learning_rate: float = 0.0001):
def __init__(self, backbone: Backbone | None = None, learning_rate: float = 0.0001):
super().__init__()
self.save_hyperparameters(ignore=["backbone"])
if backbone is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

import logging
from pathlib import Path
from typing import Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -91,7 +90,7 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O


class CatDogImageDataModule(LightningDataModule):
def __init__(self, dl_path: Union[str, Path] = "data", num_workers: int = 0, batch_size: int = 8):
def __init__(self, dl_path: str | Path = "data", num_workers: int = 0, batch_size: int = 8):
"""CatDogImageDataModule.

Args:
Expand Down
7 changes: 3 additions & 4 deletions examples/pytorch/domain_templates/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
"""

import os
from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -65,7 +64,7 @@ def __init__(
self,
data_path: str,
arch: str = "resnet18",
weights: Optional[str] = None,
weights: str | None = None,
lr: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
Expand All @@ -82,8 +81,8 @@ def __init__(
self.batch_size = batch_size
self.workers = workers
self.model = get_torchvision_model(self.arch, weights=self.weights)
self.train_dataset: Optional[Dataset] = None
self.eval_dataset: Optional[Dataset] = None
self.train_dataset: Dataset | None = None
self.eval_dataset: Dataset | None = None
# ToDo: this number of classes hall be parsed when the dataset is loaded from folder
self.train_acc1 = Accuracy(task="multiclass", num_classes=1000, top_k=1)
self.train_acc5 = Accuracy(task="multiclass", num_classes=1000, top_k=5)
Expand Down
3 changes: 1 addition & 2 deletions examples/pytorch/domain_templates/reinforce_learn_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
"""

import argparse
from collections.abc import Iterator
from typing import Callable
from collections.abc import Callable, Iterator

import gym
import torch
Expand Down
5 changes: 2 additions & 3 deletions examples/pytorch/servable_module/production.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from io import BytesIO
from os import path
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -56,8 +55,8 @@ def val_dataloader(self, *args, **kwargs):

@dataclass(unsafe_hash=True)
class Image:
height: Optional[int] = None
width: Optional[int] = None
height: int | None = None
width: int | None = None
extension: str = "JPEG"
mode: str = "RGB"
channel_first: bool = False
Expand Down
7 changes: 3 additions & 4 deletions examples/pytorch/tensor_parallel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn.functional as F
Expand All @@ -21,10 +20,10 @@ class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
n_kv_heads: int | None = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float | None = None
norm_eps: float = 1e-5
rope_theta: float = 10000

Expand Down Expand Up @@ -248,7 +247,7 @@ def __init__(
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
ffn_dim_multiplier: float | None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ignore-words-list = "te, compiletime"

[tool.ruff]
line-length = 120
target-version = "py39"
target-version = "py310"
# Exclude a variety of commonly ignored directories.
exclude = [
".git",
Expand Down
Loading
Loading