From 958f95e59a64d703f80ff4bb7c96df807175a927 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 15:43:58 +0800 Subject: [PATCH 01/37] feat: introduce `logger_map` property. --- .../logger_connector/logger_connector.py | 4 ++- src/lightning/pytorch/trainer/trainer.py | 30 ++++++++++++++----- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index f6e8885ee050a..f49121c7ac897 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from typing import Any, Optional, Union from lightning_utilities.core.apply_func import apply_to_collection @@ -82,6 +82,8 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non ) logger_ = CSVLogger(save_dir=self.trainer.default_root_dir) # type: ignore[assignment] self.trainer.loggers = [logger_] + elif isinstance(logger, Mapping): + self.trainer.loggers = logger elif isinstance(logger, Iterable): self.trainer.loggers = list(logger) else: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5768c507e2e3f..87fdabaa1b841 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -23,7 +23,7 @@ import logging import math import os -from collections.abc import Generator, Iterable +from collections.abc import Generator, Iterable, Mapping from contextlib import contextmanager from datetime import timedelta from typing import Any, Optional, Union @@ -1625,13 +1625,12 @@ def logger(self) -> Optional[Logger]: @logger.setter def logger(self, logger: Optional[Logger]) -> None: - if not logger: - self.loggers = [] - else: - self.loggers = [logger] + if not isinstance(logger, Logger) or logger is not None: + raise TypeError(f"The `Trainer.logger` property must be of type `Logger`, get {type(logger)} instead.") + self.loggers = [logger] @property - def loggers(self) -> list[Logger]: + def loggers(self) -> list[Logger] | dict[str, Logger]: """The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python @@ -1644,7 +1643,24 @@ def loggers(self) -> list[Logger]: @loggers.setter def loggers(self, loggers: Optional[list[Logger]]) -> None: - self._loggers = loggers if loggers else [] + if isinstance(loggers, Mapping): + self._loggers = list(loggers.values()) + self._logger_keys = list(loggers.keys()) + else: + self._loggers = loggers if loggers else [] + self._logger_keys = list(range(len(self._loggers))) + + @property + def logger_map(self) -> dict[str | int, Logger]: + """A mapping of logger keys to :class:`~lightning.pytorch.loggers.logger.Logger` used. + + .. code-block:: python + tb_logger = trainer.logger_map.get("tensorboard", None) + if tb_logger: + tb_logger.log_hyperparams({"lr": 0.001}) + + """ + return dict(zip(self._logger_keys, self._loggers)) @property def callback_metrics(self) -> _OUT_DICT: From 711fb4f949ca370b6f20225c1755ffc261cbced3 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 17:06:31 +0800 Subject: [PATCH 02/37] revert trainer.logger change. --- src/lightning/pytorch/trainer/trainer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 87fdabaa1b841..917a2bedb5157 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1625,12 +1625,13 @@ def logger(self) -> Optional[Logger]: @logger.setter def logger(self, logger: Optional[Logger]) -> None: - if not isinstance(logger, Logger) or logger is not None: - raise TypeError(f"The `Trainer.logger` property must be of type `Logger`, get {type(logger)} instead.") - self.loggers = [logger] + if not logger: + self.loggers = [] + else: + self.loggers = [logger] @property - def loggers(self) -> list[Logger] | dict[str, Logger]: + def loggers(self) -> list[Logger]: """The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python @@ -1642,7 +1643,7 @@ def loggers(self) -> list[Logger] | dict[str, Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[list[Logger]]) -> None: + def loggers(self, loggers: Optional[list[Logger] | Mapping[str, Logger]]) -> None: if isinstance(loggers, Mapping): self._loggers = list(loggers.values()) self._logger_keys = list(loggers.keys()) From 791dbdc20e463f62c71eeec71b58e2b9f25e06b4 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 17:06:55 +0800 Subject: [PATCH 03/37] add tests. --- .../tests_pytorch/trainer/properties/test_loggers.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 1d07f3e99d412..24a4bccb2b995 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -59,10 +59,12 @@ def test_trainer_loggers_setters(): trainer.logger = None assert trainer.logger is None assert trainer.loggers == [] + assert isinstance(trainer.loggers, list) # Test setters for trainer.loggers trainer.loggers = [logger1, logger2] assert trainer.loggers == [logger1, logger2] + assert isinstance(trainer.loggers, list) trainer.loggers = [logger1] assert trainer.loggers == [logger1] @@ -71,10 +73,18 @@ def test_trainer_loggers_setters(): trainer.loggers = [] assert trainer.loggers == [] assert trainer.logger is None + assert isinstance(trainer.loggers, list) trainer.loggers = None assert trainer.loggers == [] assert trainer.logger is None + assert isinstance(trainer.loggers, list) + + trainer.loggers = {"log1": logger1, "log2": logger2} + assert trainer.loggers == [logger1, logger2] + assert isinstance(trainer.loggers, list) + assert isinstance(trainer.logger_map, dict) + assert trainer.logger_map == {"log1": logger1, "log2": logger2} @pytest.mark.parametrize( @@ -94,3 +104,4 @@ def test_no_logger(tmp_path, logger_value): assert trainer.logger is None assert trainer.loggers == [] assert trainer.log_dir == str(tmp_path) + assert trainer.logger_map == {} From db515295c490aafe3fe2860d81fc20833520ada0 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 17:36:25 +0800 Subject: [PATCH 04/37] add test. --- tests/tests_pytorch/trainer/properties/test_loggers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 24a4bccb2b995..7b6f87538ac80 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -41,6 +41,12 @@ def test_trainer_loggers_property(): assert trainer.loggers == [trainer.logger] assert isinstance(trainer.logger, TensorBoardLogger) + trainer = Trainer(logger={"log1": logger1, "log2": logger2}) + assert trainer.logger == logger1 + assert trainer.loggers == [logger1, logger2] + assert isinstance(trainer.logger_map, dict) + assert trainer.logger_map == {"log1": logger1, "log2": logger2} + def test_trainer_loggers_setters(): """Test the behavior of setters for trainer.logger and trainer.loggers.""" From 7cb138258ed203121d4f02ca91f5eeac590a4792 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 17:44:04 +0800 Subject: [PATCH 05/37] add test. --- tests/tests_pytorch/trainer/properties/test_loggers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 7b6f87538ac80..95caec01caa82 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -28,12 +28,14 @@ def test_trainer_loggers_property(): trainer = Trainer(logger=[logger1, logger2]) assert trainer.loggers == [logger1, logger2] + assert trainer.logger_map == {0: logger1, 1: logger2} # trainer.loggers should create a list of size 1 trainer = Trainer(logger=logger1) assert trainer.logger == logger1 assert trainer.loggers == [logger1] + assert trainer.logger_map == {0: logger1} # trainer.loggers should be a list of size 1 holding the default logger trainer = Trainer(logger=True) @@ -66,25 +68,30 @@ def test_trainer_loggers_setters(): assert trainer.logger is None assert trainer.loggers == [] assert isinstance(trainer.loggers, list) + assert trainer.logger_map == {} # Test setters for trainer.loggers trainer.loggers = [logger1, logger2] assert trainer.loggers == [logger1, logger2] assert isinstance(trainer.loggers, list) + assert trainer.logger_map == {0: logger1, 1: logger2} trainer.loggers = [logger1] assert trainer.loggers == [logger1] assert trainer.logger == logger1 + assert trainer.logger_map == {0: logger1} trainer.loggers = [] assert trainer.loggers == [] assert trainer.logger is None assert isinstance(trainer.loggers, list) + assert trainer.logger_map == {} trainer.loggers = None assert trainer.loggers == [] assert trainer.logger is None assert isinstance(trainer.loggers, list) + assert trainer.logger_map == {} trainer.loggers = {"log1": logger1, "log2": logger2} assert trainer.loggers == [logger1, logger2] From b1ea7d3062bf75e1c5b6fff0ade568d84de7026a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 17:48:25 +0800 Subject: [PATCH 06/37] fix pylint --- src/lightning/pytorch/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 917a2bedb5157..f904bde1d551b 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1644,6 +1644,7 @@ def loggers(self) -> list[Logger]: @loggers.setter def loggers(self, loggers: Optional[list[Logger] | Mapping[str, Logger]]) -> None: + self._logger_keys: list[str | int] if isinstance(loggers, Mapping): self._loggers = list(loggers.values()) self._logger_keys = list(loggers.keys()) From 6bbc98d30da23576fb6b7433cf2497e199a96eb3 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 17:49:51 +0800 Subject: [PATCH 07/37] fix pylint --- src/lightning/pytorch/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index f904bde1d551b..db8c674780902 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1643,8 +1643,8 @@ def loggers(self) -> list[Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[list[Logger] | Mapping[str, Logger]]) -> None: - self._logger_keys: list[str | int] + def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger]]]) -> None: + self._logger_keys: list[Union[str, int]] if isinstance(loggers, Mapping): self._loggers = list(loggers.values()) self._logger_keys = list(loggers.keys()) @@ -1653,7 +1653,7 @@ def loggers(self, loggers: Optional[list[Logger] | Mapping[str, Logger]]) -> Non self._logger_keys = list(range(len(self._loggers))) @property - def logger_map(self) -> dict[str | int, Logger]: + def logger_map(self) -> dict[Union[str, int], Logger]: """A mapping of logger keys to :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python From 56ea5e8ceb53273267dd2902db3340c74b60ba4c Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 22:12:49 +0800 Subject: [PATCH 08/37] add test. --- tests/tests_pytorch/trainer/properties/test_loggers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 95caec01caa82..2d96f5ba90095 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -37,6 +37,10 @@ def test_trainer_loggers_property(): assert trainer.loggers == [logger1] assert trainer.logger_map == {0: logger1} + trainer.loggers.append(logger2) + assert trainer.loggers == [logger1, logger2] + assert trainer.logger_map == {0: logger1} + # trainer.loggers should be a list of size 1 holding the default logger trainer = Trainer(logger=True) From d03198557710ae8118c85df42824374c2256c052 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 22:52:18 +0800 Subject: [PATCH 09/37] refactor loggers setter. --- src/lightning/pytorch/trainer/trainer.py | 4 ++-- .../trainer/properties/test_loggers.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index db8c674780902..7be34f0b9cf5e 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1650,7 +1650,7 @@ def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger]]]) self._logger_keys = list(loggers.keys()) else: self._loggers = loggers if loggers else [] - self._logger_keys = list(range(len(self._loggers))) + self._logger_keys = None @property def logger_map(self) -> dict[Union[str, int], Logger]: @@ -1662,7 +1662,7 @@ def logger_map(self) -> dict[Union[str, int], Logger]: tb_logger.log_hyperparams({"lr": 0.001}) """ - return dict(zip(self._logger_keys, self._loggers)) + return dict(zip(self._logger_keys or [], self._loggers)) @property def callback_metrics(self) -> _OUT_DICT: diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 2d96f5ba90095..3cc496c9e60b9 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -23,23 +23,24 @@ def test_trainer_loggers_property(): """Test for correct initialization of loggers in Trainer.""" logger1 = CustomLogger() logger2 = CustomLogger() + logger3 = CustomLogger() # trainer.loggers should be a copy of the input list trainer = Trainer(logger=[logger1, logger2]) assert trainer.loggers == [logger1, logger2] - assert trainer.logger_map == {0: logger1, 1: logger2} + assert trainer.logger_map == {} # trainer.loggers should create a list of size 1 trainer = Trainer(logger=logger1) assert trainer.logger == logger1 assert trainer.loggers == [logger1] - assert trainer.logger_map == {0: logger1} + assert trainer.logger_map == {} trainer.loggers.append(logger2) assert trainer.loggers == [logger1, logger2] - assert trainer.logger_map == {0: logger1} + assert trainer.logger_map == {} # trainer.loggers should be a list of size 1 holding the default logger trainer = Trainer(logger=True) @@ -53,11 +54,16 @@ def test_trainer_loggers_property(): assert isinstance(trainer.logger_map, dict) assert trainer.logger_map == {"log1": logger1, "log2": logger2} + trainer.loggers.append(logger3) + assert trainer.loggers == [logger1, logger2, logger3] + assert trainer.logger_map == {"log1": logger1, "log2": logger2} + def test_trainer_loggers_setters(): """Test the behavior of setters for trainer.logger and trainer.loggers.""" logger1 = CustomLogger() logger2 = CustomLogger() + logger3 = CustomLogger() trainer = Trainer() assert type(trainer.logger) is TensorBoardLogger @@ -78,12 +84,12 @@ def test_trainer_loggers_setters(): trainer.loggers = [logger1, logger2] assert trainer.loggers == [logger1, logger2] assert isinstance(trainer.loggers, list) - assert trainer.logger_map == {0: logger1, 1: logger2} + assert trainer.logger_map == {} trainer.loggers = [logger1] assert trainer.loggers == [logger1] assert trainer.logger == logger1 - assert trainer.logger_map == {0: logger1} + assert trainer.logger_map == {} trainer.loggers = [] assert trainer.loggers == [] @@ -103,6 +109,9 @@ def test_trainer_loggers_setters(): assert isinstance(trainer.logger_map, dict) assert trainer.logger_map == {"log1": logger1, "log2": logger2} + trainer.loggers.append(logger3) + assert trainer.logger_map == {"log1": logger1, "log2": logger2} + @pytest.mark.parametrize( "logger_value", From e8d36953bda9b1fa3df0da07f107abe52988cb1e Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 19 Oct 2025 23:19:46 +0800 Subject: [PATCH 10/37] fix pylint. --- src/lightning/pytorch/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 7be34f0b9cf5e..f89e9d57cbd9d 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1650,7 +1650,7 @@ def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger]]]) self._logger_keys = list(loggers.keys()) else: self._loggers = loggers if loggers else [] - self._logger_keys = None + self._logger_keys = [] @property def logger_map(self) -> dict[Union[str, int], Logger]: @@ -1662,7 +1662,7 @@ def logger_map(self) -> dict[Union[str, int], Logger]: tb_logger.log_hyperparams({"lr": 0.001}) """ - return dict(zip(self._logger_keys or [], self._loggers)) + return dict(zip(self._logger_keys, self._loggers)) @property def callback_metrics(self) -> _OUT_DICT: From 01aaa41d2efdb19fc98a36ffe99caa942a4a7add Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 22 Oct 2025 02:10:06 +0800 Subject: [PATCH 11/37] _ListMap integration. --- src/lightning/pytorch/loggers/utilities.py | 111 ++++++++++++- src/lightning/pytorch/trainer/trainer.py | 24 +-- tests/tests_pytorch/loggers/test_utilities.py | 148 +++++++++++++++++- 3 files changed, 260 insertions(+), 23 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index ced8a6f1f2bd3..f3e200462e2f4 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -13,8 +13,9 @@ # limitations under the License. """Utilities for loggers.""" +from collections.abc import Mapping from pathlib import Path -from typing import Any, Union +from typing import Any, Optional, Self, TypeVar, Union from torch import Tensor @@ -100,3 +101,111 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: logger.log_hyperparams(hparams_initial) logger.log_graph(pl_module) logger.save() + + +_T = TypeVar("_T") + + +class _ListMap(list[_T]): + """A hybrid container for loggers allowing both index and name access.""" + + def __init__(self, loggers: Union[list[_T], Mapping[str, _T]] = None): + if isinstance(loggers, Mapping): + # super inits list with values + if any(not isinstance(x, str) for x in loggers): + raise TypeError("When providing a Mapping, all keys must be of type str.") + super().__init__(loggers.values()) + self._dict = dict(zip(loggers.keys(), range(len(loggers)))) + else: + super().__init__(() if loggers is None else loggers) + self._dict: dict = {} + + def __eq__(self, other): + self_list = list(self) + if isinstance(other, _ListMap): + return self_list == list(other) and self._dict == other._dict + if isinstance(other, list): + return self_list == other + return False + + # --- List-like interface --- + def __getitem__(self, key: Union[int, slice, str]) -> _T: + if isinstance(key, (int, slice)): + return list.__getitem__(self, key) + if isinstance(key, str): + return list.__getitem__(self, self._dict[key]) + raise TypeError("Key must be int / slice (for index) or str (for name).") + + def __add__(self, other: Union[list[_T], Self]) -> list[_T]: + # todo + return list.__add__(self, other) + + def __iadd__(self, other: Union[list[_T], Self]) -> Self: + # todo + return list.__iadd__(self, other) + + def __setitem__(self, key, value): + if isinstance(key, (int, slice)): + # replace element by index + return list.__setitem__(self, key, value) + if isinstance(key, str): + # replace or insert by name + if key in self._dict: + list.__setitem__(self, self._dict[key], value) + else: + self.append(value) + self._dict[key] = len(self) - 1 + return None + raise TypeError("Key must be int or str") + + def __contains__(self, item): + if isinstance(item, str): + return item in self._dict + return list.__contains__(self, item) + + # --- Dict-like interface --- + + def __delitem__(self, key): + if isinstance(key, (int, slice)): + loggers = list.__getitem__(self, key) + super(list, self).__delitem__(key) + for logger in loggers if isinstance(key, slice) else [loggers]: + name = getattr(logger, "name", None) + if name: + self._dict.pop(name, None) + elif isinstance(key, str): + logger = self._dict.pop(key) + self.remove(logger) + else: + raise TypeError("Key must be int or str") + + def keys(self): + return self._dict.keys() + + def values(self): + d = {k: self[v] for k, v in self._dict.items()} + return d.values() + + def items(self): + d = {k: self[v] for k, v in self._dict.items()} + return d.items() + + # --- List and Dict interface --- + def pop(self, key: Union[int, str] = -1, default: Optional[Any] = None) -> _T: + if isinstance(key, int): + ret = list.pop(self, key) + for str_key, idx in list(self._dict.items()): + if idx == key: + self._dict.pop(str_key) + elif idx > key: + self._dict[str_key] = idx - 1 + return ret + if isinstance(key, str): + if key not in self._dict: + return default + return self.pop(self._dict[key]) + raise TypeError("Key must be int or str") + + def __repr__(self): + ret = super().__repr__() + return f"_ListMap({ret}, keys={list(self._dict.keys())})" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index f89e9d57cbd9d..a60e71eaac4a3 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -43,7 +43,7 @@ from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers.csv_logs import CSVLogger from lightning.pytorch.loggers.tensorboard import TensorBoardLogger -from lightning.pytorch.loggers.utilities import _log_hyperparams +from lightning.pytorch.loggers.utilities import _ListMap, _log_hyperparams from lightning.pytorch.loops import _PredictionLoop, _TrainingEpochLoop from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop from lightning.pytorch.loops.fit_loop import _FitLoop @@ -1631,7 +1631,7 @@ def logger(self, logger: Optional[Logger]) -> None: self.loggers = [logger] @property - def loggers(self) -> list[Logger]: + def loggers(self) -> _ListMap[Logger]: """The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python @@ -1644,25 +1644,7 @@ def loggers(self) -> list[Logger]: @loggers.setter def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger]]]) -> None: - self._logger_keys: list[Union[str, int]] - if isinstance(loggers, Mapping): - self._loggers = list(loggers.values()) - self._logger_keys = list(loggers.keys()) - else: - self._loggers = loggers if loggers else [] - self._logger_keys = [] - - @property - def logger_map(self) -> dict[Union[str, int], Logger]: - """A mapping of logger keys to :class:`~lightning.pytorch.loggers.logger.Logger` used. - - .. code-block:: python - tb_logger = trainer.logger_map.get("tensorboard", None) - if tb_logger: - tb_logger.log_hyperparams({"lr": 0.001}) - - """ - return dict(zip(self._logger_keys, self._loggers)) + self._loggers = _ListMap(loggers) @property def callback_metrics(self) -> _OUT_DICT: diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index d83a95dda9535..86c2e7b99de02 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from lightning.pytorch.loggers import CSVLogger -from lightning.pytorch.loggers.utilities import _version +from lightning.pytorch.loggers.utilities import _ListMap, _version def test_version(tmp_path): @@ -31,3 +33,147 @@ def test_version(tmp_path): assert version == "0_2_1" version = _version(loggers, "-") assert version == "0-2-1" + + +@pytest.mark.parametrize( + "args", + [ + (1, 2), + [1, 2], + {1, 2}, + range(2), + ], +) +def test_listmap_init(args): + """Test initialization with different iterable types.""" + lm = _ListMap(args) + assert len(lm) == len(args) + assert isinstance(lm, list) + + +def test_listmap_append(): + """Test appending loggers to the collection.""" + lm = _ListMap() + lm.append(1) + assert len(lm) == 1 + lm.append(2) + assert len(lm) == 2 + + +def test_listmap_extend(): + # extent + lm = _ListMap([1, 2]) + lm.extend([1, 2, 3]) + assert len(lm) == 5 + assert lm == [1, 2, 1, 2, 3] + + +def test_listmap_pop(): + lm = _ListMap([1, 2, 3, 4]) + item = lm.pop() + assert item == 4 + assert len(lm) == 3 + item = lm.pop(1) + assert item == 2 + assert len(lm) == 2 + assert lm == [1, 3] + + +def test_listmap_getitem(): + """Test getting items from the collection.""" + lm = _ListMap([1, 2]) + assert lm[0] == 1 + assert lm[1] == 2 + + +def test_listmap_setitem(): + """Test setting items in the collection.""" + lm = _ListMap([1, 2, 3]) + lm[0] = 10 + assert lm == [10, 2, 3] + lm[1:3] = [20, 30] + assert lm == [10, 20, 30] + + +def test_listmap_add(): + """Test adding two collections together.""" + lm1 = _ListMap([1, 2]) + lm2 = _ListMap([3, 4]) + combined = lm1 + lm2 + assert isinstance(combined, list) + assert len(combined) == 4 + assert combined[0] == 1 + assert combined[1] == 2 + assert combined[2] == 3 + assert combined[3] == 4 + + combined += lm1 + assert isinstance(combined, list) + assert len(combined) == 6 + for item, expected in zip(combined, [1, 2, 3, 4, 1, 2]): + assert item == expected + + +def test_listmap_remove(): + """Test removing items from the collection.""" + lm = _ListMap([1, 2, 3]) + lm.remove(2) + assert len(lm) == 2 + assert 2 not in lm + + +# Dict type properties tests +def test_listmap_keys(): + lm = _ListMap({ + "a": 1, + "b": 2, + "c": 3, + }) + keys = lm.keys() + assert set(keys) == {"a", "b", "c"} + assert "a" in lm + assert "d" not in lm + + +def test_listmap_values(): + lm = _ListMap({ + "a": 1, + "b": 2, + "c": 3, + }) + values = lm.values() + assert set(values) == {1, 2, 3} + + +def test_listmap_dict_items(): + lm = _ListMap({ + "a": 1, + "b": 2, + "c": 3, + }) + items = lm.items() + assert set(items) == {("a", 1), ("b", 2), ("c", 3)} + + +def test_listmap_dict_pop(): + lm = _ListMap({ + "a": 1, + "b": 2, + "c": 3, + }) + value = lm.pop("b") + assert value == 2 + assert "b" not in lm + assert len(lm) == 2 + + +def test_listmap_dict_setitem(): + lm = _ListMap({ + "a": 1, + "b": 2, + }) + lm["b"] = 20 + assert lm["b"] == 20 + lm["c"] = 3 + assert lm["c"] == 3 + assert len(lm) == 3 From 022fa9214c6ee78d7870fbb0247e596b06f0eaa2 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 22 Oct 2025 17:44:36 +0800 Subject: [PATCH 12/37] fix: fix unittests. --- .../trainer/properties/test_loggers.py | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 3cc496c9e60b9..a374ab350d9a2 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -23,24 +23,21 @@ def test_trainer_loggers_property(): """Test for correct initialization of loggers in Trainer.""" logger1 = CustomLogger() logger2 = CustomLogger() - logger3 = CustomLogger() + CustomLogger() # trainer.loggers should be a copy of the input list trainer = Trainer(logger=[logger1, logger2]) assert trainer.loggers == [logger1, logger2] - assert trainer.logger_map == {} # trainer.loggers should create a list of size 1 trainer = Trainer(logger=logger1) assert trainer.logger == logger1 assert trainer.loggers == [logger1] - assert trainer.logger_map == {} trainer.loggers.append(logger2) assert trainer.loggers == [logger1, logger2] - assert trainer.logger_map == {} # trainer.loggers should be a list of size 1 holding the default logger trainer = Trainer(logger=True) @@ -51,19 +48,15 @@ def test_trainer_loggers_property(): trainer = Trainer(logger={"log1": logger1, "log2": logger2}) assert trainer.logger == logger1 assert trainer.loggers == [logger1, logger2] - assert isinstance(trainer.logger_map, dict) - assert trainer.logger_map == {"log1": logger1, "log2": logger2} - - trainer.loggers.append(logger3) - assert trainer.loggers == [logger1, logger2, logger3] - assert trainer.logger_map == {"log1": logger1, "log2": logger2} + assert trainer.loggers["log1"] is logger1 + assert trainer.loggers["log2"] is logger2 def test_trainer_loggers_setters(): """Test the behavior of setters for trainer.logger and trainer.loggers.""" logger1 = CustomLogger() logger2 = CustomLogger() - logger3 = CustomLogger() + CustomLogger() trainer = Trainer() assert type(trainer.logger) is TensorBoardLogger @@ -78,39 +71,37 @@ def test_trainer_loggers_setters(): assert trainer.logger is None assert trainer.loggers == [] assert isinstance(trainer.loggers, list) - assert trainer.logger_map == {} # Test setters for trainer.loggers trainer.loggers = [logger1, logger2] assert trainer.loggers == [logger1, logger2] assert isinstance(trainer.loggers, list) - assert trainer.logger_map == {} trainer.loggers = [logger1] assert trainer.loggers == [logger1] assert trainer.logger == logger1 - assert trainer.logger_map == {} trainer.loggers = [] assert trainer.loggers == [] assert trainer.logger is None assert isinstance(trainer.loggers, list) - assert trainer.logger_map == {} trainer.loggers = None assert trainer.loggers == [] assert trainer.logger is None assert isinstance(trainer.loggers, list) - assert trainer.logger_map == {} + + trainer.loggers = {} + assert trainer.loggers == [] + assert trainer.logger is None + assert isinstance(trainer.loggers, list) trainer.loggers = {"log1": logger1, "log2": logger2} assert trainer.loggers == [logger1, logger2] assert isinstance(trainer.loggers, list) - assert isinstance(trainer.logger_map, dict) - assert trainer.logger_map == {"log1": logger1, "log2": logger2} - trainer.loggers.append(logger3) - assert trainer.logger_map == {"log1": logger1, "log2": logger2} + assert trainer.loggers["log1"] is logger1 + assert trainer.loggers["log2"] is logger2 @pytest.mark.parametrize( @@ -118,6 +109,7 @@ def test_trainer_loggers_setters(): [ False, [], + {}, ], ) def test_no_logger(tmp_path, logger_value): @@ -130,4 +122,3 @@ def test_no_logger(tmp_path, logger_value): assert trainer.logger is None assert trainer.loggers == [] assert trainer.log_dir == str(tmp_path) - assert trainer.logger_map == {} From c309642151e214c0cc485d90fa0336650bbee7b2 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 22 Oct 2025 17:52:28 +0800 Subject: [PATCH 13/37] fix pylint. --- src/lightning/pytorch/loggers/utilities.py | 21 +++++++++++---------- src/lightning/pytorch/trainer/trainer.py | 2 +- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index f3e200462e2f4..b1b5c0f280c91 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -13,11 +13,12 @@ # limitations under the License. """Utilities for loggers.""" -from collections.abc import Mapping +from collections.abc import ItemsView, KeysView, Mapping, ValuesView from pathlib import Path -from typing import Any, Optional, Self, TypeVar, Union +from typing import Any, Optional, TypeVar, Union from torch import Tensor +from typing_extensions import Self import lightning.pytorch as pl from lightning.pytorch.callbacks import Checkpoint @@ -120,7 +121,7 @@ def __init__(self, loggers: Union[list[_T], Mapping[str, _T]] = None): super().__init__(() if loggers is None else loggers) self._dict: dict = {} - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: self_list = list(self) if isinstance(other, _ListMap): return self_list == list(other) and self._dict == other._dict @@ -144,7 +145,7 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self: # todo return list.__iadd__(self, other) - def __setitem__(self, key, value): + def __setitem__(self, key: Union[int, slice, str], value: _T) -> None: if isinstance(key, (int, slice)): # replace element by index return list.__setitem__(self, key, value) @@ -158,14 +159,14 @@ def __setitem__(self, key, value): return None raise TypeError("Key must be int or str") - def __contains__(self, item): + def __contains__(self, item: Union[_T, str]) -> bool: if isinstance(item, str): return item in self._dict return list.__contains__(self, item) # --- Dict-like interface --- - def __delitem__(self, key): + def __delitem__(self, key: Union[int, slice, str]) -> None: if isinstance(key, (int, slice)): loggers = list.__getitem__(self, key) super(list, self).__delitem__(key) @@ -179,14 +180,14 @@ def __delitem__(self, key): else: raise TypeError("Key must be int or str") - def keys(self): + def keys(self) -> KeysView[str]: return self._dict.keys() - def values(self): + def values(self) -> ValuesView[_T]: d = {k: self[v] for k, v in self._dict.items()} return d.values() - def items(self): + def items(self) -> ItemsView[str, _T]: d = {k: self[v] for k, v in self._dict.items()} return d.items() @@ -206,6 +207,6 @@ def pop(self, key: Union[int, str] = -1, default: Optional[Any] = None) -> _T: return self.pop(self._dict[key]) raise TypeError("Key must be int or str") - def __repr__(self): + def __repr__(self) -> str: ret = super().__repr__() return f"_ListMap({ret}, keys={list(self._dict.keys())})" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index a60e71eaac4a3..685bb030490c6 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -494,7 +494,7 @@ def __init__( setup._init_profiler(self, profiler) # init logger flags - self._loggers: list[Logger] + self._loggers: _ListMap[Logger] self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags From 67b888fc3bc7f82ae4dede3c070dcf603ae1c352 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 22 Oct 2025 18:53:51 +0800 Subject: [PATCH 14/37] add reverse impl. --- src/lightning/pytorch/loggers/utilities.py | 22 +++++++++++++---- tests/tests_pytorch/loggers/test_utilities.py | 24 +++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index b1b5c0f280c91..c969eff2333cc 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -13,9 +13,9 @@ # limitations under the License. """Utilities for loggers.""" -from collections.abc import ItemsView, KeysView, Mapping, ValuesView +from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView from pathlib import Path -from typing import Any, Optional, TypeVar, Union +from typing import Any, Optional, SupportsIndex, TypeVar, Union from torch import Tensor from typing_extensions import Self @@ -110,7 +110,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: class _ListMap(list[_T]): """A hybrid container for loggers allowing both index and name access.""" - def __init__(self, loggers: Union[list[_T], Mapping[str, _T]] = None): + def __init__(self, loggers: Union[Iterable[_T], Mapping[str, _T]] = None): if isinstance(loggers, Mapping): # super inits list with values if any(not isinstance(x, str) for x in loggers): @@ -145,7 +145,7 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self: # todo return list.__iadd__(self, other) - def __setitem__(self, key: Union[int, slice, str], value: _T) -> None: + def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None: if isinstance(key, (int, slice)): # replace element by index return list.__setitem__(self, key, value) @@ -192,7 +192,7 @@ def items(self) -> ItemsView[str, _T]: return d.items() # --- List and Dict interface --- - def pop(self, key: Union[int, str] = -1, default: Optional[Any] = None) -> _T: + def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T: if isinstance(key, int): ret = list.pop(self, key) for str_key, idx in list(self._dict.items()): @@ -210,3 +210,15 @@ def pop(self, key: Union[int, str] = -1, default: Optional[Any] = None) -> _T: def __repr__(self) -> str: ret = super().__repr__() return f"_ListMap({ret}, keys={list(self._dict.keys())})" + + def __reversed__(self) -> Iterable[_T]: + return reversed(list(self)) + + def reverse(self) -> None: + for key, idx in self._dict.items(): + self._dict[key] = len(self) - 1 - idx + list.reverse(self) + + def clear(self): + self._dict.clear() + list.clear(self) diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index 86c2e7b99de02..3173b83714192 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -122,6 +122,30 @@ def test_listmap_remove(): assert 2 not in lm +def test_listmap_reverse(): + """Test reversing the collection.""" + lm = _ListMap({"1": 1, "2": 2, "3": 3}) + lm.reverse() + assert lm == [3, 2, 1] + for (key, value), expected in zip(lm.items(), [("1", 1), ("2", 2), ("3", 3)]): + assert (key, value) == expected + + +def test_listmap_reversed(): + """Test reversed iterator of the collection.""" + lm = _ListMap({"1": 1, "2": 2, "3": 3}) + rev_lm = list(reversed(lm)) + assert rev_lm == [3, 2, 1] + + +def test_listmap_clear(): + """Test clearing the collection.""" + lm = _ListMap({"1": 1, "2": 2, "3": 3}) + lm.clear() + assert len(lm) == 0 + assert len(lm.keys()) == 0 + + # Dict type properties tests def test_listmap_keys(): lm = _ListMap({ From 3e9e398a97b98875b3d7199d1c8193b77b2fee29 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 22 Oct 2025 23:05:28 +0800 Subject: [PATCH 15/37] implement list methods. --- src/lightning/pytorch/loggers/utilities.py | 115 +++++++++++++----- tests/tests_pytorch/loggers/test_utilities.py | 53 ++++++-- 2 files changed, 128 insertions(+), 40 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index c969eff2333cc..2d5b4e590f393 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -15,7 +15,7 @@ from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView from pathlib import Path -from typing import Any, Optional, SupportsIndex, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union from torch import Tensor from typing_extensions import Self @@ -23,6 +23,9 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import Checkpoint +if TYPE_CHECKING: + from _typeshed import SupportsRichComparison + def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]: if len(loggers) == 1: @@ -122,28 +125,96 @@ def __init__(self, loggers: Union[Iterable[_T], Mapping[str, _T]] = None): self._dict: dict = {} def __eq__(self, other: Any) -> bool: - self_list = list(self) + list_eq = list.__eq__(self, other) if isinstance(other, _ListMap): - return self_list == list(other) and self._dict == other._dict - if isinstance(other, list): - return self_list == other - return False + dict_eq = self._dict == other._dict + return list_eq and dict_eq + return list_eq + + def copy(self): + new_listmap = _ListMap(self) + new_listmap._dict = self._dict.copy() + return new_listmap + + def extend(self, __iterable: Iterable[_T]) -> None: + if isinstance(__iterable, _ListMap): + offset = len(self) + for key, idx in __iterable._dict.items(): + self._dict[key] = idx + offset + super().extend(__iterable) + + def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T: + if isinstance(key, int): + ret = list.pop(self, key) + for str_key, idx in list(self._dict.items()): + if idx == key: + self._dict.pop(str_key) + elif idx > key: + self._dict[str_key] = idx - 1 + return ret + if isinstance(key, str): + if key not in self._dict: + return default + return self.pop(self._dict[key]) + raise TypeError("Key must be int or str") + + def insert(self, index: SupportsIndex, __object: _T) -> None: + for key, idx in self._dict.items(): + if idx >= index: + self._dict[key] = idx + 1 + list.insert(self, index, __object) + + def remove(self, __object: _T) -> None: + idx = self.index(__object) + name = None + for key, val in self._dict.items(): + if val == idx: + name = key + elif val > idx: + self._dict[key] = val - 1 + if name: + self._dict.pop(name, None) + list.remove(self, __object) + + def sort( + self, + *, + key: Optional[Callable[[_T], "SupportsRichComparison"]] = None, + reverse: bool = False, + ) -> None: + # Create a mapping from item to its name(s) + item_to_names = {} + for name, idx in self._dict.items(): + item = self[idx] + item_to_names.setdefault(item, []).append(name) + # Sort the list + list.sort(self, key=key, reverse=reverse) + # Update _dict with new indices + new_dict = {} + for idx, item in enumerate(self): + if item in item_to_names: + for name in item_to_names[item]: + new_dict[name] = idx + self._dict = new_dict # --- List-like interface --- def __getitem__(self, key: Union[int, slice, str]) -> _T: - if isinstance(key, (int, slice)): - return list.__getitem__(self, key) if isinstance(key, str): - return list.__getitem__(self, self._dict[key]) - raise TypeError("Key must be int / slice (for index) or str (for name).") + return self[self._dict[key]] + return list.__getitem__(self, key) - def __add__(self, other: Union[list[_T], Self]) -> list[_T]: - # todo - return list.__add__(self, other) + def __add__(self, other: Union[list[_T], Self]) -> Self: + new_listmap = self.copy() + new_listmap += other + return new_listmap def __iadd__(self, other: Union[list[_T], Self]) -> Self: - # todo - return list.__iadd__(self, other) + if isinstance(other, _ListMap): + offset = len(self) + for key, idx in other._dict.items(): + self._dict[key] = idx + offset + + return super().__iadd__(other) def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None: if isinstance(key, (int, slice)): @@ -192,20 +263,6 @@ def items(self) -> ItemsView[str, _T]: return d.items() # --- List and Dict interface --- - def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T: - if isinstance(key, int): - ret = list.pop(self, key) - for str_key, idx in list(self._dict.items()): - if idx == key: - self._dict.pop(str_key) - elif idx > key: - self._dict[str_key] = idx - 1 - return ret - if isinstance(key, str): - if key not in self._dict: - return default - return self.pop(self._dict[key]) - raise TypeError("Key must be int or str") def __repr__(self) -> str: ret = super().__repr__() diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index 3173b83714192..2686500c682af 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -98,20 +98,24 @@ def test_listmap_setitem(): def test_listmap_add(): """Test adding two collections together.""" lm1 = _ListMap([1, 2]) - lm2 = _ListMap([3, 4]) + lm2 = _ListMap({"3": 3, "5": 5}) combined = lm1 + lm2 - assert isinstance(combined, list) + assert isinstance(combined, _ListMap) assert len(combined) == 4 - assert combined[0] == 1 - assert combined[1] == 2 - assert combined[2] == 3 - assert combined[3] == 4 + assert combined is not lm1 + assert combined == [1, 2, 3, 5] + assert combined["3"] == 3 + assert combined["5"] == 5 - combined += lm1 - assert isinstance(combined, list) - assert len(combined) == 6 - for item, expected in zip(combined, [1, 2, 3, 4, 1, 2]): - assert item == expected + ori_lm1_id = id(lm1) + + lm1 += lm2 + assert ori_lm1_id == id(lm1) + assert isinstance(lm1, _ListMap) + assert len(lm1) == 4 + assert lm1 == [1, 2, 3, 5] + assert lm1["3"] == 3 + assert lm1["5"] == 5 def test_listmap_remove(): @@ -190,6 +194,13 @@ def test_listmap_dict_pop(): assert "b" not in lm assert len(lm) == 2 + value = lm.pop(0) + assert value == 1 + assert lm["c"] == 3 # still accessible by key + assert len(lm) == 1 + with pytest.raises(KeyError): + lm["a"] # "a" was removed + def test_listmap_dict_setitem(): lm = _ListMap({ @@ -201,3 +212,23 @@ def test_listmap_dict_setitem(): lm["c"] = 3 assert lm["c"] == 3 assert len(lm) == 3 + + +def test_listmap_sort(): + lm = _ListMap({"b": 1, "c": 3, "a": 2, "z": -7}) + + lm.extend([-1, -2, 5]) + lm.sort(key=lambda x: abs(x)) + assert lm == [1, -1, 2, -2, 3, 5, -7] + assert lm["a"] == 2 + assert lm["b"] == 1 + assert lm["c"] == 3 + assert lm["z"] == -7 + + lm = _ListMap({"b": 1, "c": 3, "a": 2, "z": -7}) + lm.sort(reverse=True) + assert lm == [3, 2, 1, -7] + assert lm["a"] == 2 + assert lm["b"] == 1 + assert lm["c"] == 3 + assert lm["z"] == -7 From fa83ab24decf6245a93d6c76b67ac1f4e5293fdb Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 22 Oct 2025 23:15:33 +0800 Subject: [PATCH 16/37] implement get method. --- src/lightning/pytorch/loggers/utilities.py | 5 ++++- tests/tests_pytorch/loggers/test_utilities.py | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 2d5b4e590f393..d87f998f6d107 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -262,7 +262,10 @@ def items(self) -> ItemsView[str, _T]: d = {k: self[v] for k, v in self._dict.items()} return d.items() - # --- List and Dict interface --- + def get(self, __key: str, default: Optional[Any] = None) -> _T: + if __key in self._dict: + return self[self._dict[__key]] + return default def __repr__(self) -> str: ret = super().__repr__() diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index 2686500c682af..fac4ecbf5cc44 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -232,3 +232,10 @@ def test_listmap_sort(): assert lm["b"] == 1 assert lm["c"] == 3 assert lm["z"] == -7 + + +def test_listmap_get(): + lm = _ListMap({"a": 1, "b": 2, "c": 3}) + assert lm.get("b") == 2 + assert lm.get("d") is None + assert lm.get("d", 10) == 10 From 085f167cbc57bef5a5df35188db581c7828ef713 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 22 Oct 2025 23:42:03 +0800 Subject: [PATCH 17/37] adding notes. --- src/lightning/pytorch/loggers/utilities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index d87f998f6d107..f5b9fea2cd7d7 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -212,6 +212,7 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self: if isinstance(other, _ListMap): offset = len(self) for key, idx in other._dict.items(): + # notes: if there are duplicate keys, the ones from other will overwrite self self._dict[key] = idx + offset return super().__iadd__(other) From 0d8f725f18dacc4b0a4e408d3cbb1479bce5c141 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Oct 2025 00:15:46 +0800 Subject: [PATCH 18/37] refactor --- src/lightning/pytorch/loggers/utilities.py | 8 +++++--- tests/tests_pytorch/loggers/test_utilities.py | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index f5b9fea2cd7d7..20a69c9ba2856 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -121,14 +121,16 @@ def __init__(self, loggers: Union[Iterable[_T], Mapping[str, _T]] = None): super().__init__(loggers.values()) self._dict = dict(zip(loggers.keys(), range(len(loggers)))) else: + default_dict = {} + if isinstance(loggers, _ListMap): + default_dict = loggers._dict.copy() super().__init__(() if loggers is None else loggers) - self._dict: dict = {} + self._dict: dict = default_dict def __eq__(self, other: Any) -> bool: list_eq = list.__eq__(self, other) if isinstance(other, _ListMap): - dict_eq = self._dict == other._dict - return list_eq and dict_eq + return list_eq and self._dict == other._dict return list_eq def copy(self): diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index fac4ecbf5cc44..0240161074775 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -42,6 +42,7 @@ def test_version(tmp_path): [1, 2], {1, 2}, range(2), + _ListMap({"a": 1, "b": 2}), ], ) def test_listmap_init(args): @@ -49,6 +50,8 @@ def test_listmap_init(args): lm = _ListMap(args) assert len(lm) == len(args) assert isinstance(lm, list) + if isinstance(args, _ListMap): + assert lm == args def test_listmap_append(): From 0e14e093867a68698f6af98f07f19254e38a928e Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Oct 2025 13:55:48 +0800 Subject: [PATCH 19/37] docs --- src/lightning/pytorch/loggers/utilities.py | 47 ++++++++++++++++------ 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 20a69c9ba2856..39e2af50c3c51 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -111,20 +111,46 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: class _ListMap(list[_T]): - """A hybrid container for loggers allowing both index and name access.""" + """A hybrid container allowing both index and name access. - def __init__(self, loggers: Union[Iterable[_T], Mapping[str, _T]] = None): - if isinstance(loggers, Mapping): + This class extends the built-in list to provide dictionary-like access to its elements + using string keys. It maintains an internal mapping of string keys to list indices, + allowing users to retrieve, set, and delete elements by their associated names. + + Args: + __iterable (Union[Iterable[_T], Mapping[str, _T]], optional): An iterable of objects or a mapping + of string keys to __iterable to initialize the container. + + Raises: + TypeError: If a Mapping is provided and any of its keys are not of type str. + + Example: + >>> listmap = _ListMap({'obj1': obj1, 'obj2': obj2}) + >>> listmap['obj1'] # Access by name + obj1 + >>> listmap[0] # Access by index + obj1 + >>> listmap['obj2'] = obj3 # Set by name + >>> listmap[1] # Now returns obj3 + obj3 + >>> listmap.append(obj4) # Append by index + >>> listmap[2] + obj4 + + """ + + def __init__(self, __iterable: Union[Iterable[_T], Mapping[str, _T]] = None): + if isinstance(__iterable, Mapping): # super inits list with values - if any(not isinstance(x, str) for x in loggers): + if any(not isinstance(x, str) for x in __iterable): raise TypeError("When providing a Mapping, all keys must be of type str.") - super().__init__(loggers.values()) - self._dict = dict(zip(loggers.keys(), range(len(loggers)))) + super().__init__(__iterable.values()) + self._dict = dict(zip(__iterable.keys(), range(len(__iterable)))) else: default_dict = {} - if isinstance(loggers, _ListMap): - default_dict = loggers._dict.copy() - super().__init__(() if loggers is None else loggers) + if isinstance(__iterable, _ListMap): + default_dict = __iterable._dict.copy() + super().__init__(() if __iterable is None else __iterable) self._dict: dict = default_dict def __eq__(self, other: Any) -> bool: @@ -274,9 +300,6 @@ def __repr__(self) -> str: ret = super().__repr__() return f"_ListMap({ret}, keys={list(self._dict.keys())})" - def __reversed__(self) -> Iterable[_T]: - return reversed(list(self)) - def reverse(self) -> None: for key, idx in self._dict.items(): self._dict[key] = len(self) - 1 - idx From 2d9f419c2b6a3eb43c68ca4f6a6f0847fcaeceed Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Oct 2025 13:56:33 +0800 Subject: [PATCH 20/37] test: add additional unittests. --- tests/tests_pytorch/loggers/test_utilities.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index 0240161074775..ed1d4f0eee019 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -242,3 +242,21 @@ def test_listmap_get(): assert lm.get("b") == 2 assert lm.get("d") is None assert lm.get("d", 10) == 10 + + +def test_listmap_setitem_append(): + lm = _ListMap({"a": 1, "b": 2}) + lm.append(3) + lm["c"] = 3 + + assert lm == [1, 2, 3, 3] + assert lm["c"] == 3 + + lm.remove(3) + assert lm == [1, 2, 3] + assert lm["c"] == 3 + + lm.remove(3) + assert lm == [1, 2] + with pytest.raises(KeyError): + lm["c"] # "c" was removed From b78daea55c6af2489b102724455b8c635c5c1639 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Oct 2025 18:10:36 +0800 Subject: [PATCH 21/37] fix: fix delete implementation. --- src/lightning/pytorch/loggers/utilities.py | 21 +++++++++++-------- tests/tests_pytorch/loggers/test_utilities.py | 21 +++++++++++++++++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 39e2af50c3c51..34b717c5ba450 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -139,7 +139,7 @@ class _ListMap(list[_T]): """ - def __init__(self, __iterable: Union[Iterable[_T], Mapping[str, _T]] = None): + def __init__(self, __iterable: Union[Mapping[str, _T], Iterable[_T]] = None): if isinstance(__iterable, Mapping): # super inits list with values if any(not isinstance(x, str) for x in __iterable): @@ -268,15 +268,18 @@ def __contains__(self, item: Union[_T, str]) -> bool: def __delitem__(self, key: Union[int, slice, str]) -> None: if isinstance(key, (int, slice)): - loggers = list.__getitem__(self, key) - super(list, self).__delitem__(key) - for logger in loggers if isinstance(key, slice) else [loggers]: - name = getattr(logger, "name", None) - if name: - self._dict.pop(name, None) + list.__delitem__(self, key) + for _key in key.indices(len(self)) if isinstance(key, slice) else [key]: + # update indices in the dict + for str_key, idx in list(self._dict.items()): + if idx == _key: + self._dict.pop(str_key) + elif idx > _key: + self._dict[str_key] = idx - 1 elif isinstance(key, str): - logger = self._dict.pop(key) - self.remove(logger) + if key not in self._dict: + raise KeyError(f"Key '{key}' not found.") + self.__delitem__(self._dict[key]) else: raise TypeError("Key must be int or str") diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index ed1d4f0eee019..fa9d327662661 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -153,6 +153,27 @@ def test_listmap_clear(): assert len(lm.keys()) == 0 +def test_listmap_delitem(): + """Test deleting items from the collection.""" + lm = _ListMap({"a": 1, "b": 2, "c": 3}) + lm.extend([3, 4, 5]) + del lm["b"] + assert len(lm) == 5 + assert "b" not in lm + del lm[0] + assert len(lm) == 4 + assert "a" not in lm + assert lm == [3, 3, 4, 5] + + del lm[-1] + assert len(lm) == 3 + assert lm == [3, 3, 4] + + del lm[-2:] + assert len(lm) == 1 + assert lm == [3] + + # Dict type properties tests def test_listmap_keys(): lm = _ListMap({ From c371b2033634976c62dc05f71d11b24370fef426 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Oct 2025 18:12:05 +0800 Subject: [PATCH 22/37] docs: fix doctest. --- src/lightning/pytorch/loggers/utilities.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 34b717c5ba450..8219eda9db7ca 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -125,17 +125,17 @@ class _ListMap(list[_T]): TypeError: If a Mapping is provided and any of its keys are not of type str. Example: - >>> listmap = _ListMap({'obj1': obj1, 'obj2': obj2}) + >>> listmap = _ListMap({'obj1': 1, 'obj2': 2}) >>> listmap['obj1'] # Access by name - obj1 + 1 >>> listmap[0] # Access by index - obj1 - >>> listmap['obj2'] = obj3 # Set by name + 1 + >>> listmap['obj2'] = 3 # Set by name >>> listmap[1] # Now returns obj3 - obj3 - >>> listmap.append(obj4) # Append by index + 3 + >>> listmap.append(4) # Append by index >>> listmap[2] - obj4 + 4 """ From 41f431102e6c28cee2b0f8ad6ae6881c6b93dd56 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Oct 2025 19:13:41 +0800 Subject: [PATCH 23/37] add unittest case. --- tests/tests_pytorch/loggers/test_utilities.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index fa9d327662661..3cfef2f685bc5 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -120,6 +120,13 @@ def test_listmap_add(): assert lm1["3"] == 3 assert lm1["5"] == 5 + lm3 = _ListMap({"3": 3, "5": 5}) + lm4 = lm2 + lm3 + assert len(lm4) == 4 + assert lm4 == [3, 5, 3, 5] + assert lm4["3"] == 3 + assert lm4["5"] == 5 + def test_listmap_remove(): """Test removing items from the collection.""" From 172ceb327e9a3ffd2d18f9003467f7a775664c05 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Oct 2025 21:47:24 +0800 Subject: [PATCH 24/37] fix: fix mypy --- src/lightning/pytorch/loggers/utilities.py | 33 +++++++++++++++------- src/lightning/pytorch/trainer/trainer.py | 2 +- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 8219eda9db7ca..52a221a9842fe 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union from torch import Tensor -from typing_extensions import Self +from typing_extensions import Self, overload import lightning.pytorch as pl from lightning.pytorch.callbacks import Checkpoint @@ -108,6 +108,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: _T = TypeVar("_T") +_PT = TypeVar("_PT") class _ListMap(list[_T]): @@ -139,19 +140,20 @@ class _ListMap(list[_T]): """ - def __init__(self, __iterable: Union[Mapping[str, _T], Iterable[_T]] = None): + def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None): if isinstance(__iterable, Mapping): # super inits list with values if any(not isinstance(x, str) for x in __iterable): raise TypeError("When providing a Mapping, all keys must be of type str.") super().__init__(__iterable.values()) - self._dict = dict(zip(__iterable.keys(), range(len(__iterable)))) + _dict = dict(zip(__iterable.keys(), range(len(__iterable)))) else: default_dict = {} if isinstance(__iterable, _ListMap): default_dict = __iterable._dict.copy() super().__init__(() if __iterable is None else __iterable) - self._dict: dict = default_dict + _dict: dict = default_dict + self._dict = _dict def __eq__(self, other: Any) -> bool: list_eq = list.__eq__(self, other) @@ -171,7 +173,7 @@ def extend(self, __iterable: Iterable[_T]) -> None: self._dict[key] = idx + offset super().extend(__iterable) - def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T: + def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[_PT] = None) -> Union[_T, _PT]: if isinstance(key, int): ret = list.pop(self, key) for str_key, idx in list(self._dict.items()): @@ -211,7 +213,7 @@ def sort( reverse: bool = False, ) -> None: # Create a mapping from item to its name(s) - item_to_names = {} + item_to_names: dict[_T, list[int]] = {} for name, idx in self._dict.items(): item = self[idx] item_to_names.setdefault(item, []).append(name) @@ -225,8 +227,13 @@ def sort( new_dict[name] = idx self._dict = new_dict - # --- List-like interface --- - def __getitem__(self, key: Union[int, slice, str]) -> _T: + @overload + def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ... + + @overload + def __getitem__(self, key: slice, /) -> list[_T]: ... + + def __getitem__(self, key, /): if isinstance(key, str): return self[self._dict[key]] return list.__getitem__(self, key) @@ -245,7 +252,13 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self: return super().__iadd__(other) - def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None: + @overload + def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ... + + @overload + def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ... + + def __setitem__(self, key, value, /) -> None: if isinstance(key, (int, slice)): # replace element by index return list.__setitem__(self, key, value) @@ -259,7 +272,7 @@ def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None: return None raise TypeError("Key must be int or str") - def __contains__(self, item: Union[_T, str]) -> bool: + def __contains__(self, item: Union[object, str]) -> bool: if isinstance(item, str): return item in self._dict return list.__contains__(self, item) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 685bb030490c6..f30086ac5d1f5 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1643,7 +1643,7 @@ def loggers(self) -> _ListMap[Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger]]]) -> None: + def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger], _ListMap[Logger]]]) -> None: self._loggers = _ListMap(loggers) @property From fffe03b02345b56cb17c850128346b6ce79edff3 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Oct 2025 21:54:25 +0800 Subject: [PATCH 25/37] test --- tests/tests_pytorch/loggers/test_utilities.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index fa9d327662661..6dc98125fa978 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -87,6 +87,8 @@ def test_listmap_getitem(): lm = _ListMap([1, 2]) assert lm[0] == 1 assert lm[1] == 2 + assert lm[-1] == 2 + assert lm[0:2] == [1, 2] def test_listmap_setitem(): From c45fa9b9b6738e189c5204642fe1d5da3a916dca Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 24 Oct 2025 08:32:18 +0800 Subject: [PATCH 26/37] fix mypy --- src/lightning/pytorch/loggers/utilities.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 52a221a9842fe..1a0ba1e0b11cc 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -161,7 +161,7 @@ def __eq__(self, other: Any) -> bool: return list_eq and self._dict == other._dict return list_eq - def copy(self): + def copy(self) -> Self: new_listmap = _ListMap(self) new_listmap._dict = self._dict.copy() return new_listmap @@ -173,7 +173,7 @@ def extend(self, __iterable: Iterable[_T]) -> None: self._dict[key] = idx + offset super().extend(__iterable) - def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[_PT] = None) -> Union[_T, _PT]: + def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[_T, _PT] = None) -> Optional[Union[_T, _PT]]: if isinstance(key, int): ret = list.pop(self, key) for str_key, idx in list(self._dict.items()): @@ -279,7 +279,7 @@ def __contains__(self, item: Union[object, str]) -> bool: # --- Dict-like interface --- - def __delitem__(self, key: Union[int, slice, str]) -> None: + def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None: if isinstance(key, (int, slice)): list.__delitem__(self, key) for _key in key.indices(len(self)) if isinstance(key, slice) else [key]: @@ -307,7 +307,13 @@ def items(self) -> ItemsView[str, _T]: d = {k: self[v] for k, v in self._dict.items()} return d.items() - def get(self, __key: str, default: Optional[Any] = None) -> _T: + @overload + def get(self, __key: str) -> Optional[_T]: ... + + @overload + def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ... + + def get(self, __key: str, default=None): if __key in self._dict: return self[self._dict[__key]] return default @@ -321,6 +327,6 @@ def reverse(self) -> None: self._dict[key] = len(self) - 1 - idx list.reverse(self) - def clear(self): + def clear(self) -> None: self._dict.clear() list.clear(self) From ec39fe59fea9ce2c0972f6d680a9b17cb7912791 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 24 Oct 2025 08:40:25 +0800 Subject: [PATCH 27/37] fix mypy --- src/lightning/pytorch/loggers/utilities.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 1a0ba1e0b11cc..273431a696d3e 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -173,7 +173,16 @@ def extend(self, __iterable: Iterable[_T]) -> None: self._dict[key] = idx + offset super().extend(__iterable) - def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[_T, _PT] = None) -> Optional[Union[_T, _PT]]: + @overload + def pop(self, key: SupportsIndex = -1, /) -> _T: ... + + @overload + def pop(self, key: str, /, default: _T) -> _T: ... + + @overload + def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ... + + def pop(self, key=-1, default=None): if isinstance(key, int): ret = list.pop(self, key) for str_key, idx in list(self._dict.items()): From a2709c2e28a8b29f1ebba6063e427cbf0882b764 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 24 Oct 2025 10:01:43 +0800 Subject: [PATCH 28/37] ref: refactor __delitem__ --- src/lightning/pytorch/loggers/utilities.py | 27 +++++++++++----------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 273431a696d3e..acf73ba07127e 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -161,7 +161,7 @@ def __eq__(self, other: Any) -> bool: return list_eq and self._dict == other._dict return list_eq - def copy(self) -> Self: + def copy(self) -> "_ListMap": new_listmap = _ListMap(self) new_listmap._dict = self._dict.copy() return new_listmap @@ -184,7 +184,7 @@ def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ... def pop(self, key=-1, default=None): if isinstance(key, int): - ret = list.pop(self, key) + ret = super().pop(key) for str_key, idx in list(self._dict.items()): if idx == key: self._dict.pop(str_key) @@ -201,7 +201,7 @@ def insert(self, index: SupportsIndex, __object: _T) -> None: for key, idx in self._dict.items(): if idx >= index: self._dict[key] = idx + 1 - list.insert(self, index, __object) + super().insert(index, __object) def remove(self, __object: _T) -> None: idx = self.index(__object) @@ -213,7 +213,7 @@ def remove(self, __object: _T) -> None: self._dict[key] = val - 1 if name: self._dict.pop(name, None) - list.remove(self, __object) + super().remove(__object) def sort( self, @@ -227,7 +227,7 @@ def sort( item = self[idx] item_to_names.setdefault(item, []).append(name) # Sort the list - list.sort(self, key=key, reverse=reverse) + super().sort(key=key, reverse=reverse) # Update _dict with new indices new_dict = {} for idx, item in enumerate(self): @@ -270,11 +270,11 @@ def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ... def __setitem__(self, key, value, /) -> None: if isinstance(key, (int, slice)): # replace element by index - return list.__setitem__(self, key, value) + return super().__setitem__(key, value) if isinstance(key, str): # replace or insert by name if key in self._dict: - list.__setitem__(self, self._dict[key], value) + super().__setitem__(self._dict[key], value) else: self.append(value) self._dict[key] = len(self) - 1 @@ -284,13 +284,18 @@ def __setitem__(self, key, value, /) -> None: def __contains__(self, item: Union[object, str]) -> bool: if isinstance(item, str): return item in self._dict - return list.__contains__(self, item) + return super().__contains__(item) # --- Dict-like interface --- def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None: + if isinstance(key, str): + if key not in self._dict: + raise KeyError(f"Key '{key}' not found.") + key: int = self._dict[key] + if isinstance(key, (int, slice)): - list.__delitem__(self, key) + super().__delitem__(key) for _key in key.indices(len(self)) if isinstance(key, slice) else [key]: # update indices in the dict for str_key, idx in list(self._dict.items()): @@ -298,10 +303,6 @@ def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None: self._dict.pop(str_key) elif idx > _key: self._dict[str_key] = idx - 1 - elif isinstance(key, str): - if key not in self._dict: - raise KeyError(f"Key '{key}' not found.") - self.__delitem__(self._dict[key]) else: raise TypeError("Key must be int or str") From 9d0d39de0e3d9af3965b8406d1124efbfbf9abc8 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 24 Oct 2025 17:52:40 +0800 Subject: [PATCH 29/37] fix: mypy --- src/lightning/pytorch/loggers/utilities.py | 32 ++++++++++++---------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index acf73ba07127e..f53bd3d4bdc18 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -141,6 +141,7 @@ class _ListMap(list[_T]): """ def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None): + _dict: dict[str, int] if isinstance(__iterable, Mapping): # super inits list with values if any(not isinstance(x, str) for x in __iterable): @@ -177,7 +178,7 @@ def extend(self, __iterable: Iterable[_T]) -> None: def pop(self, key: SupportsIndex = -1, /) -> _T: ... @overload - def pop(self, key: str, /, default: _T) -> _T: ... + def pop(self, key: Union[str, SupportsIndex], default: _T, /) -> _T: ... @overload def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ... @@ -222,14 +223,14 @@ def sort( reverse: bool = False, ) -> None: # Create a mapping from item to its name(s) - item_to_names: dict[_T, list[int]] = {} + item_to_names: dict[_T, list[str]] = {} for name, idx in self._dict.items(): item = self[idx] item_to_names.setdefault(item, []).append(name) # Sort the list super().sort(key=key, reverse=reverse) # Update _dict with new indices - new_dict = {} + new_dict: dict[str, int] = {} for idx, item in enumerate(self): if item in item_to_names: for name in item_to_names[item]: @@ -242,12 +243,12 @@ def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ... @overload def __getitem__(self, key: slice, /) -> list[_T]: ... - def __getitem__(self, key, /): + def __getitem__(self, key): if isinstance(key, str): return self[self._dict[key]] return list.__getitem__(self, key) - def __add__(self, other: Union[list[_T], Self]) -> Self: + def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]": new_listmap = self.copy() new_listmap += other return new_listmap @@ -267,7 +268,7 @@ def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ... @overload def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ... - def __setitem__(self, key, value, /) -> None: + def __setitem__(self, key, value): if isinstance(key, (int, slice)): # replace element by index return super().__setitem__(key, value) @@ -289,14 +290,17 @@ def __contains__(self, item: Union[object, str]) -> bool: # --- Dict-like interface --- def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None: + index: Union[SupportsIndex, slice] if isinstance(key, str): if key not in self._dict: raise KeyError(f"Key '{key}' not found.") - key: int = self._dict[key] + index = self._dict[key] + else: + index = key - if isinstance(key, (int, slice)): - super().__delitem__(key) - for _key in key.indices(len(self)) if isinstance(key, slice) else [key]: + if isinstance(index, (int, slice)): + super().__delitem__(index) + for _key in index.indices(len(self)) if isinstance(index, slice) else [index]: # update indices in the dict for str_key, idx in list(self._dict.items()): if idx == _key: @@ -310,12 +314,10 @@ def keys(self) -> KeysView[str]: return self._dict.keys() def values(self) -> ValuesView[_T]: - d = {k: self[v] for k, v in self._dict.items()} - return d.values() + return {k: self[v] for k, v in self._dict.items()}.values() def items(self) -> ItemsView[str, _T]: - d = {k: self[v] for k, v in self._dict.items()} - return d.items() + return {k: self[v] for k, v in self._dict.items()}.items() @overload def get(self, __key: str) -> Optional[_T]: ... @@ -323,7 +325,7 @@ def get(self, __key: str) -> Optional[_T]: ... @overload def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ... - def get(self, __key: str, default=None): + def get(self, __key, default=None): if __key in self._dict: return self[self._dict[__key]] return default From 5edf4b1772c85f79ee1ec6da1784c04dac1af99a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 18:57:15 +0800 Subject: [PATCH 30/37] fix type annotation --- src/lightning/pytorch/loggers/utilities.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index f53bd3d4bdc18..97c970d7b9e88 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -149,11 +149,11 @@ def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = super().__init__(__iterable.values()) _dict = dict(zip(__iterable.keys(), range(len(__iterable)))) else: - default_dict = {} + default_dict: dict[str, int] = {} if isinstance(__iterable, _ListMap): default_dict = __iterable._dict.copy() super().__init__(() if __iterable is None else __iterable) - _dict: dict = default_dict + _dict = default_dict self._dict = _dict def __eq__(self, other: Any) -> bool: @@ -183,7 +183,7 @@ def pop(self, key: Union[str, SupportsIndex], default: _T, /) -> _T: ... @overload def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ... - def pop(self, key=-1, default=None): + def pop(self, key: int = -1, default: Optional = None) -> _T: if isinstance(key, int): ret = super().pop(key) for str_key, idx in list(self._dict.items()): From 76b5311d321f5e64089eff7bc876c7471e6026c4 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 22:02:23 +0800 Subject: [PATCH 31/37] fix typecheck --- src/lightning/pytorch/loggers/utilities.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 97c970d7b9e88..c043a100c9dee 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -253,7 +253,7 @@ def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]": new_listmap += other return new_listmap - def __iadd__(self, other: Union[list[_T], Self]) -> Self: + def __iadd__(self, other: Iterable[_T]) -> Self: if isinstance(other, _ListMap): offset = len(self) for key, idx in other._dict.items(): @@ -268,7 +268,7 @@ def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ... @overload def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ... - def __setitem__(self, key, value): + def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Union[_T, Iterable[_T]], /) -> None: if isinstance(key, (int, slice)): # replace element by index return super().__setitem__(key, value) @@ -325,7 +325,7 @@ def get(self, __key: str) -> Optional[_T]: ... @overload def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ... - def get(self, __key, default=None): + def get(self, __key: str, default: Optional[_PT] = None) -> Optional[Union[_T, _PT]]: if __key in self._dict: return self[self._dict[__key]] return default @@ -337,8 +337,8 @@ def __repr__(self) -> str: def reverse(self) -> None: for key, idx in self._dict.items(): self._dict[key] = len(self) - 1 - idx - list.reverse(self) + return super().reverse() def clear(self) -> None: self._dict.clear() - list.clear(self) + return super().clear() From 8da3ea48624ab36d3dd9658634d11d7c5742ea13 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 22:27:55 +0800 Subject: [PATCH 32/37] fix typecheck --- src/lightning/pytorch/loggers/utilities.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index c043a100c9dee..c61751fa9772c 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -140,8 +140,9 @@ class _ListMap(list[_T]): """ + _dict: dict[str, int] + def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None): - _dict: dict[str, int] if isinstance(__iterable, Mapping): # super inits list with values if any(not isinstance(x, str) for x in __iterable): @@ -183,7 +184,7 @@ def pop(self, key: Union[str, SupportsIndex], default: _T, /) -> _T: ... @overload def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ... - def pop(self, key: int = -1, default: Optional = None) -> _T: + def pop(self, key: Union[SupportsIndex, str] = -1, default: Any = None) -> _T: if isinstance(key, int): ret = super().pop(key) for str_key, idx in list(self._dict.items()): @@ -199,10 +200,11 @@ def pop(self, key: int = -1, default: Optional = None) -> _T: raise TypeError("Key must be int or str") def insert(self, index: SupportsIndex, __object: _T) -> None: + idx_int = int(index) for key, idx in self._dict.items(): - if idx >= index: + if idx >= idx_int: self._dict[key] = idx + 1 - super().insert(index, __object) + return super().insert(index, __object) def remove(self, __object: _T) -> None: idx = self.index(__object) @@ -275,7 +277,7 @@ def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Union[_T, It if isinstance(key, str): # replace or insert by name if key in self._dict: - super().__setitem__(self._dict[key], value) + self[self._dict[key]] = value else: self.append(value) self._dict[key] = len(self) - 1 From 9aefddea8c76a2757ff9440181a286e25a6ce9ee Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 1 Nov 2025 00:25:51 +0800 Subject: [PATCH 33/37] fix typecheck --- src/lightning/pytorch/loggers/utilities.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index c61751fa9772c..e5633023a5449 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -245,10 +245,10 @@ def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ... @overload def __getitem__(self, key: slice, /) -> list[_T]: ... - def __getitem__(self, key): + def __getitem__(self, key: Union[SupportsIndex, str, slice], /) -> Union[_T, list[_T]]: if isinstance(key, str): return self[self._dict[key]] - return list.__getitem__(self, key) + return super().__getitem__(key) def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]": new_listmap = self.copy() @@ -270,7 +270,7 @@ def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ... @overload def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ... - def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Union[_T, Iterable[_T]], /) -> None: + def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Any, /) -> None: if isinstance(key, (int, slice)): # replace element by index return super().__setitem__(key, value) From 535345e6728fa606d63e6d0f9eb09de1b6490188 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 1 Nov 2025 00:30:59 +0800 Subject: [PATCH 34/37] ignore override --- src/lightning/pytorch/loggers/utilities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index e5633023a5449..2bc4dbc9052fc 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -250,12 +250,12 @@ def __getitem__(self, key: Union[SupportsIndex, str, slice], /) -> Union[_T, lis return self[self._dict[key]] return super().__getitem__(key) - def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]": + def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]": # type: ignore[override] new_listmap = self.copy() new_listmap += other return new_listmap - def __iadd__(self, other: Iterable[_T]) -> Self: + def __iadd__(self, other: Iterable[_T]) -> Self: # type: ignore[override] if isinstance(other, _ListMap): offset = len(self) for key, idx in other._dict.items(): From 3211f150ebab36c9861783d32a37a807eb19132e Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 5 Nov 2025 01:11:43 +0800 Subject: [PATCH 35/37] refactor __eq__ --- src/lightning/pytorch/loggers/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 2bc4dbc9052fc..06b718820ec41 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -158,7 +158,7 @@ def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = self._dict = _dict def __eq__(self, other: Any) -> bool: - list_eq = list.__eq__(self, other) + list_eq = super().__eq__(other) if isinstance(other, _ListMap): return list_eq and self._dict == other._dict return list_eq From 2dff76564c72109b9cdddd07669ce85f6814f65e Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 10 Nov 2025 21:03:58 +0800 Subject: [PATCH 36/37] refactor __setitem__ --- src/lightning/pytorch/loggers/utilities.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 06b718820ec41..49ef98284b2b0 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -271,9 +271,6 @@ def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ... def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ... def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Any, /) -> None: - if isinstance(key, (int, slice)): - # replace element by index - return super().__setitem__(key, value) if isinstance(key, str): # replace or insert by name if key in self._dict: @@ -282,7 +279,7 @@ def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Any, /) -> N self.append(value) self._dict[key] = len(self) - 1 return None - raise TypeError("Key must be int or str") + return super().__setitem__(key, value) def __contains__(self, item: Union[object, str]) -> bool: if isinstance(item, str): From 483ca4aabd883500694d052cdacc4a9fd9c68722 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 10 Nov 2025 21:35:52 +0800 Subject: [PATCH 37/37] fix insert implementation and add unittests --- src/lightning/pytorch/loggers/utilities.py | 3 ++ tests/tests_pytorch/loggers/test_utilities.py | 40 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 49ef98284b2b0..e09bfd917c177 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -201,6 +201,9 @@ def pop(self, key: Union[SupportsIndex, str] = -1, default: Any = None) -> _T: def insert(self, index: SupportsIndex, __object: _T) -> None: idx_int = int(index) + # Check for negative indices + if idx_int < 0: + idx_int += len(self) for key, idx in self._dict.items(): if idx >= idx_int: self._dict[key] = idx + 1 diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index 55c0648b9d56c..0ce8b085ca1c3 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -54,6 +54,11 @@ def test_listmap_init(args): assert lm == args +def test_listmap_init_wrong_type(): + with pytest.raises(TypeError): + _ListMap({1: 2, 3: 4}) + + def test_listmap_append(): """Test appending loggers to the collection.""" lm = _ListMap() @@ -70,6 +75,34 @@ def test_listmap_extend(): assert len(lm) == 5 assert lm == [1, 2, 1, 2, 3] + lm2 = _ListMap({"a": 1, "b": 2}) + lm.extend(lm2) + assert len(lm) == 7 + assert lm == [1, 2, 1, 2, 3, 1, 2] + assert lm["a"] == 1 + assert lm["b"] == 2 + + +def test_listmap_insert(): + lm = _ListMap({"a": 1, "b": 2}) + lm.insert(1, 3) + assert len(lm) == 3 + assert lm == [1, 3, 2] + assert lm["a"] == 1 + assert lm["b"] == 2 + + lm.insert(-1, 5) + assert len(lm) == 4 + assert lm == [1, 3, 5, 2] + assert lm["a"] == 1 + assert lm["b"] == 2 + + lm.insert(-2, 10) + assert len(lm) == 5 + assert lm == [1, 3, 10, 5, 2] + assert lm["a"] == 1 + assert lm["b"] == 2 + def test_listmap_pop(): lm = _ListMap([1, 2, 3, 4]) @@ -290,3 +323,10 @@ def test_listmap_setitem_append(): assert lm == [1, 2] with pytest.raises(KeyError): lm["c"] # "c" was removed + + +def test_listmap_repr(): + lm = _ListMap({"a": 1, "b": 2}) + lm.append(3) + repr_str = repr(lm) + assert repr_str == "_ListMap([1, 2, 3], keys=['a', 'b'])"