Note
Go to the end to download the full example code
Train Your Own Neural Network Potential¶
This example shows how to use TorchANI to train a neural network potential with the setup identical to NeuroChem. We will use the same configuration as specified in inputtrain.ipt
Note
TorchANI provide tools to run NeuroChem training config file inputtrain.ipt. See: Train Neural Network Potential From NeuroChem Input File.
Warning
The training setup used in this file is configured to reproduce the original research at Less is more: Sampling chemical space with active learning as much as possible. That research was done on a different platform called NeuroChem which has many default options and technical details different from PyTorch. Some decisions made here (such as, using NeuroChem’s initialization instead of PyTorch’s default initialization) is not because it gives better result, but solely based on reproducing the original research. This file should not be interpreted as a suggestions to the readers on how they should setup their models.
To begin with, let’s first import the modules and setup devices we will use:
import torch
import torchani
import os
import math
import torch.utils.tensorboard
import tqdm
import pickle
# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Now let’s setup constants and construct an AEV computer. These numbers could
be found in rHCNO-5.2R_16-3.5A_a4-8.params
The atomic self energies given in sae_linfit.dat are computed from ANI-1x
dataset. These constants can be calculated for any given dataset if None
is provided as an argument to the object of EnergyShifter
class.
Note
Besides defining these hyperparameters programmatically,
torchani.neurochem
provide tools to read them from file.
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
species_order = ['H', 'C', 'N', 'O']
num_species = len(species_order)
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter(None)
Now let’s setup datasets. These paths assumes the user run this script under
the examples
directory of TorchANI’s repository. If you download this
script, you should manually set the path of these files in your system before
this script can run successfully.
Also note that we need to subtracting energies by the self energies of all
atoms for each molecule. This makes the range of energies in a reasonable
range. The second argument defines how to convert species as a list of string
to tensor, that is, for all supported chemical symbols, which is correspond to
0
, which correspond to 1
, etc.
try:
path = os.path.dirname(os.path.realpath(__file__))
except NameError:
path = os.getcwd()
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560
pickled_dataset_path = 'dataset.pkl'
# We pickle the dataset after loading to ensure we use the same validation set
# each time we restart training, otherwise we risk mixing the validation and
# training sets on each restart.
if os.path.isfile(pickled_dataset_path):
print(f'Unpickling preprocessed dataset found in {pickled_dataset_path}')
with open(pickled_dataset_path, 'rb') as f:
dataset = pickle.load(f)
training = dataset['training'].collate(batch_size).cache()
validation = dataset['validation'].collate(batch_size).cache()
energy_shifter.self_energies = dataset['self_energies'].to(device)
else:
print(f'Processing dataset in {dspath}')
training, validation = torchani.data.load(dspath)\
.subtract_self_energies(energy_shifter, species_order)\
.species_to_indices(species_order)\
.shuffle()\
.split(0.8, None)
with open(pickled_dataset_path, 'wb') as f:
pickle.dump({'training': training,
'validation': validation,
'self_energies': energy_shifter.self_energies.cpu()}, f)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies)
Processing dataset in /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
=> loading /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5, total molecules: 1
1/1 [==============================] - 0.0s
2/1 [============================================================] - 0.1s
3/1 [==========================================================================================] - 0.1s=> loading /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5, total molecules: 1
1/1 [==============================] - 0.0s
2/1 [============================================================] - 0.0s
3/1 [==========================================================================================] - 0.1sSelf atomic energies: tensor([-16.140, 24.082, -8.092, -44.095], dtype=torch.float64)
When iterating the dataset, we will get a dict of name->property mapping
aev_dim = aev_computer.aev_length
H_network = torch.nn.Sequential(
torch.nn.Linear(aev_dim, 160),
torch.nn.CELU(0.1),
torch.nn.Linear(160, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
C_network = torch.nn.Sequential(
torch.nn.Linear(aev_dim, 144),
torch.nn.CELU(0.1),
torch.nn.Linear(144, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
N_network = torch.nn.Sequential(
torch.nn.Linear(aev_dim, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
O_network = torch.nn.Sequential(
torch.nn.Linear(aev_dim, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)
ANIModel(
(0): Sequential(
(0): Linear(in_features=384, out_features=160, bias=True)
(1): CELU(alpha=0.1)
(2): Linear(in_features=160, out_features=128, bias=True)
(3): CELU(alpha=0.1)
(4): Linear(in_features=128, out_features=96, bias=True)
(5): CELU(alpha=0.1)
(6): Linear(in_features=96, out_features=1, bias=True)
)
(1): Sequential(
(0): Linear(in_features=384, out_features=144, bias=True)
(1): CELU(alpha=0.1)
(2): Linear(in_features=144, out_features=112, bias=True)
(3): CELU(alpha=0.1)
(4): Linear(in_features=112, out_features=96, bias=True)
(5): CELU(alpha=0.1)
(6): Linear(in_features=96, out_features=1, bias=True)
)
(2): Sequential(
(0): Linear(in_features=384, out_features=128, bias=True)
(1): CELU(alpha=0.1)
(2): Linear(in_features=128, out_features=112, bias=True)
(3): CELU(alpha=0.1)
(4): Linear(in_features=112, out_features=96, bias=True)
(5): CELU(alpha=0.1)
(6): Linear(in_features=96, out_features=1, bias=True)
)
(3): Sequential(
(0): Linear(in_features=384, out_features=128, bias=True)
(1): CELU(alpha=0.1)
(2): Linear(in_features=128, out_features=112, bias=True)
(3): CELU(alpha=0.1)
(4): Linear(in_features=112, out_features=96, bias=True)
(5): CELU(alpha=0.1)
(6): Linear(in_features=96, out_features=1, bias=True)
)
)
Initialize the weights and biases.
Note
Pytorch default initialization for the weights and biases in linear layers is Kaiming uniform. See: TORCH.NN.MODULES.LINEAR We initialize the weights similarly but from the normal distribution. The biases were initialized to zero.
def init_params(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight, a=1.0)
torch.nn.init.zeros_(m.bias)
nn.apply(init_params)
ANIModel(
(0): Sequential(
(0): Linear(in_features=384, out_features=160, bias=True)
(1): CELU(alpha=0.1)
(2): Linear(in_features=160, out_features=128, bias=True)
(3): CELU(alpha=0.1)
(4): Linear(in_features=128, out_features=96, bias=True)
(5): CELU(alpha=0.1)
(6): Linear(in_features=96, out_features=1, bias=True)
)
(1): Sequential(
(0): Linear(in_features=384, out_features=144, bias=True)
(1): CELU(alpha=0.1)
(2): Linear(in_features=144, out_features=112, bias=True)
(3): CELU(alpha=0.1)
(4): Linear(in_features=112, out_features=96, bias=True)
(5): CELU(alpha=0.1)
(6): Linear(in_features=96, out_features=1, bias=True)
)
(2): Sequential(
(0): Linear(in_features=384, out_features=128, bias=True)
(1): CELU(alpha=0.1)
(2): Linear(in_features=128, out_features=112, bias=True)
(3): CELU(alpha=0.1)
(4): Linear(in_features=112, out_features=96, bias=True)
(5): CELU(alpha=0.1)
(6): Linear(in_features=96, out_features=1, bias=True)
)
(3): Sequential(
(0): Linear(in_features=384, out_features=128, bias=True)
(1): CELU(alpha=0.1)
(2): Linear(in_features=128, out_features=112, bias=True)
(3): CELU(alpha=0.1)
(4): Linear(in_features=112, out_features=96, bias=True)
(5): CELU(alpha=0.1)
(6): Linear(in_features=96, out_features=1, bias=True)
)
)
Let’s now create a pipeline of AEV Computer –> Neural Networks.
model = torchani.nn.Sequential(aev_computer, nn).to(device)
Now let’s setup the optimizers. NeuroChem uses Adam with decoupled weight decay to updates the weights and Stochastic Gradient Descent (SGD) to update the biases. Moreover, we need to specify different weight decay rate for different layes.
Note
The weight decay in inputtrain.ipt is named “l2”, but it is actually not L2 regularization. The confusion between L2 and weight decay is a common mistake in deep learning. See: Decoupled Weight Decay Regularization Also note that the weight decay only applies to weight in the training of ANI models, not bias.
AdamW = torch.optim.AdamW([
# H networks
{'params': [H_network[0].weight]},
{'params': [H_network[2].weight], 'weight_decay': 0.00001},
{'params': [H_network[4].weight], 'weight_decay': 0.000001},
{'params': [H_network[6].weight]},
# C networks
{'params': [C_network[0].weight]},
{'params': [C_network[2].weight], 'weight_decay': 0.00001},
{'params': [C_network[4].weight], 'weight_decay': 0.000001},
{'params': [C_network[6].weight]},
# N networks
{'params': [N_network[0].weight]},
{'params': [N_network[2].weight], 'weight_decay': 0.00001},
{'params': [N_network[4].weight], 'weight_decay': 0.000001},
{'params': [N_network[6].weight]},
# O networks
{'params': [O_network[0].weight]},
{'params': [O_network[2].weight], 'weight_decay': 0.00001},
{'params': [O_network[4].weight], 'weight_decay': 0.000001},
{'params': [O_network[6].weight]},
])
SGD = torch.optim.SGD([
# H networks
{'params': [H_network[0].bias]},
{'params': [H_network[2].bias]},
{'params': [H_network[4].bias]},
{'params': [H_network[6].bias]},
# C networks
{'params': [C_network[0].bias]},
{'params': [C_network[2].bias]},
{'params': [C_network[4].bias]},
{'params': [C_network[6].bias]},
# N networks
{'params': [N_network[0].bias]},
{'params': [N_network[2].bias]},
{'params': [N_network[4].bias]},
{'params': [N_network[6].bias]},
# O networks
{'params': [O_network[0].bias]},
{'params': [O_network[2].bias]},
{'params': [O_network[4].bias]},
{'params': [O_network[6].bias]},
], lr=1e-3)
Setting up a learning rate scheduler to do learning rate decay
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
Train the model by minimizing the MSE loss, until validation RMSE no longer improves during a certain number of steps, decay the learning rate and repeat the same process, stop until the learning rate is smaller than a threshold.
We first read the checkpoint files to restart training. We use latest.pt to store current training state.
latest_checkpoint = 'latest.pt'
Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
checkpoint = torch.load(latest_checkpoint)
nn.load_state_dict(checkpoint['nn'])
AdamW.load_state_dict(checkpoint['AdamW'])
SGD.load_state_dict(checkpoint['SGD'])
AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
During training, we need to validate on validation set and if validation error is better than the best, then save the new best model to a checkpoint
def validate():
# run validation
mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0
count = 0
model.train(False)
with torch.no_grad():
for properties in validation:
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
true_energies = properties['energies'].to(device).float()
_, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count))
We will also use TensorBoard to visualize our training process
Finally, we come to the training loop.
In this tutorial, we are setting the maximum epoch to a very small number, only to make this demo terminate fast. For serious training, this should be set to a much larger value
mse = torch.nn.MSELoss(reduction='none')
print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
max_epochs = 10
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'
for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
rmse = validate()
print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
learning_rate = AdamW.param_groups[0]['lr']
if learning_rate < early_stopping_learning_rate:
break
# checkpoint
if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
torch.save(nn.state_dict(), best_model_checkpoint)
AdamW_scheduler.step(rmse)
SGD_scheduler.step(rmse)
tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
for i, properties in tqdm.tqdm(
enumerate(training),
total=len(training),
desc="epoch {}".format(AdamW_scheduler.last_epoch)
):
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
true_energies = properties['energies'].to(device).float()
num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
_, predicted_energies = model((species, coordinates))
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW.zero_grad()
SGD.zero_grad()
loss.backward()
AdamW.step()
SGD.step()
# write current batch loss to TensorBoard
tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
torch.save({
'nn': nn.state_dict(),
'AdamW': AdamW.state_dict(),
'SGD': SGD.state_dict(),
'AdamW_scheduler': AdamW_scheduler.state_dict(),
'SGD_scheduler': SGD_scheduler.state_dict(),
}, latest_checkpoint)
training starting from epoch 1
RMSE: 26.091515838518397 at epoch 1
epoch 1: 0%| | 0/4 [00:00<?, ?it/s]
epoch 1: 25%|██▌ | 1/4 [00:00<00:00, 5.32it/s]
epoch 1: 50%|█████ | 2/4 [00:00<00:00, 5.33it/s]
epoch 1: 75%|███████▌ | 3/4 [00:00<00:00, 5.49it/s]
epoch 1: 100%|██████████| 4/4 [00:00<00:00, 6.42it/s]
RMSE: 261.64168730312787 at epoch 2
epoch 2: 0%| | 0/4 [00:00<?, ?it/s]
epoch 2: 25%|██▌ | 1/4 [00:00<00:00, 5.78it/s]
epoch 2: 50%|█████ | 2/4 [00:00<00:00, 4.90it/s]
epoch 2: 75%|███████▌ | 3/4 [00:00<00:00, 5.14it/s]
epoch 2: 100%|██████████| 4/4 [00:00<00:00, 5.97it/s]
RMSE: 95.37320111166214 at epoch 3
epoch 3: 0%| | 0/4 [00:00<?, ?it/s]
epoch 3: 25%|██▌ | 1/4 [00:00<00:00, 5.86it/s]
epoch 3: 50%|█████ | 2/4 [00:00<00:00, 5.82it/s]
epoch 3: 75%|███████▌ | 3/4 [00:00<00:00, 5.82it/s]
epoch 3: 100%|██████████| 4/4 [00:00<00:00, 6.66it/s]
RMSE: 69.39060823018943 at epoch 4
epoch 4: 0%| | 0/4 [00:00<?, ?it/s]
epoch 4: 25%|██▌ | 1/4 [00:00<00:00, 5.64it/s]
epoch 4: 50%|█████ | 2/4 [00:00<00:00, 5.22it/s]
epoch 4: 75%|███████▌ | 3/4 [00:00<00:00, 5.39it/s]
epoch 4: 100%|██████████| 4/4 [00:00<00:00, 6.34it/s]
RMSE: 90.6584685355766 at epoch 5
epoch 5: 0%| | 0/4 [00:00<?, ?it/s]
epoch 5: 25%|██▌ | 1/4 [00:00<00:00, 5.62it/s]
epoch 5: 50%|█████ | 2/4 [00:00<00:00, 5.48it/s]
epoch 5: 75%|███████▌ | 3/4 [00:00<00:00, 5.27it/s]
epoch 5: 100%|██████████| 4/4 [00:00<00:00, 6.25it/s]
RMSE: 25.321216619653427 at epoch 6
epoch 6: 0%| | 0/4 [00:00<?, ?it/s]
epoch 6: 25%|██▌ | 1/4 [00:00<00:00, 5.41it/s]
epoch 6: 50%|█████ | 2/4 [00:00<00:00, 5.37it/s]
epoch 6: 75%|███████▌ | 3/4 [00:00<00:00, 5.56it/s]
epoch 6: 100%|██████████| 4/4 [00:00<00:00, 6.46it/s]
RMSE: 55.29533544083192 at epoch 7
epoch 7: 0%| | 0/4 [00:00<?, ?it/s]
epoch 7: 25%|██▌ | 1/4 [00:00<00:00, 5.47it/s]
epoch 7: 50%|█████ | 2/4 [00:00<00:00, 5.61it/s]
epoch 7: 75%|███████▌ | 3/4 [00:00<00:00, 5.67it/s]
epoch 7: 100%|██████████| 4/4 [00:00<00:00, 6.60it/s]
RMSE: 13.915169713989139 at epoch 8
epoch 8: 0%| | 0/4 [00:00<?, ?it/s]
epoch 8: 25%|██▌ | 1/4 [00:00<00:00, 5.78it/s]
epoch 8: 50%|█████ | 2/4 [00:00<00:00, 5.74it/s]
epoch 8: 75%|███████▌ | 3/4 [00:00<00:00, 5.61it/s]
epoch 8: 100%|██████████| 4/4 [00:00<00:00, 6.62it/s]
RMSE: 36.929820413521185 at epoch 9
epoch 9: 0%| | 0/4 [00:00<?, ?it/s]
epoch 9: 25%|██▌ | 1/4 [00:00<00:00, 5.75it/s]
epoch 9: 50%|█████ | 2/4 [00:00<00:00, 5.76it/s]
epoch 9: 75%|███████▌ | 3/4 [00:00<00:00, 5.68it/s]
epoch 9: 100%|██████████| 4/4 [00:00<00:00, 6.66it/s]
Total running time of the script: (0 minutes 7.397 seconds)