File: rho_j_q_numba.py

package info (click to toggle)
python-dynasor 2.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 22,008 kB
  • sloc: python: 5,263; sh: 20; makefile: 3
file content (127 lines) | stat: -rw-r--r-- 3,703 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""This module replaces the original c implementation of the reciprocal
densities and currents in dynasor with numba.

Numba is as of 2023 an ongoing project to create a JIT compiler frontend for
python code using the LLVM project as backend. Due to current limitations and
quirks of numba the code is not always straightforward. Typically the code
needs to be refactored in a trial and error process to get the expected
performance but should in the end be on the level of c.

Especially, numba makes very pessimistic assumptions about aliasing but this is
expected to change in the future. Also, in theory, via the llvm-lite interface
compilation flags should be passable to LLVM.
"""

import numpy as np
import numba


# This is often faster than calling np.dot for small arrays
# Calling this instead of manually inlining it actually incurs a small
# performance hit (<10%) with current numba (2023). It increases readability
# though and will probably sort itself out with later numba versions
@numba.njit(fastmath=True, nogil=True)
def dot(a, b):
    return a[0]*b[0] + a[1]*b[1] + a[2]*b[2]


# fastmath True makes the summation fast and also speeds up exponentiation
# nogil releases the python GIL, probably not neccesary here
@numba.njit(fastmath=True, nogil=True)
def rho_q_single(x: np.ndarray,
                 q: np.ndarray) -> complex:
    """Calculates the density at a single q-point

    Parameters
    ----------
    x
        positions as a (N, 3) array
    q
        single q point as a with shape (3,)

    Returns
    -------
    rho
        complex density at the specified q-point
    """
    Nx = len(x)

    assert x.shape == (Nx, 3)
    assert q.shape == (3,)

    rho = 0.0j
    for i in range(Nx):
        alpha = dot(x[i], q)
        rho += np.exp(1j * alpha)  # very expensive operation
    return rho


# parallel enables the numba.prange directive
@numba.njit(fastmath=True, nogil=True, parallel=True)
def rho_q(x: np.ndarray, q: np.ndarray, rho: np.ndarray):
    """Calculates the fourier transformed density

    The parallelization is over q-points. The density is calculated in-place.

    Parameters
    ----------
    x
        the positions as a float array with shape (``Nx``, 3)
    q
        the q points as a float array with shape (``Nq``, 3)
    rho
        density as a complex array of length ``Nq``
    """

    Nx = len(x)
    Nq = len(q)

    assert x.shape == (Nx, 3)
    assert q.shape == (Nq, 3)
    assert rho.shape == (Nq,)

    # Numba prange is like OMP
    for i in numba.prange(Nq):
        rho[i] = rho_q_single(x, q[i])


@numba.njit(fastmath=True, parallel=True, nogil=True)
def rho_j_q(x: np.ndarray, v: np.ndarray, q: np.ndarray,
            rho: np.ndarray, j_q: np.ndarray):
    """Calculates the fourier transformed density and current.

    The output is stored in the supplied output arrays ``rho`` and ``j_q``

    Parameters
    ----------
    x
        the positions as a float array with shape (``Nx``, 3)
    v
        the velocities as a float array with shape (``Nx``, 3)
    q
        the q points as a float array with shape (``Nq``, 3)
    rho
        density as a complex array of length ``Nq``
    j_q
        current as a complex array with shape (``Nq``, 3)
    """

    Nx = len(x)
    Nq = len(q)

    assert x.shape == (Nx, 3)
    assert v.shape == (Nx, 3)
    assert q.shape == (Nq, 3)
    assert rho.shape == (Nq,)
    assert j_q.shape == (Nq, 3)

    for qi in numba.prange(Nq):
        for xi in range(Nx):

            alpha = dot(x[xi], q[qi])
            exp_ialpha = np.exp(1.0j * alpha)

            rho[qi] += exp_ialpha

            for i in range(3):
                j_q[qi, i] += exp_ialpha * v[xi][i]