torchani.sae_estimation#
Functions to calculate self atomic energies (SAEs) via linear regression.
It is recommended to use GSAEs (Ground State Atomic Energies) for new models instead, so that models predict atomization energies.
Functions
Calculate SAEs of a dataset |
|
Calculate SAEs of a dataset in an approximate manner, using SGD |
- torchani.sae_estimation.exact_saes(dataset, symbols, fraction=1.0, fit_intercept=False, device=None)[source]#
Calculate SAEs of a dataset
Given a
torchani.datasets.BatchedDataset
class, this function calculates the associated SAEs.- Parameters:
dataset (BatchedDataset) – Batched dataset to use
symbols (Sequence[str]) – A
tuple
orlist
of strings that are valid chemical symbols. (case sensitive).fraction (float) – Fraction of the dataset to use.
fit_intercept (bool) – Whether to let the multilinear regression not go through zero.
device (device | str | int | None) – Device to use for tensors
- torchani.sae_estimation.approx_saes(dataset, symbols, fraction=1.0, fit_intercept=False, device=None, max_epochs=1, lr=0.01)[source]#
Calculate SAEs of a dataset in an approximate manner, using SGD
Given a
torchani.datasets.BatchedDataset
class, this function calculates the associated SAEs using stochastic gradient descent.- Parameters:
dataset (BatchedDataset) – Batched dataset to use
symbols (Sequence[str]) – A
tuple
orlist
of strings that are valid chemical symbols. (case sensitive).fraction (float) – Fraction of the dataset to use.
fit_intercept (bool) – Whether to let the multilinear regression not go through zero.
device (device | str | int | None) – Device to use for tensors
max_epochs (int) – Maximum number of epochs
lr (float) – Learning rate
verbose – Whether to print detailed info to
stdout
.