.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/jit.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_jit.py: Using TorchScript to serialize and deploy model =============================================== Models in TorchANI's model zoo support TorchScript. TorchScript is a way to create serializable and optimizable models from PyTorch code. It allows users to saved their models from a Python process and loaded in a process where there is no Python dependency. .. GENERATED FROM PYTHON SOURCE LINES 12-13 To begin with, let's first import the modules we will use: .. GENERATED FROM PYTHON SOURCE LINES 13-18 .. code-block:: default import torch import torchani from typing import Tuple, Optional from torch import Tensor .. GENERATED FROM PYTHON SOURCE LINES 19-24 Scripting builtin model directly -------------------------------- Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8 models trained with diffrent initialization. .. GENERATED FROM PYTHON SOURCE LINES 24-26 .. code-block:: default model = torchani.models.ANI1ccx(periodic_table_index=True) .. rst-class:: sphx-glr-script-out .. code-block:: none /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/torchani-2.2.3-py3.8.egg/torchani/resources/ .. GENERATED FROM PYTHON SOURCE LINES 27-28 It is very easy to compile and save the model using `torch.jit`. .. GENERATED FROM PYTHON SOURCE LINES 28-31 .. code-block:: default compiled_model = torch.jit.script(model) torch.jit.save(compiled_model, 'compiled_model.pt') .. GENERATED FROM PYTHON SOURCE LINES 32-33 Besides compiling the ensemble, it is also possible to compile a single network .. GENERATED FROM PYTHON SOURCE LINES 33-36 .. code-block:: default compiled_model0 = torch.jit.script(model[0]) torch.jit.save(compiled_model0, 'compiled_model0.pt') .. GENERATED FROM PYTHON SOURCE LINES 37-39 For testing purposes, we will now load the models we just saved and see if they produces the same output as the original model: .. GENERATED FROM PYTHON SOURCE LINES 39-43 .. code-block:: default loaded_compiled_model = torch.jit.load('compiled_model.pt') loaded_compiled_model0 = torch.jit.load('compiled_model0.pt') .. GENERATED FROM PYTHON SOURCE LINES 44-45 We use the molecule below to test: .. GENERATED FROM PYTHON SOURCE LINES 45-53 .. code-block:: default coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], [-0.83140486, 0.39370209, -0.26395324], [-0.66518241, -0.84461308, 0.20759389], [0.45554739, 0.54289633, 0.81170881], [0.66091919, -0.16799635, -0.91037834]]]) # In periodic table, C = 6 and H = 1 species = torch.tensor([[6, 1, 1, 1, 1]]) .. GENERATED FROM PYTHON SOURCE LINES 54-55 And here is the result: .. GENERATED FROM PYTHON SOURCE LINES 55-63 .. code-block:: default energies_ensemble = model((species, coordinates)).energies energies_single = model[0]((species, coordinates)).energies energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies energies_single_jit = loaded_compiled_model0((species, coordinates)).energies print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item()) print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item()) .. rst-class:: sphx-glr-script-out .. code-block:: none Ensemble energy, eager mode vs loaded jit: -40.42562057899495 -40.42562057899495 Single network energy, eager mode vs loaded jit: -40.428783473392926 -40.428783473392926 .. GENERATED FROM PYTHON SOURCE LINES 64-76 Customize the model and script ------------------------------ You could also customize the model you want to export. For example, let's do the following customization to the model: - uses double as dtype instead of float - don't care about periodic boundary condition - in addition to energies, allow returnsing optionally forces, and hessians - when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, ... you could do the following: .. GENERATED FROM PYTHON SOURCE LINES 76-115 .. code-block:: default class CustomModule(torch.nn.Module): def __init__(self): super().__init__() self.model = torchani.models.ANI1x(periodic_table_index=True).double() # self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double() # self.model = torchani.models.ANI1ccx(periodic_table_index=True).double() def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False, return_hessians: bool = False) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: if return_forces or return_hessians: coordinates.requires_grad_(True) energies = self.model((species, coordinates)).energies forces: Optional[Tensor] = None # noqa: E701 hessians: Optional[Tensor] = None if return_forces or return_hessians: grad = torch.autograd.grad([energies.sum()], [coordinates], create_graph=return_hessians)[0] assert grad is not None forces = -grad if return_hessians: hessians = torchani.utils.hessian(coordinates, forces=forces) return energies, forces, hessians custom_model = CustomModule() compiled_custom_model = torch.jit.script(custom_model) torch.jit.save(compiled_custom_model, 'compiled_custom_model.pt') loaded_compiled_custom_model = torch.jit.load('compiled_custom_model.pt') energies, forces, hessians = custom_model(species, coordinates, True, True) energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(species, coordinates, True, True) print('Energy, eager mode vs loaded jit:', energies.item(), energies_jit.item()) print() print('Force, eager mode vs loaded jit:\n', forces.squeeze(0), '\n', forces_jit.squeeze(0)) print() torch.set_printoptions(sci_mode=False, linewidth=1000) print('Hessian, eager mode vs loaded jit:\n', hessians.squeeze(0), '\n', hessians_jit.squeeze(0)) .. rst-class:: sphx-glr-script-out .. code-block:: none /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/torchani-2.2.3-py3.8.egg/torchani/resources/ Energy, eager mode vs loaded jit: -40.4590220910793 -40.4590220910793 Force, eager mode vs loaded jit: tensor([[ 0.031, -0.132, -0.053], [-0.129, 0.164, -0.077], [ 0.086, -0.043, 0.041], [ 0.027, 0.006, 0.038], [-0.014, 0.005, 0.051]], grad_fn=) tensor([[ 0.031, -0.132, -0.053], [-0.129, 0.164, -0.077], [ 0.086, -0.043, 0.041], [ 0.027, 0.006, 0.038], [-0.014, 0.005, 0.051]], grad_fn=) Hessian, eager mode vs loaded jit: tensor([[ 3.149, -0.035, 0.557, -1.715, 0.685, -0.522, -0.588, -0.382, 0.082, -0.398, -0.305, -0.409, -0.448, 0.037, 0.292], [ -0.035, 1.765, 0.287, 0.796, -0.430, 0.253, -0.494, -0.686, 0.095, -0.292, -0.459, -0.544, 0.025, -0.189, -0.092], [ 0.557, 0.287, 2.097, -0.556, 0.210, -0.221, 0.104, 0.128, -0.289, -0.380, -0.525, -0.935, 0.275, -0.100, -0.652], [ -1.715, 0.796, -0.556, 1.756, -0.815, 0.584, 0.065, -0.049, 0.014, -0.048, 0.030, -0.022, -0.058, 0.037, -0.020], [ 0.685, -0.430, 0.210, -0.815, 0.624, -0.347, 0.191, -0.208, 0.146, -0.071, 0.018, -0.023, 0.009, -0.003, 0.015], [ -0.522, 0.253, -0.221, 0.584, -0.347, 0.199, -0.065, 0.097, 0.032, -0.084, 0.040, -0.039, 0.087, -0.044, 0.029], [ -0.588, -0.494, 0.104, 0.065, 0.191, -0.065, 0.555, 0.397, -0.084, -0.012, -0.034, 0.019, -0.020, -0.060, 0.027], [ -0.382, -0.686, 0.128, -0.049, -0.208, 0.097, 0.397, 0.893, -0.234, 0.013, -0.012, 0.020, 0.021, 0.014, -0.011], [ 0.082, 0.095, -0.289, 0.014, 0.146, 0.032, -0.084, -0.234, 0.223, -0.043, -0.070, 0.035, 0.031, 0.064, -0.001], [ -0.398, -0.292, -0.380, -0.048, -0.071, -0.084, -0.012, 0.013, -0.043, 0.408, 0.316, 0.460, 0.051, 0.034, 0.048], [ -0.305, -0.459, -0.525, 0.030, 0.018, 0.040, -0.034, -0.012, -0.070, 0.316, 0.454, 0.584, -0.007, -0.001, -0.030], [ -0.409, -0.544, -0.935, -0.022, -0.023, -0.039, 0.019, 0.020, 0.035, 0.460, 0.584, 1.004, -0.048, -0.037, -0.065], [ -0.448, 0.025, 0.275, -0.058, 0.009, 0.087, -0.020, 0.021, 0.031, 0.051, -0.007, -0.048, 0.475, -0.049, -0.346], [ 0.037, -0.189, -0.100, 0.037, -0.003, -0.044, -0.060, 0.014, 0.064, 0.034, -0.001, -0.037, -0.049, 0.180, 0.117], [ 0.292, -0.092, -0.652, -0.020, 0.015, 0.029, 0.027, -0.011, -0.001, 0.048, -0.030, -0.065, -0.346, 0.117, 0.689]]) tensor([[ 3.149, -0.035, 0.557, -1.715, 0.685, -0.522, -0.588, -0.382, 0.082, -0.398, -0.305, -0.409, -0.448, 0.037, 0.292], [ -0.035, 1.765, 0.287, 0.796, -0.430, 0.253, -0.494, -0.686, 0.095, -0.292, -0.459, -0.544, 0.025, -0.189, -0.092], [ 0.557, 0.287, 2.097, -0.556, 0.210, -0.221, 0.104, 0.128, -0.289, -0.380, -0.525, -0.935, 0.275, -0.100, -0.652], [ -1.715, 0.796, -0.556, 1.756, -0.815, 0.584, 0.065, -0.049, 0.014, -0.048, 0.030, -0.022, -0.058, 0.037, -0.020], [ 0.685, -0.430, 0.210, -0.815, 0.624, -0.347, 0.191, -0.208, 0.146, -0.071, 0.018, -0.023, 0.009, -0.003, 0.015], [ -0.522, 0.253, -0.221, 0.584, -0.347, 0.199, -0.065, 0.097, 0.032, -0.084, 0.040, -0.039, 0.087, -0.044, 0.029], [ -0.588, -0.494, 0.104, 0.065, 0.191, -0.065, 0.555, 0.397, -0.084, -0.012, -0.034, 0.019, -0.020, -0.060, 0.027], [ -0.382, -0.686, 0.128, -0.049, -0.208, 0.097, 0.397, 0.893, -0.234, 0.013, -0.012, 0.020, 0.021, 0.014, -0.011], [ 0.082, 0.095, -0.289, 0.014, 0.146, 0.032, -0.084, -0.234, 0.223, -0.043, -0.070, 0.035, 0.031, 0.064, -0.001], [ -0.398, -0.292, -0.380, -0.048, -0.071, -0.084, -0.012, 0.013, -0.043, 0.408, 0.316, 0.460, 0.051, 0.034, 0.048], [ -0.305, -0.459, -0.525, 0.030, 0.018, 0.040, -0.034, -0.012, -0.070, 0.316, 0.454, 0.584, -0.007, -0.001, -0.030], [ -0.409, -0.544, -0.935, -0.022, -0.023, -0.039, 0.019, 0.020, 0.035, 0.460, 0.584, 1.004, -0.048, -0.037, -0.065], [ -0.448, 0.025, 0.275, -0.058, 0.009, 0.087, -0.020, 0.021, 0.031, 0.051, -0.007, -0.048, 0.475, -0.049, -0.346], [ 0.037, -0.189, -0.100, 0.037, -0.003, -0.044, -0.060, 0.014, 0.064, 0.034, -0.001, -0.037, -0.049, 0.180, 0.117], [ 0.292, -0.092, -0.652, -0.020, 0.015, 0.029, 0.027, -0.011, -0.001, 0.048, -0.030, -0.065, -0.346, 0.117, 0.689]]) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 5.007 seconds) .. _sphx_glr_download_examples_jit.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: jit.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: jit.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_