Training an ANI network using a custom script#

This example shows how to use TorchANI to train a neural network potential.

import math
from pathlib import Path

import torch
import torch.utils.tensorboard
from tqdm import tqdm

import torchani
from torchani.arch import ANI, simple_ani
from torchani.datasets import ANIDataset, ANIBatchedDataset, BatchedDataset
from torchani.units import hartree2kcalpermol
from torchani.grad import forces_for_training

Device and dataset to run the training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ds = ANIDataset("../dataset/ani-1x/sample.h5")
Verifying format correctness: 0it [00:00, ?it/s]
Verifying format correctness: 36it [00:00, 3180.58it/s]

We prebatch the dataset to train with memory efficiency, keeping a good performance.

batched_dataset_path = Path("./batched_dataset").resolve()
if not batched_dataset_path.exists():
    torchani.datasets.create_batched_dataset(
        ds,
        dest_path=batched_dataset_path,
        batch_size=2560,
        splits={"training": 0.8, "validation": 0.2},
    )

train_ds: BatchedDataset = ANIBatchedDataset(batched_dataset_path, split="training")
valid_ds: BatchedDataset = ANIBatchedDataset(batched_dataset_path, split="validation")

We use the pytorch DataLoader with multiprocessing to load the batches while we train

For more info about the DataLoader and multiprocessing read https://pytorch.org/docs/stable/data.html

CACHE saves all data in memory. It is very memory intensive but faster. Also, pin_memory is automatically performed by ANIBatchedDataset in the CACHE case, so it should be set to False for the DataLoader.

CACHE: bool = True
if CACHE:
    train_ds = train_ds.cache()
    valid_ds = valid_ds.cache()

training = train_ds.as_dataloader(num_workers=0)
validation = valid_ds.as_dataloader(num_workers=0)
Cacheing training, Warning: this may use a lot of RAM!:   0%|                                                                                                          | 0/4 [00:00<?, ?it/s]

Pinning memory ...

Cacheing validation, Warning: this may use a lot of RAM!:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]

Pinning memory ...

We can use the transforms module to modify the batches, the API for transforms is very similar to torchvision’s API with the difference that the transforms are applied to both target and inputs in all cases.

Transform can be passed to the “transform” argument of ANIBatchedDataset to to be performed on-the-fly on CPU (slow if no CACHE)

Transform can also be applied directly when training on GPU

Transform can also be applied to a dataset when batching it, by using the inplace_transform argument of create_batched_dataset (Be careful, this may be error prone)

In this case we wont apply any transform

Lets generate a model from scratch. For simplicity we use PyTorch’s default random initialization for the weights.

model = simple_ani(("H", "C", "N", "O"), lot="wb97x-631gd", repulsion=True)

Set up of optimizer and lr-scheduler

optimizer = torch.optim.AdamW(
    params=model.neural_networks.parameters(),
    lr=0.5e-3,
    weight_decay=1e-6,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.5,
    patience=100,
    threshold=0,
)

We first read the checkpoint files to restart training. We use latest_traininig.pt to store current training state.

/home/ipickering/Repos/ani/examples/training.py:93: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(latest_training_state_checkpoint_path)

