Note
Click here to download the full example code
SO(3) Lie Group Operations¶
The module nnp.so3
contains tools to rotate point clouds in 3D space.
Let’s first import all the packages we will use:
import torch
from scipy.linalg import expm as scipy_expm
from torch import Tensor
The following function implements rotation along an axis passing the origin. Rotation in 3D is not as trivial as in 2D. Here we start from the equation of motion. Let \(\vec{n}\) be the unit vector pointing to direction of the axis, then for an infinitesimal rotation, we have
That is
where \(\epsilon_{ijk}\) is the Levi-Civita symbol, let \(W_{ik}=\epsilon_{ijk} n_j\), then the above equation becomes a matrix equation:
It is not hard to see that \(W\) is a skew-symmetric matrix. From the above equation and the knowledge of linear algebra, matrix Lie algebra/group, it is not hard to see that the set of all rotation operations along the axis \(\vec{n}\) is a one parameter Lie group. And the skew-symmetric matrices together with standard matrix commutator is a Lie algebra.This Lie group and Lie algebra is connected by the exponential map. See Wikipedia Exponential map (Lie theory) for more detail.
So it is easy to tell that:
where \(\vec{r}\left(0\right)\) is the initial coordinates, and \(\vec{r}\left(\theta\right)\) is the final coordinates after rotating \(\theta\).
To implement, let’s first define the Levi-Civita symbol:
levi_civita = torch.zeros(3, 3, 3)
levi_civita[0, 1, 2] = levi_civita[1, 2, 0] = levi_civita[2, 0, 1] = 1
levi_civita[0, 2, 1] = levi_civita[2, 1, 0] = levi_civita[1, 0, 2] = -1
PyTorch does not have matrix exp, let’s implement it here using scipy
def expm(matrix: Tensor) -> Tensor:
# TODO: remove this part when pytorch support matrix_exp
ndarray = matrix.detach().cpu().numpy()
return torch.from_numpy(scipy_expm(ndarray)).to(matrix)
Now we are ready to implement the \(\exp \left(\theta W\right)\)
def rotate_along(axis: Tensor) -> Tensor:
r"""Compute group elements of rotating along an axis passing origin.
Arguments:
axis: a vector (x, y, z) whose direction specifies the axis of the rotation,
length specifies the radius to rotate, and sign specifies clockwise
or anti-clockwise.
Return:
the rotational matrix :math:`\exp{\left(\theta W\right)}`.
"""
W = torch.einsum('ijk,j->ik', levi_civita.to(axis), axis)
return expm(W)
Total running time of the script: ( 0 minutes 0.002 seconds)