Note
Go to the end to download the full example code
Using TorchScript to serialize and deploy model¶
Models in TorchANI’s model zoo support TorchScript. TorchScript is a way to create serializable and optimizable models from PyTorch code. It allows users to saved their models from a Python process and loaded in a process where there is no Python dependency.
To begin with, let’s first import the modules we will use:
Scripting builtin model directly¶
Let’s now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8 models trained with diffrent initialization.
model = torchani.models.ANI1ccx(periodic_table_index=True)
/opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/torchani-2.2.3-py3.8.egg/torchani/resources/
It is very easy to compile and save the model using torch.jit.
compiled_model = torch.jit.script(model)
torch.jit.save(compiled_model, 'compiled_model.pt')
Besides compiling the ensemble, it is also possible to compile a single network
compiled_model0 = torch.jit.script(model[0])
torch.jit.save(compiled_model0, 'compiled_model0.pt')
For testing purposes, we will now load the models we just saved and see if they produces the same output as the original model:
loaded_compiled_model = torch.jit.load('compiled_model.pt')
loaded_compiled_model0 = torch.jit.load('compiled_model0.pt')
We use the molecule below to test:
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]])
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])
And here is the result:
energies_ensemble = model((species, coordinates)).energies
energies_single = model[0]((species, coordinates)).energies
energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
energies_single_jit = loaded_compiled_model0((species, coordinates)).energies
print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item())
print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item())
Ensemble energy, eager mode vs loaded jit: -40.42562057899495 -40.42562057899495
Single network energy, eager mode vs loaded jit: -40.428783473392926 -40.428783473392926
Customize the model and script¶
You could also customize the model you want to export. For example, let’s do the following customization to the model:
- uses double as dtype instead of float
- don’t care about periodic boundary condition
- in addition to energies, allow returnsing optionally forces, and hessians
- when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, …
you could do the following:
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torchani.models.ANI1x(periodic_table_index=True).double()
# self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double()
# self.model = torchani.models.ANI1ccx(periodic_table_index=True).double()
def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False,
return_hessians: bool = False) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
if return_forces or return_hessians:
coordinates.requires_grad_(True)
energies = self.model((species, coordinates)).energies
forces: Optional[Tensor] = None # noqa: E701
hessians: Optional[Tensor] = None
if return_forces or return_hessians:
grad = torch.autograd.grad([energies.sum()], [coordinates], create_graph=return_hessians)[0]
assert grad is not None
forces = -grad
if return_hessians:
hessians = torchani.utils.hessian(coordinates, forces=forces)
return energies, forces, hessians
custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, 'compiled_custom_model.pt')
loaded_compiled_custom_model = torch.jit.load('compiled_custom_model.pt')
energies, forces, hessians = custom_model(species, coordinates, True, True)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(species, coordinates, True, True)
print('Energy, eager mode vs loaded jit:', energies.item(), energies_jit.item())
print()
print('Force, eager mode vs loaded jit:\n', forces.squeeze(0), '\n', forces_jit.squeeze(0))
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print('Hessian, eager mode vs loaded jit:\n', hessians.squeeze(0), '\n', hessians_jit.squeeze(0))
/opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/torchani-2.2.3-py3.8.egg/torchani/resources/
Energy, eager mode vs loaded jit: -40.4590220910793 -40.4590220910793
Force, eager mode vs loaded jit:
tensor([[ 0.031, -0.132, -0.053],
[-0.129, 0.164, -0.077],
[ 0.086, -0.043, 0.041],
[ 0.027, 0.006, 0.038],
[-0.014, 0.005, 0.051]], grad_fn=<SqueezeBackward1>)
tensor([[ 0.031, -0.132, -0.053],
[-0.129, 0.164, -0.077],
[ 0.086, -0.043, 0.041],
[ 0.027, 0.006, 0.038],
[-0.014, 0.005, 0.051]], grad_fn=<SqueezeBackward1>)
Hessian, eager mode vs loaded jit:
tensor([[ 3.149, -0.035, 0.557, -1.715, 0.685, -0.522, -0.588, -0.382, 0.082, -0.398, -0.305, -0.409, -0.448, 0.037, 0.292],
[ -0.035, 1.765, 0.287, 0.796, -0.430, 0.253, -0.494, -0.686, 0.095, -0.292, -0.459, -0.544, 0.025, -0.189, -0.092],
[ 0.557, 0.287, 2.097, -0.556, 0.210, -0.221, 0.104, 0.128, -0.289, -0.380, -0.525, -0.935, 0.275, -0.100, -0.652],
[ -1.715, 0.796, -0.556, 1.756, -0.815, 0.584, 0.065, -0.049, 0.014, -0.048, 0.030, -0.022, -0.058, 0.037, -0.020],
[ 0.685, -0.430, 0.210, -0.815, 0.624, -0.347, 0.191, -0.208, 0.146, -0.071, 0.018, -0.023, 0.009, -0.003, 0.015],
[ -0.522, 0.253, -0.221, 0.584, -0.347, 0.199, -0.065, 0.097, 0.032, -0.084, 0.040, -0.039, 0.087, -0.044, 0.029],
[ -0.588, -0.494, 0.104, 0.065, 0.191, -0.065, 0.555, 0.397, -0.084, -0.012, -0.034, 0.019, -0.020, -0.060, 0.027],
[ -0.382, -0.686, 0.128, -0.049, -0.208, 0.097, 0.397, 0.893, -0.234, 0.013, -0.012, 0.020, 0.021, 0.014, -0.011],
[ 0.082, 0.095, -0.289, 0.014, 0.146, 0.032, -0.084, -0.234, 0.223, -0.043, -0.070, 0.035, 0.031, 0.064, -0.001],
[ -0.398, -0.292, -0.380, -0.048, -0.071, -0.084, -0.012, 0.013, -0.043, 0.408, 0.316, 0.460, 0.051, 0.034, 0.048],
[ -0.305, -0.459, -0.525, 0.030, 0.018, 0.040, -0.034, -0.012, -0.070, 0.316, 0.454, 0.584, -0.007, -0.001, -0.030],
[ -0.409, -0.544, -0.935, -0.022, -0.023, -0.039, 0.019, 0.020, 0.035, 0.460, 0.584, 1.004, -0.048, -0.037, -0.065],
[ -0.448, 0.025, 0.275, -0.058, 0.009, 0.087, -0.020, 0.021, 0.031, 0.051, -0.007, -0.048, 0.475, -0.049, -0.346],
[ 0.037, -0.189, -0.100, 0.037, -0.003, -0.044, -0.060, 0.014, 0.064, 0.034, -0.001, -0.037, -0.049, 0.180, 0.117],
[ 0.292, -0.092, -0.652, -0.020, 0.015, 0.029, 0.027, -0.011, -0.001, 0.048, -0.030, -0.065, -0.346, 0.117, 0.689]])
tensor([[ 3.149, -0.035, 0.557, -1.715, 0.685, -0.522, -0.588, -0.382, 0.082, -0.398, -0.305, -0.409, -0.448, 0.037, 0.292],
[ -0.035, 1.765, 0.287, 0.796, -0.430, 0.253, -0.494, -0.686, 0.095, -0.292, -0.459, -0.544, 0.025, -0.189, -0.092],
[ 0.557, 0.287, 2.097, -0.556, 0.210, -0.221, 0.104, 0.128, -0.289, -0.380, -0.525, -0.935, 0.275, -0.100, -0.652],
[ -1.715, 0.796, -0.556, 1.756, -0.815, 0.584, 0.065, -0.049, 0.014, -0.048, 0.030, -0.022, -0.058, 0.037, -0.020],
[ 0.685, -0.430, 0.210, -0.815, 0.624, -0.347, 0.191, -0.208, 0.146, -0.071, 0.018, -0.023, 0.009, -0.003, 0.015],
[ -0.522, 0.253, -0.221, 0.584, -0.347, 0.199, -0.065, 0.097, 0.032, -0.084, 0.040, -0.039, 0.087, -0.044, 0.029],
[ -0.588, -0.494, 0.104, 0.065, 0.191, -0.065, 0.555, 0.397, -0.084, -0.012, -0.034, 0.019, -0.020, -0.060, 0.027],
[ -0.382, -0.686, 0.128, -0.049, -0.208, 0.097, 0.397, 0.893, -0.234, 0.013, -0.012, 0.020, 0.021, 0.014, -0.011],
[ 0.082, 0.095, -0.289, 0.014, 0.146, 0.032, -0.084, -0.234, 0.223, -0.043, -0.070, 0.035, 0.031, 0.064, -0.001],
[ -0.398, -0.292, -0.380, -0.048, -0.071, -0.084, -0.012, 0.013, -0.043, 0.408, 0.316, 0.460, 0.051, 0.034, 0.048],
[ -0.305, -0.459, -0.525, 0.030, 0.018, 0.040, -0.034, -0.012, -0.070, 0.316, 0.454, 0.584, -0.007, -0.001, -0.030],
[ -0.409, -0.544, -0.935, -0.022, -0.023, -0.039, 0.019, 0.020, 0.035, 0.460, 0.584, 1.004, -0.048, -0.037, -0.065],
[ -0.448, 0.025, 0.275, -0.058, 0.009, 0.087, -0.020, 0.021, 0.031, 0.051, -0.007, -0.048, 0.475, -0.049, -0.346],
[ 0.037, -0.189, -0.100, 0.037, -0.003, -0.044, -0.060, 0.014, 0.064, 0.034, -0.001, -0.037, -0.049, 0.180, 0.117],
[ 0.292, -0.092, -0.652, -0.020, 0.015, 0.029, 0.027, -0.011, -0.001, 0.048, -0.030, -0.065, -0.346, 0.117, 0.689]])
Total running time of the script: ( 0 minutes 5.007 seconds)