Train Your Own Neural Network Potential

This example shows how to use TorchANI to train a neural network potential with the setup identical to NeuroChem. We will use the same configuration as specified in inputtrain.ipt

Note

TorchANI provide tools to run NeuroChem training config file inputtrain.ipt. See: Train Neural Network Potential From NeuroChem Input File.

Warning

The training setup used in this file is configured to reproduce the original research at Less is more: Sampling chemical space with active learning as much as possible. That research was done on a different platform called NeuroChem which has many default options and technical details different from PyTorch. Some decisions made here (such as, using NeuroChem’s initialization instead of PyTorch’s default initialization) is not because it gives better result, but solely based on reproducing the original research. This file should not be interpreted as a suggestions to the readers on how they should setup their models.

To begin with, let’s first import the modules and setup devices we will use:

import torch
import torchani
import os
import math
import torch.utils.tensorboard
import tqdm
import pickle

# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Now let’s setup constants and construct an AEV computer. These numbers could be found in rHCNO-5.2R_16-3.5A_a4-8.params The atomic self energies given in sae_linfit.dat are computed from ANI-1x dataset. These constants can be calculated for any given dataset if None is provided as an argument to the object of EnergyShifter class.

Note

Besides defining these hyperparameters programmatically, torchani.neurochem provide tools to read them from file.

Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
species_order = ['H', 'C', 'N', 'O']
num_species = len(species_order)
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter(None)

Now let’s setup datasets. These paths assumes the user run this script under the examples directory of TorchANI’s repository. If you download this script, you should manually set the path of these files in your system before this script can run successfully.

Also note that we need to subtracting energies by the self energies of all atoms for each molecule. This makes the range of energies in a reasonable range. The second argument defines how to convert species as a list of string to tensor, that is, for all supported chemical symbols, which is correspond to 0, which correspond to 1, etc.

try:
    path = os.path.dirname(os.path.realpath(__file__))
except NameError:
    path = os.getcwd()
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560

pickled_dataset_path = 'dataset.pkl'

# We pickle the dataset after loading to ensure we use the same validation set
# each time we restart training, otherwise we risk mixing the validation and
# training sets on each restart.
if os.path.isfile(pickled_dataset_path):
    print(f'Unpickling preprocessed dataset found in {pickled_dataset_path}')
    with open(pickled_dataset_path, 'rb') as f:
        dataset = pickle.load(f)
    training = dataset['training'].collate(batch_size).cache()
    validation = dataset['validation'].collate(batch_size).cache()
    energy_shifter.self_energies = dataset['self_energies'].to(device)
else:
    print(f'Processing dataset in {dspath}')
    training, validation = torchani.data.load(dspath)\
                                        .subtract_self_energies(energy_shifter, species_order)\
                                        .species_to_indices(species_order)\
                                        .shuffle()\
                                        .split(0.8, None)
    with open(pickled_dataset_path, 'wb') as f:
        pickle.dump({'training': training,
                     'validation': validation,
                     'self_energies': energy_shifter.self_energies.cpu()}, f)
    training = training.collate(batch_size).cache()
    validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies)
Processing dataset in /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
=> loading /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5, total molecules: 1

1/1  [==============================] - 0.0s

2/1  [============================================================] - 0.0s
3/1  [==========================================================================================] - 0.1s=> loading /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5, total molecules: 1

1/1  [==============================] - 0.0s

2/1  [============================================================] - 0.0s
3/1  [==========================================================================================] - 0.0sSelf atomic energies:  tensor([-16.140,  24.082,  -8.092, -44.095], dtype=torch.float64)

When iterating the dataset, we will get a dict of name->property mapping

