Source code for torchani.aev

import torch

from torch import Tensor
import math
from typing import Tuple, Optional, NamedTuple
import sys
import warnings
import importlib_metadata

has_cuaev = 'torchani.cuaev' in importlib_metadata.metadata(__package__).get_all('Provides')

if has_cuaev:
    # We need to import torchani.cuaev to tell PyTorch to initialize torch.ops.cuaev
    from . import cuaev  # type: ignore # noqa: F401
else:
    warnings.warn("cuaev not installed")

if sys.version_info[:2] < (3, 7):
    class FakeFinal:
        def __getitem__(self, x):
            return x
    Final = FakeFinal()
else:
    from torch.jit import Final


class SpeciesAEV(NamedTuple):
    species: Tensor
    aevs: Tensor


def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
    # assuming all elements in distances are smaller than cutoff
    return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5


def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> Tensor:
    """Compute the radial subAEV terms of the center atom given neighbors

    This correspond to equation (3) in the `ANI paper`_. This function just
    compute the terms. The sum in the equation is not computed.
    The input tensor have shape (conformations, atoms, N), where ``N``
    is the number of neighbor atoms within the cutoff radius and output
    tensor should have shape
    (conformations, atoms, ``self.radial_sublength()``)

    .. _ANI paper:
        http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
    """
    distances = distances.view(-1, 1, 1)
    fc = cutoff_cosine(distances, Rcr)
    # Note that in the equation in the paper there is no 0.25
    # coefficient, but in NeuroChem there is such a coefficient.
    # We choose to be consistent with NeuroChem instead of the paper here.
    ret = 0.25 * torch.exp(-EtaR * (distances - ShfR)**2) * fc
    # At this point, ret now has shape
    # (conformations x atoms, ?, ?) where ? depend on constants.
    # We then should flat the last 2 dimensions to view the subAEV as a two
    # dimensional tensor (onnx doesn't support negative indices in flatten)
    return ret.flatten(start_dim=1)


def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
                  ShfA: Tensor, vectors12: Tensor) -> Tensor:
    """Compute the angular subAEV terms of the center atom given neighbor pairs.

    This correspond to equation (4) in the `ANI paper`_. This function just
    compute the terms. The sum in the equation is not computed.
    The input tensor have shape (conformations, atoms, N), where N
    is the number of neighbor atom pairs within the cutoff radius and
    output tensor should have shape
    (conformations, atoms, ``self.angular_sublength()``)

    .. _ANI paper:
        http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
    """
    vectors12 = vectors12.view(2, -1, 3, 1, 1, 1, 1)
    distances12 = vectors12.norm(2, dim=-5)
    cos_angles = vectors12.prod(0).sum(1) / torch.clamp(distances12.prod(0), min=1e-10)
    # 0.95 is multiplied to the cos values to prevent acos from returning NaN.
    angles = torch.acos(0.95 * cos_angles)

    fcj12 = cutoff_cosine(distances12, Rca)
    factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
    factor2 = torch.exp(-EtaA * (distances12.sum(0) / 2 - ShfA) ** 2)
    ret = 2 * factor1 * factor2 * fcj12.prod(0)
    # At this point, ret now has shape
    # (conformations x atoms, ?, ?, ?, ?) where ? depend on constants.
    # We then should flat the last 4 dimensions to view the subAEV as a two
    # dimensional tensor (onnx doesn't support negative indices in flatten)
    return ret.flatten(start_dim=1)


