1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
# Copyright (c) 2017 - 2022 ExplosionAI GmbH, released under BSD-3-Clause.
import numpy
import numpy.random
from blis.py import gemm, einsum
from timeit import default_timer as timer
numpy.random.seed(0)
def create_data(nO, nI, batch_size):
X = numpy.zeros((batch_size, nI), dtype="f")
X += numpy.random.uniform(-1.0, 1.0, X.shape)
W = numpy.zeros((nO, nI), dtype="f")
W += numpy.random.uniform(-1.0, 1.0, W.shape)
return X, W
def get_numpy_blas():
blas_libs = numpy.__config__.blas_ilp64_opt_info["libraries"]
return blas_libs[0]
def numpy_gemm(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((batch_size, nO), dtype="f")
for i in range(n):
numpy.dot(X, W, out=y)
total += y.sum()
y.fill(0)
print("Total:", total)
def blis_gemm(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((batch_size, nO), dtype="f")
for i in range(n):
gemm(X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)
def numpy_einsum(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((nO, batch_size), dtype="f")
for i in range(n):
numpy.einsum("ab,cb->ca", X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)
def blis_einsum(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((nO, batch_size), dtype="f")
for i in range(n):
einsum("ab,cb->ca", X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)
def main(nI=128 * 3, nO=128 * 3, batch_size=2000):
print(
"Setting up data for gemm. 1000 iters, "
"nO={nO} nI={nI} batch_size={batch_size}".format(**locals())
)
numpy_blas = get_numpy_blas()
X1, W1 = create_data(nI, nO, batch_size)
X2 = X1.copy()
W2 = W1.copy()
print("Blis gemm...")
start = timer()
blis_gemm(X2, W2, n=1000)
end = timer()
blis_time = end - start
print("%.2f seconds" % blis_time)
print("Numpy (%s) gemm..." % numpy_blas)
start = timer()
numpy_gemm(X1, W1)
end = timer()
numpy_time = end - start
print("%.2f seconds" % numpy_time)
print("Blis einsum ab,cb->ca")
start = timer()
blis_einsum(X2, W2, n=1000)
end = timer()
blis_time = end - start
print("%.2f seconds" % blis_time)
print("Numpy (%s) einsum ab,cb->ca" % numpy_blas)
start = timer()
numpy_einsum(X2, W2)
end = timer()
numpy_time = end - start
print("%.2f seconds" % numpy_time)
if __name__:
main()
|