ANIModel(
  (0): Sequential(
    (0): Linear(in_features=384, out_features=160, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=160, out_features=128, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=128, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=384, out_features=144, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=144, out_features=112, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=112, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
  (2): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=128, out_features=112, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=112, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
  (3): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=128, out_features=112, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=112, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
)

Initialize the weights and biases.

Note

Pytorch default initialization for the weights and biases in linear layers is Kaiming uniform. See: TORCH.NN.MODULES.LINEAR We initialize the weights similarly but from the normal distribution. The biases were initialized to zero.

def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)


nn.apply(init_params)
ANIModel(
  (0): Sequential(
    (0): Linear(in_features=384, out_features=160, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=160, out_features=128, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=128, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=384, out_features=144, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=144, out_features=112, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=112, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
  (2): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=128, out_features=112, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=112, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
  (3): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=128, out_features=112, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=112, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
)

Let’s now create a pipeline of AEV Computer –> Neural Networks.

model = torchani.nn.Sequential(aev_computer, nn).to(device)

Now let’s setup the optimizers. NeuroChem uses Adam with decoupled weight decay to updates the weights and Stochastic Gradient Descent (SGD) to update the biases. Moreover, we need to specify different weight decay rate for different layes.

Note

The weight decay in inputtrain.ipt is named “l2”, but it is actually not L2 regularization. The confusion between L2 and weight decay is a common mistake in deep learning. See: Decoupled Weight Decay Regularization Also note that the weight decay only applies to weight in the training of ANI models, not bias.

AdamW = torch.optim.AdamW([
    # H networks
    {'params': [H_network[0].weight]},
    {'params': [H_network[2].weight], 'weight_decay': 0.00001},
    {'params': [H_network[4].weight], 'weight_decay': 0.000001},
    {'params': [H_network[6].weight]},
    # C networks
    {'params': [C_network[0].weight]},
    {'params': [C_network[2].weight], 'weight_decay': 0.00001},
    {'params': [C_network[4].weight], 'weight_decay': 0.000001},
    {'params': [C_network[6].weight]},
    # N networks
    {'params': [N_network[0].weight]},
    {'params': [N_network[2].weight], 'weight_decay': 0.00001},
    {'params': [N_network[4].weight], 'weight_decay': 0.000001},
    {'params': [N_network[6].weight]},
    # O networks
    {'params': [O_network[0].weight]},
    {'params': [O_network[2].weight], 'weight_decay': 0.00001},
    {'params': [O_network[4].weight], 'weight_decay': 0.000001},
    {'params': [O_network[6].weight]},
])

SGD = torch.optim.SGD([
    # H networks
    {'params': [H_network[0].bias]},
    {'params': [H_network[2].bias]},
    {'params': [H_network[4].bias]},
    {'params': [H_network[6].bias]},
    # C networks
    {'params': [C_network[0].bias]},
    {'params': [C_network[2].bias]},
    {'params': [C_network[4].bias]},
    {'params': [C_network[6].bias]},
    # N networks
    {'params': [N_network[0].bias]},
    {'params': [N_network[2].bias]},
    {'params': [N_network[4].bias]},
    {'params': [N_network[6].bias]},
    # O networks
    {'params': [O_network[0].bias]},
    {'params': [O_network[2].bias]},
    {'params': [O_network[4].bias]},
    {'params': [O_network[6].bias]},
], lr=1e-3)

Setting up a learning rate scheduler to do learning rate decay

AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)

Train the model by minimizing the MSE loss, until validation RMSE no longer improves during a certain number of steps, decay the learning rate and repeat the same process, stop until the learning rate is smaller than a threshold.

We first read the checkpoint files to restart training. We use latest.pt to store current training state.

latest_checkpoint = 'latest.pt'

Resume training from previously saved checkpoints:

if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    SGD.load_state_dict(checkpoint['SGD'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
    SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])

During training, we need to validate on validation set and if validation error is better than the best, then save the new best model to a checkpoint

def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    total_mse = 0.0
    count = 0
    model.train(False)
    with torch.no_grad():
        for properties in validation:
            species = properties['species'].to(device)
            coordinates = properties['coordinates'].to(device).float()
            true_energies = properties['energies'].to(device).float()
            _, predicted_energies = model((species, coordinates))
            total_mse += mse_sum(predicted_energies, true_energies).item()
            count += predicted_energies.shape[0]
    model.train(True)
    return hartree2kcalmol(math.sqrt(total_mse / count))

We will also use TensorBoard to visualize our training process

Finally, we come to the training loop.

In this tutorial, we are setting the maximum epoch to a very small number, only to make this demo terminate fast. For serious training, this should be set to a much larger value

mse = torch.nn.MSELoss(reduction='none')

print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
max_epochs = 10
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'

for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
    rmse = validate()
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)

    learning_rate = AdamW.param_groups[0]['lr']

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
        torch.save(nn.state_dict(), best_model_checkpoint)

    AdamW_scheduler.step(rmse)
    SGD_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)

    for i, properties in tqdm.tqdm(
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))

        loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()

        AdamW.zero_grad()
        SGD.zero_grad()
        loss.backward()
        AdamW.step()
        SGD.step()

        # write current batch loss to TensorBoard
        tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)

    torch.save({
        'nn': nn.state_dict(),
        'AdamW': AdamW.state_dict(),
        'SGD': SGD.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
        'SGD_scheduler': SGD_scheduler.state_dict(),
    }, latest_checkpoint)
