Using TorchScript to serialize and deploy model

Models in TorchANI’s model zoo support TorchScript. TorchScript is a way to create serializable and optimizable models from PyTorch code. It allows users to saved their models from a Python process and loaded in a process where there is no Python dependency.

To begin with, let’s first import the modules we will use:

import torch
import torchani
from typing import Tuple, Optional
from torch import Tensor

Scripting builtin model directly

Let’s now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8 models trained with diffrent initialization.

model = torchani.models.ANI1ccx(periodic_table_index=True)
/opt/hostedtoolcache/Python/3.8.17/x64/lib/python3.8/site-packages/torchani-2.2.3-py3.8.egg/torchani/resources/

It is very easy to compile and save the model using torch.jit.

compiled_model = torch.jit.script(model)
torch.jit.save(compiled_model, 'compiled_model.pt')

Besides compiling the ensemble, it is also possible to compile a single network

compiled_model0 = torch.jit.script(model[0])
torch.jit.save(compiled_model0, 'compiled_model0.pt')

For testing purposes, we will now load the models we just saved and see if they produces the same output as the original model:

loaded_compiled_model = torch.jit.load('compiled_model.pt')
loaded_compiled_model0 = torch.jit.load('compiled_model0.pt')

We use the molecule below to test:

coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
                             [-0.83140486, 0.39370209, -0.26395324],
                             [-0.66518241, -0.84461308, 0.20759389],
                             [0.45554739, 0.54289633, 0.81170881],
                             [0.66091919, -0.16799635, -0.91037834]]])
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])

And here is the result:

energies_ensemble = model((species, coordinates)).energies
energies_single = model[0]((species, coordinates)).energies
energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
energies_single_jit = loaded_compiled_model0((species, coordinates)).energies
print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item())
print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item())
Ensemble energy, eager mode vs loaded jit: -40.42562057899495 -40.42562057899495
Single network energy, eager mode vs loaded jit: -40.428783473392926 -40.428783473392926

Customize the model and script

You could also customize the model you want to export. For example, let’s do the following customization to the model:

  • uses double as dtype instead of float
  • don’t care about periodic boundary condition
  • in addition to energies, allow returnsing optionally forces, and hessians
  • when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, …

you could do the following:

class CustomModule(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = torchani.models.ANI1x(periodic_table_index=True).double()
        # self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double()
        # self.model = torchani.models.ANI1ccx(periodic_table_index=True).double()

    def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False,
                return_hessians: bool = False) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
        if return_forces or return_hessians:
            coordinates.requires_grad_(True)

        energies = self.model((species, coordinates)).energies

        forces: Optional[Tensor] = None  # noqa: E701
        hessians: Optional[Tensor] = None
        if return_forces or return_hessians:
            grad = torch.autograd.grad([energies.sum()], [coordinates], create_graph=return_hessians)[0]
            assert grad is not None
            forces = -grad
            if return_hessians:
                hessians = torchani.utils.hessian(coordinates, forces=forces)
        return energies, forces, hessians


custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, 'compiled_custom_model.pt')
loaded_compiled_custom_model = torch.jit.load('compiled_custom_model.pt')
energies, forces, hessians = custom_model(species, coordinates, True, True)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(species, coordinates, True, True)

print('Energy, eager mode vs loaded jit:', energies.item(), energies_jit.item())
print()
print('Force, eager mode vs loaded jit:\n', forces.squeeze(0), '\n', forces_jit.squeeze(0))
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print('Hessian, eager mode vs loaded jit:\n', hessians.squeeze(0), '\n', hessians_jit.squeeze(0))
/opt/hostedtoolcache/Python/3.8.17/x64/lib/python3.8/site-packages/torchani-2.2.3-py3.8.egg/torchani/resources/
Energy, eager mode vs loaded jit: -40.4590220910793 -40.4590220910793

Force, eager mode vs loaded jit:
 tensor([[ 0.031, -0.132, -0.053],
        [-0.129,  0.164, -0.077],
        [ 0.086, -0.043,  0.041],
        [ 0.027,  0.006,  0.038],
        [-0.014,  0.005,  0.051]], grad_fn=<SqueezeBackward1>)
 tensor([[ 0.031, -0.132, -0.053],
        [-0.129,  0.164, -0.077],
        [ 0.086, -0.043,  0.041],
        [ 0.027,  0.006,  0.038],
        [-0.014,  0.005,  0.051]], grad_fn=<SqueezeBackward1>)