def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
    """Compute the shifts of unit cell along the given cell vectors to make it
    large enough to contain all pairs of neighbor atoms with PBC under
    consideration

    Arguments:
        cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
        vectors defining unit cell:
            tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
        cutoff (float): the cutoff inside which atoms are considered pairs
        pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
            if pbc is enabled for that direction.

    Returns:
        :class:`torch.Tensor`: long tensor of shifts. the center cell and
            symmetric cells are not included.
    """
    reciprocal_cell = cell.inverse().t()
    inv_distances = reciprocal_cell.norm(2, -1)
    num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
    num_repeats = torch.where(pbc, num_repeats, num_repeats.new_zeros(()))
    r1 = torch.arange(1, num_repeats[0].item() + 1, device=cell.device)
    r2 = torch.arange(1, num_repeats[1].item() + 1, device=cell.device)
    r3 = torch.arange(1, num_repeats[2].item() + 1, device=cell.device)
    o = torch.zeros(1, dtype=torch.long, device=cell.device)
    return torch.cat([
        torch.cartesian_prod(r1, r2, r3),
        torch.cartesian_prod(r1, r2, o),
        torch.cartesian_prod(r1, r2, -r3),
        torch.cartesian_prod(r1, o, r3),
        torch.cartesian_prod(r1, o, o),
        torch.cartesian_prod(r1, o, -r3),
        torch.cartesian_prod(r1, -r2, r3),
        torch.cartesian_prod(r1, -r2, o),
        torch.cartesian_prod(r1, -r2, -r3),
        torch.cartesian_prod(o, r2, r3),
        torch.cartesian_prod(o, r2, o),
        torch.cartesian_prod(o, r2, -r3),
        torch.cartesian_prod(o, o, r3),
    ])


def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
                   shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor]:
    """Compute pairs of atoms that are neighbors

    Arguments:
        padding_mask (:class:`torch.Tensor`): boolean tensor of shape
            (molecules, atoms) for padding mask. 1 == is padding.
        coordinates (:class:`torch.Tensor`): tensor of shape
            (molecules, atoms, 3) for atom coordinates.
        cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
            defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
        cutoff (float): the cutoff inside which atoms are considered pairs
        shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
    """
    coordinates = coordinates.detach().masked_fill(padding_mask.unsqueeze(-1), math.nan)
    cell = cell.detach()
    num_atoms = padding_mask.shape[1]
    num_mols = padding_mask.shape[0]
    all_atoms = torch.arange(num_atoms, device=cell.device)

    # Step 2: center cell
    # torch.triu_indices is faster than combinations
    p12_center = torch.triu_indices(num_atoms, num_atoms, 1, device=cell.device)
    shifts_center = shifts.new_zeros((p12_center.shape[1], 3))

    # Step 3: cells with shifts
    # shape convention (shift index, molecule index, atom index, 3)
    num_shifts = shifts.shape[0]
    all_shifts = torch.arange(num_shifts, device=cell.device)
    prod = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).t()
    shift_index = prod[0]
    p12 = prod[1:]
    shifts_outside = shifts.index_select(0, shift_index)

    # Step 4: combine results for all cells
    shifts_all = torch.cat([shifts_center, shifts_outside])
    p12_all = torch.cat([p12_center, p12], dim=1)
    shift_values = shifts_all.to(cell.dtype) @ cell

    # step 5, compute distances, and find all pairs within cutoff
    selected_coordinates = coordinates.index_select(1, p12_all.view(-1)).view(num_mols, 2, -1, 3)
    distances = (selected_coordinates[:, 0, ...] - selected_coordinates[:, 1, ...] + shift_values).norm(2, -1)
    in_cutoff = (distances <= cutoff).nonzero()
    molecule_index, pair_index = in_cutoff.unbind(1)
    molecule_index *= num_atoms
    atom_index12 = p12_all[:, pair_index]
    shifts = shifts_all.index_select(0, pair_index)
    return molecule_index + atom_index12, shifts


def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tensor:
    """Compute pairs of atoms that are neighbors (doesn't use PBC)

    This function bypasses the calculation of shifts and duplication
    of atoms in order to make calculations faster

    Arguments:
        padding_mask (:class:`torch.Tensor`): boolean tensor of shape
            (molecules, atoms) for padding mask. 1 == is padding.
        coordinates (:class:`torch.Tensor`): tensor of shape
            (molecules, atoms, 3) for atom coordinates.
        cutoff (float): the cutoff inside which atoms are considered pairs
    """
    coordinates = coordinates.detach().masked_fill(padding_mask.unsqueeze(-1), math.nan)
    current_device = coordinates.device
    num_atoms = padding_mask.shape[1]
    num_mols = padding_mask.shape[0]
    p12_all = torch.triu_indices(num_atoms, num_atoms, 1, device=current_device)
    p12_all_flattened = p12_all.view(-1)

    pair_coordinates = coordinates.index_select(1, p12_all_flattened).view(num_mols, 2, -1, 3)
    distances = (pair_coordinates[:, 0, ...] - pair_coordinates[:, 1, ...]).norm(2, -1)
    in_cutoff = (distances <= cutoff).nonzero()
    molecule_index, pair_index = in_cutoff.unbind(1)
    molecule_index *= num_atoms
    atom_index12 = p12_all[:, pair_index] + molecule_index
    return atom_index12


def triu_index(num_species: int) -> Tensor:
    species1, species2 = torch.triu_indices(num_species, num_species).unbind(0)
    pair_index = torch.arange(species1.shape[0], dtype=torch.long)
    ret = torch.zeros(num_species, num_species, dtype=torch.long)
    ret[species1, species2] = pair_index
    ret[species2, species1] = pair_index
    return ret


def cumsum_from_zero(input_: Tensor) -> Tensor:
    cumsum = torch.zeros_like(input_)
    torch.cumsum(input_[:-1], dim=0, out=cumsum[1:])
    return cumsum


def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
    """Input: indices for pairs of atoms that are close to each other.
    each pair only appear once, i.e. only one of the pairs (1, 2) and
    (2, 1) exists.

    Output: indices for all central atoms and it pairs of neighbors. For
    example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
    (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
    central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
    are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
    """
    # convert representation from pair to central-others
    ai1 = atom_index12.view(-1)
    sorted_ai1, rev_indices = ai1.sort()

    # sort and compute unique key
    uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_inverse=False, return_counts=True)

    # compute central_atom_index
    pair_sizes = torch.div(counts * (counts - 1), 2, rounding_mode='trunc')
    pair_indices = torch.repeat_interleave(pair_sizes)
    central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)

    # do local combinations within unique key, assuming sorted
    m = counts.max().item() if counts.numel() > 0 else 0
    n = pair_sizes.shape[0]
    intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1)
    mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
    sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask]
    sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices)

    # unsort result from last part
    local_index12 = rev_indices[sorted_local_index12]

    # compute mapping between representation of central-other to pair
    n = atom_index12.shape[1]
    sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1
    return central_atom_index, local_index12 % n, sign12


def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor,
                constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
                sizes: Tuple[int, int, int, int, int], cell_shifts: Optional[Tuple[Tensor, Tensor]]) -> Tensor:
    Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
    num_species, radial_sublength, radial_length, angular_sublength, angular_length = sizes
    num_molecules = species.shape[0]
    num_atoms = species.shape[1]
    num_species_pairs = angular_length // angular_sublength
    coordinates_ = coordinates
    coordinates = coordinates_.flatten(0, 1)

    # PBC calculation is bypassed if there are no shifts
    if cell_shifts is None:
        atom_index12 = neighbor_pairs_nopbc(species == -1, coordinates_, Rcr)
        selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(2, -1, 3)
        vec = selected_coordinates[0] - selected_coordinates[1]
    else:
        cell, shifts = cell_shifts
        atom_index12, shifts = neighbor_pairs(species == -1, coordinates_, cell, shifts, Rcr)
        shift_values = shifts.to(cell.dtype) @ cell
        selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(2, -1, 3)
        vec = selected_coordinates[0] - selected_coordinates[1] + shift_values

    species = species.flatten()
    species12 = species[atom_index12]

    distances = vec.norm(2, -1)

    # compute radial aev
    radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
    radial_aev = radial_terms_.new_zeros((num_molecules * num_atoms * num_species, radial_sublength))
    index12 = atom_index12 * num_species + species12.flip(0)
    radial_aev.index_add_(0, index12[0], radial_terms_)
    radial_aev.index_add_(0, index12[1], radial_terms_)
    radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)

    # Rca is usually much smaller than Rcr, using neighbor list with cutoff=Rcr is a waste of resources
    # Now we will get a smaller neighbor list that only cares about atoms with distances <= Rca
    even_closer_indices = (distances <= Rca).nonzero().flatten()
    atom_index12 = atom_index12.index_select(1, even_closer_indices)
    species12 = species12.index_select(1, even_closer_indices)
    vec = vec.index_select(0, even_closer_indices)

    # compute angular aev
    central_atom_index, pair_index12, sign12 = triple_by_molecule(atom_index12)
    species12_small = species12[:, pair_index12]
    vec12 = vec.index_select(0, pair_index12.view(-1)).view(2, -1, 3) * sign12.unsqueeze(-1)
    species12_ = torch.where(sign12 == 1, species12_small[1], species12_small[0])
    angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec12)
    angular_aev = angular_terms_.new_zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength))
    index = central_atom_index * num_species_pairs + triu_index[species12_[0], species12_[1]]
    angular_aev.index_add_(0, index, angular_terms_)
    angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
    return torch.cat([radial_aev, angular_aev], dim=-1)


