# -*- coding: utf-8 -*-
"""Tools for interfacing with `ASE`_.
.. _ASE:
https://wiki.fysik.dtu.dk/ase
"""
import torch
from . import utils
import ase.calculators.calculator
import ase.units
[docs]class Calculator(ase.calculators.calculator.Calculator):
"""TorchANI calculator for ASE
Arguments:
species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
model (:class:`torch.nn.Module`): neural network potential model
that convert coordinates into energies.
overwrite (bool): After wrapping atoms into central box, whether
to replace the original positions stored in :class:`ase.Atoms`
object with the wrapped positions.
"""
implemented_properties = ['energy', 'forces', 'stress', 'free_energy']
def __init__(self, species, model, overwrite=False):
super().__init__()
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
self.model = model
# Since ANI is used in inference mode, no gradients on model parameters are required here
for p in self.model.parameters():
p.requires_grad_(False)
self.overwrite = overwrite
a_parameter = next(self.model.parameters())
self.device = a_parameter.device
self.dtype = a_parameter.dtype
try:
# We assume that the model has a "periodic_table_index" attribute
# if it doesn't we set the calculator's attribute to false and we
# assume that species will be correctly transformed by
# species_to_tensor
self.periodic_table_index = model.periodic_table_index
except AttributeError:
self.periodic_table_index = False
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super().calculate(atoms, properties, system_changes)
cell = torch.tensor(self.atoms.get_cell(complete=True).array,
dtype=self.dtype, device=self.device)
pbc = torch.tensor(self.atoms.get_pbc(), dtype=torch.bool,
device=self.device)
pbc_enabled = pbc.any().item()
if self.periodic_table_index:
species = torch.tensor(self.atoms.get_atomic_numbers(), dtype=torch.long, device=self.device)
else:
species = self.species_to_tensor(self.atoms.get_chemical_symbols()).to(self.device)
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties)
if pbc_enabled:
coordinates = utils.map2central(cell, coordinates, pbc)
if self.overwrite and atoms is not None:
atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy())
if 'stress' in properties:
scaling = torch.eye(3, requires_grad=True, dtype=self.dtype, device=self.device)
coordinates = coordinates @ scaling
coordinates = coordinates.unsqueeze(0)
if pbc_enabled:
if 'stress' in properties:
cell = cell @ scaling
energy = self.model((species, coordinates), cell=cell, pbc=pbc).energies
else:
energy = self.model((species, coordinates)).energies
energy *= ase.units.Hartree
self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item()
if 'forces' in properties:
forces = -torch.autograd.grad(energy.squeeze(), coordinates, retain_graph='stress' in properties)[0]
self.results['forces'] = forces.squeeze(0).to('cpu').numpy()
if 'stress' in properties:
volume = self.atoms.get_volume()
stress = torch.autograd.grad(energy.squeeze(), scaling)[0] / volume
self.results['stress'] = stress.cpu().numpy()