Training an ANI network using a custom script#

This example shows how to use TorchANI to train a neural network potential.

import math
from pathlib import Path

import torch
import torch.utils.tensorboard
from tqdm import tqdm

import torchani
from torchani.arch import ANI, simple_ani
from torchani.datasets import ANIDataset, ANIBatchedDataset, BatchedDataset
from torchani.units import hartree2kcalpermol
from torchani.grad import forces_for_training

Device and dataset to run the training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ds = ANIDataset("../dev-data/hf-data/dataset/ani-1x/sample.h5")
Verifying format correctness: 0it [00:00, ?it/s]
Verifying format correctness: 36it [00:00, 1726.19it/s]

We prebatch the dataset to train with memory efficiency, keeping a good performance.

batched_dataset_path = Path("./batched_dataset").resolve()
if not batched_dataset_path.exists():
    torchani.datasets.create_batched_dataset(
        ds,
        dest_path=batched_dataset_path,
        batch_size=2560,
        splits={"training": 0.8, "validation": 0.2},
    )

train_ds: BatchedDataset = ANIBatchedDataset(batched_dataset_path, split="training")
valid_ds: BatchedDataset = ANIBatchedDataset(batched_dataset_path, split="validation")
Dividing dataset in splits with fractions {'training': 0.8, 'validation': 0.2}
Divisions will have sizes:
    training: 9253
    validation: 2313

training: Collecting packet 1/1:   0%|          | 0/6 [00:00<?, ?it/s]


training: Saving packet 1/1:   0%|          | 0/4 [00:00<?, ?it/s]


validation: Collecting packet 1/1:   0%|          | 0/6 [00:00<?, ?it/s]


validation: Saving packet 1/1:   0%|          | 0/1 [00:00<?, ?it/s]

We use the pytorch DataLoader with multiprocessing to load the batches while we train

For more info about the DataLoader and multiprocessing read https://pytorch.org/docs/stable/data.html

CACHE saves all data in memory. It is very memory intensive but faster. Also, pin_memory is automatically performed by ANIBatchedDataset in the CACHE case, so it should be set to False for the DataLoader.

CACHE: bool = True
if CACHE:
    train_ds = train_ds.cache()
    valid_ds = valid_ds.cache()

training = train_ds.as_dataloader(num_workers=0)
validation = valid_ds.as_dataloader(num_workers=0)
Cacheing training, Warning: this may use a lot of RAM!:   0%|          | 0/4 [00:00<?, ?it/s]


Cacheing validation, Warning: this may use a lot of RAM!:   0%|          | 0/1 [00:00<?, ?it/s]

We can use the transforms module to modify the batches, the API for transforms is very similar to torchvision’s API with the difference that the transforms are applied to both target and inputs in all cases.

Transform can be passed to the “transform” argument of ANIBatchedDataset to to be performed on-the-fly on CPU (slow if no CACHE)

Transform can also be applied directly when training on GPU

Transform can also be applied to a dataset when batching it, by using the inplace_transform argument of create_batched_dataset (Be careful, this may be error prone)

In this case we wont apply any transform

Lets generate a model from scratch. For simplicity we use PyTorch’s default random initialization for the weights.

model = simple_ani(("H", "C", "N", "O"), lot="wb97x-631gd", repulsion=True)

Set up of optimizer and lr-scheduler

optimizer = torch.optim.AdamW(
    params=model.neural_networks.parameters(),
    lr=0.5e-3,
    weight_decay=1e-6,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.5,
    patience=100,
    threshold=0,
)

We first read the checkpoint files to restart training. We use latest_traininig.pt to store current training state.

latest_training_state_checkpoint_path = Path("./latest_training_state.pt").resolve()
best_model_state_checkpoint_path = Path("./best_model_state.pt").resolve()
if latest_training_state_checkpoint_path.exists():
    checkpoint = torch.load(latest_training_state_checkpoint_path, weights_only=True)
    model.load_state_dict(checkpoint["model"])
    scheduler.load_state_dict(checkpoint["scheduler"])
    optimizer.load_state_dict(checkpoint["optimizer"])

