diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 3283867e9bc..5b82885f967 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -29,6 +29,7 @@ dependencies: - dm_control - mujoco<3.3.6 - mlflow + - trackio - av - coverage - ray diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 2eac1112692..432eb99020c 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -28,6 +28,7 @@ dependencies: - dm_control - mujoco<3.3.6 - mlflow + - trackio - av - coverage - ray diff --git a/.github/unittest/linux_sota/scripts/environment.yml b/.github/unittest/linux_sota/scripts/environment.yml index 848720a7bbb..a3ad87752f7 100644 --- a/.github/unittest/linux_sota/scripts/environment.yml +++ b/.github/unittest/linux_sota/scripts/environment.yml @@ -25,6 +25,7 @@ dependencies: - dm_control - mujoco<3.3.6 - mlflow + - trackio - av - coverage - vmas diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index c47436d11a8..b086bde3c07 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -402,6 +402,7 @@ Loggers csv.CSVLogger mlflow.MLFlowLogger tensorboard.TensorboardLogger + trackio.TrackioLogger wandb.WandbLogger get_logger generate_exp_name diff --git a/test/test_loggers.py b/test/test_loggers.py index 3ddcd6b5a5e..4991f3619db 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -21,6 +21,7 @@ from torchrl.record.loggers.csv import CSVLogger from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger +from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger from torchrl.record.loggers.wandb import _has_wandb, WandbLogger from torchrl.record.recorder import PixelRenderTransform, VideoRecorder @@ -455,6 +456,78 @@ def make_env(): env.close() +@pytest.fixture() +def trackio_logger(): + exp_name = "ramala" + logger = TrackioLogger(project="test", exp_name=exp_name) + yield logger + logger.experiment.finish() + del logger + + +@pytest.mark.skipif(not _has_trackio, reason="trackio not installed") +class TestTrackioLogger: + @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) + def test_log_scalar(self, steps, trackio_logger): + torch.manual_seed(0) + + values = torch.rand(3) + for i in range(3): + scalar_name = "foo" + scalar_value = values[i].item() + trackio_logger.log_scalar( + value=scalar_value, + name=scalar_name, + step=steps[i] if steps else None, + ) + + @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) + def test_log_str(self, steps, trackio_logger): + for i in range(3): + trackio_logger.log_str( + name="foo", + value="bar", + step=steps[i] if steps else None, + ) + + def test_log_video(self, trackio_logger): + torch.manual_seed(0) + + # creating a sample video (T, C, H, W), where T - number of frames, + # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. + # the first 64 frames are black and the next 64 are white + video = torch.cat( + (torch.zeros(128, 3, 32, 32), torch.full((128, 3, 32, 32), 255)) + ) + video = video[None, :] + trackio_logger.log_video( + name="foo", + video=video, + fps=4, + format="mp4", + ) + trackio_logger.log_video( + name="foo_16fps", + video=video, + fps=16, + format="mp4", + ) + + def test_log_hparams(self, trackio_logger, config): + trackio_logger.log_hparams(config) + for key, value in config.items(): + assert trackio_logger.experiment.config[key] == value + + @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) + def test_log_histogram(self, steps, trackio_logger): + torch.manual_seed(0) + for i in range(3): + data = torch.randn(100) + trackio_logger.log_histogram( + "hist", data, step=steps[i] if steps else None, bins=10 + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/record/loggers/trackio.py b/torchrl/record/loggers/trackio.py new file mode 100644 index 00000000000..67c094d9609 --- /dev/null +++ b/torchrl/record/loggers/trackio.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util + +from collections.abc import Sequence + +import numpy as np + +from torch import Tensor + +from .common import Logger + +_has_trackio = importlib.util.find_spec("trackio") is not None +_has_omegaconf = importlib.util.find_spec("omegaconf") is not None + + +class TrackioLogger(Logger): + """Wrapper for the trackio logger. + + Args: + exp_name (str): The name of the experiment. + project (str): The name of the project. + + Keyword Args: + fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``. + **kwargs: Extra keyword arguments for ``trackio.init``. + + """ + + @classmethod + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + def __init__( + self, + exp_name: str, + project: str, + *, + video_fps: int = 32, + **kwargs, + ) -> None: + if not _has_trackio: + raise ImportError("trackio could not be imported") + + self.video_fps = video_fps + self._trackio_kwargs = { + "name": exp_name, + "project": project, + "resume": "allow", + **kwargs, + } + + super().__init__(exp_name=exp_name, log_dir=project) + + def _create_experiment(self): + """Creates a trackio experiment. + + Args: + exp_name (str): The name of the experiment. + + Returns: + A trackio.Experiment object. + """ + if not _has_trackio: + raise ImportError("Trackio is not installed") + import trackio + + return trackio.init(**self._trackio_kwargs) + + def log_scalar(self, name: str, value: float, step: int | None = None) -> None: + """Logs a scalar value to trackio. + + Args: + name (str): The name of the scalar. + value (float): The value of the scalar. + step (int, optional): The step at which the scalar is logged. + Defaults to None. + """ + self.experiment.log({name: value}, step=step) + + def log_video(self, name: str, video: Tensor, **kwargs) -> None: + """Log videos inputs to trackio. + + Args: + name (str): The name of the video. + video (Tensor): The video to be logged. + **kwargs: Other keyword arguments. By construction, log_video + supports 'step' (integer indicating the step index), 'format' + (default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are + passed as-is to the :obj:`experiment.log` method. + """ + import trackio + + fps = kwargs.pop("fps", self.video_fps) + format = kwargs.pop("format", "mp4") + self.experiment.log( + { + name: trackio.Video( + video.numpy().astype(np.uint8), fps=fps, format=format + ) + }, + **kwargs, + ) + + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 + """Logs the hyperparameters of the experiment. + + Args: + cfg (DictConfig or dict): The configuration of the experiment. + + """ + if type(cfg) is not dict and _has_omegaconf: + if not _has_omegaconf: + raise ImportError( + "OmegaConf could not be imported. " + "Cannot log hydra configs without OmegaConf." + ) + from omegaconf import OmegaConf + + cfg = OmegaConf.to_container(cfg, resolve=True) + self.experiment.config.update(cfg) + + def __repr__(self) -> str: + return f"TrackioLogger(experiment={self.experiment.__repr__()})" + + def log_histogram(self, name: str, data: Sequence, **kwargs): + """Add histogram to log. + + Args: + name (str): Data identifier + data (torch.Tensor, numpy.ndarray): Values to build histogram + + Keyword Args: + step (int): Global step value to record + bins (int): Number of bins to use for the histogram + + """ + import trackio + + num_bins = kwargs.pop("bins", None) + step = kwargs.pop("step", None) + self.experiment.log( + {name: trackio.Histogram(data, num_bins=num_bins)}, step=step + ) + + def log_str(self, name: str, value: str, step: int | None = None) -> None: + """Logs a string value to trackio using a table format for better visualization. + + Args: + name (str): The name of the string data. + value (str): The string value to log. + step (int, optional): The step at which the string is logged. + Defaults to None. + """ + import trackio + + # Create a table with a single row + table = trackio.Table(columns=["text"], data=[[value]]) + self.experiment.log({name: table}, step=step) diff --git a/torchrl/record/loggers/utils.py b/torchrl/record/loggers/utils.py index 5fe443db301..08d65e3c675 100644 --- a/torchrl/record/loggers/utils.py +++ b/torchrl/record/loggers/utils.py @@ -35,7 +35,7 @@ def get_logger( If empty, ``None`` is returned. logger_name (str): Name to be used as a log_dir experiment_name (str): Name of the experiment - kwargs (dict[str]): might contain either `wandb_kwargs` or `mlflow_kwargs` + kwargs (dict[str]): might contain either `wandb_kwargs`, `mlflow_kwargs` or `trackio_kwargs` """ if logger_type == "tensorboard": from torchrl.record.loggers.tensorboard import TensorboardLogger @@ -63,6 +63,14 @@ def get_logger( exp_name=experiment_name, **mlflow_kwargs, ) + elif logger_type == "trackio": + from torchrl.record.loggers.trackio import TrackioLogger + + trackio_kwargs = kwargs.get("trackio_kwargs", {}) + project = trackio_kwargs.pop("project", "torchrl") + logger = TrackioLogger( + project=project, exp_name=experiment_name, **trackio_kwargs + ) elif logger_type in ("", None): return None else: