Using TorchScript to serialize models#

All built-in models and modules in the torchani library, support TorchScript serialization, which is a native PyTorch feature where a python model is translated into a PyTorch-specific format. If you use TorchScript you can load the resulting serialized files in a process where there is no Python dependency.

# To begin with, let's first import the modules we will use:
from pathlib import Path
import typing as tp

import torch
from torch import Tensor

import torchani
from torchani.grad import hessians, forces

Scripting an ANI model directly#

Let’s now load the built-in ANI-1ccx models. The ANI-2x model contains 8 models trained with diffrent initialization and on different splits of a dataset

model = torchani.models.ANI2x()

It is very easy to compile and save the model using torch.jit.

compiled_model = torch.jit.script(model)
torch.jit.save(compiled_model, "compiled_model.pt")

For testing purposes, we will now load the model we just saved and see if they produces the same output as the original model:

loaded_compiled_model = torch.jit.load("compiled_model.pt")

We use the molecule below to test:

coordinates = torch.tensor(
    [
        [
            [0.03192167, 0.00638559, 0.01301679],
            [-0.83140486, 0.39370209, -0.26395324],
            [-0.66518241, -0.84461308, 0.20759389],
            [0.45554739, 0.54289633, 0.81170881],
            [0.66091919, -0.16799635, -0.91037834],
        ]
    ]
)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])

And here is the result:

energies_ensemble = model((species, coordinates)).energies
energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
print(
    "Ensemble energy, eager mode vs loaded jit:",
    energies_ensemble.item(),
    energies_ensemble_jit.item(),
)
Ensemble energy, eager mode vs loaded jit: -40.45979309082031 -40.45979309082031

Customize the model and script#

You could also customize the model you want to export. For example, let’s do the following customization to the model:

  • uses double as dtype instead of float

  • don’t care about periodic boundary condition

  • in addition to energies, allow returning optionally forces, and hessians

you could do the following:

class CustomModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchani.models.ANI1x().double()

    def forward(
        self,
        species: Tensor,
        coordinates: Tensor,
        return_forces: bool = False,
        return_hessians: bool = False,
    ) -> tp.Tuple[Tensor, tp.Optional[Tensor], tp.Optional[Tensor]]:
        if return_forces or return_hessians:
            coordinates.requires_grad_(True)
        energies = self.model((species, coordinates)).energies
        _forces: tp.Optional[Tensor] = None
        _hessians: tp.Optional[Tensor] = None
        if return_forces or return_hessians:
            _forces = forces(
                energies, coordinates, retain_graph=True, create_graph=return_hessians
            )
            if return_hessians:
                assert _forces is not None
                _hessians = hessians(_forces, coordinates)
        return energies, _forces, _hessians


custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, "compiled_custom_model.pt")
loaded_compiled_custom_model = torch.jit.load("compiled_custom_model.pt")
energies_eager, forces_eager, hessians_eager = custom_model(
    species, coordinates, True, True
)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(
    species, coordinates, True, True
)

print("Energy, eager mode vs loaded jit:", energies_eager.item(), energies_jit.item())
print()
print(
    "Force, eager mode vs loaded jit:\n",
    forces_eager.squeeze(0),
    "\n",
    forces_jit.squeeze(0),
)
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print(
    "Hessian, eager mode vs loaded jit:\n",
    hessians_eager.squeeze(0),
    "\n",
    hessians_jit.squeeze(0),
)
# Lets delete the files we created for cleanup
Path("compiled_custom_model.pt").unlink()
Path("compiled_model.pt").unlink()
Energy, eager mode vs loaded jit: -40.45902126308002 -40.45902126308002

Force, eager mode vs loaded jit:
 tensor([[ 0.031, -0.132, -0.053],
        [-0.129,  0.164, -0.077],
        [ 0.086, -0.043,  0.041],
        [ 0.027,  0.006,  0.038],
        [-0.014,  0.005,  0.051]], grad_fn=<SqueezeBackward1>)
 tensor([[ 0.031, -0.132, -0.053],
        [-0.129,  0.164, -0.077],
        [ 0.086, -0.043,  0.041],
        [ 0.027,  0.006,  0.038],
        [-0.014,  0.005,  0.051]], grad_fn=<SqueezeBackward1>)

