|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import importlib.util |
| 8 | + |
| 9 | +from collections.abc import Sequence |
| 10 | + |
| 11 | +from torch import Tensor |
| 12 | + |
| 13 | +from .common import Logger |
| 14 | + |
| 15 | +_has_trackio = importlib.util.find_spec("trackio") is not None |
| 16 | +_has_omegaconf = importlib.util.find_spec("omegaconf") is not None |
| 17 | + |
| 18 | + |
| 19 | +class TrackioLogger(Logger): |
| 20 | + """Wrapper for the trackio logger. |
| 21 | +
|
| 22 | + Args: |
| 23 | + exp_name (str): The name of the experiment. |
| 24 | + project (str): The name of the project. |
| 25 | +
|
| 26 | + Keyword Args: |
| 27 | + fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``. |
| 28 | + **kwargs: Extra keyword arguments for ``trackio.init``. |
| 29 | +
|
| 30 | + """ |
| 31 | + |
| 32 | + @classmethod |
| 33 | + def __new__(cls, *args, **kwargs): |
| 34 | + return super().__new__(cls) |
| 35 | + |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + exp_name: str, |
| 39 | + project: str, |
| 40 | + *, |
| 41 | + video_fps: int = 32, |
| 42 | + **kwargs, |
| 43 | + ) -> None: |
| 44 | + if not _has_trackio: |
| 45 | + raise ImportError("trackio could not be imported") |
| 46 | + |
| 47 | + self.video_fps = video_fps |
| 48 | + self._trackio_kwargs = { |
| 49 | + "name": exp_name, |
| 50 | + "project": project, |
| 51 | + "resume": "allow", |
| 52 | + **kwargs, |
| 53 | + } |
| 54 | + |
| 55 | + super().__init__(exp_name=exp_name, log_dir=project) |
| 56 | + |
| 57 | + def _create_experiment(self): |
| 58 | + """Creates a trackio experiment. |
| 59 | +
|
| 60 | + Args: |
| 61 | + exp_name (str): The name of the experiment. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + A trackio.Experiment object. |
| 65 | + """ |
| 66 | + if not _has_trackio: |
| 67 | + raise ImportError("Trackio is not installed") |
| 68 | + import trackio |
| 69 | + |
| 70 | + return trackio.init(**self._trackio_kwargs) |
| 71 | + |
| 72 | + def log_scalar( |
| 73 | + self, name: str, value: float, step: int | None = None |
| 74 | + ) -> None: |
| 75 | + """Logs a scalar value to trackio. |
| 76 | +
|
| 77 | + Args: |
| 78 | + name (str): The name of the scalar. |
| 79 | + value (float): The value of the scalar. |
| 80 | + step (int, optional): The step at which the scalar is logged. |
| 81 | + Defaults to None. |
| 82 | + """ |
| 83 | + self.experiment.log({name: value}, step=step) |
| 84 | + |
| 85 | + def log_video(self, name: str, video: Tensor, **kwargs) -> None: |
| 86 | + """Log videos inputs to trackio. |
| 87 | +
|
| 88 | + Args: |
| 89 | + name (str): The name of the video. |
| 90 | + video (Tensor): The video to be logged. |
| 91 | + **kwargs: Other keyword arguments. By construction, log_video |
| 92 | + supports 'step' (integer indicating the step index), 'format' |
| 93 | + (default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are |
| 94 | + passed as-is to the :obj:`experiment.log` method. |
| 95 | + """ |
| 96 | + import trackio |
| 97 | + |
| 98 | + fps = kwargs.pop("fps", self.video_fps) |
| 99 | + format = kwargs.pop("format", "mp4") |
| 100 | + self.experiment.log( |
| 101 | + {name: trackio.Video(video, fps=fps, format=format)}, |
| 102 | + **kwargs, |
| 103 | + ) |
| 104 | + |
| 105 | + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 |
| 106 | + """Logs the hyperparameters of the experiment. |
| 107 | +
|
| 108 | + Args: |
| 109 | + cfg (DictConfig or dict): The configuration of the experiment. |
| 110 | +
|
| 111 | + """ |
| 112 | + if type(cfg) is not dict and _has_omegaconf: |
| 113 | + if not _has_omegaconf: |
| 114 | + raise ImportError( |
| 115 | + "OmegaConf could not be imported. " |
| 116 | + "Cannot log hydra configs without OmegaConf." |
| 117 | + ) |
| 118 | + from omegaconf import OmegaConf |
| 119 | + |
| 120 | + cfg = OmegaConf.to_container(cfg, resolve=True) |
| 121 | + self.experiment.config.update(cfg) |
| 122 | + |
| 123 | + def __repr__(self) -> str: |
| 124 | + return f"TrackioLogger(experiment={self.experiment.__repr__()})" |
| 125 | + |
| 126 | + def log_histogram(self, name: str, data: Sequence, **kwargs): |
| 127 | + raise NotImplementedError("Logging histograms in trackio is not permitted.") |
| 128 | + |
| 129 | + def log_str(self, name: str, value: str, step: int | None = None) -> None: |
| 130 | + """Logs a string value to trackio using a table format for better visualization. |
| 131 | +
|
| 132 | + Args: |
| 133 | + name (str): The name of the string data. |
| 134 | + value (str): The string value to log. |
| 135 | + step (int, optional): The step at which the string is logged. |
| 136 | + Defaults to None. |
| 137 | + """ |
| 138 | + import trackio |
| 139 | + |
| 140 | + # Create a table with a single row |
| 141 | + table = trackio.Table(columns=["text"], data=[[value]]) |
| 142 | + |
| 143 | + if step is not None: |
| 144 | + self.experiment.log({name: value}, step=step) |
| 145 | + else: |
| 146 | + self.experiment.log({name: table}) |
0 commit comments