.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/nnp_training.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_nnp_training.py: .. _training-example: Train Your Own Neural Network Potential ======================================= This example shows how to use TorchANI to train a neural network potential with the setup identical to NeuroChem. We will use the same configuration as specified in `inputtrain.ipt`_ .. _`inputtrain.ipt`: https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt .. note:: TorchANI provide tools to run NeuroChem training config file `inputtrain.ipt`. See: :ref:`neurochem-training`. .. warning:: The training setup used in this file is configured to reproduce the original research at `Less is more: Sampling chemical space with active learning`_ as much as possible. That research was done on a different platform called NeuroChem which has many default options and technical details different from PyTorch. Some decisions made here (such as, using NeuroChem's initialization instead of PyTorch's default initialization) is not because it gives better result, but solely based on reproducing the original research. This file should not be interpreted as a suggestions to the readers on how they should setup their models. .. _`Less is more: Sampling chemical space with active learning`: https://aip.scitation.org/doi/full/10.1063/1.5023802 .. GENERATED FROM PYTHON SOURCE LINES 34-35 To begin with, let's first import the modules and setup devices we will use: .. GENERATED FROM PYTHON SOURCE LINES 35-50 .. code-block:: default import torch import torchani import os import math import torch.utils.tensorboard import tqdm import pickle # helper function to convert energy unit from Hartree to kcal/mol from torchani.units import hartree2kcalmol # device to run the training device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') .. GENERATED FROM PYTHON SOURCE LINES 51-66 Now let's setup constants and construct an AEV computer. These numbers could be found in `rHCNO-5.2R_16-3.5A_a4-8.params` The atomic self energies given in `sae_linfit.dat`_ are computed from ANI-1x dataset. These constants can be calculated for any given dataset if ``None`` is provided as an argument to the object of :class:`EnergyShifter` class. .. note:: Besides defining these hyperparameters programmatically, :mod:`torchani.neurochem` provide tools to read them from file. .. _rHCNO-5.2R_16-3.5A_a4-8.params: https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params .. _sae_linfit.dat: https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat .. GENERATED FROM PYTHON SOURCE LINES 66-80 .. code-block:: default Rcr = 5.2000e+00 Rca = 3.5000e+00 EtaR = torch.tensor([1.6000000e+01], device=device) ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device) Zeta = torch.tensor([3.2000000e+01], device=device) ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device) EtaA = torch.tensor([8.0000000e+00], device=device) ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device) species_order = ['H', 'C', 'N', 'O'] num_species = len(species_order) aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species) energy_shifter = torchani.utils.EnergyShifter(None) .. GENERATED FROM PYTHON SOURCE LINES 81-91 Now let's setup datasets. These paths assumes the user run this script under the ``examples`` directory of TorchANI's repository. If you download this script, you should manually set the path of these files in your system before this script can run successfully. Also note that we need to subtracting energies by the self energies of all atoms for each molecule. This makes the range of energies in a reasonable range. The second argument defines how to convert species as a list of string to tensor, that is, for all supported chemical symbols, which is correspond to ``0``, which correspond to ``1``, etc. .. GENERATED FROM PYTHON SOURCE LINES 91-126 .. code-block:: default try: path = os.path.dirname(os.path.realpath(__file__)) except NameError: path = os.getcwd() dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') batch_size = 2560 pickled_dataset_path = 'dataset.pkl' # We pickle the dataset after loading to ensure we use the same validation set # each time we restart training, otherwise we risk mixing the validation and # training sets on each restart. if os.path.isfile(pickled_dataset_path): print(f'Unpickling preprocessed dataset found in {pickled_dataset_path}') with open(pickled_dataset_path, 'rb') as f: dataset = pickle.load(f) training = dataset['training'].collate(batch_size).cache() validation = dataset['validation'].collate(batch_size).cache() energy_shifter.self_energies = dataset['self_energies'].to(device) else: print(f'Processing dataset in {dspath}') training, validation = torchani.data.load(dspath)\ .subtract_self_energies(energy_shifter, species_order)\ .species_to_indices(species_order)\ .shuffle()\ .split(0.8, None) with open(pickled_dataset_path, 'wb') as f: pickle.dump({'training': training, 'validation': validation, 'self_energies': energy_shifter.self_energies.cpu()}, f) training = training.collate(batch_size).cache() validation = validation.collate(batch_size).cache() print('Self atomic energies: ', energy_shifter.self_energies) .. rst-class:: sphx-glr-script-out .. code-block:: none Processing dataset in /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 => loading /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5, total molecules: 1 1/1 [==============================] - 0.0s 2/1 [============================================================] - 0.1s 3/1 [==========================================================================================] - 0.1s=> loading /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5, total molecules: 1 1/1 [==============================] - 0.0s 2/1 [============================================================] - 0.0s 3/1 [==========================================================================================] - 0.1sSelf atomic energies: tensor([-16.140, 24.082, -8.092, -44.095], dtype=torch.float64) .. GENERATED FROM PYTHON SOURCE LINES 127-131 When iterating the dataset, we will get a dict of name->property mapping ############################################################################## Now let's define atomic neural networks. .. GENERATED FROM PYTHON SOURCE LINES 131-176 .. code-block:: default aev_dim = aev_computer.aev_length H_network = torch.nn.Sequential( torch.nn.Linear(aev_dim, 160), torch.nn.CELU(0.1), torch.nn.Linear(160, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) C_network = torch.nn.Sequential( torch.nn.Linear(aev_dim, 144), torch.nn.CELU(0.1), torch.nn.Linear(144, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) N_network = torch.nn.Sequential( torch.nn.Linear(aev_dim, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) O_network = torch.nn.Sequential( torch.nn.Linear(aev_dim, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) nn = torchani.ANIModel([H_network, C_network, N_network, O_network]) print(nn) .. rst-class:: sphx-glr-script-out .. code-block:: none ANIModel( (0): Sequential( (0): Linear(in_features=384, out_features=160, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=160, out_features=128, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=128, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (1): Sequential( (0): Linear(in_features=384, out_features=144, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=144, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (2): Sequential( (0): Linear(in_features=384, out_features=128, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=128, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (3): Sequential( (0): Linear(in_features=384, out_features=128, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=128, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) ) .. GENERATED FROM PYTHON SOURCE LINES 177-187 Initialize the weights and biases. .. note:: Pytorch default initialization for the weights and biases in linear layers is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_ We initialize the weights similarly but from the normal distribution. The biases were initialized to zero. .. _TORCH.NN.MODULES.LINEAR: https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear .. GENERATED FROM PYTHON SOURCE LINES 187-197 .. code-block:: default def init_params(m): if isinstance(m, torch.nn.Linear): torch.nn.init.kaiming_normal_(m.weight, a=1.0) torch.nn.init.zeros_(m.bias) nn.apply(init_params) .. rst-class:: sphx-glr-script-out .. code-block:: none ANIModel( (0): Sequential( (0): Linear(in_features=384, out_features=160, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=160, out_features=128, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=128, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (1): Sequential( (0): Linear(in_features=384, out_features=144, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=144, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (2): Sequential( (0): Linear(in_features=384, out_features=128, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=128, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (3): Sequential( (0): Linear(in_features=384, out_features=128, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=128, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) ) .. GENERATED FROM PYTHON SOURCE LINES 198-199 Let's now create a pipeline of AEV Computer --> Neural Networks. .. GENERATED FROM PYTHON SOURCE LINES 199-201 .. code-block:: default model = torchani.nn.Sequential(aev_computer, nn).to(device) .. GENERATED FROM PYTHON SOURCE LINES 202-216 Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay to updates the weights and Stochastic Gradient Descent (SGD) to update the biases. Moreover, we need to specify different weight decay rate for different layes. .. note:: The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not L2 regularization. The confusion between L2 and weight decay is a common mistake in deep learning. See: `Decoupled Weight Decay Regularization`_ Also note that the weight decay only applies to weight in the training of ANI models, not bias. .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. GENERATED FROM PYTHON SOURCE LINES 216-263 .. code-block:: default AdamW = torch.optim.AdamW([ # H networks {'params': [H_network[0].weight]}, {'params': [H_network[2].weight], 'weight_decay': 0.00001}, {'params': [H_network[4].weight], 'weight_decay': 0.000001}, {'params': [H_network[6].weight]}, # C networks {'params': [C_network[0].weight]}, {'params': [C_network[2].weight], 'weight_decay': 0.00001}, {'params': [C_network[4].weight], 'weight_decay': 0.000001}, {'params': [C_network[6].weight]}, # N networks {'params': [N_network[0].weight]}, {'params': [N_network[2].weight], 'weight_decay': 0.00001}, {'params': [N_network[4].weight], 'weight_decay': 0.000001}, {'params': [N_network[6].weight]}, # O networks {'params': [O_network[0].weight]}, {'params': [O_network[2].weight], 'weight_decay': 0.00001}, {'params': [O_network[4].weight], 'weight_decay': 0.000001}, {'params': [O_network[6].weight]}, ]) SGD = torch.optim.SGD([ # H networks {'params': [H_network[0].bias]}, {'params': [H_network[2].bias]}, {'params': [H_network[4].bias]}, {'params': [H_network[6].bias]}, # C networks {'params': [C_network[0].bias]}, {'params': [C_network[2].bias]}, {'params': [C_network[4].bias]}, {'params': [C_network[6].bias]}, # N networks {'params': [N_network[0].bias]}, {'params': [N_network[2].bias]}, {'params': [N_network[4].bias]}, {'params': [N_network[6].bias]}, # O networks {'params': [O_network[0].bias]}, {'params': [O_network[2].bias]}, {'params': [O_network[4].bias]}, {'params': [O_network[6].bias]}, ], lr=1e-3) .. GENERATED FROM PYTHON SOURCE LINES 264-265 Setting up a learning rate scheduler to do learning rate decay .. GENERATED FROM PYTHON SOURCE LINES 265-268 .. code-block:: default AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0) SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0) .. GENERATED FROM PYTHON SOURCE LINES 269-275 Train the model by minimizing the MSE loss, until validation RMSE no longer improves during a certain number of steps, decay the learning rate and repeat the same process, stop until the learning rate is smaller than a threshold. We first read the checkpoint files to restart training. We use `latest.pt` to store current training state. .. GENERATED FROM PYTHON SOURCE LINES 275-277 .. code-block:: default latest_checkpoint = 'latest.pt' .. GENERATED FROM PYTHON SOURCE LINES 278-279 Resume training from previously saved checkpoints: .. GENERATED FROM PYTHON SOURCE LINES 279-287 .. code-block:: default if os.path.isfile(latest_checkpoint): checkpoint = torch.load(latest_checkpoint) nn.load_state_dict(checkpoint['nn']) AdamW.load_state_dict(checkpoint['AdamW']) SGD.load_state_dict(checkpoint['SGD']) AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler']) SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler']) .. GENERATED FROM PYTHON SOURCE LINES 288-290 During training, we need to validate on validation set and if validation error is better than the best, then save the new best model to a checkpoint .. GENERATED FROM PYTHON SOURCE LINES 290-310 .. code-block:: default def validate(): # run validation mse_sum = torch.nn.MSELoss(reduction='sum') total_mse = 0.0 count = 0 model.train(False) with torch.no_grad(): for properties in validation: species = properties['species'].to(device) coordinates = properties['coordinates'].to(device).float() true_energies = properties['energies'].to(device).float() _, predicted_energies = model((species, coordinates)) total_mse += mse_sum(predicted_energies, true_energies).item() count += predicted_energies.shape[0] model.train(True) return hartree2kcalmol(math.sqrt(total_mse / count)) .. GENERATED FROM PYTHON SOURCE LINES 311-312 We will also use TensorBoard to visualize our training process .. GENERATED FROM PYTHON SOURCE LINES 312-314 .. code-block:: default tensorboard = torch.utils.tensorboard.SummaryWriter() .. GENERATED FROM PYTHON SOURCE LINES 315-320 Finally, we come to the training loop. In this tutorial, we are setting the maximum epoch to a very small number, only to make this demo terminate fast. For serious training, this should be set to a much larger value .. GENERATED FROM PYTHON SOURCE LINES 320-376 .. code-block:: default mse = torch.nn.MSELoss(reduction='none') print("training starting from epoch", AdamW_scheduler.last_epoch + 1) max_epochs = 10 early_stopping_learning_rate = 1.0E-5 best_model_checkpoint = 'best.pt' for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): rmse = validate() print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1) learning_rate = AdamW.param_groups[0]['lr'] if learning_rate < early_stopping_learning_rate: break # checkpoint if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best): torch.save(nn.state_dict(), best_model_checkpoint) AdamW_scheduler.step(rmse) SGD_scheduler.step(rmse) tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch) tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch) tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch) for i, properties in tqdm.tqdm( enumerate(training), total=len(training), desc="epoch {}".format(AdamW_scheduler.last_epoch) ): species = properties['species'].to(device) coordinates = properties['coordinates'].to(device).float() true_energies = properties['energies'].to(device).float() num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) _, predicted_energies = model((species, coordinates)) loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() AdamW.zero_grad() SGD.zero_grad() loss.backward() AdamW.step() SGD.step() # write current batch loss to TensorBoard tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i) torch.save({ 'nn': nn.state_dict(), 'AdamW': AdamW.state_dict(), 'SGD': SGD.state_dict(), 'AdamW_scheduler': AdamW_scheduler.state_dict(), 'SGD_scheduler': SGD_scheduler.state_dict(), }, latest_checkpoint) .. rst-class:: sphx-glr-script-out .. code-block:: none training starting from epoch 1 RMSE: 26.091515838518397 at epoch 1 epoch 1: 0%| | 0/4 [00:00` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: nnp_training.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_