File: axpb.py

package info (click to toggle)
compyle 0.8.1-11
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,100 kB
  • sloc: python: 12,337; makefile: 21
file content (73 lines) | stat: -rw-r--r-- 1,724 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from compyle.api import Elementwise, annotate, wrap, get_config
import numpy as np
from numpy import sin
import time


@annotate(i='int', doublep='x, y, a, b')
def axpb(i, x, y, a, b):
    y[i] = a[i]*sin(x[i]) + b[i]


def setup(backend, openmp=False):
    get_config().use_openmp = openmp
    e = Elementwise(axpb, backend=backend)
    return e


def data(n, backend):
    x = np.linspace(0, 1, n)
    y = np.zeros_like(x)
    a = x*x
    b = np.sqrt(x + 1)
    return wrap(x, y, a, b, backend=backend)


def compare(m=20):
    N = 2**np.arange(1, 25)
    backends = [['cython', False], ['cython', True]]
    try:
        import pyopencl
        backends.append(['opencl', False])
    except ImportError as e:
        pass

    try:
        import pycuda
        backends.append(['cuda', False])
    except ImportError as e:
        pass

    timing = []
    for backend in backends:
        e = setup(*backend)
        times = []
        for n in N:
            args = data(n, backend[0])
            t = []
            for j in range(m):
                start = time.time()
                e(*args)
                secs = time.time() - start
                t.append(secs)
            times.append(np.average(t))
        timing.append(times)

    return N, backends, np.array(timing)


def plot_timing(n, timing, backends):
    from matplotlib import pyplot as plt
    backends[1][0] = 'openmp'
    for t, backend in zip(timing[1:], backends[1:]):
        plt.semilogx(n, timing[0]/t, label='serial/' + backend[0], marker='+')
    plt.grid()
    plt.xlabel('N')
    plt.ylabel('Speedup')
    plt.legend()
    plt.show()


if __name__ == '__main__':
    n, backends, times = compare()
    plot_timing(n, times, backends)