|
21 | 21 | from torchrl.record.loggers.csv import CSVLogger |
22 | 22 | from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger |
23 | 23 | from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger |
| 24 | +from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger |
24 | 25 | from torchrl.record.loggers.wandb import _has_wandb, WandbLogger |
25 | 26 | from torchrl.record.recorder import PixelRenderTransform, VideoRecorder |
26 | 27 |
|
@@ -455,6 +456,65 @@ def make_env(): |
455 | 456 | env.close() |
456 | 457 |
|
457 | 458 |
|
| 459 | +@pytest.fixture(scope="function") |
| 460 | +def trackio_logger(): |
| 461 | + exp_name = "ramala" |
| 462 | + logger = TrackioLogger(project="test", exp_name=exp_name) |
| 463 | + yield logger |
| 464 | + logger.experiment.finish() |
| 465 | + del logger |
| 466 | + |
| 467 | + |
| 468 | +@pytest.mark.skipif(not _has_trackio, reason="trackio not installed") |
| 469 | +class TestTrackioLogger: |
| 470 | + @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) |
| 471 | + def test_log_scalar(self, steps, trackio_logger): |
| 472 | + torch.manual_seed(0) |
| 473 | + |
| 474 | + values = torch.rand(3) |
| 475 | + for i in range(3): |
| 476 | + scalar_name = "foo" |
| 477 | + scalar_value = values[i].item() |
| 478 | + trackio_logger.log_scalar( |
| 479 | + value=scalar_value, |
| 480 | + name=scalar_name, |
| 481 | + step=steps[i] if steps else None, |
| 482 | + ) |
| 483 | + |
| 484 | + def test_log_video(self, trackio_logger): |
| 485 | + torch.manual_seed(0) |
| 486 | + |
| 487 | + # creating a sample video (T, C, H, W), where T - number of frames, |
| 488 | + # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. |
| 489 | + # the first 64 frames are black and the next 64 are white |
| 490 | + video = torch.cat( |
| 491 | + (torch.zeros(128, 1, 32, 32), torch.full((128, 1, 32, 32), 255)) |
| 492 | + ) |
| 493 | + video = video[None, :] |
| 494 | + trackio_logger.log_video( |
| 495 | + name="foo", |
| 496 | + video=video, |
| 497 | + fps=4, |
| 498 | + format="mp4", |
| 499 | + ) |
| 500 | + trackio_logger.log_video( |
| 501 | + name="foo_16fps", |
| 502 | + video=video, |
| 503 | + fps=16, |
| 504 | + format="mp4", |
| 505 | + ) |
| 506 | + |
| 507 | + def test_log_hparams(self, trackio_logger, config): |
| 508 | + trackio_logger.log_hparams(config) |
| 509 | + for key, value in config.items(): |
| 510 | + assert trackio_logger.experiment.config[key] == value |
| 511 | + |
| 512 | + def test_log_histogram(self, trackio_logger): |
| 513 | + with pytest.raises(NotImplementedError): |
| 514 | + data = torch.randn(10) |
| 515 | + trackio_logger.log_histogram("hist", data, step=0, bins=2) |
| 516 | + |
| 517 | + |
458 | 518 | if __name__ == "__main__": |
459 | 519 | args, unknown = argparse.ArgumentParser().parse_known_args() |
460 | 520 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) |
0 commit comments