Skip to content

Commit 7e5f801

Browse files
authored
Merge pull request ChEB-AI#31 from schnamo/dev
pyyaml instead of yaml, union instead of pipe
2 parents 08582c6 + e9e0ab1 commit 7e5f801

File tree

5 files changed

+29
-25
lines changed

5 files changed

+29
-25
lines changed

chebai/loss/semantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import math
66
import torch
7-
from typing import Literal
7+
from typing import Literal, Union
88

99
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
1010
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
@@ -14,7 +14,7 @@
1414
class ImplicationLoss(torch.nn.Module):
1515
def __init__(
1616
self,
17-
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
17+
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
1818
base_loss: torch.nn.Module = None,
1919
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
2020
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
@@ -114,7 +114,7 @@ class DisjointLoss(ImplicationLoss):
114114
def __init__(
115115
self,
116116
path_to_disjointness,
117-
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
117+
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
118118
base_loss: torch.nn.Module = None,
119119
disjoint_loss_weight=100,
120120
**kwargs,

chebai/result/analyse_sem.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchmetrics.functional.classification import multilabel_f1_score
1212
import wandb
1313
import gc
14+
from typing import List, Union
1415
from utils import *
1516

1617
DEVICE = "cpu" # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -244,7 +245,7 @@ def analyse_run(
244245
labeled_data_cls=ChEBIOver100, # use labels from this dataset for violations
245246
chebi_version=231,
246247
results_path=os.path.join("_semantic", "eval_results.csv"),
247-
violation_metrics: [str | list[callable]] = "all",
248+
violation_metrics: Union[str, List[callable]] = "all",
248249
verbose_violation_output=False,
249250
):
250251
"""Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided),

chebai/trainer/CustomTrainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def predict_from_file(
6464
smiles_strings = [inp.strip() for inp in input.readlines()]
6565
loaded_model.eval()
6666
predictions = self._predict_smiles(loaded_model, smiles_strings)
67-
predictions_df = pd.DataFrame(predictions.detach().numpy())
67+
predictions_df = pd.DataFrame(predictions.detach().cpu().numpy())
6868
if classes_path is not None:
6969
with open(classes_path, "r") as f:
7070
predictions_df.columns = [cls.strip() for cls in f.readlines()]
@@ -74,7 +74,10 @@ def predict_from_file(
7474
def _predict_smiles(self, model: LightningModule, smiles: List[str]):
7575
reader = ChemDataReader()
7676
parsed_smiles = [reader._read_data(s) for s in smiles]
77-
x = pad_sequence([torch.tensor(a) for a in parsed_smiles], batch_first=True)
77+
x = pad_sequence(
78+
[torch.tensor(a, device=model.device) for a in parsed_smiles],
79+
batch_first=True,
80+
)
7881
cls_tokens = (
7982
torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1)
8083
* CLS_TOKEN

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"iterative-stratification",
4949
"wandb",
5050
"chardet",
51-
"yaml",
51+
"pyyaml",
5252
"torchmetrics",
5353
],
5454
extras_require={"dev": ["black", "isort", "pre-commit"]},

0 commit comments

Comments
 (0)