File: benchmark.py

package info (click to toggle)
python-cython-blis 1.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 43,676 kB
  • sloc: ansic: 645,510; sh: 2,354; asm: 1,466; python: 821; cpp: 585; makefile: 14
file content (107 lines) | stat: -rw-r--r-- 2,695 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) 2017 - 2022 ExplosionAI GmbH, released under BSD-3-Clause.
import numpy
import numpy.random
from blis.py import gemm, einsum
from timeit import default_timer as timer

numpy.random.seed(0)


def create_data(nO, nI, batch_size):
    X = numpy.zeros((batch_size, nI), dtype="f")
    X += numpy.random.uniform(-1.0, 1.0, X.shape)
    W = numpy.zeros((nO, nI), dtype="f")
    W += numpy.random.uniform(-1.0, 1.0, W.shape)
    return X, W


def get_numpy_blas():
    blas_libs = numpy.__config__.blas_ilp64_opt_info["libraries"]
    return blas_libs[0]


def numpy_gemm(X, W, n=1000):
    nO, nI = W.shape
    batch_size = X.shape[0]
    total = 0.0
    y = numpy.zeros((batch_size, nO), dtype="f")
    for i in range(n):
        numpy.dot(X, W, out=y)
        total += y.sum()
        y.fill(0)
    print("Total:", total)


def blis_gemm(X, W, n=1000):
    nO, nI = W.shape
    batch_size = X.shape[0]
    total = 0.0
    y = numpy.zeros((batch_size, nO), dtype="f")
    for i in range(n):
        gemm(X, W, out=y)
        total += y.sum()
        y.fill(0.0)
    print("Total:", total)


def numpy_einsum(X, W, n=1000):
    nO, nI = W.shape
    batch_size = X.shape[0]
    total = 0.0
    y = numpy.zeros((nO, batch_size), dtype="f")
    for i in range(n):
        numpy.einsum("ab,cb->ca", X, W, out=y)
        total += y.sum()
        y.fill(0.0)
    print("Total:", total)


def blis_einsum(X, W, n=1000):
    nO, nI = W.shape
    batch_size = X.shape[0]
    total = 0.0
    y = numpy.zeros((nO, batch_size), dtype="f")
    for i in range(n):
        einsum("ab,cb->ca", X, W, out=y)
        total += y.sum()
        y.fill(0.0)
    print("Total:", total)


def main(nI=128 * 3, nO=128 * 3, batch_size=2000):
    print(
        "Setting up data for gemm. 1000 iters,  "
        "nO={nO} nI={nI} batch_size={batch_size}".format(**locals())
    )
    numpy_blas = get_numpy_blas()
    X1, W1 = create_data(nI, nO, batch_size)
    X2 = X1.copy()
    W2 = W1.copy()
    print("Blis gemm...")
    start = timer()
    blis_gemm(X2, W2, n=1000)
    end = timer()
    blis_time = end - start
    print("%.2f seconds" % blis_time)
    print("Numpy (%s) gemm..." % numpy_blas)
    start = timer()
    numpy_gemm(X1, W1)
    end = timer()
    numpy_time = end - start
    print("%.2f seconds" % numpy_time)
    print("Blis einsum ab,cb->ca")
    start = timer()
    blis_einsum(X2, W2, n=1000)
    end = timer()
    blis_time = end - start
    print("%.2f seconds" % blis_time)
    print("Numpy (%s) einsum ab,cb->ca" % numpy_blas)
    start = timer()
    numpy_einsum(X2, W2)
    end = timer()
    numpy_time = end - start
    print("%.2f seconds" % numpy_time)


if __name__:
    main()