Skip to content

Commit e9e0ab1

Browse files
author
sfluegel
committed
Merge remote-tracking branch 'upstream/dev' into dev
2 parents f5e5289 + 08582c6 commit e9e0ab1

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

chebai/loggers/custom.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Literal, Optional, Union
2+
from typing import Literal, Optional, Union, List
33
import os
44

55
from lightning.fabric.utilities.types import _PATH
@@ -22,13 +22,16 @@ def __init__(
2222
entity: Optional[str] = None,
2323
offline: bool = False,
2424
log_model: Union[Literal["all"], bool] = False,
25+
verbose_hyperparameters: bool = False,
26+
tags: Optional[List[str]] = None,
2527
**kwargs,
2628
):
2729
if version is None:
2830
version = f"{datetime.now():%y%m%d-%H%M}"
2931
self._version = version
3032
self._name = name
3133
self._fold = fold
34+
self.verbose_hyperparameters = verbose_hyperparameters
3235
super().__init__(
3336
name=self.name,
3437
save_dir=save_dir,
@@ -40,6 +43,8 @@ def __init__(
4043
offline=offline,
4144
**kwargs,
4245
)
46+
if tags:
47+
self.experiment.tags += tuple(tags)
4348

4449
@property
4550
def name(self) -> Optional[str]:

chebai/preprocessing/datasets/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
if self.use_inner_cross_validation:
109109
os.makedirs(os.path.join(self.raw_dir, self.fold_dir), exist_ok=True)
110110
os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True)
111+
self.save_hyperparameters()
111112

112113
@property
113114
def identifier(self):

chebai/trainer/CustomTrainer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader
12+
from chebai.loggers.custom import CustomLogger
1213

1314
log = logging.getLogger(__name__)
1415

@@ -20,6 +21,35 @@ def __init__(self, *args, **kwargs):
2021
super().__init__(*args, **kwargs)
2122
# instantiation custom logger connector
2223
self._logger_connector.on_trainer_init(self.logger, 1)
24+
# log additional hyperparameters to wandb
25+
if isinstance(self.logger, CustomLogger):
26+
custom_logger = self.logger
27+
assert isinstance(custom_logger, CustomLogger)
28+
if custom_logger.verbose_hyperparameters:
29+
log_kwargs = {}
30+
for key, value in self.init_kwargs.items():
31+
log_key, log_value = self._resolve_logging_argument(key, value)
32+
log_kwargs[log_key] = log_value
33+
self.logger.log_hyperparams(log_kwargs)
34+
35+
def _resolve_logging_argument(self, key, value):
36+
if isinstance(value, list):
37+
key_value_pairs = [
38+
self._resolve_logging_argument(f"{key}_{i}", v)
39+
for i, v in enumerate(value)
40+
]
41+
return key, {k: v for k, v in key_value_pairs}
42+
if not (
43+
isinstance(value, str)
44+
or isinstance(value, float)
45+
or isinstance(value, int)
46+
or value is None
47+
):
48+
params = {"class": value.__class__}
49+
params.update(value.__dict__)
50+
return key, params
51+
else:
52+
return key, value
2353

2454
def predict_from_file(
2555
self,

0 commit comments

Comments
 (0)