Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ed2ca6c
create new class for solubility data
schnamo Jun 5, 2024
ead3007
adjusting new class
schnamo Jun 5, 2024
5956183
add solubility yml file
schnamo Jun 6, 2024
c3afeed
adjusting solubility class to correctly download solubility data
schnamo Jun 7, 2024
d57b073
make it compatible with classification problem
schnamo Jun 7, 2024
0faca31
onehotencoding for solubility labels
schnamo Jul 4, 2024
4000215
adjust to regression, add yml files for regression
schnamo Jul 17, 2024
0709188
adjust prediction to regression
schnamo Jul 17, 2024
f8bd06a
refactor code
schnamo Jul 23, 2024
21fbde4
regression fix, yml files for mae loss
schnamo Jul 25, 2024
f3bfe08
take out kinect dataset
schnamo Jul 25, 2024
e26925d
adjust learning rate
schnamo Jul 25, 2024
0f2f85f
adjustments for new solu dataset
schnamo Dec 12, 2024
d0da5c2
Merge branch 'dev' of https://github.com/schnamo/python-chebai into s…
schnamo Dec 12, 2024
0d94b44
working on evaluation script, addded a bunch of things earlier for so…
schnamo Dec 19, 2024
45228ba
further adjusting evaluation function for regression
schnamo Dec 20, 2024
dbf8532
regression adjustments
schnamo Dec 20, 2024
fa97f45
fix union expression
schnamo Jan 8, 2025
8b91dce
fix tuple issue to make it backwards compatible
schnamo Jan 8, 2025
677d6ec
wandb
schnamo Jan 13, 2025
2c159e8
fix issue with solubility dataset read in
schnamo Jan 16, 2025
b537b7f
Fix missing label handling
MGlauer Jan 17, 2025
a99e438
add more datasets
schnamo Jan 17, 2025
754de12
Merge commit 'b537b7fd776e6afc535e05a111a0bc6a493ec8e9' of https://gi…
schnamo Jan 17, 2025
9b084cb
merge branches part 2
schnamo Jan 17, 2025
326e9a2
add more datasets
schnamo Jan 18, 2025
c272f45
adjust metrics for classifications, add BBBP
schnamo Jan 18, 2025
dc9e104
more datasets
schnamo Jan 19, 2025
9a3967d
bug fixes and different loss and electra params
schnamo Jan 20, 2025
1bc8736
changes to missing labels: negate labels as well as logits, add them …
schnamo Jan 21, 2025
4885960
try different splits, remove debugging comments
schnamo Feb 13, 2025
baa085f
Merge branch 'dev' of https://github.com/schnamo/python-chebai into s…
schnamo Feb 20, 2025
ba01607
fix issue with input args
schnamo Mar 12, 2025
f74964c
add missing configs
schnamo Mar 19, 2025
59064af
add HIV dataset handling
schnamo Apr 4, 2025
93d47eb
dd MUV dataset
schnamo Apr 4, 2025
87babcc
debugging
schnamo Apr 18, 2025
ebe049e
final updates
schnamo Jul 1, 2025
188f32f
add focal loss
schnamo Sep 29, 2025
dccc2e3
add focal loss
schnamo Sep 29, 2025
d57016f
format for lint
schnamo Sep 29, 2025
41c0b1c
lint fix
schnamo Sep 29, 2025
4c993a2
lint fix
schnamo Sep 29, 2025
4aa1771
add regression to readme
schnamo Sep 29, 2025
d411c9e
fix union expression
schnamo Jan 8, 2025
18d8e02
fix tuple issue to make it backwards compatible
schnamo Jan 8, 2025
af7df07
Merge branch 'sol_final' into dev and adjustments to new logic, adjus…
Oct 23, 2025
ed1d4b4
adjust to current dev branch
Oct 29, 2025
dca60a3
adjust all regression tasks to new logic
Oct 29, 2025
b6f0d23
adjust classification tasks for new logic
Oct 29, 2025
9b29411
lightning cli issue
Oct 29, 2025
d56e226
black-lint fix
schnamo Oct 29, 2025
fc444e0
fix load from checkpoint issues for pretrained models
schnamo Nov 3, 2025
5304b3e
adding decoding of encoded tokens function
schnamo Nov 11, 2025
426f1b0
remove print statements from debugging
schnamo Nov 11, 2025
b0b3113
Merge branch 'dev' of https://github.com/ChEB-AI/python-chebai into dev
schnamo Nov 13, 2025
fb6fdb7
lint fixes
schnamo Nov 13, 2025
81f8025
ruff fixes
schnamo Nov 13, 2025
9a24fd7
black fixes
schnamo Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,30 @@
ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI.
The library emphasizes the incorporation of the semantic qualities of the ontology into the learning process.

## Installation
## News

We now support regression tasks!

## Note for developers

