Note
Go to the end to download the full example code.
Using TorchScript to serialize models#
All built-in models and modules in the torchani
library, support TorchScript
serialization, which is a native PyTorch feature where a python model is translated into
a PyTorch-specific format. If you use TorchScript you can load the resulting serialized
files in a process where there is no Python dependency.
Scripting an ANI model directly#
Let’s now load the built-in ANI-1ccx models. The ANI-2x model contains 8 models trained with diffrent initialization and on different splits of a dataset
model = torchani.models.ANI2x()
/home/ipickering/Repos/ani/torchani/arch.py:1196: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
dict_ = torch.load(path, map_location=torch.device("cpu"))
It is very easy to compile and save the model using torch.jit
.
compiled_model = torch.jit.script(model)
torch.jit.save(compiled_model, "compiled_model.pt")
For testing purposes, we will now load the model we just saved and see if they produces the same output as the original model:
loaded_compiled_model = torch.jit.load("compiled_model.pt")
We use the molecule below to test:
coordinates = torch.tensor(
[
[
[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834],
]
]
)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])
And here is the result:
energies_ensemble = model((species, coordinates)).energies
energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
print(
"Ensemble energy, eager mode vs loaded jit:",
energies_ensemble.item(),
energies_ensemble_jit.item(),
)
Ensemble energy, eager mode vs loaded jit: -40.45979309082031 -40.45979309082031
Customize the model and script#
You could also customize the model you want to export. For example, let’s do the following customization to the model:
uses double as dtype instead of float
don’t care about periodic boundary condition
in addition to energies, allow returning optionally forces, and hessians
you could do the following:
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torchani.models.ANI1x().double()
def forward(
self,
species: Tensor,
coordinates: Tensor,
return_forces: bool = False,
return_hessians: bool = False,
) -> tp.Tuple[Tensor, tp.Optional[Tensor], tp.Optional[Tensor]]:
if return_forces or return_hessians:
coordinates.requires_grad_(True)
energies = self.model((species, coordinates)).energies
_forces: tp.Optional[Tensor] = None
_hessians: tp.Optional[Tensor] = None
if return_forces or return_hessians:
_forces = forces(
energies, coordinates, retain_graph=True, create_graph=return_hessians
)
if return_hessians:
assert _forces is not None
_hessians = hessians(_forces, coordinates)
return energies, _forces, _hessians
custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, "compiled_custom_model.pt")
loaded_compiled_custom_model = torch.jit.load("compiled_custom_model.pt")
energies_eager, forces_eager, hessians_eager = custom_model(
species, coordinates, True, True
)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(
species, coordinates, True, True
)
print("Energy, eager mode vs loaded jit:", energies_eager.item(), energies_jit.item())
print()
print(
"Force, eager mode vs loaded jit:\n",
forces_eager.squeeze(0),
"\n",
forces_jit.squeeze(0),
)
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print(
"Hessian, eager mode vs loaded jit:\n",
hessians_eager.squeeze(0),
"\n",
hessians_jit.squeeze(0),
)
# Lets delete the files we created for cleanup
Path("compiled_custom_model.pt").unlink()
Path("compiled_model.pt").unlink()
Energy, eager mode vs loaded jit: -40.45902126308002 -40.45902126308002
Force, eager mode vs loaded jit:
tensor([[ 0.031, -0.132, -0.053],
[-0.129, 0.164, -0.077],
[ 0.086, -0.043, 0.041],
[ 0.027, 0.006, 0.038],
[-0.014, 0.005, 0.051]], grad_fn=<SqueezeBackward1>)
tensor([[ 0.031, -0.132, -0.053],
[-0.129, 0.164, -0.077],
[ 0.086, -0.043, 0.041],
[ 0.027, 0.006, 0.038],
[-0.014, 0.005, 0.051]], grad_fn=<SqueezeBackward1>)
Hessian, eager mode vs loaded jit:
tensor([[ 3.149, -0.035, 0.557, -1.715, 0.685, -0.522, -0.588, -0.382, 0.082, -0.398, -0.305, -0.409, -0.448, 0.037, 0.292],
[ -0.035, 1.765, 0.287, 0.796, -0.430, 0.253, -0.494, -0.686, 0.095, -0.292, -0.459, -0.544, 0.025, -0.189, -0.092],
[ 0.557, 0.287, 2.097, -0.556, 0.210, -0.221, 0.104, 0.128, -0.289, -0.380, -0.525, -0.935, 0.275, -0.100, -0.652],
[ -1.715, 0.796, -0.556, 1.756, -0.815, 0.584, 0.065, -0.049, 0.014, -0.048, 0.030, -0.022, -0.058, 0.037, -0.020],
[ 0.685, -0.430, 0.210, -0.815, 0.624, -0.347, 0.191, -0.208, 0.146, -0.071, 0.018, -0.023, 0.009, -0.003, 0.015],
[ -0.522, 0.253, -0.221, 0.584, -0.347, 0.199, -0.065, 0.097, 0.032, -0.084, 0.040, -0.039, 0.087, -0.044, 0.029],
[ -0.588, -0.494, 0.104, 0.065, 0.191, -0.065, 0.555, 0.397, -0.084, -0.012, -0.034, 0.019, -0.020, -0.060, 0.027],
[ -0.382, -0.686, 0.128, -0.049, -0.208, 0.097, 0.397, 0.893, -0.234, 0.013, -0.012, 0.020, 0.021, 0.014, -0.011],
[ 0.082, 0.095, -0.289, 0.014, 0.146, 0.032, -0.084, -0.234, 0.223, -0.043, -0.070, 0.035, 0.031, 0.064, -0.001],
[ -0.398, -0.292, -0.380, -0.048, -0.071, -0.084, -0.012, 0.013, -0.043, 0.408, 0.316, 0.460, 0.051, 0.034, 0.048],
[ -0.305, -0.459, -0.525, 0.030, 0.018, 0.040, -0.034, -0.012, -0.070, 0.316, 0.454, 0.584, -0.007, -0.001, -0.030],
[ -0.409, -0.544, -0.935, -0.022, -0.023, -0.039, 0.019, 0.020, 0.035, 0.460, 0.584, 1.004, -0.048, -0.037, -0.065],
[ -0.448, 0.025, 0.275, -0.058, 0.009, 0.087, -0.020, 0.021, 0.031, 0.051, -0.007, -0.048, 0.475, -0.049, -0.346],
[ 0.037, -0.189, -0.100, 0.037, -0.003, -0.044, -0.060, 0.014, 0.064, 0.034, -0.001, -0.037, -0.049, 0.180, 0.117],
[ 0.292, -0.092, -0.652, -0.020, 0.015, 0.029, 0.027, -0.011, -0.001, 0.048, -0.030, -0.065, -0.346, 0.117, 0.689]])
tensor([[ 3.149, -0.035, 0.557, -1.715, 0.685, -0.522, -0.588, -0.382, 0.082, -0.398, -0.305, -0.409, -0.448, 0.037, 0.292],
[ -0.035, 1.765, 0.287, 0.796, -0.430, 0.253, -0.494, -0.686, 0.095, -0.292, -0.459, -0.544, 0.025, -0.189, -0.092],
[ 0.557, 0.287, 2.097, -0.556, 0.210, -0.221, 0.104, 0.128, -0.289, -0.380, -0.525, -0.935, 0.275, -0.100, -0.652],
[ -1.715, 0.796, -0.556, 1.756, -0.815, 0.584, 0.065, -0.049, 0.014, -0.048, 0.030, -0.022, -0.058, 0.037, -0.020],
[ 0.685, -0.430, 0.210, -0.815, 0.624, -0.347, 0.191, -0.208, 0.146, -0.071, 0.018, -0.023, 0.009, -0.003, 0.015],
[ -0.522, 0.253, -0.221, 0.584, -0.347, 0.199, -0.065, 0.097, 0.032, -0.084, 0.040, -0.039, 0.087, -0.044, 0.029],
[ -0.588, -0.494, 0.104, 0.065, 0.191, -0.065, 0.555, 0.397, -0.084, -0.012, -0.034, 0.019, -0.020, -0.060, 0.027],
[ -0.382, -0.686, 0.128, -0.049, -0.208, 0.097, 0.397, 0.893, -0.234, 0.013, -0.012, 0.020, 0.021, 0.014, -0.011],
[ 0.082, 0.095, -0.289, 0.014, 0.146, 0.032, -0.084, -0.234, 0.223, -0.043, -0.070, 0.035, 0.031, 0.064, -0.001],
[ -0.398, -0.292, -0.380, -0.048, -0.071, -0.084, -0.012, 0.013, -0.043, 0.408, 0.316, 0.460, 0.051, 0.034, 0.048],
[ -0.305, -0.459, -0.525, 0.030, 0.018, 0.040, -0.034, -0.012, -0.070, 0.316, 0.454, 0.584, -0.007, -0.001, -0.030],
[ -0.409, -0.544, -0.935, -0.022, -0.023, -0.039, 0.019, 0.020, 0.035, 0.460, 0.584, 1.004, -0.048, -0.037, -0.065],
[ -0.448, 0.025, 0.275, -0.058, 0.009, 0.087, -0.020, 0.021, 0.031, 0.051, -0.007, -0.048, 0.475, -0.049, -0.346],
[ 0.037, -0.189, -0.100, 0.037, -0.003, -0.044, -0.060, 0.014, 0.064, 0.034, -0.001, -0.037, -0.049, 0.180, 0.117],
[ 0.292, -0.092, -0.652, -0.020, 0.015, 0.029, 0.027, -0.011, -0.001, 0.048, -0.030, -0.065, -0.346, 0.117, 0.689]])