Skip to content

Commit a983306

Browse files
Xmaster6yvmoens
authored andcommitted
TrackioLogger
1 parent a7707ca commit a983306

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

torchrl/record/loggers/trackio.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import importlib.util
8+
9+
from collections.abc import Sequence
10+
11+
from torch import Tensor
12+
13+
from .common import Logger
14+
15+
_has_trackio = importlib.util.find_spec("trackio") is not None
16+
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
17+
18+
19+
class TrackioLogger(Logger):
20+
"""Wrapper for the trackio logger.
21+
22+
Args:
23+
exp_name (str): The name of the experiment.
24+
project (str): The name of the project.
25+
26+
Keyword Args:
27+
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
28+
**kwargs: Extra keyword arguments for ``trackio.init``.
29+
30+
"""
31+
32+
@classmethod
33+
def __new__(cls, *args, **kwargs):
34+
return super().__new__(cls)
35+
36+
def __init__(
37+
self,
38+
exp_name: str,
39+
project: str,
40+
*,
41+
video_fps: int = 32,
42+
**kwargs,
43+
) -> None:
44+
if not _has_trackio:
45+
raise ImportError("trackio could not be imported")
46+
47+
self.video_fps = video_fps
48+
self._trackio_kwargs = {
49+
"name": exp_name,
50+
"project": project,
51+
"resume": "allow",
52+
**kwargs,
53+
}
54+
55+
super().__init__(exp_name=exp_name, log_dir=project)
56+
57+
def _create_experiment(self):
58+
"""Creates a trackio experiment.
59+
60+
Args:
61+
exp_name (str): The name of the experiment.
62+
63+
Returns:
64+
A trackio.Experiment object.
65+
"""
66+
if not _has_trackio:
67+
raise ImportError("Trackio is not installed")
68+
import trackio
69+
70+
return trackio.init(**self._trackio_kwargs)
71+
72+
def log_scalar(
73+
self, name: str, value: float, step: int | None = None
74+
) -> None:
75+
"""Logs a scalar value to trackio.
76+
77+
Args:
78+
name (str): The name of the scalar.
79+
value (float): The value of the scalar.
80+
step (int, optional): The step at which the scalar is logged.
81+
Defaults to None.
82+
"""
83+
self.experiment.log({name: value}, step=step)
84+
85+
def log_video(self, name: str, video: Tensor, **kwargs) -> None:
86+
"""Log videos inputs to trackio.
87+
88+
Args:
89+
name (str): The name of the video.
90+
video (Tensor): The video to be logged.
91+
**kwargs: Other keyword arguments. By construction, log_video
92+
supports 'step' (integer indicating the step index), 'format'
93+
(default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are
94+
passed as-is to the :obj:`experiment.log` method.
95+
"""
96+
import trackio
97+
98+
fps = kwargs.pop("fps", self.video_fps)
99+
format = kwargs.pop("format", "mp4")
100+
self.experiment.log(
101+
{name: trackio.Video(video, fps=fps, format=format)},
102+
**kwargs,
103+
)
104+
105+
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
106+
"""Logs the hyperparameters of the experiment.
107+
108+
Args:
109+
cfg (DictConfig or dict): The configuration of the experiment.
110+
111+
"""
112+
if type(cfg) is not dict and _has_omegaconf:
113+
if not _has_omegaconf:
114+
raise ImportError(
115+
"OmegaConf could not be imported. "
116+
"Cannot log hydra configs without OmegaConf."
117+
)
118+
from omegaconf import OmegaConf
119+
120+
cfg = OmegaConf.to_container(cfg, resolve=True)
121+
self.experiment.config.update(cfg)
122+
123+
def __repr__(self) -> str:
124+
return f"TrackioLogger(experiment={self.experiment.__repr__()})"
125+
126+
def log_histogram(self, name: str, data: Sequence, **kwargs):
127+
raise NotImplementedError("Logging histograms in trackio is not permitted.")
128+
129+
def log_str(self, name: str, value: str, step: int | None = None) -> None:
130+
"""Logs a string value to trackio using a table format for better visualization.
131+
132+
Args:
133+
name (str): The name of the string data.
134+
value (str): The string value to log.
135+
step (int, optional): The step at which the string is logged.
136+
Defaults to None.
137+
"""
138+
import trackio
139+
140+
# Create a table with a single row
141+
table = trackio.Table(columns=["text"], data=[[value]])
142+
143+
if step is not None:
144+
self.experiment.log({name: value}, step=step)
145+
else:
146+
self.experiment.log({name: table})

0 commit comments

Comments
 (0)