Hessian, eager mode vs loaded jit:
 tensor([[     3.149,     -0.035,      0.557,     -1.715,      0.685,     -0.522,     -0.588,     -0.382,      0.082,     -0.398,     -0.305,     -0.409,     -0.448,      0.037,      0.292],
        [    -0.035,      1.765,      0.287,      0.796,     -0.430,      0.253,     -0.494,     -0.686,      0.095,     -0.292,     -0.459,     -0.544,      0.025,     -0.189,     -0.092],
        [     0.557,      0.287,      2.097,     -0.556,      0.210,     -0.221,      0.104,      0.128,     -0.289,     -0.380,     -0.525,     -0.935,      0.275,     -0.100,     -0.652],
        [    -1.715,      0.796,     -0.556,      1.756,     -0.815,      0.584,      0.065,     -0.049,      0.014,     -0.048,      0.030,     -0.022,     -0.058,      0.037,     -0.020],
        [     0.685,     -0.430,      0.210,     -0.815,      0.624,     -0.347,      0.191,     -0.208,      0.146,     -0.071,      0.018,     -0.023,      0.009,     -0.003,      0.015],
        [    -0.522,      0.253,     -0.221,      0.584,     -0.347,      0.199,     -0.065,      0.097,      0.032,     -0.084,      0.040,     -0.039,      0.087,     -0.044,      0.029],
        [    -0.588,     -0.494,      0.104,      0.065,      0.191,     -0.065,      0.555,      0.397,     -0.084,     -0.012,     -0.034,      0.019,     -0.020,     -0.060,      0.027],
        [    -0.382,     -0.686,      0.128,     -0.049,     -0.208,      0.097,      0.397,      0.893,     -0.234,      0.013,     -0.012,      0.020,      0.021,      0.014,     -0.011],
        [     0.082,      0.095,     -0.289,      0.014,      0.146,      0.032,     -0.084,     -0.234,      0.223,     -0.043,     -0.070,      0.035,      0.031,      0.064,     -0.001],
        [    -0.398,     -0.292,     -0.380,     -0.048,     -0.071,     -0.084,     -0.012,      0.013,     -0.043,      0.408,      0.316,      0.460,      0.051,      0.034,      0.048],
        [    -0.305,     -0.459,     -0.525,      0.030,      0.018,      0.040,     -0.034,     -0.012,     -0.070,      0.316,      0.454,      0.584,     -0.007,     -0.001,     -0.030],
        [    -0.409,     -0.544,     -0.935,     -0.022,     -0.023,     -0.039,      0.019,      0.020,      0.035,      0.460,      0.584,      1.004,     -0.048,     -0.037,     -0.065],
        [    -0.448,      0.025,      0.275,     -0.058,      0.009,      0.087,     -0.020,      0.021,      0.031,      0.051,     -0.007,     -0.048,      0.475,     -0.049,     -0.346],
        [     0.037,     -0.189,     -0.100,      0.037,     -0.003,     -0.044,     -0.060,      0.014,      0.064,      0.034,     -0.001,     -0.037,     -0.049,      0.180,      0.117],
        [     0.292,     -0.092,     -0.652,     -0.020,      0.015,      0.029,      0.027,     -0.011,     -0.001,      0.048,     -0.030,     -0.065,     -0.346,      0.117,      0.689]])
 tensor([[     3.149,     -0.035,      0.557,     -1.715,      0.685,     -0.522,     -0.588,     -0.382,      0.082,     -0.398,     -0.305,     -0.409,     -0.448,      0.037,      0.292],
        [    -0.035,      1.765,      0.287,      0.796,     -0.430,      0.253,     -0.494,     -0.686,      0.095,     -0.292,     -0.459,     -0.544,      0.025,     -0.189,     -0.092],
        [     0.557,      0.287,      2.097,     -0.556,      0.210,     -0.221,      0.104,      0.128,     -0.289,     -0.380,     -0.525,     -0.935,      0.275,     -0.100,     -0.652],
        [    -1.715,      0.796,     -0.556,      1.756,     -0.815,      0.584,      0.065,     -0.049,      0.014,     -0.048,      0.030,     -0.022,     -0.058,      0.037,     -0.020],
        [     0.685,     -0.430,      0.210,     -0.815,      0.624,     -0.347,      0.191,     -0.208,      0.146,     -0.071,      0.018,     -0.023,      0.009,     -0.003,      0.015],
        [    -0.522,      0.253,     -0.221,      0.584,     -0.347,      0.199,     -0.065,      0.097,      0.032,     -0.084,      0.040,     -0.039,      0.087,     -0.044,      0.029],
        [    -0.588,     -0.494,      0.104,      0.065,      0.191,     -0.065,      0.555,      0.397,     -0.084,     -0.012,     -0.034,      0.019,     -0.020,     -0.060,      0.027],
        [    -0.382,     -0.686,      0.128,     -0.049,     -0.208,      0.097,      0.397,      0.893,     -0.234,      0.013,     -0.012,      0.020,      0.021,      0.014,     -0.011],
        [     0.082,      0.095,     -0.289,      0.014,      0.146,      0.032,     -0.084,     -0.234,      0.223,     -0.043,     -0.070,      0.035,      0.031,      0.064,     -0.001],
        [    -0.398,     -0.292,     -0.380,     -0.048,     -0.071,     -0.084,     -0.012,      0.013,     -0.043,      0.408,      0.316,      0.460,      0.051,      0.034,      0.048],
        [    -0.305,     -0.459,     -0.525,      0.030,      0.018,      0.040,     -0.034,     -0.012,     -0.070,      0.316,      0.454,      0.584,     -0.007,     -0.001,     -0.030],
        [    -0.409,     -0.544,     -0.935,     -0.022,     -0.023,     -0.039,      0.019,      0.020,      0.035,      0.460,      0.584,      1.004,     -0.048,     -0.037,     -0.065],
        [    -0.448,      0.025,      0.275,     -0.058,      0.009,      0.087,     -0.020,      0.021,      0.031,      0.051,     -0.007,     -0.048,      0.475,     -0.049,     -0.346],
        [     0.037,     -0.189,     -0.100,      0.037,     -0.003,     -0.044,     -0.060,      0.014,      0.064,      0.034,     -0.001,     -0.037,     -0.049,      0.180,      0.117],
        [     0.292,     -0.092,     -0.652,     -0.020,      0.015,      0.029,      0.027,     -0.011,     -0.001,      0.048,     -0.030,     -0.065,     -0.346,      0.117,      0.689]])

Total running time of the script: (0 minutes 6.314 seconds)

Gallery generated by Sphinx-Gallery