Hessian, eager mode vs loaded jit:
 tensor([[     3.149,     -0.035,      0.557,     -1.715,      0.685,     -0.522,     -0.588,     -0.382,      0.082,     -0.398,     -0.305,     -0.409,     -0.448,      0.037,      0.292],
        [    -0.035,      1.765,      0.287,      0.796,     -0.430,      0.253,     -0.494,     -0.686,      0.095,     -0.292,     -0.459,     -0.544,      0.025,     -0.189,     -0.092],
        [     0.557,      0.287,      2.097,     -0.556,      0.210,     -0.221,      0.104,      0.128,     -0.289,     -0.380,     -0.525,     -0.935,      0.275,     -0.100,     -0.652],
        [    -1.715,      0.796,     -0.556,      1.756,     -0.815,      0.584,      0.065,     -0.049,      0.014,     -0.048,      0.030,     -0.022,     -0.058,      0.037,     -0.020],
        [     0.685,     -0.430,      0.210,     -0.815,      0.624,     -0.347,      0.191,     -0.208,      0.146,     -0.071,      0.018,     -0.023,      0.009,     -0.003,      0.015],
        [    -0.522,      0.253,     -0.221,      0.584,     -0.347,      0.199,     -0.065,      0.097,      0.032,     -0.084,      0.040,     -0.039,      0.087,     -0.044,      0.029],
        [    -0.588,     -0.494,      0.104,      0.065,      0.191,     -0.065,      0.555,      0.397,     -0.084,     -0.012,     -0.034,      0.019,     -0.020,     -0.060,      0.027],
        [    -0.382,     -0.686,      0.128,     -0.049,     -0.208,      0.097,      0.397,      0.893,     -0.234,      0.013,     -0.012,      0.020,      0.021,      0.014,     -0.011],
        [     0.082,      0.095,     -0.289,      0.014,      0.146,      0.032,     -0.084,     -0.234,      0.223,     -0.043,     -0.070,      0.035,      0.031,      0.064,     -0.001],
        [    -0.398,     -0.292,     -0.380,     -0.048,     -0.071,     -0.084,     -0.012,      0.013,     -0.043,      0.408,      0.316,      0.460,      0.051,      0.034,      0.048],
        [    -0.305,     -0.459,     -0.525,      0.030,      0.018,      0.040,     -0.034,     -0.012,     -0.070,      0.316,      0.454,      0.584,     -0.007,     -0.001,     -0.030],
        [    -0.409,     -0.544,     -0.935,     -0.022,     -0.023,     -0.039,      0.019,      0.020,      0.035,      0.460,      0.584,      1.004,     -0.048,     -0.037,     -0.065],
        [    -0.448,      0.025,      0.275,     -0.058,      0.009,      0.087,     -0.020,      0.021,      0.031,      0.051,     -0.007,     -0.048,      0.475,     -0.049,     -0.346],
        [     0.037,     -0.189,     -0.100,      0.037,     -0.003,     -0.044,     -0.060,      0.014,      0.064,      0.034,     -0.001,     -0.037,     -0.049,      0.180,      0.117],
        [     0.292,     -0.092,     -0.652,     -0.020,      0.015,      0.029,      0.027,     -0.011,     -0.001,      0.048,     -0.030,     -0.065,     -0.346,      0.117,      0.689]])
 tensor([[     3.149,     -0.035,      0.557,     -1.715,      0.685,     -0.522,     -0.588,     -0.382,      0.082,     -0.398,     -0.305,     -0.409,     -0.448,      0.037,      0.292],
        [    -0.035,      1.765,      0.287,      0.796,     -0.430,      0.253,     -0.494,     -0.686,      0.095,     -0.292,     -0.459,     -0.544,      0.025,     -0.189,     -0.092],
        [     0.557,      0.287,      2.097,     -0.556,      0.210,     -0.221,      0.104,      0.128,     -0.289,     -0.380,     -0.525,     -0.935,      0.275,     -0.100,     -0.652],
        [    -1.715,      0.796,     -0.556,      1.756,     -0.815,      0.584,      0.065,     -0.049,      0.014,     -0.048,      0.030,     -0.022,     -0.058,      0.037,     -0.020],
        [     0.685,     -0.430,      0.210,     -0.815,      0.624,     -0.347,      0.191,     -0.208,      0.146,     -0.071,      0.018,     -0.023,      0.009,     -0.003,      0.015],
        [    -0.522,      0.253,     -0.221,      0.584,     -0.347,      0.199,     -0.065,      0.097,      0.032,     -0.084,      0.040,     -0.039,      0.087,     -0.044,      0.029],
        [    -0.588,     -0.494,      0.104,      0.065,      0.191,     -0.065,      0.555,      0.397,     -0.084,     -0.012,     -0.034,      0.019,     -0.020,     -0.060,      0.027],
        [    -0.382,     -0.686,      0.128,     -0.049,     -0.208,      0.097,      0.397,      0.893,     -0.234,      0.013,     -0.012,      0.020,      0.021,      0.014,     -0.011],
        [     0.082,      0.095,     -0.289,      0.014,      0.146,      0.032,     -0.084,     -0.234,      0.223,     -0.043,     -0.070,      0.035,      0.031,      0.064,     -0.001],
        [    -0.398,     -0.292,     -0.380,     -0.048,     -0.071,     -0.084,     -0.012,      0.013,     -0.043,      0.408,      0.316,      0.460,      0.051,      0.034,      0.048],
        [    -0.305,     -0.459,     -0.525,      0.030,      0.018,      0.040,     -0.034,     -0.012,     -0.070,      0.316,      0.454,      0.584,     -0.007,     -0.001,     -0.030],
        [    -0.409,     -0.544,     -0.935,     -0.022,     -0.023,     -0.039,      0.019,      0.020,      0.035,      0.460,      0.584,      1.004,     -0.048,     -0.037,     -0.065],
        [    -0.448,      0.025,      0.275,     -0.058,      0.009,      0.087,     -0.020,      0.021,      0.031,      0.051,     -0.007,     -0.048,      0.475,     -0.049,     -0.346],
        [     0.037,     -0.189,     -0.100,      0.037,     -0.003,     -0.044,     -0.060,      0.014,      0.064,      0.034,     -0.001,     -0.037,     -0.049,      0.180,      0.117],
        [     0.292,     -0.092,     -0.652,     -0.020,      0.015,      0.029,      0.027,     -0.011,     -0.001,      0.048,     -0.030,     -0.065,     -0.346,      0.117,      0.689]])