.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/nnp_training_force.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_nnp_training_force.py: .. _force-training-example: Train Neural Network Potential To Both Energies and Forces ========================================================== We have seen how to train a neural network potential by manually writing training loop in :ref:`training-example`. This tutorial shows how to modify that script to train to force. .. GENERATED FROM PYTHON SOURCE LINES 14-17 Most part of the script are the same as :ref:`training-example`, we will omit the comments for these parts. Please refer to :ref:`training-example` for more information .. GENERATED FROM PYTHON SOURCE LINES 17-62 .. code-block:: default import torch import torchani import os import math import torch.utils.tensorboard import tqdm # helper function to convert energy unit from Hartree to kcal/mol from torchani.units import hartree2kcalmol device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') Rcr = 5.2000e+00 Rca = 3.5000e+00 EtaR = torch.tensor([1.6000000e+01], device=device) ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device) Zeta = torch.tensor([3.2000000e+01], device=device) ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device) EtaA = torch.tensor([8.0000000e+00], device=device) ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device) species_order = ['H', 'C', 'N', 'O'] num_species = len(species_order) aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species) energy_shifter = torchani.utils.EnergyShifter(None) try: path = os.path.dirname(os.path.realpath(__file__)) except NameError: path = os.getcwd() dspath = os.path.join(path, '../dataset/ani-1x/sample.h5') batch_size = 2560 training, validation = torchani.data.load( dspath, additional_properties=('forces',) ).subtract_self_energies(energy_shifter, species_order).species_to_indices(species_order).shuffle().split(0.8, None) training = training.collate(batch_size).cache() validation = validation.collate(batch_size).cache() print('Self atomic energies: ', energy_shifter.self_energies) .. rst-class:: sphx-glr-script-out .. code-block:: none => loading /home/runner/work/torchani/torchani/examples/../dataset/ani-1x/sample.h5, total molecules: 6 1/6 [====>.........................] - 0.0s 2/6 [=========>....................] - 0.0s 3/6 [==============>...............] - 0.0s 4/6 [===================>..........] - 0.1s 5/6 [========================>.....] - 0.1s 6/6 [==============================] - 0.1s => loading /home/runner/work/torchani/torchani/examples/../dataset/ani-1x/sample.h5, total molecules: 6 1/6 [====>.........................] - 0.0s 2/6 [=========>....................] - 0.0s 3/6 [==============>...............] - 0.0s 4/6 [===================>..........] - 0.1s 5/6 [========================>.....] - 0.1s 6/6 [==============================] - 0.1s Self atomic energies: tensor([-19.354, -19.354, -54.712, -75.163], dtype=torch.float64) .. GENERATED FROM PYTHON SOURCE LINES 63-64 The code to define networks, optimizers, are mostly the same .. GENERATED FROM PYTHON SOURCE LINES 64-109 .. code-block:: default aev_dim = aev_computer.aev_length H_network = torch.nn.Sequential( torch.nn.Linear(aev_dim, 160), torch.nn.CELU(0.1), torch.nn.Linear(160, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) C_network = torch.nn.Sequential( torch.nn.Linear(aev_dim, 144), torch.nn.CELU(0.1), torch.nn.Linear(144, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) N_network = torch.nn.Sequential( torch.nn.Linear(aev_dim, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) O_network = torch.nn.Sequential( torch.nn.Linear(aev_dim, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) nn = torchani.ANIModel([H_network, C_network, N_network, O_network]) print(nn) .. rst-class:: sphx-glr-script-out .. code-block:: none ANIModel( (0): Sequential( (0): Linear(in_features=384, out_features=160, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=160, out_features=128, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=128, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (1): Sequential( (0): Linear(in_features=384, out_features=144, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=144, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (2): Sequential( (0): Linear(in_features=384, out_features=128, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=128, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (3): Sequential( (0): Linear(in_features=384, out_features=128, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=128, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) ) .. GENERATED FROM PYTHON SOURCE LINES 110-120 Initialize the weights and biases. .. note:: Pytorch default initialization for the weights and biases in linear layers is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_ We initialize the weights similarly but from the normal distribution. The biases were initialized to zero. .. _TORCH.NN.MODULES.LINEAR: https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear .. GENERATED FROM PYTHON SOURCE LINES 120-130 .. code-block:: default def init_params(m): if isinstance(m, torch.nn.Linear): torch.nn.init.kaiming_normal_(m.weight, a=1.0) torch.nn.init.zeros_(m.bias) nn.apply(init_params) .. rst-class:: sphx-glr-script-out .. code-block:: none ANIModel( (0): Sequential( (0): Linear(in_features=384, out_features=160, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=160, out_features=128, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=128, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (1): Sequential( (0): Linear(in_features=384, out_features=144, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=144, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (2): Sequential( (0): Linear(in_features=384, out_features=128, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=128, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (3): Sequential( (0): Linear(in_features=384, out_features=128, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=128, out_features=112, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=112, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) ) .. GENERATED FROM PYTHON SOURCE LINES 131-132 Let's now create a pipeline of AEV Computer --> Neural Networks. .. GENERATED FROM PYTHON SOURCE LINES 132-134 .. code-block:: default model = torchani.nn.Sequential(aev_computer, nn).to(device) .. GENERATED FROM PYTHON SOURCE LINES 135-137 Here we will use Adam with weight decay for the weights and Stochastic Gradient Descent for biases. .. GENERATED FROM PYTHON SOURCE LINES 137-187 .. code-block:: default AdamW = torch.optim.AdamW([ # H networks {'params': [H_network[0].weight]}, {'params': [H_network[2].weight], 'weight_decay': 0.00001}, {'params': [H_network[4].weight], 'weight_decay': 0.000001}, {'params': [H_network[6].weight]}, # C networks {'params': [C_network[0].weight]}, {'params': [C_network[2].weight], 'weight_decay': 0.00001}, {'params': [C_network[4].weight], 'weight_decay': 0.000001}, {'params': [C_network[6].weight]}, # N networks {'params': [N_network[0].weight]}, {'params': [N_network[2].weight], 'weight_decay': 0.00001}, {'params': [N_network[4].weight], 'weight_decay': 0.000001}, {'params': [N_network[6].weight]}, # O networks {'params': [O_network[0].weight]}, {'params': [O_network[2].weight], 'weight_decay': 0.00001}, {'params': [O_network[4].weight], 'weight_decay': 0.000001}, {'params': [O_network[6].weight]}, ]) SGD = torch.optim.SGD([ # H networks {'params': [H_network[0].bias]}, {'params': [H_network[2].bias]}, {'params': [H_network[4].bias]}, {'params': [H_network[6].bias]}, # C networks {'params': [C_network[0].bias]}, {'params': [C_network[2].bias]}, {'params': [C_network[4].bias]}, {'params': [C_network[6].bias]}, # N networks {'params': [N_network[0].bias]}, {'params': [N_network[2].bias]}, {'params': [N_network[4].bias]}, {'params': [N_network[6].bias]}, # O networks {'params': [O_network[0].bias]}, {'params': [O_network[2].bias]}, {'params': [O_network[4].bias]}, {'params': [O_network[6].bias]}, ], lr=1e-3) AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0) SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0) .. GENERATED FROM PYTHON SOURCE LINES 188-189 This part of the code is also the same .. GENERATED FROM PYTHON SOURCE LINES 189-191 .. code-block:: default latest_checkpoint = 'force-training-latest.pt' .. GENERATED FROM PYTHON SOURCE LINES 192-193 Resume training from previously saved checkpoints: .. GENERATED FROM PYTHON SOURCE LINES 193-201 .. code-block:: default if os.path.isfile(latest_checkpoint): checkpoint = torch.load(latest_checkpoint) nn.load_state_dict(checkpoint['nn']) AdamW.load_state_dict(checkpoint['AdamW']) SGD.load_state_dict(checkpoint['SGD']) AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler']) SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler']) .. GENERATED FROM PYTHON SOURCE LINES 202-204 During training, we need to validate on validation set and if validation error is better than the best, then save the new best model to a checkpoint .. GENERATED FROM PYTHON SOURCE LINES 204-224 .. code-block:: default def validate(): # run validation mse_sum = torch.nn.MSELoss(reduction='sum') total_mse = 0.0 count = 0 model.train(False) with torch.no_grad(): for properties in validation: species = properties['species'].to(device) coordinates = properties['coordinates'].to(device).float() true_energies = properties['energies'].to(device).float() _, predicted_energies = model((species, coordinates)) total_mse += mse_sum(predicted_energies, true_energies).item() count += predicted_energies.shape[0] model.train(True) return hartree2kcalmol(math.sqrt(total_mse / count)) .. GENERATED FROM PYTHON SOURCE LINES 225-226 We will also use TensorBoard to visualize our training process .. GENERATED FROM PYTHON SOURCE LINES 226-228 .. code-block:: default tensorboard = torch.utils.tensorboard.SummaryWriter() .. GENERATED FROM PYTHON SOURCE LINES 229-230 In the training loop, we need to compute force, and loss for forces .. GENERATED FROM PYTHON SOURCE LINES 230-302 .. code-block:: default mse = torch.nn.MSELoss(reduction='none') print("training starting from epoch", AdamW_scheduler.last_epoch + 1) # We only train 3 epoches here in able to generate the docs quickly. # Real training should take much more than 3 epoches. max_epochs = 3 early_stopping_learning_rate = 1.0E-5 force_coefficient = 0.1 # controls the importance of energy loss vs force loss best_model_checkpoint = 'force-training-best.pt' for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): rmse = validate() print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1) learning_rate = AdamW.param_groups[0]['lr'] if learning_rate < early_stopping_learning_rate: break # checkpoint if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best): torch.save(nn.state_dict(), best_model_checkpoint) AdamW_scheduler.step(rmse) SGD_scheduler.step(rmse) tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch) tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch) tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch) # Besides being stored in x, species and coordinates are also stored in y. # So here, for simplicity, we just ignore the x and use y for everything. for i, properties in tqdm.tqdm( enumerate(training), total=len(training), desc="epoch {}".format(AdamW_scheduler.last_epoch) ): species = properties['species'].to(device) coordinates = properties['coordinates'].to(device).float().requires_grad_(True) true_energies = properties['energies'].to(device).float() true_forces = properties['forces'].to(device).float() num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) _, predicted_energies = model((species, coordinates)) # We can use torch.autograd.grad to compute force. Remember to # create graph so that the loss of the force can contribute to # the gradient of parameters, and also to retain graph so that # we can backward through it a second time when computing gradient # w.r.t. parameters. forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0] # Now the total loss has two parts, energy loss and force loss energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() force_loss = (mse(true_forces, forces).sum(dim=(1, 2)) / num_atoms).mean() loss = energy_loss + force_coefficient * force_loss AdamW.zero_grad() SGD.zero_grad() loss.backward() AdamW.step() SGD.step() # write current batch loss to TensorBoard tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i) torch.save({ 'nn': nn.state_dict(), 'AdamW': AdamW.state_dict(), 'SGD': SGD.state_dict(), 'AdamW_scheduler': AdamW_scheduler.state_dict(), 'SGD_scheduler': SGD_scheduler.state_dict(), }, latest_checkpoint) .. rst-class:: sphx-glr-script-out .. code-block:: none training starting from epoch 1 RMSE: 86.10589445003251 at epoch 1 epoch 1: 0%| | 0/4 [00:00` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: nnp_training_force.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_