Note
Go to the end to download the full example code.
Training an ANI network using a custom script#
This example shows how to use TorchANI to train a neural network potential.
import math
from pathlib import Path
import torch
import torch.utils.tensorboard
from tqdm import tqdm
import torchani
from torchani.arch import ANI, simple_ani
from torchani.datasets import ANIDataset, ANIBatchedDataset, BatchedDataset
from torchani.units import hartree2kcalpermol
from torchani.grad import forces_for_training
Device and dataset to run the training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ds = ANIDataset("../dev-data/hf-data/dataset/ani-1x/sample.h5")
Verifying format correctness: 0it [00:00, ?it/s]
Verifying format correctness: 36it [00:00, 1726.19it/s]
We prebatch the dataset to train with memory efficiency, keeping a good performance.
batched_dataset_path = Path("./batched_dataset").resolve()
if not batched_dataset_path.exists():
torchani.datasets.create_batched_dataset(
ds,
dest_path=batched_dataset_path,
batch_size=2560,
splits={"training": 0.8, "validation": 0.2},
)
train_ds: BatchedDataset = ANIBatchedDataset(batched_dataset_path, split="training")
valid_ds: BatchedDataset = ANIBatchedDataset(batched_dataset_path, split="validation")
Dividing dataset in splits with fractions {'training': 0.8, 'validation': 0.2}
Divisions will have sizes:
training: 9253
validation: 2313
training: Collecting packet 1/1: 0%| | 0/6 [00:00<?, ?it/s]
training: Saving packet 1/1: 0%| | 0/4 [00:00<?, ?it/s]
validation: Collecting packet 1/1: 0%| | 0/6 [00:00<?, ?it/s]
validation: Saving packet 1/1: 0%| | 0/1 [00:00<?, ?it/s]
We use the pytorch DataLoader with multiprocessing to load the batches while we train
For more info about the DataLoader and multiprocessing read https://pytorch.org/docs/stable/data.html
CACHE saves all data in memory. It is very memory intensive but faster. Also, pin_memory is automatically performed by ANIBatchedDataset in the CACHE case, so it should be set to False for the DataLoader.
CACHE: bool = True
if CACHE:
train_ds = train_ds.cache()
valid_ds = valid_ds.cache()
training = train_ds.as_dataloader(num_workers=0)
validation = valid_ds.as_dataloader(num_workers=0)
Cacheing training, Warning: this may use a lot of RAM!: 0%| | 0/4 [00:00<?, ?it/s]
Cacheing validation, Warning: this may use a lot of RAM!: 0%| | 0/1 [00:00<?, ?it/s]
We can use the transforms module to modify the batches, the API for transforms is very similar to torchvision’s API with the difference that the transforms are applied to both target and inputs in all cases.
Transform can be passed to the “transform” argument of ANIBatchedDataset to to be performed on-the-fly on CPU (slow if no CACHE)
Transform can also be applied directly when training on GPU
Transform can also be applied to a dataset when batching it, by using the inplace_transform argument of create_batched_dataset (Be careful, this may be error prone)
In this case we wont apply any transform
Lets generate a model from scratch. For simplicity we use PyTorch’s default random initialization for the weights.
model = simple_ani(("H", "C", "N", "O"), lot="wb97x-631gd", repulsion=True)
Set up of optimizer and lr-scheduler
optimizer = torch.optim.AdamW(
params=model.neural_networks.parameters(),
lr=0.5e-3,
weight_decay=1e-6,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=0.5,
patience=100,
threshold=0,
)
We first read the checkpoint files to restart training. We use latest_traininig.pt
to store current training state.
latest_training_state_checkpoint_path = Path("./latest_training_state.pt").resolve()
best_model_state_checkpoint_path = Path("./best_model_state.pt").resolve()
if latest_training_state_checkpoint_path.exists():
checkpoint = torch.load(latest_training_state_checkpoint_path, weights_only=True)
model.load_state_dict(checkpoint["model"])
scheduler.load_state_dict(checkpoint["scheduler"])
optimizer.load_state_dict(checkpoint["optimizer"])
model.to(dtype=torch.float32, device=device)
ANI(
(neighborlist): AllPairs()
(energy_shifter): SelfEnergy()
(species_converter): SpeciesConverter()
(potentials): ModuleDict(
(repulsion_xtb): RepulsionXTB(
(cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
)
(nnp): NNPotential(
(aev_computer): AEVComputer(
# out_dim=384
# radial_len=64 (16.67% of feats)
# angular_len=320 (83.33% of feats)
num_species=4,
strategy=pyaev,
(radial): ANIRadial(
# num_feats=16
# num_shifts=16
eta=19.7000,
shifts=[0.9000, 1.1688, 1.4375, 1.7062, 1.9750, 2.2438, 2.5125, 2.7812, 3.0500, 3.3187, 3.5875, 3.8563, 4.1250, 4.3938, 4.6625, 4.9313],
cutoff=5.2000,
(cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
)
(angular): ANIAngular(
# num_feats=32
# num_shifts=8
# num_sections=4
eta=12.5000,
zeta=14.1000,
shifts=[0.9000, 1.2250, 1.5500, 1.8750, 2.2000, 2.5250, 2.8500, 3.1750],
sections=[0.3927, 1.1781, 1.9635, 2.7489],
cutoff=3.5000,
(cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
)
(neighborlist): AllPairs()
)
(neural_networks): ANINetworks(
(atomics): ModuleDict(
(H): AtomicNetwork(
layer_dims=(384, 256, 192, 160, 1),
activation=GELU(approximate='none'),
bias=False,
(layers): ModuleList(
(0): Linear(in_features=384, out_features=256, bias=False)
(1): Linear(in_features=256, out_features=192, bias=False)
(2): Linear(in_features=192, out_features=160, bias=False)
)
(final_layer): Linear(in_features=160, out_features=1, bias=False)
(activation): GELU(approximate='none')
)
(C): AtomicNetwork(
layer_dims=(384, 224, 192, 160, 1),
activation=GELU(approximate='none'),
bias=False,
(layers): ModuleList(
(0): Linear(in_features=384, out_features=224, bias=False)
(1): Linear(in_features=224, out_features=192, bias=False)
(2): Linear(in_features=192, out_features=160, bias=False)
)
(final_layer): Linear(in_features=160, out_features=1, bias=False)
(activation): GELU(approximate='none')
)
(N): AtomicNetwork(
layer_dims=(384, 192, 160, 128, 1),
activation=GELU(approximate='none'),
bias=False,
(layers): ModuleList(
(0): Linear(in_features=384, out_features=192, bias=False)
(1): Linear(in_features=192, out_features=160, bias=False)
(2): Linear(in_features=160, out_features=128, bias=False)
)
(final_layer): Linear(in_features=128, out_features=1, bias=False)
(activation): GELU(approximate='none')
)
(O): AtomicNetwork(
layer_dims=(384, 192, 160, 128, 1),
activation=GELU(approximate='none'),
bias=False,
(layers): ModuleList(
(0): Linear(in_features=384, out_features=192, bias=False)
(1): Linear(in_features=192, out_features=160, bias=False)
(2): Linear(in_features=160, out_features=128, bias=False)
)
(final_layer): Linear(in_features=128, out_features=1, bias=False)
(activation): GELU(approximate='none')
)
)
)
)
)
(_dummy): Potential()
)
During training, we need to validate on validation set and if validation error is better than the best, then save the new best model to a checkpoint
def validate(model: ANI, validation: torch.utils.data.DataLoader) -> float:
squared_error = 0.0
count = 0
model.train(False)
with torch.no_grad():
for properties in validation:
properties = {
k: v.to(device, non_blocking=True) for k, v in properties.items()
}
species = properties["species"]
coordinates = properties["coordinates"].float()
target_energies = properties["energies"].float()
output = model((species, coordinates))
predicted_energies = output.energies
squared_error += (predicted_energies - target_energies).pow(2).sum().item()
count += predicted_energies.shape[0]
model.train(True)
rmse = math.sqrt(squared_error / count)
return hartree2kcalpermol(rmse)
We will also use TensorBoard to visualize our training process
Criteria for stopping training
max_epochs = 5
min_learning_rate = 1.0e-10
Epoch 0 is right before training starts
if scheduler.last_epoch == 0:
rmse = validate(model, validation)
print(f"Before training starts: Validation RMSE (kcal/mol) {rmse}")
scheduler.step(rmse)
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
latest_training_state_checkpoint_path,
)
Before training starts: Validation RMSE (kcal/mol) 710.9953818821597
Finally, we come to the training loop.
mse = torch.nn.MSELoss(reduction="none")
force_training = False
force_coefficient = 0.1
for epoch in range(scheduler.last_epoch, max_epochs + 1):
# Stop training if the lr is below a given threshold
if optimizer.param_groups[0]["lr"] < min_learning_rate:
break
# Loop over batches
for batch in tqdm(
training,
total=len(training),
desc=f"Epoch {epoch}",
):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
species = batch["species"]
coordinates = batch["coordinates"].float()
target_energies = batch["energies"].float()
num_atoms = (species >= 0).sum(dim=1, dtype=target_energies.dtype)
output = model((species, coordinates))
predicted_energies = output.energies
if force_training:
target_forces = batch["forces"].float()
predicted_forces = forces_for_training(predicted_energies, coordinates)
energy_loss = (
mse(predicted_energies, target_energies) / num_atoms.sqrt()
).mean()
force_loss = (
mse(predicted_forces, target_forces).sum(dim=(1, 2)) / num_atoms
).mean()
loss = energy_loss + force_coefficient * force_loss
else:
loss = (mse(predicted_energies, target_energies) / num_atoms.sqrt()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Validate
rmse = validate(model, validation)
print(f"After epoch {epoch}: Validation RMSE (kcal/mol) {rmse}")
# Checkpoint the model if the RMSE; improved
if scheduler.is_better(rmse, scheduler.best):
torch.save(model.state_dict(), best_model_state_checkpoint_path)
# Step the epoch-scheduler
scheduler.step(rmse)
# Checkpoint the training state
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
latest_training_state_checkpoint_path,
)
# Log scalars
tensorboard.add_scalar("validation_rmse_kcalpermol", rmse, epoch)
tensorboard.add_scalar("best_validation_rmse_kcalpermol", scheduler.best, epoch)
tensorboard.add_scalar("learning_rate", optimizer.param_groups[0]["lr"], epoch)
tensorboard.add_scalar("epoch_loss_square_ha", loss, epoch)
Epoch 1: 0%| | 0/4 [00:00<?, ?it/s]
Epoch 1: 25%|██▌ | 1/4 [00:00<00:00, 6.52it/s]
Epoch 1: 50%|█████ | 2/4 [00:00<00:00, 6.91it/s]
Epoch 1: 100%|██████████| 4/4 [00:00<00:00, 7.90it/s]
Epoch 1: 100%|██████████| 4/4 [00:00<00:00, 7.64it/s]
After epoch 1: Validation RMSE (kcal/mol) 585.5371636219849
Epoch 2: 0%| | 0/4 [00:00<?, ?it/s]
Epoch 2: 25%|██▌ | 1/4 [00:00<00:00, 7.15it/s]
Epoch 2: 50%|█████ | 2/4 [00:00<00:00, 7.11it/s]
Epoch 2: 75%|███████▌ | 3/4 [00:00<00:00, 7.10it/s]
Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 7.82it/s]
After epoch 2: Validation RMSE (kcal/mol) 359.7906486781786
Epoch 3: 0%| | 0/4 [00:00<?, ?it/s]
Epoch 3: 25%|██▌ | 1/4 [00:00<00:00, 6.74it/s]
Epoch 3: 50%|█████ | 2/4 [00:00<00:00, 6.96it/s]
Epoch 3: 75%|███████▌ | 3/4 [00:00<00:00, 6.99it/s]
Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 7.69it/s]
After epoch 3: Validation RMSE (kcal/mol) 113.57901940424718
Epoch 4: 0%| | 0/4 [00:00<?, ?it/s]
Epoch 4: 25%|██▌ | 1/4 [00:00<00:00, 7.00it/s]
Epoch 4: 50%|█████ | 2/4 [00:00<00:00, 7.08it/s]
Epoch 4: 100%|██████████| 4/4 [00:00<00:00, 7.99it/s]
Epoch 4: 100%|██████████| 4/4 [00:00<00:00, 7.78it/s]
After epoch 4: Validation RMSE (kcal/mol) 269.32655357216015
Epoch 5: 0%| | 0/4 [00:00<?, ?it/s]
Epoch 5: 25%|██▌ | 1/4 [00:00<00:00, 7.14it/s]
Epoch 5: 50%|█████ | 2/4 [00:00<00:00, 7.07it/s]
Epoch 5: 75%|███████▌ | 3/4 [00:00<00:00, 7.12it/s]
Epoch 5: 100%|██████████| 4/4 [00:00<00:00, 7.90it/s]
After epoch 5: Validation RMSE (kcal/mol) 75.81136508301796