1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
""" This module manages different implementations of matrix multiplication
related routines.
"""
import generic as _gen
import numarrayall as _na
try:
"""
Provides a BLAS-optimized (i.e. much faster) replacement `dot`
function for numarray arrays.
BLAS routines exist only for 32 & 64 bit float and complex types;
if BLAS routines cannot be used the dot defined here defers to the
standard `numarray.dot`.
This package is integrated with numarray so its improved functions
are used whenever they're available; they're not always available
because not everyone needs or wants to install a BLAS.
"""
__author__ = "Richard Everson (R.M.Everson@exeter.ac.uk)"
__revision__ = "$Revision: 1.2 $"
__version__ = "1.0"
import numarray._dotblas as _dotblas
import numarray._numarray as _numarray
USING_BLAS = 1
except ImportError:
USING_BLAS = 0
if not USING_BLAS:
from numarray._numarray import dot, innerproduct
else:
def dot(a, b):
"""returns matrix-multiplication between a and b.
The product-sum is over the last dimension of a and the
second-to-last dimension of b.
NB: No conjugation of complex arguments is performed.
This version uses the BLAS optimized routines where possible.
"""
try:
return _dotblas.dot(a, b)
except TypeError:
try:
return _numarray.dot(a, b)
except TypeError,detail:
if _na.shape(a) == () or _na.shape(b) == ():
return a*b
else:
raise TypeError, detail or "invalid types for dot"
def innerproduct(a, b):
"""returns inner product between a and b.
The product-sum is over the last dimension of a and b.
NB: No conjugation of complex arguments is performed.
This version uses the BLAS optimized routines where possible.
"""
try:
return _dotblas.innerproduct(a, b)
except TypeError:
try:
return _numarray.innerproduct(a, b)
except TypeError,detail:
if (_na.shape(a) == () or _na.shape(b) == ()):
return a*b
else:
raise TypeError, detail or "invalid types for innerproduct"
def vdot(a, b):
"""Returns the dot product of 2 vectors (or anything that can be made into
a vector). NB: this is not the same as `dot`, as it takes the conjugate
of its first argument if complex and always returns a scalar."""
a, b = _na.ravel(a), _na.ravel(b)
try:
return _dotblas.vdot(a, b)
# in case we get an integer Value
except TypeError:
return _numarray.dot(a, b)
matrixmultiply = dot
def outerproduct(array1, array2):
"""outerproduct(array1, array2) computes the NxM outerproduct of N vector
'array1' and M vector 'array2', where result[i,j] = array1[i]*array2[j].
"""
array1=_gen.reshape(
_na.asarray(array1), (-1,1)) # ravel array1 into an Nx1
array2=_gen.reshape(
_na.asarray(array2), (1,-1)) # ravel array2 into a 1xM
return matrixmultiply(array1,array2) # return NxM result
def tensormultiply(array1, array2):
"""tensormultiply returns the product for any rank >=1 arrays, defined as:
r_{xxx, yyy} = \sum_k array1_{xxx, k} array2_{k, yyyy}
where xxx, yyy denote the rest of the a and b dimensions.
"""
array1, array2 = _na.asarray(array1), _na.asarray(array2)
if array1.shape[-1] != array2.shape[0]:
raise ValueError, "Unmatched dimensions"
shape = array1.shape[:-1] + array2.shape[1:]
return _gen.reshape(dot(_gen.reshape(array1, (-1, array1.shape[-1])),
_gen.reshape(array2, (array2.shape[0], -1))), shape)
def kroneckerproduct(a,b):
'''Computes a otimes b where otimes is the Kronecker product operator.
Note: the Kronecker product is also known as the matrix direct product
or tensor product. It is defined as follows for 2D arrays a and b
where shape(a)=(m,n) and shape(b)=(p,q):
c = a otimes b => cij = a[i,j]*b where cij is the ij-th submatrix of c.
So shape(c)=(m*p,n*q).
>>> print kroneckerproduct([[1,2]],[[3],[4]])
[[3 6]
[4 8]]
>>> print kroneckerproduct([[1,2]],[[3,4]])
[ [3 4 6 8]]
>>> print kroneckerproduct([[1],[2]],[[3],[4]])
[[3]
[4]
[6]
[8]]
'''
a, b = _na.asarray(a), _na.asarray(b)
if not (len(a.shape)==2 and len(b.shape)==2):
raise ValueError, 'Input must be 2D arrays.'
if not a.iscontiguous():
a = _gen.reshape(a, a.shape)
if not b.iscontiguous():
b = _gen.reshape(b, b.shape)
o = outerproduct(a,b)
o.shape = a.shape + b.shape
return _gen.concatenate(_gen.concatenate(o, axis=1), axis=1)
|