Source code for torchani.nn

import torch
from collections import OrderedDict
from torch import Tensor
from typing import Tuple, NamedTuple, Optional
from . import utils


class SpeciesEnergies(NamedTuple):
    species: Tensor
    energies: Tensor


class SpeciesCoordinates(NamedTuple):
    species: Tensor
    coordinates: Tensor


[docs]class ANIModel(torch.nn.ModuleDict): """ANI model that compute energies from species and AEVs. Different atom types might have different modules, when computing energies, for each atom, the module for its corresponding atom type will be applied to its AEV, after that, outputs of modules will be reduced along different atoms to obtain molecular energies. .. warning:: The species must be indexed in 0, 1, 2, 3, ..., not the element index in periodic table. Check :class:`torchani.SpeciesConverter` if you want periodic table indexing. .. note:: The resulting energies are in Hartree. Arguments: modules (:class:`collections.abc.Sequence`): Modules for each atom types. Atom types are distinguished by their order in :attr:`modules`, which means, for example ``modules[i]`` must be the module for atom type ``i``. Different atom types can share a module by putting the same reference in :attr:`modules`. """ @staticmethod def ensureOrderedDict(modules): if isinstance(modules, OrderedDict): return modules od = OrderedDict() for i, m in enumerate(modules): od[str(i)] = m return od def __init__(self, modules): super().__init__(self.ensureOrderedDict(modules)) def forward(self, species_aev: Tuple[Tensor, Tensor], # type: ignore cell: Optional[Tensor] = None, pbc: Optional[Tensor] = None) -> SpeciesEnergies: species, aev = species_aev assert species.shape == aev.shape[:-1] atomic_energies = self._atomic_energies((species, aev)) # shape of atomic energies is (C, A) return SpeciesEnergies(species, torch.sum(atomic_energies, dim=1)) @torch.jit.export def _atomic_energies(self, species_aev: Tuple[Tensor, Tensor]) -> Tensor: # Obtain the atomic energies associated with a given tensor of AEV's species, aev = species_aev assert species.shape == aev.shape[:-1] species_ = species.flatten() aev = aev.flatten(0, 1) output = aev.new_zeros(species_.shape) for i, m in enumerate(self.values()): mask = (species_ == i) midx = mask.nonzero().flatten() if midx.shape[0] > 0: input_ = aev.index_select(0, midx) output.masked_scatter_(mask, m(input_).flatten()) output = output.view_as(species) return output
[docs]class Ensemble(torch.nn.ModuleList): """Compute the average output of an ensemble of modules.""" def __init__(self, modules): super().__init__(modules) self.size = len(modules) def forward(self, species_input: Tuple[Tensor, Tensor], # type: ignore cell: Optional[Tensor] = None, pbc: Optional[Tensor] = None) -> SpeciesEnergies: sum_ = 0 for x in self: sum_ += x(species_input)[1] species, _ = species_input return SpeciesEnergies(species, sum_ / self.size)
class Sequential(torch.nn.ModuleList): """Modified Sequential module that accept Tuple type as input""" def __init__(self, *modules): super().__init__(modules) def forward(self, input_: Tuple[Tensor, Tensor], # type: ignore cell: Optional[Tensor] = None, pbc: Optional[Tensor] = None): for module in self: input_ = module(input_, cell=cell, pbc=pbc) return input_
[docs]class Gaussian(torch.nn.Module): """Gaussian activation""" def forward(self, x: Tensor) -> Tensor: return torch.exp(- x * x)
[docs]class SpeciesConverter(torch.nn.Module): """Converts tensors with species labeled as atomic numbers into tensors labeled with internal torchani indices according to a custom ordering scheme. It takes a custom species ordering as initialization parameter. If the class is initialized with ['H', 'C', 'N', 'O'] for example, it will convert a tensor [1, 1, 6, 7, 1, 8] into a tensor [0, 0, 1, 2, 0, 3] Arguments: species (:class:`collections.abc.Sequence` of :class:`str`): sequence of all supported species, in order (it is recommended to order according to atomic number). """ conv_tensor: Tensor def __init__(self, species): super().__init__() rev_idx = {s: k for k, s in enumerate(utils.PERIODIC_TABLE)} maxidx = max(rev_idx.values()) self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long)) for i, s in enumerate(species): self.conv_tensor[rev_idx[s]] = i
[docs] def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None, pbc: Optional[Tensor] = None): """Convert species from periodic table element index to 0, 1, 2, 3, ... indexing""" species, coordinates = input_ converted_species = self.conv_tensor[species] # check if unknown species are included if converted_species[species.ne(-1)].lt(0).any(): raise ValueError(f'Unknown species found in {species}') return SpeciesCoordinates(converted_species.to(species.device), coordinates)