import torch
from collections import OrderedDict
from torch import Tensor
from typing import Tuple, NamedTuple, Optional
from . import utils
class SpeciesEnergies(NamedTuple):
species: Tensor
energies: Tensor
class SpeciesCoordinates(NamedTuple):
species: Tensor
coordinates: Tensor
[docs]class ANIModel(torch.nn.ModuleDict):
"""ANI model that compute energies from species and AEVs.
Different atom types might have different modules, when computing
energies, for each atom, the module for its corresponding atom type will
be applied to its AEV, after that, outputs of modules will be reduced along
different atoms to obtain molecular energies.
.. warning::
The species must be indexed in 0, 1, 2, 3, ..., not the element
index in periodic table. Check :class:`torchani.SpeciesConverter`
if you want periodic table indexing.
.. note:: The resulting energies are in Hartree.
Arguments:
modules (:class:`collections.abc.Sequence`): Modules for each atom
types. Atom types are distinguished by their order in
:attr:`modules`, which means, for example ``modules[i]`` must be
the module for atom type ``i``. Different atom types can share a
module by putting the same reference in :attr:`modules`.
"""
@staticmethod
def ensureOrderedDict(modules):
if isinstance(modules, OrderedDict):
return modules
od = OrderedDict()
for i, m in enumerate(modules):
od[str(i)] = m
return od
def __init__(self, modules):
super().__init__(self.ensureOrderedDict(modules))
def forward(self, species_aev: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
species, aev = species_aev
assert species.shape == aev.shape[:-1]
atomic_energies = self._atomic_energies((species, aev))
# shape of atomic energies is (C, A)
return SpeciesEnergies(species, torch.sum(atomic_energies, dim=1))
@torch.jit.export
def _atomic_energies(self, species_aev: Tuple[Tensor, Tensor]) -> Tensor:
# Obtain the atomic energies associated with a given tensor of AEV's
species, aev = species_aev
assert species.shape == aev.shape[:-1]
species_ = species.flatten()
aev = aev.flatten(0, 1)
output = aev.new_zeros(species_.shape)
for i, m in enumerate(self.values()):
mask = (species_ == i)
midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
input_ = aev.index_select(0, midx)
output.masked_scatter_(mask, m(input_).flatten())
output = output.view_as(species)
return output
[docs]class Ensemble(torch.nn.ModuleList):
"""Compute the average output of an ensemble of modules."""
def __init__(self, modules):
super().__init__(modules)
self.size = len(modules)
def forward(self, species_input: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
sum_ = 0
for x in self:
sum_ += x(species_input)[1]
species, _ = species_input
return SpeciesEnergies(species, sum_ / self.size)
class Sequential(torch.nn.ModuleList):
"""Modified Sequential module that accept Tuple type as input"""
def __init__(self, *modules):
super().__init__(modules)
def forward(self, input_: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None):
for module in self:
input_ = module(input_, cell=cell, pbc=pbc)
return input_
[docs]class Gaussian(torch.nn.Module):
"""Gaussian activation"""
def forward(self, x: Tensor) -> Tensor:
return torch.exp(- x * x)
[docs]class SpeciesConverter(torch.nn.Module):
"""Converts tensors with species labeled as atomic numbers into tensors
labeled with internal torchani indices according to a custom ordering
scheme. It takes a custom species ordering as initialization parameter. If
the class is initialized with ['H', 'C', 'N', 'O'] for example, it will
convert a tensor [1, 1, 6, 7, 1, 8] into a tensor [0, 0, 1, 2, 0, 3]
Arguments:
species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order (it is recommended to order
according to atomic number).
"""
conv_tensor: Tensor
def __init__(self, species):
super().__init__()
rev_idx = {s: k for k, s in enumerate(utils.PERIODIC_TABLE)}
maxidx = max(rev_idx.values())
self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long))
for i, s in enumerate(species):
self.conv_tensor[rev_idx[s]] = i
[docs] def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None):
"""Convert species from periodic table element index to 0, 1, 2, 3, ... indexing"""
species, coordinates = input_
converted_species = self.conv_tensor[species]
# check if unknown species are included
if converted_species[species.ne(-1)].lt(0).any():
raise ValueError(f'Unknown species found in {species}')
return SpeciesCoordinates(converted_species.to(species.device), coordinates)