You can install ChEBai via pip:
If you have used ChEBai before PR #39, the file structure in which your ChEBI-data is saved has changed. This means that
datasets will be freshly generated. The data however is the same. If you want to keep the old data (including the old
splits), you can use a migration script. It copies the old data to the new location for a specific ChEBI class
(including chebi version and other parameters). The script can be called by specifying the data module from a config
```
pip install chebai
python chebai/preprocessing/migration/chebi_data_migration.py migrate --datamodule=[path-to-data-config]
```
or by specifying the class name (e.g. `ChEBIOver50`) and arguments separately
```
python chebai/preprocessing/migration/chebi_data_migration.py migrate --class_name=[data-class] [--chebi_version=[version]]
```
The new dataset will by default generate random data splits (with a given seed).
To reuse a fixed data split, you have to provide the path of the csv file generated during the migration:
`--data.init_args.splits_file_path=[path-to-processed_data]/splits.csv`

Alternatively, you can get the latest development version directly from GitHub:
## Installation

To install ChEBai, follow these steps:

1. Clone the repository:
```
Expand Down Expand Up @@ -63,11 +79,16 @@ A command with additional options may look like this:
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
```

### Fine-tuning for Toxicity prediction
### Fine-tuning for classification tasks, e.g. Toxicity prediction
```
python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model]
```

### Fine-tuning for regression tasks, e.g. solubility prediction
```
python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=configs/training/solCur_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model]
```

