Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import re
import tempfile
from typing import Optional, List, Tuple, Dict
from typing import Optional, List, Tuple, Dict, Union
import torch
from torch.autograd import grad
from torch import nn, Tensor
Expand Down Expand Up @@ -494,7 +494,7 @@ class Ensemble(torch.nn.ModuleList):

Args:
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy).
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns the standard deviation of model ouputs and derivatives.
"""

def __init__(self, modules: List[nn.Module], return_std: bool = False):
Expand All @@ -505,32 +505,66 @@ def __init__(self, modules: List[nn.Module], return_std: bool = False):

def forward(
self,
*args,
**kwargs,
):
"""Average predictions over all models in the ensemble.
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:
Note that this function will fail if derivative=False.

"""
Compute the output of the model.

This function optionally supports periodic boundary conditions with
arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy
certain requirements:

.. code:: python

a[1] = a[2] = b[2] = 0
a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff
a[0] >= 2*b[0]
a[0] >= 2*c[0]
b[1] >= 2*c[1]


These requirements correspond to a particular rotation of the system and
reduced form of the vectors, as well as the requirement that the cutoff be
no larger than half the box width.

Args:
*args: Positional arguments to forward to the models.
**kwargs: Keyword arguments to forward to the models.
Returns:
Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy).
z (Tensor): Atomic numbers of the atoms in the molecule. Shape: (N,).
pos (Tensor): Atomic positions in the molecule. Shape: (N, 3).
batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,).
box (Tensor, optional): Box vectors. Shape (3, 3).
The vectors defining the periodic box. This must have shape `(3, 3)`,
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.

Returns:
Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: The mean output of the models, the mean derivatives, the std of the outputs if return_std is true, the std of the derivatives if return_std is true.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

say "the mean negative derivatives"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe something like this should stay in this docstring:

Average predictions over all models in the ensemble.
        The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
        ```


"""


y = []
neg_dy = []
for model in self:
res = model(*args, **kwargs)
res = model(z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args)
y.append(res[0])
neg_dy.append(res[1])
y = torch.stack(y)
neg_dy = torch.stack(neg_dy)
y_mean = torch.mean(y, axis=0)
neg_dy_mean = torch.mean(neg_dy, axis=0)
y_std = torch.std(y, axis=0)
neg_dy_std = torch.std(neg_dy, axis=0)
y_mean = torch.mean(y, dim=0)
neg_dy_mean = torch.mean(neg_dy, dim=0)
y_std = torch.std(y, dim=0)
neg_dy_std = torch.std(neg_dy, dim=0)

if self.return_std:
return y_mean, neg_dy_mean, y_std, neg_dy_std
else:
return y_mean, neg_dy_mean
return y_mean, neg_dy_mean, None, None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No point in returning None, None. The only reason for the if was to return less variables.
Remove the if and always return y_mean, neg_dy_mean, y_std, neg_dy_std

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And remove the class option and also from the loader.