1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
|
import torch
from e3nn.o3._irreps import Irreps
from e3nn.o3._tensor_product._tensor_product import TensorProduct
from e3nn.util.jit import compile_mode
@compile_mode("trace")
class Norm(torch.nn.Module):
r"""Norm of each irrep in a direct sum of irreps.
Parameters
----------
irreps_in : `e3nn.o3.Irreps`
representation of the input
squared : bool, optional
Whether to return the squared norm. ``False`` by default, i.e. the norm itself (sqrt of squared norm) is returned.
Examples
--------
Compute the norms of 17 vectors.
>>> norm = Norm("17x1o")
>>> norm(torch.randn(17 * 3)).shape
torch.Size([17])
"""
squared: bool
def __init__(self, irreps_in, squared: bool = False) -> None:
super().__init__()
irreps_in = Irreps(irreps_in).simplify()
irreps_out = Irreps([(mul, "0e") for mul, _ in irreps_in])
instr = [(i, i, i, "uuu", False, ir.dim) for i, (mul, ir) in enumerate(irreps_in)]
self.tp = TensorProduct(irreps_in, irreps_in, irreps_out, instr, irrep_normalization="component")
self.irreps_in = irreps_in
self.irreps_out = irreps_out.simplify()
self.squared = squared
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.irreps_in})"
def forward(self, features):
"""Compute norms of irreps in ``features``.
Parameters
----------
features : `torch.Tensor`
tensor of shape ``(..., irreps_in.dim)``
Returns
-------
`torch.Tensor`
tensor of shape ``(..., irreps_out.dim)``
"""
out = self.tp(features, features)
if self.squared:
return out
else:
# ReLU fixes gradients at zero
return out.relu().sqrt()
|