training starting from epoch 1
RMSE: 131.69915654000263 at epoch 1

epoch 1:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 1:  25%|##5       | 1/4 [00:00<00:00,  6.71it/s]
epoch 1:  50%|#####     | 2/4 [00:00<00:00,  6.88it/s]
epoch 1:  75%|#######5  | 3/4 [00:00<00:00,  6.99it/s]
epoch 1: 100%|##########| 4/4 [00:00<00:00,  8.12it/s]
RMSE: 133.3053020199091 at epoch 2

epoch 2:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 2:  25%|##5       | 1/4 [00:00<00:00,  7.12it/s]
epoch 2:  50%|#####     | 2/4 [00:00<00:00,  7.14it/s]
epoch 2:  75%|#######5  | 3/4 [00:00<00:00,  7.15it/s]
epoch 2: 100%|##########| 4/4 [00:00<00:00,  8.33it/s]
RMSE: 65.24803736824568 at epoch 3

epoch 3:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 3:  25%|##5       | 1/4 [00:00<00:00,  7.11it/s]
epoch 3:  50%|#####     | 2/4 [00:00<00:00,  7.11it/s]
epoch 3:  75%|#######5  | 3/4 [00:00<00:00,  7.13it/s]
epoch 3: 100%|##########| 4/4 [00:00<00:00,  8.32it/s]
RMSE: 25.67967103952448 at epoch 4

epoch 4:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 4:  25%|##5       | 1/4 [00:00<00:00,  7.04it/s]
epoch 4:  50%|#####     | 2/4 [00:00<00:00,  6.96it/s]
epoch 4:  75%|#######5  | 3/4 [00:00<00:00,  7.02it/s]
epoch 4: 100%|##########| 4/4 [00:00<00:00,  8.20it/s]
RMSE: 59.19369350701647 at epoch 5

epoch 5:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 5:  25%|##5       | 1/4 [00:00<00:00,  7.15it/s]
epoch 5:  50%|#####     | 2/4 [00:00<00:00,  7.13it/s]
epoch 5:  75%|#######5  | 3/4 [00:00<00:00,  7.08it/s]
epoch 5: 100%|##########| 4/4 [00:00<00:00,  8.27it/s]
RMSE: 35.32657262349891 at epoch 6

epoch 6:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 6:  25%|##5       | 1/4 [00:00<00:00,  7.06it/s]
epoch 6:  50%|#####     | 2/4 [00:00<00:00,  7.02it/s]
epoch 6:  75%|#######5  | 3/4 [00:00<00:00,  7.05it/s]
epoch 6: 100%|##########| 4/4 [00:00<00:00,  8.24it/s]
RMSE: 18.11028976287304 at epoch 7

epoch 7:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 7:  25%|##5       | 1/4 [00:00<00:00,  7.07it/s]
epoch 7:  50%|#####     | 2/4 [00:00<00:00,  7.02it/s]
epoch 7:  75%|#######5  | 3/4 [00:00<00:00,  7.02it/s]
epoch 7: 100%|##########| 4/4 [00:00<00:00,  8.19it/s]
RMSE: 30.805631221574284 at epoch 8

epoch 8:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 8:  25%|##5       | 1/4 [00:00<00:00,  7.04it/s]
epoch 8:  50%|#####     | 2/4 [00:00<00:00,  7.06it/s]
epoch 8:  75%|#######5  | 3/4 [00:00<00:00,  7.05it/s]
epoch 8: 100%|##########| 4/4 [00:00<00:00,  8.23it/s]
RMSE: 10.526330425054946 at epoch 9

epoch 9:   0%|          | 0/4 [00:00<?, ?it/s]
epoch 9:  25%|##5       | 1/4 [00:00<00:00,  6.92it/s]
epoch 9:  50%|#####     | 2/4 [00:00<00:00,  7.00it/s]
epoch 9:  75%|#######5  | 3/4 [00:00<00:00,  7.03it/s]
epoch 9: 100%|##########| 4/4 [00:00<00:00,  8.20it/s]

Total running time of the script: ( 0 minutes 5.761 seconds)

Gallery generated by Sphinx-Gallery