def jit_unused_if_no_cuaev(condition=has_cuaev):
    def decorator(func):
        if not condition:
            return torch.jit.unused(func)
        return func
    return decorator


[docs]class AEVComputer(torch.nn.Module): r"""The AEV computer that takes coordinates as input and outputs aevs. Arguments: Rcr (float): :math:`R_C` in equation (2) when used at equation (3) in the `ANI paper`_. Rca (float): :math:`R_C` in equation (2) when used at equation (4) in the `ANI paper`_. EtaR (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in equation (3) in the `ANI paper`_. ShfR (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in equation (3) in the `ANI paper`_. EtaA (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in equation (4) in the `ANI paper`_. Zeta (:class:`torch.Tensor`): The 1D tensor of :math:`\zeta` in equation (4) in the `ANI paper`_. ShfA (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in equation (4) in the `ANI paper`_. ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in equation (4) in the `ANI paper`_. num_species (int): Number of supported atom types. use_cuda_extension (bool): Whether to use cuda extension for faster calculation (needs cuaev installed). .. _ANI paper: http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract """ Rcr: Final[float] Rca: Final[float] num_species: Final[int] radial_sublength: Final[int] radial_length: Final[int] angular_sublength: Final[int] angular_length: Final[int] aev_length: Final[int] sizes: Final[Tuple[int, int, int, int, int]] triu_index: Tensor use_cuda_extension: Final[bool] def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species, use_cuda_extension=False): super().__init__() self.Rcr = Rcr self.Rca = Rca assert Rca <= Rcr, "Current implementation of AEVComputer assumes Rca <= Rcr" self.num_species = num_species # cuda aev if use_cuda_extension: assert has_cuaev, "AEV cuda extension is not installed" self.use_cuda_extension = use_cuda_extension # convert constant tensors to a ready-to-broadcast shape # shape convension (..., EtaR, ShfR) self.register_buffer('EtaR', EtaR.view(-1, 1)) self.register_buffer('ShfR', ShfR.view(1, -1)) # shape convension (..., EtaA, Zeta, ShfA, ShfZ) self.register_buffer('EtaA', EtaA.view(-1, 1, 1, 1)) self.register_buffer('Zeta', Zeta.view(1, -1, 1, 1)) self.register_buffer('ShfA', ShfA.view(1, 1, -1, 1)) self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1)) # The length of radial subaev of a single species self.radial_sublength = self.EtaR.numel() * self.ShfR.numel() # The length of full radial aev self.radial_length = self.num_species * self.radial_sublength # The length of angular subaev of a single species self.angular_sublength = self.EtaA.numel() * self.Zeta.numel() * self.ShfA.numel() * self.ShfZ.numel() # The length of full angular aev self.angular_length = (self.num_species * (self.num_species + 1)) // 2 * self.angular_sublength # The length of full aev self.aev_length = self.radial_length + self.angular_length self.sizes = self.num_species, self.radial_sublength, self.radial_length, self.angular_sublength, self.angular_length self.register_buffer('triu_index', triu_index(num_species).to(device=self.EtaR.device)) # Set up default cell and compute default shifts. # These values are used when cell and pbc switch are not given. cutoff = max(self.Rcr, self.Rca) default_cell = torch.eye(3, dtype=self.EtaR.dtype, device=self.EtaR.device) default_pbc = torch.zeros(3, dtype=torch.bool, device=self.EtaR.device) default_shifts = compute_shifts(default_cell, default_pbc, cutoff) self.register_buffer('default_cell', default_cell) self.register_buffer('default_shifts', default_shifts) # Should create only when use_cuda_extension is True. # However jit needs to know cuaev_computer's Type even when use_cuda_extension is False, because it is enabled when cuaev is available if has_cuaev: self.init_cuaev_computer() # When has_cuaev is true, and use_cuda_extension is false, and user enable use_cuda_extension afterwards, # then another init_cuaev_computer will be needed self.cuaev_enabled = True if self.use_cuda_extension else False @jit_unused_if_no_cuaev() def init_cuaev_computer(self): self.cuaev_computer = torch.classes.cuaev.CuaevComputer(self.Rcr, self.Rca, self.EtaR.flatten(), self.ShfR.flatten(), self.EtaA.flatten(), self.Zeta.flatten(), self.ShfA.flatten(), self.ShfZ.flatten(), self.num_species) @jit_unused_if_no_cuaev() def compute_cuaev(self, species, coordinates): species_int = species.to(torch.int32) aev = torch.ops.cuaev.run(coordinates, species_int, self.cuaev_computer) return aev
[docs] @classmethod def cover_linearly(cls, radial_cutoff: float, angular_cutoff: float, radial_eta: float, angular_eta: float, radial_dist_divisions: int, angular_dist_divisions: int, zeta: float, angle_sections: int, num_species: int, angular_start: float = 0.9, radial_start: float = 0.9): r""" Provides a convenient way to linearly fill cutoffs This is a user friendly constructor that builds an :class:`torchani.AEVComputer` where the subdivisions along the the distance dimension for the angular and radial sub-AEVs, and the angle sections for the angular sub-AEV, are linearly covered with shifts. By default the distance shifts start at 0.9 Angstroms. To reproduce the ANI-1x AEV's the signature ``(5.2, 3.5, 16.0, 8.0, 16, 4, 32.0, 8, 4)`` can be used. """ # This is intended to be self documenting code that explains the way # the AEV parameters for ANI1x were chosen. This is not necessarily the # best or most optimal way but it is a relatively sensible default. Rcr = radial_cutoff Rca = angular_cutoff EtaR = torch.tensor([float(radial_eta)]) EtaA = torch.tensor([float(angular_eta)]) Zeta = torch.tensor([float(zeta)]) ShfR = torch.linspace(radial_start, radial_cutoff, radial_dist_divisions + 1)[:-1] ShfA = torch.linspace(angular_start, angular_cutoff, angular_dist_divisions + 1)[:-1] angle_start = math.pi / (2 * angle_sections) ShfZ = (torch.linspace(0, math.pi, angle_sections + 1) + angle_start)[:-1] return cls(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
def constants(self): return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
[docs] def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None, pbc: Optional[Tensor] = None) -> SpeciesAEV: """Compute AEVs Arguments: input_ (tuple): Can be one of the following two cases: If you don't care about periodic boundary conditions at all, then input can be a tuple of two tensors: species, coordinates. species must have shape ``(N, A)``, coordinates must have shape ``(N, A, 3)`` where ``N`` is the number of molecules in a batch, and ``A`` is the number of atoms. .. warning:: The species must be indexed in 0, 1, 2, 3, ..., not the element index in periodic table. Check :class:`torchani.SpeciesConverter` if you want periodic table indexing. .. note:: The coordinates, and cell are in Angstrom. If you want to apply periodic boundary conditions, then the input would be a tuple of two tensors (species, coordinates) and two keyword arguments `cell=...` , and `pbc=...` where species and coordinates are the same as described above, cell is a tensor of shape (3, 3) of the three vectors defining unit cell: .. code-block:: python tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]]) and pbc is boolean vector of size 3 storing if pbc is enabled for that direction. Returns: NamedTuple: Species and AEVs. species are the species from the input unchanged, and AEVs is a tensor of shape ``(N, A, self.aev_length())`` """ species, coordinates = input_ assert species.dim() == 2 assert species.shape == coordinates.shape[:-1] assert coordinates.shape[-1] == 3 if self.use_cuda_extension: assert (cell is None and pbc is None), "cuaev currently does not support PBC" # if use_cuda_extension is enabled after initialization if not self.cuaev_enabled: self.init_cuaev_computer() aev = self.compute_cuaev(species, coordinates) return SpeciesAEV(species, aev) if cell is None and pbc is None: aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None) else: assert (cell is not None and pbc is not None) cutoff = max(self.Rcr, self.Rca) shifts = compute_shifts(cell, pbc, cutoff) aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, (cell, shifts)) return SpeciesAEV(species, aev)