-
Notifications
You must be signed in to change notification settings - Fork 6
New regression and classification datasets for ontology pre-training #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 53 commits
ed2ca6c
ead3007
5956183
c3afeed
d57b073
0faca31
4000215
0709188
f8bd06a
21fbde4
f3bfe08
e26925d
0f2f85f
d0da5c2
0d94b44
45228ba
dbf8532
fa97f45
8b91dce
677d6ec
2c159e8
b537b7f
a99e438
754de12
9b084cb
326e9a2
c272f45
dc9e104
9a3967d
1bc8736
4885960
baa085f
ba01607
f74964c
59064af
93d47eb
87babcc
ebe049e
188f32f
dccc2e3
d57016f
41c0b1c
4c993a2
4aa1771
d411c9e
18d8e02
af7df07
ed1d4b4
dca60a3
b6f0d23
9b29411
d56e226
fc444e0
5304b3e
426f1b0
b0b3113
fb6fdb7
81f8025
9a24fd7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| # from https://github.com/itakurah/Focal-loss-PyTorch | ||
|
|
||
|
|
||
| class FocalLoss(nn.Module): | ||
| def __init__( | ||
| self, | ||
| gamma=2, | ||
| alpha=None, | ||
| reduction="mean", | ||
| task_type="binary", | ||
| num_classes=None, | ||
| ): | ||
| """ | ||
| Unified Focal Loss class for binary, multi-class, and multi-label classification tasks. | ||
| :param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma | ||
| :param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used. | ||
| :param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum' | ||
| :param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label' | ||
| :param num_classes: Number of classes (only required for multi-class classification) | ||
| """ | ||
| super(FocalLoss, self).__init__() | ||
| self.gamma = gamma | ||
| self.alpha = alpha | ||
| self.reduction = reduction | ||
| self.task_type = task_type | ||
| self.num_classes = num_classes | ||
|
|
||
| # Handle alpha for class balancing in multi-class tasks | ||
| if ( | ||
| task_type == "multi-class" | ||
| and alpha is not None | ||
| and isinstance(alpha, (list, torch.Tensor)) | ||
| ): | ||
| assert ( | ||
| num_classes is not None | ||
| ), "num_classes must be specified for multi-class classification" | ||
| if isinstance(alpha, list): | ||
| self.alpha = torch.Tensor(alpha) | ||
| else: | ||
| self.alpha = alpha | ||
|
|
||
| def forward(self, inputs, targets): | ||
| """ | ||
| Forward pass to compute the Focal Loss based on the specified task type. | ||
| :param inputs: Predictions (logits) from the model. | ||
| Shape: | ||
| - binary/multi-label: (batch_size, num_classes) | ||
| - multi-class: (batch_size, num_classes) | ||
| :param targets: Ground truth labels. | ||
| Shape: | ||
| - binary: (batch_size,) | ||
| - multi-label: (batch_size, num_classes) | ||
| - multi-class: (batch_size,) | ||
| """ | ||
| if self.task_type == "binary": | ||
| return self.binary_focal_loss(inputs, targets) | ||
| elif self.task_type == "multi-class": | ||
| return self.multi_class_focal_loss(inputs, targets) | ||
| elif self.task_type == "multi-label": | ||
| return self.multi_label_focal_loss(inputs, targets) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'." | ||
| ) | ||
|
|
||
| def binary_focal_loss(self, inputs, targets): | ||
| """Focal loss for binary classification.""" | ||
| probs = torch.sigmoid(inputs) | ||
| targets = targets.float() | ||
|
|
||
| # Compute binary cross entropy | ||
| bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | ||
|
|
||
| # Compute focal weight | ||
| p_t = probs * targets + (1 - probs) * (1 - targets) | ||
| focal_weight = (1 - p_t) ** self.gamma | ||
|
|
||
| # Apply alpha if provided | ||
| if self.alpha is not None: | ||
| alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) | ||
| bce_loss = alpha_t * bce_loss | ||
|
|
||
| # Apply focal loss weighting | ||
| loss = focal_weight * bce_loss | ||
|
|
||
| if self.reduction == "mean": | ||
| return loss.mean() | ||
| elif self.reduction == "sum": | ||
| return loss.sum() | ||
| return loss | ||
|
|
||
| def multi_class_focal_loss(self, inputs, targets): | ||
| """Focal loss for multi-class classification.""" | ||
| if self.alpha is not None: | ||
| alpha = self.alpha.to(inputs.device) | ||
|
|
||
| # Convert logits to probabilities with softmax | ||
| probs = F.softmax(inputs, dim=1) | ||
|
|
||
| # One-hot encode the targets | ||
| targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float() | ||
|
|
||
| # Compute cross-entropy for each class | ||
| ce_loss = -targets_one_hot * torch.log(probs) | ||
|
|
||
| # Compute focal weight | ||
| p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample | ||
| focal_weight = (1 - p_t) ** self.gamma | ||
|
|
||
| # Apply alpha if provided (per-class weighting) | ||
| if self.alpha is not None: | ||
| alpha_t = alpha.gather(0, targets) | ||
| ce_loss = alpha_t.unsqueeze(1) * ce_loss | ||
|
|
||
| # Apply focal loss weight | ||
| loss = focal_weight.unsqueeze(1) * ce_loss | ||
|
|
||
| if self.reduction == "mean": | ||
| return loss.mean() | ||
| elif self.reduction == "sum": | ||
| return loss.sum() | ||
| return loss | ||
|
|
||
| def multi_label_focal_loss(self, inputs, targets): | ||
| """Focal loss for multi-label classification.""" | ||
| probs = torch.sigmoid(inputs) | ||
|
|
||
| # Compute binary cross entropy | ||
| bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | ||
|
|
||
| # Compute focal weight | ||
| p_t = probs * targets + (1 - probs) * (1 - targets) | ||
| focal_weight = (1 - p_t) ** self.gamma | ||
|
|
||
| # Apply alpha if provided | ||
| if self.alpha is not None: | ||
| alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) | ||
| bce_loss = alpha_t * bce_loss | ||
|
|
||
| # Apply focal loss weight | ||
| loss = focal_weight * bce_loss | ||
|
|
||
| if self.reduction == "mean": | ||
| return loss.mean() | ||
| elif self.reduction == "sum": | ||
| return loss.sum() | ||
| return loss |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,7 +41,8 @@ def __init__( | |
| exclude_hyperparameter_logging: Optional[Iterable[str]] = None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__() | ||
| super().__init__(**kwargs) | ||
| # super().__init__() | ||
| if exclude_hyperparameter_logging is None: | ||
| exclude_hyperparameter_logging = tuple() | ||
| self.criterion = criterion | ||
|
|
@@ -264,7 +265,7 @@ def _execute( | |
| loss_kwargs = dict() | ||
| if self.pass_loss_kwargs: | ||
| loss_kwargs = loss_kwargs_candidates | ||
| loss_kwargs["current_epoch"] = self.trainer.current_epoch | ||
| # loss_kwargs["current_epoch"] = self.trainer.current_epoch | ||
|
||
| loss = self.criterion(loss_data, loss_labels, **loss_kwargs) | ||
| if isinstance(loss, tuple): | ||
| unnamed_loss_index = 1 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,8 @@ | |
|
|
||
| logging.getLogger("pysmiles").setLevel(logging.CRITICAL) | ||
|
|
||
| from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa | ||
| # TODO: put back in before pull request | ||
| # from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa | ||
|
||
|
|
||
|
|
||
| class ElectraPre(ChebaiBaseNet): | ||
|
|
@@ -40,6 +41,7 @@ class ElectraPre(ChebaiBaseNet): | |
|
|
||
| def __init__(self, config: Dict[str, Any] = None, **kwargs: Any): | ||
| super().__init__(config=config, **kwargs) | ||
|
|
||
| self.generator_config = ElectraConfig(**config["generator"]) | ||
| self.generator = ElectraForMaskedLM(self.generator_config) | ||
| self.discriminator_config = ElectraConfig(**config["discriminator"]) | ||
|
|
@@ -224,6 +226,7 @@ def __init__( | |
| config: Optional[Dict[str, Any]] = None, | ||
| pretrained_checkpoint: Optional[str] = None, | ||
| load_prefix: Optional[str] = None, | ||
| model_type="classification", | ||
| **kwargs: Any, | ||
| ): | ||
| # Remove this property in order to prevent it from being stored as a | ||
|
|
@@ -236,6 +239,8 @@ def __init__( | |
| config["num_labels"] = self.out_dim | ||
| self.config = ElectraConfig(**config, output_attentions=True) | ||
| self.word_dropout = nn.Dropout(config.get("word_dropout", 0)) | ||
| self.model_type = model_type | ||
| self.pass_loss_kwargs = True | ||
|
|
||
| in_d = self.config.hidden_size | ||
| self.output = nn.Sequential( | ||
|
|
@@ -262,6 +267,10 @@ def __init__( | |
| else: | ||
| self.electra = ElectraModel(config=self.config) | ||
|
|
||
| # freeze parameters | ||
| # for param in self.electra.parameters(): | ||
| # param.requires_grad = False | ||
|
|
||
| def _process_for_loss( | ||
| self, | ||
| model_output: Dict[str, Tensor], | ||
|
|
@@ -280,9 +289,16 @@ def _process_for_loss( | |
| tuple: A tuple containing the processed model output, labels, and loss arguments. | ||
| """ | ||
| kwargs_copy = dict(loss_kwargs) | ||
| output = model_output["logits"] | ||
| if labels is not None: | ||
| labels = labels.float() | ||
| return model_output["logits"], labels, kwargs_copy | ||
| if "missing_labels" in kwargs_copy: | ||
| missing_labels = kwargs_copy.pop("missing_labels") | ||
| output = output * (~missing_labels).int() - 10000 * missing_labels.int() | ||
| labels = labels * (~missing_labels).int() | ||
| if self.model_type == "classification": | ||
| assert ((labels <= torch.tensor(1.0)) & (labels >= torch.tensor(0.0))).all() | ||
| return output, labels, kwargs_copy | ||
|
|
||
| def _get_prediction_and_labels( | ||
| self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor] | ||
|
|
@@ -303,7 +319,25 @@ def _get_prediction_and_labels( | |
| if "non_null_labels" in loss_kwargs: | ||
| n = loss_kwargs["non_null_labels"] | ||
| d = d[n] | ||
| return torch.sigmoid(d), labels.int() if labels is not None else None | ||
| if self.model_type == "classification": | ||
| # print(self.model_type, ' in electra 324') | ||
| # for mulitclass here softmax instead of sigmoid | ||
| d = torch.sigmoid( | ||
| d | ||
| ) # changing this made a difference for the roc-auc but not the f1, why? | ||
| if "missing_labels" in loss_kwargs: | ||
| missing_labels = loss_kwargs["missing_labels"] | ||
| d = d * (~missing_labels).int().to( | ||
| device=d.device | ||
| ) # we set the prob of missing labels to 0 | ||
| labels = labels * (~missing_labels).int().to( | ||
| device=d.device | ||
| ) # we set the labels of missing labels to 0 | ||
| return d, labels.int() if labels is not None else None | ||
| elif self.model_type == "regression": | ||
| return d, labels | ||
| else: | ||
| raise ValueError("Please specify a valid model type in your model config.") | ||
|
|
||
| def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does
weight_epoch_dependentappear twice here?