model.to(dtype=torch.float32, device=device)
ANI(
  (neighborlist): AllPairs()
  (energy_shifter): SelfEnergy()
  (species_converter): SpeciesConverter()
  (potentials): ModuleDict(
    (repulsion_xtb): RepulsionXTB(
      (cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
    )
    (nnp): NNPotential(
      (aev_computer): AEVComputer(
        #  out_dim=384
        #  radial_len=64 (16.67% of feats)
        #  angular_len=320 (83.33% of feats)
        num_species=4,
        strategy=pyaev,
        (radial): ANIRadial(
          #  num_feats=16
          #  num_shifts=16
          eta=19.7000,
          shifts=[0.9000, 1.1688, 1.4375, 1.7062, 1.9750, 2.2438, 2.5125, 2.7812, 3.0500, 3.3187, 3.5875, 3.8563, 4.1250, 4.3938, 4.6625, 4.9313],
          cutoff=5.2000,
          (cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
        )
        (angular): ANIAngular(
          #  num_feats=32
          #  num_shifts=8
          #  num_sections=4
          eta=12.5000,
          zeta=14.1000,
          shifts=[0.9000, 1.2250, 1.5500, 1.8750, 2.2000, 2.5250, 2.8500, 3.1750],
          sections=[0.3927, 1.1781, 1.9635, 2.7489],
          cutoff=3.5000,
          (cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
        )
        (neighborlist): AllPairs()
      )
      (neural_networks): ANINetworks(
        (atomics): ModuleDict(
          (H): AtomicNetwork(
            layer_dims=(384, 256, 192, 160, 1),
            activation=GELU(approximate='none'),
            bias=False,
            (layers): ModuleList(
              (0): Linear(in_features=384, out_features=256, bias=False)
              (1): Linear(in_features=256, out_features=192, bias=False)
              (2): Linear(in_features=192, out_features=160, bias=False)
            )
            (final_layer): Linear(in_features=160, out_features=1, bias=False)
            (activation): GELU(approximate='none')
          )
          (C): AtomicNetwork(
            layer_dims=(384, 224, 192, 160, 1),
            activation=GELU(approximate='none'),
            bias=False,
            (layers): ModuleList(
              (0): Linear(in_features=384, out_features=224, bias=False)
              (1): Linear(in_features=224, out_features=192, bias=False)
              (2): Linear(in_features=192, out_features=160, bias=False)
            )
            (final_layer): Linear(in_features=160, out_features=1, bias=False)
            (activation): GELU(approximate='none')
          )
          (N): AtomicNetwork(
            layer_dims=(384, 192, 160, 128, 1),
            activation=GELU(approximate='none'),
            bias=False,
            (layers): ModuleList(
              (0): Linear(in_features=384, out_features=192, bias=False)
              (1): Linear(in_features=192, out_features=160, bias=False)
              (2): Linear(in_features=160, out_features=128, bias=False)
            )
            (final_layer): Linear(in_features=128, out_features=1, bias=False)
            (activation): GELU(approximate='none')
          )
          (O): AtomicNetwork(
            layer_dims=(384, 192, 160, 128, 1),
            activation=GELU(approximate='none'),
            bias=False,
            (layers): ModuleList(
              (0): Linear(in_features=384, out_features=192, bias=False)
              (1): Linear(in_features=192, out_features=160, bias=False)
              (2): Linear(in_features=160, out_features=128, bias=False)
            )
            (final_layer): Linear(in_features=128, out_features=1, bias=False)
            (activation): GELU(approximate='none')
          )
        )
      )
    )
  )
  (_dummy): Potential()
)

During training, we need to validate on validation set and if validation error is better than the best, then save the new best model to a checkpoint

def validate(model: ANI, validation: torch.utils.data.DataLoader) -> float:
    squared_error = 0.0
    count = 0
    model.train(False)
    with torch.no_grad():
        for properties in validation:
            properties = {
                k: v.to(device, non_blocking=True) for k, v in properties.items()
            }
            species = properties["species"]
            coordinates = properties["coordinates"].float()
            target_energies = properties["energies"].float()
            output = model((species, coordinates))
            predicted_energies = output.energies
            squared_error += (predicted_energies - target_energies).pow(2).sum().item()
            count += predicted_energies.shape[0]
    model.train(True)
    rmse = math.sqrt(squared_error / count)
    return hartree2kcalpermol(rmse)

We will also use TensorBoard to visualize our training process

Criteria for stopping training

Epoch 0 is right before training starts

if scheduler.last_epoch == 0:
    rmse = validate(model, validation)
    print(f"Before training starts: Validation RMSE (kcal/mol) {rmse}")
    scheduler.step(rmse)
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        },
        latest_training_state_checkpoint_path,
    )
Before training starts: Validation RMSE (kcal/mol) 710.9953818821597

Finally, we come to the training loop.

mse = torch.nn.MSELoss(reduction="none")
force_training = False
force_coefficient = 0.1
for epoch in range(scheduler.last_epoch, max_epochs + 1):
    # Stop training if the lr is below a given threshold
    if optimizer.param_groups[0]["lr"] < min_learning_rate:
        break

    # Loop over batches
    for batch in tqdm(
        training,
        total=len(training),
        desc=f"Epoch {epoch}",
    ):
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        species = batch["species"]
        coordinates = batch["coordinates"].float()
        target_energies = batch["energies"].float()
        num_atoms = (species >= 0).sum(dim=1, dtype=target_energies.dtype)
        output = model((species, coordinates))
        predicted_energies = output.energies
        if force_training:
            target_forces = batch["forces"].float()
            predicted_forces = forces_for_training(predicted_energies, coordinates)
            energy_loss = (
                mse(predicted_energies, target_energies) / num_atoms.sqrt()
            ).mean()
            force_loss = (
                mse(predicted_forces, target_forces).sum(dim=(1, 2)) / num_atoms
            ).mean()
            loss = energy_loss + force_coefficient * force_loss
        else:
            loss = (mse(predicted_energies, target_energies) / num_atoms.sqrt()).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validate
    rmse = validate(model, validation)
    print(f"After epoch {epoch}: Validation RMSE (kcal/mol) {rmse}")

    # Checkpoint the model if the RMSE; improved
    if scheduler.is_better(rmse, scheduler.best):
        torch.save(model.state_dict(), best_model_state_checkpoint_path)

    # Step the epoch-scheduler
    scheduler.step(rmse)

    # Checkpoint the training state
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        },
        latest_training_state_checkpoint_path,
    )

    # Log scalars
    tensorboard.add_scalar("validation_rmse_kcalpermol", rmse, epoch)
    tensorboard.add_scalar("best_validation_rmse_kcalpermol", scheduler.best, epoch)
    tensorboard.add_scalar("learning_rate", optimizer.param_groups[0]["lr"], epoch)
    tensorboard.add_scalar("epoch_loss_square_ha", loss, epoch)
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 1:  25%|██▌       | 1/4 [00:00<00:00,  6.52it/s]
Epoch 1:  50%|█████     | 2/4 [00:00<00:00,  6.91it/s]
Epoch 1: 100%|██████████| 4/4 [00:00<00:00,  7.90it/s]
Epoch 1: 100%|██████████| 4/4 [00:00<00:00,  7.64it/s]
After epoch 1: Validation RMSE (kcal/mol) 585.5371636219849

Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 2:  25%|██▌       | 1/4 [00:00<00:00,  7.15it/s]
Epoch 2:  50%|█████     | 2/4 [00:00<00:00,  7.11it/s]
Epoch 2:  75%|███████▌  | 3/4 [00:00<00:00,  7.10it/s]
Epoch 2: 100%|██████████| 4/4 [00:00<00:00,  7.82it/s]
After epoch 2: Validation RMSE (kcal/mol) 359.7906486781786

Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 3:  25%|██▌       | 1/4 [00:00<00:00,  6.74it/s]
Epoch 3:  50%|█████     | 2/4 [00:00<00:00,  6.96it/s]
Epoch 3:  75%|███████▌  | 3/4 [00:00<00:00,  6.99it/s]
Epoch 3: 100%|██████████| 4/4 [00:00<00:00,  7.69it/s]
After epoch 3: Validation RMSE (kcal/mol) 113.57901940424718

Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 4:  25%|██▌       | 1/4 [00:00<00:00,  7.00it/s]
Epoch 4:  50%|█████     | 2/4 [00:00<00:00,  7.08it/s]
Epoch 4: 100%|██████████| 4/4 [00:00<00:00,  7.99it/s]
Epoch 4: 100%|██████████| 4/4 [00:00<00:00,  7.78it/s]
After epoch 4: Validation RMSE (kcal/mol) 269.32655357216015

Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 5:  25%|██▌       | 1/4 [00:00<00:00,  7.14it/s]
Epoch 5:  50%|█████     | 2/4 [00:00<00:00,  7.07it/s]
Epoch 5:  75%|███████▌  | 3/4 [00:00<00:00,  7.12it/s]
Epoch 5: 100%|██████████| 4/4 [00:00<00:00,  7.90it/s]
After epoch 5: Validation RMSE (kcal/mol) 75.81136508301796