ANI(
  (neighborlist): AllPairs()
  (energy_shifter): SelfEnergy()
  (species_converter): SpeciesConverter()
  (potentials): ModuleDict(
    (repulsion_xtb): RepulsionXTB(
      (cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
    )
    (nnp): NNPotential(
      (aev_computer): AEVComputer(
        #  out_dim=384
        #  radial_len=64 (16.67% of feats)
        #  angular_len=320 (83.33% of feats)
        num_species=4,
        strategy=pyaev,
        (radial): ANIRadial(
          #  num_feats=16
          #  num_shifts=16
          eta=19.7000,
          shifts=[0.9000, 1.1688, 1.4375, 1.7062, 1.9750, 2.2438, 2.5125, 2.7812, 3.0500, 3.3187, 3.5875, 3.8563, 4.1250, 4.3938, 4.6625, 4.9313],
          cutoff=5.2000,
          (cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
        )
        (angular): ANIAngular(
          #  num_feats=32
          #  num_shifts=8
          #  num_sections=4
          eta=12.5000,
          zeta=14.1000,
          shifts=[0.9000, 1.2250, 1.5500, 1.8750, 2.2000, 2.5250, 2.8500, 3.1750],
          sections=[0.3927, 1.1781, 1.9635, 2.7489],
          cutoff=3.5000,
          (cutoff_fn): CutoffSmooth(order=2, eps=1.0e-10)
        )
        (neighborlist): AllPairs()
      )
      (neural_networks): ANINetworks(
        (atomics): ModuleDict(
          (H): AtomicNetwork(
            layer_dims=(384, 256, 192, 160, 1),
            activation=GELU(approximate='none'),
            bias=False,
            (layers): ModuleList(
              (0): Linear(in_features=384, out_features=256, bias=False)
              (1): Linear(in_features=256, out_features=192, bias=False)
              (2): Linear(in_features=192, out_features=160, bias=False)
            )
            (final_layer): Linear(in_features=160, out_features=1, bias=False)
            (activation): GELU(approximate='none')
          )
          (C): AtomicNetwork(
            layer_dims=(384, 224, 192, 160, 1),
            activation=GELU(approximate='none'),
            bias=False,
            (layers): ModuleList(
              (0): Linear(in_features=384, out_features=224, bias=False)
              (1): Linear(in_features=224, out_features=192, bias=False)
              (2): Linear(in_features=192, out_features=160, bias=False)
            )
            (final_layer): Linear(in_features=160, out_features=1, bias=False)
            (activation): GELU(approximate='none')
          )
          (N): AtomicNetwork(
            layer_dims=(384, 192, 160, 128, 1),
            activation=GELU(approximate='none'),
            bias=False,
            (layers): ModuleList(
              (0): Linear(in_features=384, out_features=192, bias=False)
              (1): Linear(in_features=192, out_features=160, bias=False)
              (2): Linear(in_features=160, out_features=128, bias=False)
            )
            (final_layer): Linear(in_features=128, out_features=1, bias=False)
            (activation): GELU(approximate='none')
          )
          (O): AtomicNetwork(
            layer_dims=(384, 192, 160, 128, 1),
            activation=GELU(approximate='none'),
            bias=False,
            (layers): ModuleList(
              (0): Linear(in_features=384, out_features=192, bias=False)
              (1): Linear(in_features=192, out_features=160, bias=False)
              (2): Linear(in_features=160, out_features=128, bias=False)
            )
            (final_layer): Linear(in_features=128, out_features=1, bias=False)
            (activation): GELU(approximate='none')
          )
        )
      )
    )
  )
)

During training, we need to validate on validation set and if validation error is better than the best, then save the new best model to a checkpoint

def validate(model: ANI, validation: torch.utils.data.DataLoader) -> float:
    squared_error = 0.0
    count = 0
    model.train(False)
    with torch.no_grad():
        for properties in validation:
            properties = {
                k: v.to(device, non_blocking=True) for k, v in properties.items()
            }
            species = properties["species"]
            coordinates = properties["coordinates"].float()
            target_energies = properties["energies"].float()
            output = model((species, coordinates))
            predicted_energies = output.energies
            squared_error += (predicted_energies - target_energies).pow(2).sum().item()
            count += predicted_energies.shape[0]
    model.train(True)
    rmse = math.sqrt(squared_error / count)
    return hartree2kcalpermol(rmse)

We will also use TensorBoard to visualize our training process

Criteria for stopping training

Epoch 0 is right before training starts

if scheduler.last_epoch == 0:
    rmse = validate(model, validation)
    print(f"Before training starts: Validation RMSE (kcal/mol) {rmse}")
    scheduler.step(rmse)
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        },
        latest_training_state_checkpoint_path,
    )

Finally, we come to the training loop.

mse = torch.nn.MSELoss(reduction="none")
force_training = False
force_coefficient = 0.1
for epoch in range(scheduler.last_epoch, max_epochs + 1):
    # Stop training if the lr is below a given threshold
    if optimizer.param_groups[0]["lr"] < min_learning_rate:
        break

    # Loop over batches
    for batch in tqdm(
        training,
        total=len(training),
        desc=f"Epoch {epoch}",
    ):
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        species = batch["species"]
        coordinates = batch["coordinates"].float()
        target_energies = batch["energies"].float()
        num_atoms = (species >= 0).sum(dim=1, dtype=target_energies.dtype)
        output = model((species, coordinates))
        predicted_energies = output.energies
        if force_training:
            target_forces = batch["forces"].float()
            predicted_forces = forces_for_training(predicted_energies, coordinates)
            energy_loss = (
                mse(predicted_energies, target_energies) / num_atoms.sqrt()
            ).mean()
            force_loss = (
                mse(predicted_forces, target_forces).sum(dim=(1, 2)) / num_atoms
            ).mean()
            loss = energy_loss + force_coefficient * force_loss
        else:
            loss = (mse(predicted_energies, target_energies) / num_atoms.sqrt()).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validate
    rmse = validate(model, validation)
    print(f"After epoch {epoch}: Validation RMSE (kcal/mol) {rmse}")

    # Checkpoint the model if the RMSE; improved
    if scheduler.is_better(rmse, scheduler.best):
        torch.save(model.state_dict(), best_model_state_checkpoint_path)

    # Step the epoch-scheduler
    scheduler.step(rmse)

    # Checkpoint the training state
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        },
        latest_training_state_checkpoint_path,
    )

    # Log scalars
    tensorboard.add_scalar("validation_rmse_kcalpermol", rmse, epoch)
    tensorboard.add_scalar("best_validation_rmse_kcalpermol", scheduler.best, epoch)
    tensorboard.add_scalar("learning_rate", optimizer.param_groups[0]["lr"], epoch)
    tensorboard.add_scalar("epoch_loss_square_ha", loss, epoch)