### Predicting classes given SMILES strings
```
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
Expand Down
27 changes: 26 additions & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,40 @@ def call_data_methods(data: Type[XYBaseDataModule]):
)

for kind in ("train", "val", "test"):
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
for average in (
"micro-f1",
"macro-f1",
"balanced-accuracy",
"roc-auc",
"f1",
"mse",
"rmse",
"r2",
):
# When using lightning > 2.5.1 then need to uncomment all metrics that are not used
# for average in ("mse", "rmse","r2"): # for regression
# for average in ("f1", "roc-auc"): # for binary classification
# for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
# for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
parser.link_arguments(
"data.num_of_labels",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
apply_on="instantiate",
)

parser.link_arguments(
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
)
# parser.link_arguments(
# "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
# )
# parser.link_arguments(
# "data", "model.init_args.criterion.init_args.data_extractor"
# )
# parser.link_arguments(
# "data.init_args.chebi_version",
# "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
# )

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
Expand Down
152 changes: 152 additions & 0 deletions chebai/loss/focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


# from https://github.com/itakurah/Focal-loss-PyTorch


class FocalLoss(nn.Module):
def __init__(
self,
gamma=2,
alpha=None,
reduction="mean",
task_type="binary",
num_classes=None,
):
"""
Unified Focal Loss class for binary, multi-class, and multi-label classification tasks.
:param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma
:param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used.
:param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum'
:param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label'
:param num_classes: Number of classes (only required for multi-class classification)
"""
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.task_type = task_type
self.num_classes = num_classes

# Handle alpha for class balancing in multi-class tasks
if (
task_type == "multi-class"
and alpha is not None
and isinstance(alpha, (list, torch.Tensor))
):
assert (
num_classes is not None
), "num_classes must be specified for multi-class classification"
if isinstance(alpha, list):
self.alpha = torch.Tensor(alpha)
else:
self.alpha = alpha

def forward(self, inputs, targets):
"""
Forward pass to compute the Focal Loss based on the specified task type.
:param inputs: Predictions (logits) from the model.
Shape:
- binary/multi-label: (batch_size, num_classes)
- multi-class: (batch_size, num_classes)
:param targets: Ground truth labels.
Shape:
- binary: (batch_size,)
- multi-label: (batch_size, num_classes)
- multi-class: (batch_size,)
"""
if self.task_type == "binary":
return self.binary_focal_loss(inputs, targets)
elif self.task_type == "multi-class":
return self.multi_class_focal_loss(inputs, targets)
elif self.task_type == "multi-label":
return self.multi_label_focal_loss(inputs, targets)
else:
raise ValueError(
f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'."
)

def binary_focal_loss(self, inputs, targets):
"""Focal loss for binary classification."""
probs = torch.sigmoid(inputs)
targets = targets.float()

# Compute binary cross entropy
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

# Compute focal weight
p_t = probs * targets + (1 - probs) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma

# Apply alpha if provided
if self.alpha is not None:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
bce_loss = alpha_t * bce_loss

# Apply focal loss weighting
loss = focal_weight * bce_loss

if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss

def multi_class_focal_loss(self, inputs, targets):
"""Focal loss for multi-class classification."""
if self.alpha is not None:
alpha = self.alpha.to(inputs.device)

# Convert logits to probabilities with softmax
probs = F.softmax(inputs, dim=1)

# One-hot encode the targets
targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()

# Compute cross-entropy for each class
ce_loss = -targets_one_hot * torch.log(probs)

# Compute focal weight
p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample
focal_weight = (1 - p_t) ** self.gamma

# Apply alpha if provided (per-class weighting)
if self.alpha is not None:
alpha_t = alpha.gather(0, targets)
ce_loss = alpha_t.unsqueeze(1) * ce_loss

# Apply focal loss weight
loss = focal_weight.unsqueeze(1) * ce_loss

if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss

def multi_label_focal_loss(self, inputs, targets):
"""Focal loss for multi-label classification."""
probs = torch.sigmoid(inputs)

# Compute binary cross entropy
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

# Compute focal weight
p_t = probs * targets + (1 - probs) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma

# Apply alpha if provided
if self.alpha is not None:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
bce_loss = alpha_t * bce_loss

# Apply focal loss weight
loss = focal_weight * bce_loss

if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
4 changes: 2 additions & 2 deletions chebai/loss/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import os
import pickle
from typing import TYPE_CHECKING, List, Literal, Union
from typing import TYPE_CHECKING, List, Literal, Union, Tuple

import torch

Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
pos_epsilon: float = 0.01,
multiply_by_softmax: bool = False,
use_sigmoidal_implication: bool = False,
weight_epoch_dependent: Union[bool | tuple[int, int]] = False,
weight_epoch_dependent: Union[bool, Tuple[int, int]] = False,
start_at_epoch: int = 0,
violations_per_cls_aggregator: Literal[
"sum", "max", "mean", "log-sum", "log-max", "log-mean"
Expand Down
4 changes: 2 additions & 2 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
# super().__init__()
if exclude_hyperparameter_logging is None:
exclude_hyperparameter_logging = tuple()
self.criterion = criterion
Expand Down Expand Up @@ -273,7 +274,6 @@ def _execute(
loss_kwargs = dict()
if self.pass_loss_kwargs:
loss_kwargs = loss_kwargs_candidates
loss_kwargs["current_epoch"] = self.trainer.current_epoch
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
if isinstance(loss, tuple):
unnamed_loss_index = 1
Expand Down
34 changes: 32 additions & 2 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

logging.getLogger("pysmiles").setLevel(logging.CRITICAL)


from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa


Expand All @@ -40,6 +41,7 @@ class ElectraPre(ChebaiBaseNet):

def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
super().__init__(config=config, **kwargs)

self.generator_config = ElectraConfig(**config["generator"])
self.generator = ElectraForMaskedLM(self.generator_config)
self.discriminator_config = ElectraConfig(**config["discriminator"])
Expand Down Expand Up @@ -224,6 +226,7 @@ def __init__(
config: Optional[Dict[str, Any]] = None,
pretrained_checkpoint: Optional[str] = None,
load_prefix: Optional[str] = None,
model_type="classification",
freeze_electra: bool = False,
**kwargs: Any,
):
Expand All @@ -237,6 +240,8 @@ def __init__(
config["num_labels"] = self.out_dim
self.config = ElectraConfig(**config, output_attentions=True)
self.word_dropout = nn.Dropout(config.get("word_dropout", 0))
self.model_type = model_type
self.pass_loss_kwargs = True

in_d = self.config.hidden_size
self.output = nn.Sequential(
Expand Down Expand Up @@ -285,9 +290,16 @@ def _process_for_loss(
tuple: A tuple containing the processed model output, labels, and loss arguments.
"""
kwargs_copy = dict(loss_kwargs)
output = model_output["logits"]
if labels is not None:
labels = labels.float()
return model_output["logits"], labels, kwargs_copy
if "missing_labels" in kwargs_copy:
missing_labels = kwargs_copy.pop("missing_labels")
output = output * (~missing_labels).int() - 10000 * missing_labels.int()
labels = labels * (~missing_labels).int()
if self.model_type == "classification":
assert ((labels <= torch.tensor(1.0)) & (labels >= torch.tensor(0.0))).all()
return output, labels, kwargs_copy

def _get_prediction_and_labels(
self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor]
Expand All @@ -308,7 +320,25 @@ def _get_prediction_and_labels(
if "non_null_labels" in loss_kwargs:
n = loss_kwargs["non_null_labels"]
d = d[n]
return torch.sigmoid(d), labels.int() if labels is not None else None
if self.model_type == "classification":
# print(self.model_type, ' in electra 324')
# for mulitclass here softmax instead of sigmoid
d = torch.sigmoid(
d
) # changing this made a difference for the roc-auc but not the f1, why?
if "missing_labels" in loss_kwargs:
missing_labels = loss_kwargs["missing_labels"]
d = d * (~missing_labels).int().to(
device=d.device
) # we set the prob of missing labels to 0
labels = labels * (~missing_labels).int().to(
device=d.device
) # we set the labels of missing labels to 0
return d, labels.int() if labels is not None else None
elif self.model_type == "regression":
return d, labels
else:
raise ValueError("Please specify a valid model type in your model config.")

def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
"""
Expand Down
Loading
Loading