Note
Go to the end to download the full example code.
Using TorchScript to serialize models#
All built-in models and modules in the torchani library, support TorchScript
serialization, which is a native PyTorch feature where a python model is translated into
a PyTorch-specific format. If you use TorchScript you can load the resulting serialized
files in a process where there is no Python dependency.
Scripting an ANI model directly#
Let’s now load the built-in ANI-1ccx models. The ANI-2x model contains 8 models trained with diffrent initialization and on different splits of a dataset
model = torchani.models.ANI2x()
It is very easy to compile and save the model using torch.jit.
compiled_model = torch.jit.script(model)
torch.jit.save(compiled_model, "compiled_model.pt")
For testing purposes, we will now load the model we just saved and see if they produces the same output as the original model:
loaded_compiled_model = torch.jit.load("compiled_model.pt")
We use the molecule below to test:
coordinates = torch.tensor(
[
[
[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834],
]
]
)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])
And here is the result:
energies_ensemble = model((species, coordinates)).energies
energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
print(
"Ensemble energy, eager mode vs loaded jit:",
energies_ensemble.item(),
energies_ensemble_jit.item(),
)
Ensemble energy, eager mode vs loaded jit: -40.45979309082031 -40.45979309082031
Customize the model and script#
You could also customize the model you want to export. For example, let’s do the following customization to the model:
uses double as dtype instead of float
don’t care about periodic boundary condition
in addition to energies, allow returning optionally forces, and hessians
you could do the following:
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torchani.models.ANI1x().double()
def forward(
self,
species: Tensor,
coordinates: Tensor,
return_forces: bool = False,
return_hessians: bool = False,
) -> tp.Tuple[Tensor, tp.Optional[Tensor], tp.Optional[Tensor]]:
if return_forces or return_hessians:
coordinates.requires_grad_(True)
energies = self.model((species, coordinates)).energies
_forces: tp.Optional[Tensor] = None
_hessians: tp.Optional[Tensor] = None
if return_forces or return_hessians:
_forces = forces(
energies, coordinates, retain_graph=True, create_graph=return_hessians
)
if return_hessians:
assert _forces is not None
_hessians = hessians(_forces, coordinates)
return energies, _forces, _hessians
custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, "compiled_custom_model.pt")
loaded_compiled_custom_model = torch.jit.load("compiled_custom_model.pt")
energies_eager, forces_eager, hessians_eager = custom_model(
species, coordinates, True, True
)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(
species, coordinates, True, True
)
print("Energy, eager mode vs loaded jit:", energies_eager.item(), energies_jit.item())
print()
print(
"Force, eager mode vs loaded jit:\n",
forces_eager.squeeze(0),
"\n",
forces_jit.squeeze(0),
)
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print(
"Hessian, eager mode vs loaded jit:\n",
hessians_eager.squeeze(0),
"\n",
hessians_jit.squeeze(0),
)
# Lets delete the files we created for cleanup
Path("compiled_custom_model.pt").unlink()
Path("compiled_model.pt").unlink()
Energy, eager mode vs loaded jit: -40.45902126308002 -40.45902126308002
Force, eager mode vs loaded jit:
tensor([[ 0.031, -0.132, -0.053],
[-0.129, 0.164, -0.077],
[ 0.086, -0.043, 0.041],
[ 0.027, 0.006, 0.038],
[-0.014, 0.005, 0.051]], grad_fn=<SqueezeBackward1>)
tensor([[ 0.031, -0.132, -0.053],
[-0.129, 0.164, -0.077],
[ 0.086, -0.043, 0.041],
[ 0.027, 0.006, 0.038],
[-0.014, 0.005, 0.051]], grad_fn=<SqueezeBackward1>)
Hessian, eager mode vs loaded jit:
tensor([[ 3.149, -0.035, 0.557, -1.715, 0.685, -0.522, -0.588, -0.382, 0.082, -0.398, -0.305, -0.409, -0.448, 0.037, 0.292],
[ -0.035, 1.765, 0.287, 0.796, -0.430, 0.253, -0.494, -0.686, 0.095, -0.292, -0.459, -0.544, 0.025, -0.189, -0.092],
[ 0.557, 0.287, 2.097, -0.556, 0.210, -0.221, 0.104, 0.128, -0.289, -0.380, -0.525, -0.935, 0.275, -0.100, -0.652],
[ -1.715, 0.796, -0.556, 1.756, -0.815, 0.584, 0.065, -0.049, 0.014, -0.048, 0.030, -0.022, -0.058, 0.037, -0.020],
[ 0.685, -0.430, 0.210, -0.815, 0.624, -0.347, 0.191, -0.208, 0.146, -0.071, 0.018, -0.023, 0.009, -0.003, 0.015],
[ -0.522, 0.253, -0.221, 0.584, -0.347, 0.199, -0.065, 0.097, 0.032, -0.084, 0.040, -0.039, 0.087, -0.044, 0.029],
[ -0.588, -0.494, 0.104, 0.065, 0.191, -0.065, 0.555, 0.397, -0.084, -0.012, -0.034, 0.019, -0.020, -0.060, 0.027],
[ -0.382, -0.686, 0.128, -0.049, -0.208, 0.097, 0.397, 0.893, -0.234, 0.013, -0.012, 0.020, 0.021, 0.014, -0.011],
[ 0.082, 0.095, -0.289, 0.014, 0.146, 0.032, -0.084, -0.234, 0.223, -0.043, -0.070, 0.035, 0.031, 0.064, -0.001],
[ -0.398, -0.292, -0.380, -0.048, -0.071, -0.084, -0.012, 0.013, -0.043, 0.408, 0.316, 0.460, 0.051, 0.034, 0.048],
[ -0.305, -0.459, -0.525, 0.030, 0.018, 0.040, -0.034, -0.012, -0.070, 0.316, 0.454, 0.584, -0.007, -0.001, -0.030],
[ -0.409, -0.544, -0.935, -0.022, -0.023, -0.039, 0.019, 0.020, 0.035, 0.460, 0.584, 1.004, -0.048, -0.037, -0.065],
[ -0.448, 0.025, 0.275, -0.058, 0.009, 0.087, -0.020, 0.021, 0.031, 0.051, -0.007, -0.048, 0.475, -0.049, -0.346],
[ 0.037, -0.189, -0.100, 0.037, -0.003, -0.044, -0.060, 0.014, 0.064, 0.034, -0.001, -0.037, -0.049, 0.180, 0.117],
[ 0.292, -0.092, -0.652, -0.020, 0.015, 0.029, 0.027, -0.011, -0.001, 0.048, -0.030, -0.065, -0.346, 0.117, 0.689]])
tensor([[ 3.149, -0.035, 0.557, -1.715, 0.685, -0.522, -0.588, -0.382, 0.082, -0.398, -0.305, -0.409, -0.448, 0.037, 0.292],
[ -0.035, 1.765, 0.287, 0.796, -0.430, 0.253, -0.494, -0.686, 0.095, -0.292, -0.459, -0.544, 0.025, -0.189, -0.092],
[ 0.557, 0.287, 2.097, -0.556, 0.210, -0.221, 0.104, 0.128, -0.289, -0.380, -0.525, -0.935, 0.275, -0.100, -0.652],
[ -1.715, 0.796, -0.556, 1.756, -0.815, 0.584, 0.065, -0.049, 0.014, -0.048, 0.030, -0.022, -0.058, 0.037, -0.020],
[ 0.685, -0.430, 0.210, -0.815, 0.624, -0.347, 0.191, -0.208, 0.146, -0.071, 0.018, -0.023, 0.009, -0.003, 0.015],
[ -0.522, 0.253, -0.221, 0.584, -0.347, 0.199, -0.065, 0.097, 0.032, -0.084, 0.040, -0.039, 0.087, -0.044, 0.029],
[ -0.588, -0.494, 0.104, 0.065, 0.191, -0.065, 0.555, 0.397, -0.084, -0.012, -0.034, 0.019, -0.020, -0.060, 0.027],
[ -0.382, -0.686, 0.128, -0.049, -0.208, 0.097, 0.397, 0.893, -0.234, 0.013, -0.012, 0.020, 0.021, 0.014, -0.011],
[ 0.082, 0.095, -0.289, 0.014, 0.146, 0.032, -0.084, -0.234, 0.223, -0.043, -0.070, 0.035, 0.031, 0.064, -0.001],
[ -0.398, -0.292, -0.380, -0.048, -0.071, -0.084, -0.012, 0.013, -0.043, 0.408, 0.316, 0.460, 0.051, 0.034, 0.048],
[ -0.305, -0.459, -0.525, 0.030, 0.018, 0.040, -0.034, -0.012, -0.070, 0.316, 0.454, 0.584, -0.007, -0.001, -0.030],
[ -0.409, -0.544, -0.935, -0.022, -0.023, -0.039, 0.019, 0.020, 0.035, 0.460, 0.584, 1.004, -0.048, -0.037, -0.065],
[ -0.448, 0.025, 0.275, -0.058, 0.009, 0.087, -0.020, 0.021, 0.031, 0.051, -0.007, -0.048, 0.475, -0.049, -0.346],
[ 0.037, -0.189, -0.100, 0.037, -0.003, -0.044, -0.060, 0.014, 0.064, 0.034, -0.001, -0.037, -0.049, 0.180, 0.117],
[ 0.292, -0.092, -0.652, -0.020, 0.015, 0.029, 0.027, -0.011, -0.001, 0.048, -0.030, -0.065, -0.346, 0.117, 0.689]])