Skip to content

Commit 9c1af68

Browse files
Xmaster6yvmoens
authored andcommitted
test TrackioLogger
1 parent a983306 commit 9c1af68

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

test/test_loggers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchrl.record.loggers.csv import CSVLogger
2222
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
2323
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
24+
from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger
2425
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
2526
from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
2627

@@ -455,6 +456,65 @@ def make_env():
455456
env.close()
456457

457458

459+
@pytest.fixture(scope="function")
460+
def trackio_logger():
461+
exp_name = "ramala"
462+
logger = TrackioLogger(project="test", exp_name=exp_name)
463+
yield logger
464+
logger.experiment.finish()
465+
del logger
466+
467+
468+
@pytest.mark.skipif(not _has_trackio, reason="trackio not installed")
469+
class TestTrackioLogger:
470+
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
471+
def test_log_scalar(self, steps, trackio_logger):
472+
torch.manual_seed(0)
473+
474+
values = torch.rand(3)
475+
for i in range(3):
476+
scalar_name = "foo"
477+
scalar_value = values[i].item()
478+
trackio_logger.log_scalar(
479+
value=scalar_value,
480+
name=scalar_name,
481+
step=steps[i] if steps else None,
482+
)
483+
484+
def test_log_video(self, trackio_logger):
485+
torch.manual_seed(0)
486+
487+
# creating a sample video (T, C, H, W), where T - number of frames,
488+
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
489+
# the first 64 frames are black and the next 64 are white
490+
video = torch.cat(
491+
(torch.zeros(128, 1, 32, 32), torch.full((128, 1, 32, 32), 255))
492+
)
493+
video = video[None, :]
494+
trackio_logger.log_video(
495+
name="foo",
496+
video=video,
497+
fps=4,
498+
format="mp4",
499+
)
500+
trackio_logger.log_video(
501+
name="foo_16fps",
502+
video=video,
503+
fps=16,
504+
format="mp4",
505+
)
506+
507+
def test_log_hparams(self, trackio_logger, config):
508+
trackio_logger.log_hparams(config)
509+
for key, value in config.items():
510+
assert trackio_logger.experiment.config[key] == value
511+
512+
def test_log_histogram(self, trackio_logger):
513+
with pytest.raises(NotImplementedError):
514+
data = torch.randn(10)
515+
trackio_logger.log_histogram("hist", data, step=0, bins=2)
516+
517+
458518
if __name__ == "__main__":
459519
args, unknown = argparse.ArgumentParser().parse_known_args()
460520
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)