Train Neural Network Potential From NeuroChem Input File

This example shows how to use TorchANI’s NeuroChem trainer to read and run NeuroChem’s training config file to train a neural network potential.

To begin with, let’s first import the modules we will use:

import torchani
import torch
import os
import sys
import tqdm

Now let’s setup path for the dataset and NeuroChem input file. Note that these paths assumes the user run this script under the examples directory of TorchANI’s repository. If you download this script, you should manually set the path of these files in your system before this script can run successfully. Also note that here for our demo purpose, we set both training set and validation set the ani_gdb_s01.h5 in TorchANI’s repository. This allows this program to finish very quick, because that dataset is very small. But this is wrong and should be avoided for any serious training.

try:
    path = os.path.dirname(os.path.realpath(__file__))
except NameError:
    path = os.getcwd()
cfg_path = os.path.join(path, '../tests/test_data/inputtrain.ipt')
training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')  # noqa: E501
validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')  # noqa: E501

We also need to set the device to run the training:

device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_str)


trainer = torchani.neurochem.Trainer(cfg_path, device, True, 'runs')
trainer.load_data(training_path, validation_path)
=> loading /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5, total molecules: 1

1/1  [==============================] - 0.0s

2/1  [============================================================] - 0.0s
3/1  [==========================================================================================] - 0.1s=> loading /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5, total molecules: 1

1/1  [==============================] - 0.0s

2/1  [============================================================] - 0.0s
3/1  [==========================================================================================] - 0.1s

Once everything is set up, running NeuroChem is very easy. We simplify need a trainer.run(). But here, in order for sphinx-gallery to be able to capture the output of tqdm, let’s do some hacking first to make tqdm to print its progressbar to stdout.

def my_tqdm(*args, **kwargs):
    return tqdm.tqdm(*args, **kwargs, file=sys.stdout)


trainer.tqdm = my_tqdm

Now, let’s go!

trainer.run()
epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 1:  20%|██        | 1/5 [00:00<00:00,  5.42it/s]
epoch 1:  40%|████      | 2/5 [00:00<00:00,  5.45it/s]
epoch 1:  60%|██████    | 3/5 [00:00<00:00,  5.41it/s]
epoch 1:  80%|████████  | 4/5 [00:00<00:00,  5.09it/s]
epoch 1: 100%|██████████| 5/5 [00:00<00:00,  6.13it/s]

epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 2:  20%|██        | 1/5 [00:00<00:00,  5.61it/s]
epoch 2:  40%|████      | 2/5 [00:00<00:00,  5.67it/s]
epoch 2:  60%|██████    | 3/5 [00:00<00:00,  5.53it/s]
epoch 2:  80%|████████  | 4/5 [00:00<00:00,  5.53it/s]
epoch 2: 100%|██████████| 5/5 [00:00<00:00,  6.50it/s]

epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 3:  20%|██        | 1/5 [00:00<00:00,  5.46it/s]
epoch 3:  40%|████      | 2/5 [00:00<00:00,  5.28it/s]
epoch 3:  60%|██████    | 3/5 [00:00<00:00,  5.42it/s]
epoch 3:  80%|████████  | 4/5 [00:00<00:00,  5.36it/s]
epoch 3: 100%|██████████| 5/5 [00:00<00:00,  6.30it/s]

epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 4:  20%|██        | 1/5 [00:00<00:00,  5.09it/s]
epoch 4:  40%|████      | 2/5 [00:00<00:00,  5.22it/s]
epoch 4:  60%|██████    | 3/5 [00:00<00:00,  5.41it/s]
epoch 4:  80%|████████  | 4/5 [00:00<00:00,  5.47it/s]
epoch 4: 100%|██████████| 5/5 [00:00<00:00,  6.33it/s]

epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 5:  20%|██        | 1/5 [00:00<00:00,  5.45it/s]
epoch 5:  40%|████      | 2/5 [00:00<00:00,  5.37it/s]
epoch 5:  60%|██████    | 3/5 [00:00<00:00,  5.38it/s]
epoch 5:  80%|████████  | 4/5 [00:00<00:00,  5.50it/s]
epoch 5: 100%|██████████| 5/5 [00:00<00:00,  6.40it/s]

epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 6:  20%|██        | 1/5 [00:00<00:00,  4.01it/s]
epoch 6:  40%|████      | 2/5 [00:00<00:00,  4.53it/s]
epoch 6:  60%|██████    | 3/5 [00:00<00:00,  4.92it/s]
epoch 6:  80%|████████  | 4/5 [00:00<00:00,  5.18it/s]
epoch 6: 100%|██████████| 5/5 [00:00<00:00,  5.82it/s]

epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 7:  20%|██        | 1/5 [00:00<00:00,  5.57it/s]
epoch 7:  40%|████      | 2/5 [00:00<00:00,  5.62it/s]
epoch 7:  60%|██████    | 3/5 [00:00<00:00,  5.59it/s]
epoch 7:  80%|████████  | 4/5 [00:00<00:00,  5.51it/s]
epoch 7: 100%|██████████| 5/5 [00:00<00:00,  6.49it/s]

epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 8:  20%|██        | 1/5 [00:00<00:00,  5.42it/s]
epoch 8:  40%|████      | 2/5 [00:00<00:00,  5.41it/s]
epoch 8:  60%|██████    | 3/5 [00:00<00:00,  5.45it/s]
epoch 8:  80%|████████  | 4/5 [00:00<00:00,  5.44it/s]
epoch 8: 100%|██████████| 5/5 [00:00<00:00,  6.35it/s]

epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 9:  20%|██        | 1/5 [00:00<00:00,  5.52it/s]
epoch 9:  40%|████      | 2/5 [00:00<00:00,  5.52it/s]
epoch 9:  60%|██████    | 3/5 [00:00<00:00,  5.54it/s]
epoch 9:  80%|████████  | 4/5 [00:00<00:00,  5.57it/s]
epoch 9: 100%|██████████| 5/5 [00:00<00:00,  6.50it/s]

epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 10:  20%|██        | 1/5 [00:00<00:00,  5.55it/s]
epoch 10:  40%|████      | 2/5 [00:00<00:00,  5.45it/s]
epoch 10:  60%|██████    | 3/5 [00:00<00:00,  5.48it/s]
epoch 10:  80%|████████  | 4/5 [00:00<00:00,  5.57it/s]
epoch 10: 100%|██████████| 5/5 [00:00<00:00,  6.47it/s]

epoch 11:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 11:  20%|██        | 1/5 [00:00<00:00,  4.89it/s]
epoch 11:  40%|████      | 2/5 [00:00<00:00,  5.24it/s]
epoch 11:  60%|██████    | 3/5 [00:00<00:00,  5.36it/s]
epoch 11:  80%|████████  | 4/5 [00:00<00:00,  5.31it/s]
epoch 11: 100%|██████████| 5/5 [00:00<00:00,  6.20it/s]

Alternatively, you can run NeuroChem trainer directly using command line. There is no need for programming. Just run the following command for help python -m torchani.neurochem.trainer -h for usage. For this demo, the equivalent command is:

cmd = ['python', '-m', 'torchani.neurochem.trainer', '-d', device_str,
       '--tqdm', '--tensorboard', 'runs', cfg_path, training_path,
       validation_path]
print(' '.join(cmd))
python -m torchani.neurochem.trainer -d cpu --tqdm --tensorboard runs /home/runner/work/torchani/torchani/examples/../tests/test_data/inputtrain.ipt /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 /home/runner/work/torchani/torchani/examples/../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5

Now let’s invoke this command to see what we get. Again, we redirect stderr to stdout simplify for sphinx-gallery to be able to capture it when generating this document:

from subprocess import Popen, PIPE  # noqa: E402
print(Popen(cmd, stderr=PIPE).stderr.read().decode('utf-8'))
/opt/hostedtoolcache/Python/3.8.17/x64/lib/python3.8/site-packages/torchani-2.2.3-py3.8.egg/torchani/aev.py:16: UserWarning: cuaev not installed
  warnings.warn("cuaev not installed")

epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 1:  20%|██        | 1/5 [00:00<00:00,  5.47it/s]
epoch 1:  40%|████      | 2/5 [00:00<00:00,  5.43it/s]
epoch 1:  60%|██████    | 3/5 [00:00<00:00,  5.60it/s]
epoch 1:  80%|████████  | 4/5 [00:00<00:00,  5.61it/s]
epoch 1: 100%|██████████| 5/5 [00:00<00:00,  6.52it/s]

epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 2:  20%|██        | 1/5 [00:00<00:00,  5.76it/s]
epoch 2:  40%|████      | 2/5 [00:00<00:00,  5.56it/s]
epoch 2:  60%|██████    | 3/5 [00:00<00:00,  5.61it/s]
epoch 2:  80%|████████  | 4/5 [00:00<00:00,  5.66it/s]
epoch 2: 100%|██████████| 5/5 [00:00<00:00,  6.61it/s]

epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 3:  20%|██        | 1/5 [00:00<00:00,  5.18it/s]
epoch 3:  40%|████      | 2/5 [00:00<00:00,  5.46it/s]
epoch 3:  60%|██████    | 3/5 [00:00<00:00,  5.51it/s]
epoch 3:  80%|████████  | 4/5 [00:00<00:00,  5.43it/s]
epoch 3: 100%|██████████| 5/5 [00:00<00:00,  6.37it/s]

epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 4:  20%|██        | 1/5 [00:00<00:00,  5.13it/s]
epoch 4:  40%|████      | 2/5 [00:00<00:00,  5.35it/s]
epoch 4:  60%|██████    | 3/5 [00:00<00:00,  5.46it/s]
epoch 4:  80%|████████  | 4/5 [00:00<00:00,  5.49it/s]
epoch 4: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s]

epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 5:  20%|██        | 1/5 [00:00<00:00,  5.51it/s]
epoch 5:  40%|████      | 2/5 [00:00<00:00,  5.65it/s]
epoch 5:  60%|██████    | 3/5 [00:00<00:00,  5.62it/s]
epoch 5:  80%|████████  | 4/5 [00:00<00:00,  5.68it/s]
epoch 5: 100%|██████████| 5/5 [00:00<00:00,  6.61it/s]

epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 6:  20%|██        | 1/5 [00:00<00:00,  5.76it/s]
epoch 6:  40%|████      | 2/5 [00:00<00:00,  5.64it/s]
epoch 6:  60%|██████    | 3/5 [00:00<00:00,  5.23it/s]
epoch 6:  80%|████████  | 4/5 [00:00<00:00,  5.09it/s]
epoch 6: 100%|██████████| 5/5 [00:00<00:00,  6.15it/s]

epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 7:  20%|██        | 1/5 [00:00<00:00,  5.76it/s]
epoch 7:  40%|████      | 2/5 [00:00<00:00,  5.60it/s]
epoch 7:  60%|██████    | 3/5 [00:00<00:00,  5.69it/s]
epoch 7:  80%|████████  | 4/5 [00:00<00:00,  5.66it/s]
epoch 7: 100%|██████████| 5/5 [00:00<00:00,  6.63it/s]

epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 8:  20%|██        | 1/5 [00:00<00:00,  5.44it/s]
epoch 8:  40%|████      | 2/5 [00:00<00:00,  5.65it/s]
epoch 8:  60%|██████    | 3/5 [00:00<00:00,  5.67it/s]
epoch 8:  80%|████████  | 4/5 [00:00<00:00,  5.61it/s]
epoch 8: 100%|██████████| 5/5 [00:00<00:00,  6.57it/s]

epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 9:  20%|██        | 1/5 [00:00<00:00,  5.73it/s]
epoch 9:  40%|████      | 2/5 [00:00<00:00,  5.77it/s]
epoch 9:  60%|██████    | 3/5 [00:00<00:00,  5.73it/s]
epoch 9:  80%|████████  | 4/5 [00:00<00:00,  5.74it/s]
epoch 9: 100%|██████████| 5/5 [00:00<00:00,  6.59it/s]

epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]
epoch 10:  20%|██        | 1/5 [00:00<00:00,  5.62it/s]
epoch 10:  40%|████      | 2/5 [00:00<00:00,  5.59it/s]
epoch 10:  60%|██████    | 3/5 [00:00<00:00,  5.53it/s]
epoch 10:  80%|████████  | 4/5 [00:00<00:00,  5.41it/s]
epoch 10: 100%|██████████| 5/5 [00:00<00:00,  6.42it/s]

Total running time of the script: (0 minutes 28.653 seconds)

Gallery generated by Sphinx-Gallery