1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
# cython: boundscheck=False
# Copyright (c) 2017 - 2022 ExplosionAI GmbH, released under BSD-3-Clause.
cimport numpy as np
from . cimport cy
from .cy cimport reals1d_ft, reals2d_ft, float1d_t, float2d_t
from .cy cimport const_reals1d_ft, const_reals2d_ft, const_float1d_t, const_float2d_t
from .cy cimport const_double1d_t, const_double2d_t
import numpy
def axpy(const_reals1d_ft A, double scale=1., np.ndarray out=None):
if const_reals1d_ft is const_float1d_t:
if out is None:
out = numpy.zeros((A.shape[0],), dtype='f')
B = <float*>out.data
return out
elif const_reals1d_ft is const_double1d_t:
if out is None:
out = numpy.zeros((A.shape[0],), dtype='d')
B = <double*>out.data
with nogil:
cy.axpyv(cy.NO_CONJUGATE, A.shape[0], scale, &A[0], 1, B, 1)
return out
else:
B = NULL
raise TypeError("Unhandled fused type")
def batch_axpy(reals2d_ft A, reals1d_ft B, np.ndarray out=None):
pass
def ger(const_reals2d_ft A, const_reals1d_ft B, double scale=1., np.ndarray out=None):
if const_reals2d_ft is const_float2d_t and const_reals1d_ft is const_float1d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[0]), dtype='f')
with nogil:
cy.ger(
cy.NO_CONJUGATE, cy.NO_CONJUGATE,
A.shape[0], B.shape[0],
scale,
&A[0,0], 1,
&B[0], 1,
<float*>out.data, out.shape[1], 1)
return out
elif const_reals2d_ft is const_double2d_t and const_reals1d_ft is const_double1d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[0]), dtype='d')
with nogil:
cy.ger(
cy.NO_CONJUGATE, cy.NO_CONJUGATE,
A.shape[0], B.shape[0],
scale,
&A[0,0], 1,
&B[0], 1,
<double*>out.data, out.shape[1], 1)
return out
else:
C = NULL
raise TypeError("Unhandled fused type")
def gemm(const_reals2d_ft A, const_reals2d_ft B,
np.ndarray out=None, bint trans1=False, bint trans2=False,
double alpha=1., double beta=1.):
cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
cdef cy.dim_t nK_b = B.shape[0] if not trans2 else B.shape[1]
cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
if nK != nK_b:
msg = "Shape mismatch for blis.gemm: (%d, %d), (%d, %d)"
raise ValueError(msg % (nM, nK, nK_b, nN))
if const_reals2d_ft is const_float2d_t:
if out is None:
if beta == 0.:
out = numpy.empty((nM, nN), dtype='f')
else:
out = numpy.zeros((nM, nN), dtype='f')
C = <float*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
nM, nN, nK,
alpha,
&A[0,0], A.shape[1], 1,
&B[0,0], B.shape[1], 1,
beta,
C, out.shape[1], 1)
return out
elif const_reals2d_ft is const_double2d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
C = <double*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
A.shape[0], B.shape[1], A.shape[1],
alpha,
&A[0,0], A.shape[1], 1,
&B[0,0], B.shape[1], 1,
beta,
C, out.shape[1], 1)
return out
else:
C = NULL
raise TypeError("Unhandled fused type")
def gemv(const_reals2d_ft A, const_reals1d_ft B,
bint trans1=False, double alpha=1., double beta=1.,
np.ndarray out=None):
if const_reals1d_ft is const_float1d_t and const_reals2d_ft is const_float2d_t:
if out is None:
out = numpy.zeros((A.shape[0],), dtype='f')
with nogil:
cy.gemv(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.NO_CONJUGATE,
A.shape[0], A.shape[1],
alpha,
&A[0,0], A.shape[1], 1,
&B[0], 1,
beta,
<float*>out.data, 1)
return out
elif const_reals1d_ft is const_double1d_t and const_reals2d_ft is const_double2d_t:
if out is None:
out = numpy.zeros((A.shape[0],), dtype='d')
with nogil:
cy.gemv(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.NO_CONJUGATE,
A.shape[0], A.shape[1],
alpha,
&A[0,0], A.shape[1], 1,
&B[0], 1,
beta,
<double*>out.data, 1)
return out
else:
raise TypeError("Unhandled fused type")
def dotv(const_reals1d_ft X, const_reals1d_ft Y, bint conjX=False, bint conjY=False):
if X.shape[0] != Y.shape[0]:
msg = "Shape mismatch for blis.dotv: (%d,), (%d,)"
raise ValueError(msg % (X.shape[0], Y.shape[0]))
return cy.dotv(
cy.CONJUGATE if conjX else cy.NO_CONJUGATE,
cy.CONJUGATE if conjY else cy.NO_CONJUGATE,
X.shape[0], &X[0], &Y[0], 1, 1
)
def einsum(todo, A, B, out=None):
if todo == 'a,a->a':
return axpy(A, B, out=out)
elif todo == 'a,b->ab':
return ger(A, B, out=out)
elif todo == 'a,b->ba':
return ger(B, A, out=out)
elif todo == 'ab,a->ab':
return batch_axpy(A, B, out=out)
elif todo == 'ab,a->ba':
return batch_axpy(A, B, trans1=True, out=out)
elif todo == 'ab,b->a':
return gemv(A, B, out=out)
elif todo == 'ab,a->b':
return gemv(A, B, trans1=True, out=out)
# The rule here is, look at the first dimension of the output. That must
# occur in arg1. Set trans1 if it's dimension 2.
# E.g. bc is output, b occurs in ab, so that must be arg1. So we need
# trans1=True, to make ba,ac->bc
elif todo == 'ab,ac->bc':
return gemm(A, B, trans1=True, trans2=False, out=out)
elif todo == 'ab,ac->cb':
return gemm(B, A, out=out, trans1=True, trans2=True)
elif todo == 'ab,bc->ac':
return gemm(A, B, out=out, trans1=False, trans2=False)
elif todo == 'ab,bc->ca':
return gemm(B, A, out=out, trans1=True, trans2=True)
elif todo == 'ab,ca->bc':
return gemm(A, B, out=out, trans1=True, trans2=True)
elif todo == 'ab,ca->cb':
return gemm(B, A, out=out, trans1=False, trans2=False)
elif todo == 'ab,cb->ac':
return gemm(A, B, out=out, trans1=False, trans2=True)
elif todo == 'ab,cb->ca':
return gemm(B, A, out=out, trans1=False, trans2=True)
else:
raise ValueError("Invalid einsum: %s" % todo)
|