diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index ced8a6f1f2bd3..e09bfd917c177 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -13,14 +13,19 @@ # limitations under the License. """Utilities for loggers.""" +from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView from pathlib import Path -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union from torch import Tensor +from typing_extensions import Self, overload import lightning.pytorch as pl from lightning.pytorch.callbacks import Checkpoint +if TYPE_CHECKING: + from _typeshed import SupportsRichComparison + def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]: if len(loggers) == 1: @@ -100,3 +105,242 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: logger.log_hyperparams(hparams_initial) logger.log_graph(pl_module) logger.save() + + +_T = TypeVar("_T") +_PT = TypeVar("_PT") + + +class _ListMap(list[_T]): + """A hybrid container allowing both index and name access. + + This class extends the built-in list to provide dictionary-like access to its elements + using string keys. It maintains an internal mapping of string keys to list indices, + allowing users to retrieve, set, and delete elements by their associated names. + + Args: + __iterable (Union[Iterable[_T], Mapping[str, _T]], optional): An iterable of objects or a mapping + of string keys to __iterable to initialize the container. + + Raises: + TypeError: If a Mapping is provided and any of its keys are not of type str. + + Example: + >>> listmap = _ListMap({'obj1': 1, 'obj2': 2}) + >>> listmap['obj1'] # Access by name + 1 + >>> listmap[0] # Access by index + 1 + >>> listmap['obj2'] = 3 # Set by name + >>> listmap[1] # Now returns obj3 + 3 + >>> listmap.append(4) # Append by index + >>> listmap[2] + 4 + + """ + + _dict: dict[str, int] + + def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None): + if isinstance(__iterable, Mapping): + # super inits list with values + if any(not isinstance(x, str) for x in __iterable): + raise TypeError("When providing a Mapping, all keys must be of type str.") + super().__init__(__iterable.values()) + _dict = dict(zip(__iterable.keys(), range(len(__iterable)))) + else: + default_dict: dict[str, int] = {} + if isinstance(__iterable, _ListMap): + default_dict = __iterable._dict.copy() + super().__init__(() if __iterable is None else __iterable) + _dict = default_dict + self._dict = _dict + + def __eq__(self, other: Any) -> bool: + list_eq = super().__eq__(other) + if isinstance(other, _ListMap): + return list_eq and self._dict == other._dict + return list_eq + + def copy(self) -> "_ListMap": + new_listmap = _ListMap(self) + new_listmap._dict = self._dict.copy() + return new_listmap + + def extend(self, __iterable: Iterable[_T]) -> None: + if isinstance(__iterable, _ListMap): + offset = len(self) + for key, idx in __iterable._dict.items(): + self._dict[key] = idx + offset + super().extend(__iterable) + + @overload + def pop(self, key: SupportsIndex = -1, /) -> _T: ... + + @overload + def pop(self, key: Union[str, SupportsIndex], default: _T, /) -> _T: ... + + @overload + def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ... + + def pop(self, key: Union[SupportsIndex, str] = -1, default: Any = None) -> _T: + if isinstance(key, int): + ret = super().pop(key) + for str_key, idx in list(self._dict.items()): + if idx == key: + self._dict.pop(str_key) + elif idx > key: + self._dict[str_key] = idx - 1 + return ret + if isinstance(key, str): + if key not in self._dict: + return default + return self.pop(self._dict[key]) + raise TypeError("Key must be int or str") + + def insert(self, index: SupportsIndex, __object: _T) -> None: + idx_int = int(index) + # Check for negative indices + if idx_int < 0: + idx_int += len(self) + for key, idx in self._dict.items(): + if idx >= idx_int: + self._dict[key] = idx + 1 + return super().insert(index, __object) + + def remove(self, __object: _T) -> None: + idx = self.index(__object) + name = None + for key, val in self._dict.items(): + if val == idx: + name = key + elif val > idx: + self._dict[key] = val - 1 + if name: + self._dict.pop(name, None) + super().remove(__object) + + def sort( + self, + *, + key: Optional[Callable[[_T], "SupportsRichComparison"]] = None, + reverse: bool = False, + ) -> None: + # Create a mapping from item to its name(s) + item_to_names: dict[_T, list[str]] = {} + for name, idx in self._dict.items(): + item = self[idx] + item_to_names.setdefault(item, []).append(name) + # Sort the list + super().sort(key=key, reverse=reverse) + # Update _dict with new indices + new_dict: dict[str, int] = {} + for idx, item in enumerate(self): + if item in item_to_names: + for name in item_to_names[item]: + new_dict[name] = idx + self._dict = new_dict + + @overload + def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ... + + @overload + def __getitem__(self, key: slice, /) -> list[_T]: ... + + def __getitem__(self, key: Union[SupportsIndex, str, slice], /) -> Union[_T, list[_T]]: + if isinstance(key, str): + return self[self._dict[key]] + return super().__getitem__(key) + + def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]": # type: ignore[override] + new_listmap = self.copy() + new_listmap += other + return new_listmap + + def __iadd__(self, other: Iterable[_T]) -> Self: # type: ignore[override] + if isinstance(other, _ListMap): + offset = len(self) + for key, idx in other._dict.items(): + # notes: if there are duplicate keys, the ones from other will overwrite self + self._dict[key] = idx + offset + + return super().__iadd__(other) + + @overload + def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ... + + @overload + def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ... + + def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Any, /) -> None: + if isinstance(key, str): + # replace or insert by name + if key in self._dict: + self[self._dict[key]] = value + else: + self.append(value) + self._dict[key] = len(self) - 1 + return None + return super().__setitem__(key, value) + + def __contains__(self, item: Union[object, str]) -> bool: + if isinstance(item, str): + return item in self._dict + return super().__contains__(item) + + # --- Dict-like interface --- + + def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None: + index: Union[SupportsIndex, slice] + if isinstance(key, str): + if key not in self._dict: + raise KeyError(f"Key '{key}' not found.") + index = self._dict[key] + else: + index = key + + if isinstance(index, (int, slice)): + super().__delitem__(index) + for _key in index.indices(len(self)) if isinstance(index, slice) else [index]: + # update indices in the dict + for str_key, idx in list(self._dict.items()): + if idx == _key: + self._dict.pop(str_key) + elif idx > _key: + self._dict[str_key] = idx - 1 + else: + raise TypeError("Key must be int or str") + + def keys(self) -> KeysView[str]: + return self._dict.keys() + + def values(self) -> ValuesView[_T]: + return {k: self[v] for k, v in self._dict.items()}.values() + + def items(self) -> ItemsView[str, _T]: + return {k: self[v] for k, v in self._dict.items()}.items() + + @overload + def get(self, __key: str) -> Optional[_T]: ... + + @overload + def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ... + + def get(self, __key: str, default: Optional[_PT] = None) -> Optional[Union[_T, _PT]]: + if __key in self._dict: + return self[self._dict[__key]] + return default + + def __repr__(self) -> str: + ret = super().__repr__() + return f"_ListMap({ret}, keys={list(self._dict.keys())})" + + def reverse(self) -> None: + for key, idx in self._dict.items(): + self._dict[key] = len(self) - 1 - idx + return super().reverse() + + def clear(self) -> None: + self._dict.clear() + return super().clear() diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index f6e8885ee050a..f49121c7ac897 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from typing import Any, Optional, Union from lightning_utilities.core.apply_func import apply_to_collection @@ -82,6 +82,8 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non ) logger_ = CSVLogger(save_dir=self.trainer.default_root_dir) # type: ignore[assignment] self.trainer.loggers = [logger_] + elif isinstance(logger, Mapping): + self.trainer.loggers = logger elif isinstance(logger, Iterable): self.trainer.loggers = list(logger) else: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index f2f59e396ab23..3dd23652aebe1 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -23,7 +23,7 @@ import logging import math import os -from collections.abc import Generator, Iterable +from collections.abc import Generator, Iterable, Mapping from contextlib import contextmanager from datetime import timedelta from typing import Any, Optional, Union @@ -43,7 +43,7 @@ from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers.csv_logs import CSVLogger from lightning.pytorch.loggers.tensorboard import TensorBoardLogger -from lightning.pytorch.loggers.utilities import _log_hyperparams +from lightning.pytorch.loggers.utilities import _ListMap, _log_hyperparams from lightning.pytorch.loops import _PredictionLoop, _TrainingEpochLoop from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop from lightning.pytorch.loops.fit_loop import _FitLoop @@ -494,7 +494,7 @@ def __init__( setup._init_profiler(self, profiler) # init logger flags - self._loggers: list[Logger] + self._loggers: _ListMap[Logger] self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags @@ -1680,7 +1680,7 @@ def logger(self, logger: Optional[Logger]) -> None: self.loggers = [logger] @property - def loggers(self) -> list[Logger]: + def loggers(self) -> _ListMap[Logger]: """The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python @@ -1692,8 +1692,8 @@ def loggers(self) -> list[Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[list[Logger]]) -> None: - self._loggers = loggers if loggers else [] + def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger], _ListMap[Logger]]]) -> None: + self._loggers = _ListMap(loggers) @property def callback_metrics(self) -> _OUT_DICT: diff --git a/tests/tests_pytorch/loggers/test_utilities.py b/tests/tests_pytorch/loggers/test_utilities.py index d83a95dda9535..0ce8b085ca1c3 100644 --- a/tests/tests_pytorch/loggers/test_utilities.py +++ b/tests/tests_pytorch/loggers/test_utilities.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from lightning.pytorch.loggers import CSVLogger -from lightning.pytorch.loggers.utilities import _version +from lightning.pytorch.loggers.utilities import _ListMap, _version def test_version(tmp_path): @@ -31,3 +33,300 @@ def test_version(tmp_path): assert version == "0_2_1" version = _version(loggers, "-") assert version == "0-2-1" + + +@pytest.mark.parametrize( + "args", + [ + (1, 2), + [1, 2], + {1, 2}, + range(2), + _ListMap({"a": 1, "b": 2}), + ], +) +def test_listmap_init(args): + """Test initialization with different iterable types.""" + lm = _ListMap(args) + assert len(lm) == len(args) + assert isinstance(lm, list) + if isinstance(args, _ListMap): + assert lm == args + + +def test_listmap_init_wrong_type(): + with pytest.raises(TypeError): + _ListMap({1: 2, 3: 4}) + + +def test_listmap_append(): + """Test appending loggers to the collection.""" + lm = _ListMap() + lm.append(1) + assert len(lm) == 1 + lm.append(2) + assert len(lm) == 2 + + +def test_listmap_extend(): + # extent + lm = _ListMap([1, 2]) + lm.extend([1, 2, 3]) + assert len(lm) == 5 + assert lm == [1, 2, 1, 2, 3] + + lm2 = _ListMap({"a": 1, "b": 2}) + lm.extend(lm2) + assert len(lm) == 7 + assert lm == [1, 2, 1, 2, 3, 1, 2] + assert lm["a"] == 1 + assert lm["b"] == 2 + + +def test_listmap_insert(): + lm = _ListMap({"a": 1, "b": 2}) + lm.insert(1, 3) + assert len(lm) == 3 + assert lm == [1, 3, 2] + assert lm["a"] == 1 + assert lm["b"] == 2 + + lm.insert(-1, 5) + assert len(lm) == 4 + assert lm == [1, 3, 5, 2] + assert lm["a"] == 1 + assert lm["b"] == 2 + + lm.insert(-2, 10) + assert len(lm) == 5 + assert lm == [1, 3, 10, 5, 2] + assert lm["a"] == 1 + assert lm["b"] == 2 + + +def test_listmap_pop(): + lm = _ListMap([1, 2, 3, 4]) + item = lm.pop() + assert item == 4 + assert len(lm) == 3 + item = lm.pop(1) + assert item == 2 + assert len(lm) == 2 + assert lm == [1, 3] + + +def test_listmap_getitem(): + """Test getting items from the collection.""" + lm = _ListMap([1, 2]) + assert lm[0] == 1 + assert lm[1] == 2 + assert lm[-1] == 2 + assert lm[0:2] == [1, 2] + + +def test_listmap_setitem(): + """Test setting items in the collection.""" + lm = _ListMap([1, 2, 3]) + lm[0] = 10 + assert lm == [10, 2, 3] + lm[1:3] = [20, 30] + assert lm == [10, 20, 30] + + +def test_listmap_add(): + """Test adding two collections together.""" + lm1 = _ListMap([1, 2]) + lm2 = _ListMap({"3": 3, "5": 5}) + combined = lm1 + lm2 + assert isinstance(combined, _ListMap) + assert len(combined) == 4 + assert combined is not lm1 + assert combined == [1, 2, 3, 5] + assert combined["3"] == 3 + assert combined["5"] == 5 + + ori_lm1_id = id(lm1) + + lm1 += lm2 + assert ori_lm1_id == id(lm1) + assert isinstance(lm1, _ListMap) + assert len(lm1) == 4 + assert lm1 == [1, 2, 3, 5] + assert lm1["3"] == 3 + assert lm1["5"] == 5 + + lm3 = _ListMap({"3": 3, "5": 5}) + lm4 = lm2 + lm3 + assert len(lm4) == 4 + assert lm4 == [3, 5, 3, 5] + assert lm4["3"] == 3 + assert lm4["5"] == 5 + + +def test_listmap_remove(): + """Test removing items from the collection.""" + lm = _ListMap([1, 2, 3]) + lm.remove(2) + assert len(lm) == 2 + assert 2 not in lm + + +def test_listmap_reverse(): + """Test reversing the collection.""" + lm = _ListMap({"1": 1, "2": 2, "3": 3}) + lm.reverse() + assert lm == [3, 2, 1] + for (key, value), expected in zip(lm.items(), [("1", 1), ("2", 2), ("3", 3)]): + assert (key, value) == expected + + +def test_listmap_reversed(): + """Test reversed iterator of the collection.""" + lm = _ListMap({"1": 1, "2": 2, "3": 3}) + rev_lm = list(reversed(lm)) + assert rev_lm == [3, 2, 1] + + +def test_listmap_clear(): + """Test clearing the collection.""" + lm = _ListMap({"1": 1, "2": 2, "3": 3}) + lm.clear() + assert len(lm) == 0 + assert len(lm.keys()) == 0 + + +def test_listmap_delitem(): + """Test deleting items from the collection.""" + lm = _ListMap({"a": 1, "b": 2, "c": 3}) + lm.extend([3, 4, 5]) + del lm["b"] + assert len(lm) == 5 + assert "b" not in lm + del lm[0] + assert len(lm) == 4 + assert "a" not in lm + assert lm == [3, 3, 4, 5] + + del lm[-1] + assert len(lm) == 3 + assert lm == [3, 3, 4] + + del lm[-2:] + assert len(lm) == 1 + assert lm == [3] + + +# Dict type properties tests +def test_listmap_keys(): + lm = _ListMap({ + "a": 1, + "b": 2, + "c": 3, + }) + keys = lm.keys() + assert set(keys) == {"a", "b", "c"} + assert "a" in lm + assert "d" not in lm + + +def test_listmap_values(): + lm = _ListMap({ + "a": 1, + "b": 2, + "c": 3, + }) + values = lm.values() + assert set(values) == {1, 2, 3} + + +def test_listmap_dict_items(): + lm = _ListMap({ + "a": 1, + "b": 2, + "c": 3, + }) + items = lm.items() + assert set(items) == {("a", 1), ("b", 2), ("c", 3)} + + +def test_listmap_dict_pop(): + lm = _ListMap({ + "a": 1, + "b": 2, + "c": 3, + }) + value = lm.pop("b") + assert value == 2 + assert "b" not in lm + assert len(lm) == 2 + + value = lm.pop(0) + assert value == 1 + assert lm["c"] == 3 # still accessible by key + assert len(lm) == 1 + with pytest.raises(KeyError): + lm["a"] # "a" was removed + + +def test_listmap_dict_setitem(): + lm = _ListMap({ + "a": 1, + "b": 2, + }) + lm["b"] = 20 + assert lm["b"] == 20 + lm["c"] = 3 + assert lm["c"] == 3 + assert len(lm) == 3 + + +def test_listmap_sort(): + lm = _ListMap({"b": 1, "c": 3, "a": 2, "z": -7}) + + lm.extend([-1, -2, 5]) + lm.sort(key=lambda x: abs(x)) + assert lm == [1, -1, 2, -2, 3, 5, -7] + assert lm["a"] == 2 + assert lm["b"] == 1 + assert lm["c"] == 3 + assert lm["z"] == -7 + + lm = _ListMap({"b": 1, "c": 3, "a": 2, "z": -7}) + lm.sort(reverse=True) + assert lm == [3, 2, 1, -7] + assert lm["a"] == 2 + assert lm["b"] == 1 + assert lm["c"] == 3 + assert lm["z"] == -7 + + +def test_listmap_get(): + lm = _ListMap({"a": 1, "b": 2, "c": 3}) + assert lm.get("b") == 2 + assert lm.get("d") is None + assert lm.get("d", 10) == 10 + + +def test_listmap_setitem_append(): + lm = _ListMap({"a": 1, "b": 2}) + lm.append(3) + lm["c"] = 3 + + assert lm == [1, 2, 3, 3] + assert lm["c"] == 3 + + lm.remove(3) + assert lm == [1, 2, 3] + assert lm["c"] == 3 + + lm.remove(3) + assert lm == [1, 2] + with pytest.raises(KeyError): + lm["c"] # "c" was removed + + +def test_listmap_repr(): + lm = _ListMap({"a": 1, "b": 2}) + lm.append(3) + repr_str = repr(lm) + assert repr_str == "_ListMap([1, 2, 3], keys=['a', 'b'])" diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 1d07f3e99d412..a374ab350d9a2 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -23,6 +23,7 @@ def test_trainer_loggers_property(): """Test for correct initialization of loggers in Trainer.""" logger1 = CustomLogger() logger2 = CustomLogger() + CustomLogger() # trainer.loggers should be a copy of the input list trainer = Trainer(logger=[logger1, logger2]) @@ -35,17 +36,27 @@ def test_trainer_loggers_property(): assert trainer.logger == logger1 assert trainer.loggers == [logger1] + trainer.loggers.append(logger2) + assert trainer.loggers == [logger1, logger2] + # trainer.loggers should be a list of size 1 holding the default logger trainer = Trainer(logger=True) assert trainer.loggers == [trainer.logger] assert isinstance(trainer.logger, TensorBoardLogger) + trainer = Trainer(logger={"log1": logger1, "log2": logger2}) + assert trainer.logger == logger1 + assert trainer.loggers == [logger1, logger2] + assert trainer.loggers["log1"] is logger1 + assert trainer.loggers["log2"] is logger2 + def test_trainer_loggers_setters(): """Test the behavior of setters for trainer.logger and trainer.loggers.""" logger1 = CustomLogger() logger2 = CustomLogger() + CustomLogger() trainer = Trainer() assert type(trainer.logger) is TensorBoardLogger @@ -59,10 +70,12 @@ def test_trainer_loggers_setters(): trainer.logger = None assert trainer.logger is None assert trainer.loggers == [] + assert isinstance(trainer.loggers, list) # Test setters for trainer.loggers trainer.loggers = [logger1, logger2] assert trainer.loggers == [logger1, logger2] + assert isinstance(trainer.loggers, list) trainer.loggers = [logger1] assert trainer.loggers == [logger1] @@ -71,10 +84,24 @@ def test_trainer_loggers_setters(): trainer.loggers = [] assert trainer.loggers == [] assert trainer.logger is None + assert isinstance(trainer.loggers, list) trainer.loggers = None assert trainer.loggers == [] assert trainer.logger is None + assert isinstance(trainer.loggers, list) + + trainer.loggers = {} + assert trainer.loggers == [] + assert trainer.logger is None + assert isinstance(trainer.loggers, list) + + trainer.loggers = {"log1": logger1, "log2": logger2} + assert trainer.loggers == [logger1, logger2] + assert isinstance(trainer.loggers, list) + + assert trainer.loggers["log1"] is logger1 + assert trainer.loggers["log2"] is logger2 @pytest.mark.parametrize( @@ -82,6 +109,7 @@ def test_trainer_loggers_setters(): [ False, [], + {}, ], ) def test_no_logger(tmp_path, logger_value):