File: _spherical_harmonics_generator.py

package info (click to toggle)
python-e3nn 0.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,700 kB
  • sloc: python: 13,368; makefile: 23
file content (67 lines) | stat: -rw-r--r-- 2,165 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import sympy
from sympy.printing.pycode import pycode

from e3nn import o3


def _generate_spherical_harmonics(lmax, device=None) -> None:  # pragma: no cover
    r"""code used to generate the code above

    based on `wigner_3j`
    """
    torch.set_default_dtype(torch.float64)

    def to_frac(x: float):
        from fractions import Fraction

        s = 1 if x >= 0 else -1
        x = x**2
        x = Fraction(x).limit_denominator()
        x = s * sympy.sqrt(x)
        x = sympy.simplify(x)
        return x

    print("sh_0_0 = torch.ones_like(x)")
    print("if lmax == 0:")
    print("    return torch.stack([")
    print("        sh_0_0,")
    print("    ], dim=-1)")
    print()

    x_var, y_var, z_var = sympy.symbols("x y z")
    polynomials = [sympy.sqrt(3) * x_var, sympy.sqrt(3) * y_var, sympy.sqrt(3) * z_var]

    def sub_z1(p, names, polynormz):
        p = p.subs(x_var, 0).subs(y_var, 1).subs(z_var, 0)
        for n, c in zip(names, polynormz):
            p = p.subs(n, c)
        return p

    poly_evalz = [sub_z1(p, [], []) for p in polynomials]

    for l in range(1, lmax + 1):
        sh_variables = sympy.symbols(" ".join(f"sh_{l}_{m}" for m in range(2 * l + 1)))

        for n, p in zip(sh_variables, polynomials):
            print(f"{n} = {pycode(p)}")

        print(f"if lmax == {l}:")
        u = ",\n        ".join(", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1))
        print(f"    return torch.stack([\n        {u}\n    ], dim=-1)")
        print()

        if l == lmax:
            break

        polynomials = [
            sum(to_frac(c.item()) * v * sh for cj, v in zip(cij, [x_var, y_var, z_var]) for c, sh in zip(cj, sh_variables))
            for cij in o3.wigner_3j(l + 1, 1, l, device=device)
        ]

        poly_evalz = [sub_z1(p, sh_variables, poly_evalz) for p in polynomials]
        norm = sympy.sqrt(sum(p**2 for p in poly_evalz))
        polynomials = [sympy.sqrt(2 * l + 3) * p / norm for p in polynomials]
        poly_evalz = [sympy.sqrt(2 * l + 3) * p / norm for p in poly_evalz]

        polynomials = [sympy.simplify(p, full=True) for p in polynomials]