File: bench_vm.py

package info (click to toggle)
compyle 0.9.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 972 kB
  • sloc: python: 12,853; makefile: 21
file content (82 lines) | stat: -rw-r--r-- 2,404 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import time

from compyle.config import get_config
import vm_numba as VN
import vm_elementwise as VE
import vm_kernel as VK


def setup(mod, backend, openmp):
    get_config().use_openmp = openmp
    if mod == VE:
        e = VE.Elementwise(VE.velocity, backend)
    elif mod == VN:
        e = VN.velocity
    elif mod == VK:
        e = VK.Kernel(VK.velocity, backend)

    return e


def data(n, mod, backend):
    if mod == VN:
        args = mod.make_vortices(n)
    else:
        args = mod.make_vortices(n, backend)
    return args


def compare(m=5):
    # Warm up the jit to prevent the timing from going off for the first point.
    VN.velocity(*VN.make_vortices(100))
    N = np.array([10, 50, 100, 200, 500, 1000, 2000, 4000, 6000,
                  8000, 10000, 15000, 20000])
    backends = [(VN, '', False), (VE, 'cython', False), (VE, 'cython', True),
                (VE, 'opencl', False), (VK, 'opencl', False)]
    timing = []
    for backend in backends:
        e = setup(*backend)
        times = []
        for n in N:
            args = data(n, backend[0], backend[1])
            t = []
            for j in range(m):
                start = time.time()
                e(*args)
                t.append(time.time() - start)
            times.append(np.min(t))
        timing.append(times)

    return N, np.array(timing)


def plot_timing(n, timing):
    from matplotlib import pyplot as plt
    plt.plot(n, timing[0]/timing[1], label='numba/cython', marker='+')
    plt.plot(n, timing[0]/timing[2], label='numba/openmp', marker='+')
    plt.plot(n, timing[0]/timing[3], label='numba/opencl', marker='+')
    plt.plot(n, timing[0]/timing[4], label='numba/opencl local', marker='+')
    plt.grid()
    plt.xlabel('N')
    plt.ylabel('Speedup')
    plt.legend()
    plt.figure()
    gflop = 12*n*n/1e9
    plt.plot(n, gflop/timing[0], label='numba', marker='+')
    plt.plot(n, gflop/timing[1], label='Cython', marker='+')
    plt.plot(n, gflop/timing[2], label='OpenMP', marker='+')
    plt.plot(n, gflop/timing[3], label='OpenCL', marker='+')
    plt.plot(n, gflop/timing[4], label='OpenCL Local', marker='+')
    plt.grid()
    plt.xlabel('N')
    plt.ylabel('GFLOPS')
    plt.legend()
    plt.show()
    best = timing[:, -1].min()
    print("Fastest time for n=", n[-1], best, "secs")


if __name__ == '__main__':
    n, t = compare()
    plot_timing(n, t)