torchani.sae_estimation#

Functions to calculate self atomic energies (SAEs) via linear regression.

It is recommended to use GSAEs (Ground State Atomic Energies) for new models instead, so that models predict atomization energies.

Functions

exact_saes

Calculate SAEs of a dataset

approx_saes

Calculate SAEs of a dataset in an approximate manner, using SGD

torchani.sae_estimation.exact_saes(dataset, symbols, fraction=1.0, fit_intercept=False, device=None)[source]#

Calculate SAEs of a dataset

Given a torchani.datasets.BatchedDataset class, this function calculates the associated SAEs.

Parameters:
  • dataset (BatchedDataset) – Batched dataset to use

  • symbols (Sequence[str]) – A tuple or list of strings that are valid chemical symbols. (case sensitive).

  • fraction (float) – Fraction of the dataset to use.

  • fit_intercept (bool) – Whether to let the multilinear regression not go through zero.

  • device (device | str | int | None) – Device to use for tensors

torchani.sae_estimation.approx_saes(dataset, symbols, fraction=1.0, fit_intercept=False, device=None, max_epochs=1, lr=0.01)[source]#

Calculate SAEs of a dataset in an approximate manner, using SGD

Given a torchani.datasets.BatchedDataset class, this function calculates the associated SAEs using stochastic gradient descent.

Parameters:
  • dataset (BatchedDataset) – Batched dataset to use

  • symbols (Sequence[str]) – A tuple or list of strings that are valid chemical symbols. (case sensitive).

  • fraction (float) – Fraction of the dataset to use.

  • fit_intercept (bool) – Whether to let the multilinear regression not go through zero.

  • device (device | str | int | None) – Device to use for tensors

  • max_epochs (int) – Maximum number of epochs

  • lr (float) – Learning rate

  • verbose – Whether to print detailed info to stdout.