From 6ff024e0dd3f1b24f9a202d0b6c5dcb104865291 Mon Sep 17 00:00:00 2001 From: Iruos8805 Date: Tue, 4 Nov 2025 23:48:01 +0530 Subject: [PATCH 1/2] feat: allow configurable join_char for WandbLogger metric formatting --- src/lightning/pytorch/loggers/wandb.py | 5 ++ .../loggers/test_wandb_logger_join_char.py | 51 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 tests/tests_pytorch/loggers/test_wandb_logger_join_char.py diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 37ca362fa40c1..01e67ff5b7b05 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -279,6 +279,7 @@ def any_lightning_module_function_or_hook(self): experiment: WandB experiment object. Automatically set when creating a run. checkpoint_name: Name of the model checkpoint artifact being logged. add_file_policy: If "mutable", copies file to tempdirectory before upload. + oin_char: Separator character used to format metric keys before logging. \**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. Raises: @@ -306,6 +307,7 @@ def __init__( prefix: str = "", checkpoint_name: Optional[str] = None, add_file_policy: Literal["mutable", "immutable"] = "mutable", + join_char: Optional[str] = None, **kwargs: Any, ) -> None: if not _WANDB_AVAILABLE: @@ -351,6 +353,7 @@ def __init__( self._name = self._wandb_init.get("name") self._id = self._wandb_init.get("id") self._checkpoint_name = checkpoint_name + self.LOGGER_JOIN_CHAR = join_char or self.LOGGER_JOIN_CHAR def __getstate__(self) -> dict[str, Any]: import wandb @@ -440,6 +443,8 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) + join_char = getattr(self, "LOGGER_JOIN_CHAR", ".") + metrics = {k.replace("/", join_char): v for k, v in metrics.items()} if step is not None and not self._wandb_init.get("sync_tensorboard"): self.experiment.log(dict(metrics, **{"trainer/global_step": step})) else: diff --git a/tests/tests_pytorch/loggers/test_wandb_logger_join_char.py b/tests/tests_pytorch/loggers/test_wandb_logger_join_char.py new file mode 100644 index 0000000000000..60faeecaceda3 --- /dev/null +++ b/tests/tests_pytorch/loggers/test_wandb_logger_join_char.py @@ -0,0 +1,51 @@ +from lightning.pytorch.loggers import WandbLogger + + +def test_wandb_logger_custom_join_char(wandb_mock): + """Verify that WandbLogger correctly formats metric keys and logs them according to the specified join_char + value.""" + + # Case 1: Default join_char -> metrics should be logged with hyphens + wandb_mock.run = None + logger = WandbLogger() # default join_char = "-" + logger.log_metrics({"train/loss": 0.5}, step=1) + + print("\nACTUAL:", wandb_mock.init().log.call_args) + + # Expected output: metrics logged using hyphenated keys + wandb_mock.init().log.assert_called_with({"train-loss": 0.5, "trainer/global_step": 1}) + + # Case 2: Custom join_char="/" -> metrics should keep their original keys + wandb_mock.init().log.reset_mock() + logger = WandbLogger(join_char="/") + logger.log_metrics({"train/loss": 0.7}, step=2) + + # Expected output: metrics logged with unmodified key names + wandb_mock.init().log.assert_called_with({"train/loss": 0.7, "trainer/global_step": 2}) + + # Case 3: Custom join_char="_" -> metrics should use underscores + wandb_mock.init().log.reset_mock() + logger = WandbLogger(join_char="_") + logger.log_metrics({"train/loss": 0.9}, step=3) + + # Expected output: metrics logged with underscores in key names + wandb_mock.init().log.assert_called_with({"train_loss": 0.9, "trainer/global_step": 3}) + + +def test_join_char_persists_across_runs(wandb_mock): + """Ensure that a user-defined join_char is consistently applied across multiple logs.""" + wandb_mock.run = None + logger = WandbLogger(join_char="-") + + # Log multiple metric sets with the same logger instance + logger.log_metrics({"val/loss": 0.1}, step=1) + logger.log_metrics({"val/acc": 0.95}, step=2) + + # Retrieve the logged calls for inspection + calls = wandb_mock.init().log.call_args_list + logged_1 = calls[0].args[0] if calls[0].args else calls[0].kwargs + logged_2 = calls[1].args[0] if calls[1].args else calls[1].kwargs + + # Expected output: both logs use the same join_char formatting (hyphens) + assert "val-loss" in logged_1 + assert "val-acc" in logged_2 From cb90319ac66908944ba1f7619eabacc2ece879a5 Mon Sep 17 00:00:00 2001 From: Iruos8805 Date: Tue, 4 Nov 2025 23:56:22 +0530 Subject: [PATCH 2/2] fix: correct minor typo in wandb.py --- src/lightning/pytorch/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 01e67ff5b7b05..bdcb9ae8d0c8d 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -279,7 +279,7 @@ def any_lightning_module_function_or_hook(self): experiment: WandB experiment object. Automatically set when creating a run. checkpoint_name: Name of the model checkpoint artifact being logged. add_file_policy: If "mutable", copies file to tempdirectory before upload. - oin_char: Separator character used to format metric keys before logging. + join_char: Separator character used to format metric keys before logging. \**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. Raises: