1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
|
/*
-- MAGMA (version 2.9.0) --
Univ. of Tennessee, Knoxville
Univ. of California, Berkeley
Univ. of Colorado, Denver
@date January 2025
@generated from magmablas/zgemm_batched.cpp, normal z -> d, Wed Jan 22 14:41:59 2025
@author Jakub Kurzak
@author Stan Tomov
@author Mark Gates
@author Azzam Haidar
*/
#include "magma_internal.h"
#include "commonblas_d.h"
#define PRECISION_d
/* on some platforms (i.e. hipMAGMA on ROCm stack), we define custom types
* * So, to keep the C++ compiler from giving errors, we cast arguments to internal
* * BLAS routines. The hipify script should replace `cu*Complex` with appropriate HIP types
* *
* * FUTURE READERS: If hipBLAS changes numbers to `hipblas*Complex` rather than `hip*Complex`,
* * these will need more complicated macro if/else blocks
* */
#ifdef PRECISION_z
#ifdef MAGMA_HAVE_HIP
typedef double BackendFloat_t;
#else
typedef double BackendFloat_t;
#endif
#elif defined(PRECISION_c)
#ifdef MAGMA_HAVE_HIP
typedef hipblasComplex BackendFloat_t;
#else
typedef cuFloatComplex BackendFloat_t;
#endif
#elif defined(PRECISION_d)
typedef double BackendFloat_t;
#else
typedef float BackendFloat_t;
#endif
void
magma_dgemm_batched_core(
magma_trans_t transA, magma_trans_t transB,
magma_int_t m, magma_int_t n, magma_int_t k,
double alpha,
double const * const * dA_array, magma_int_t Ai, magma_int_t Aj, magma_int_t ldda,
double const * const * dB_array, magma_int_t Bi, magma_int_t Bj, magma_int_t lddb,
double beta,
double **dC_array, magma_int_t Ci, magma_int_t Cj, magma_int_t lddc,
magma_int_t batchCount, magma_queue_t queue )
{
magma_int_t use_cublas = magma_drecommend_cublas_gemm_batched(transA, transB, m, n, k);
magma_int_t zero_offset = (Ai == 0 && Aj == 0 && Bi == 0 && Bj == 0 && Ci == 0 && Cj == 0);
if(use_cublas){
if(zero_offset){
cublasDgemmBatched(
queue->cublas_handle(), cublas_trans_const(transA), cublas_trans_const(transB),
int(m), int(n), int(k),
(BackendFloat_t*)&alpha, (const BackendFloat_t**)dA_array, int(ldda),
(const BackendFloat_t**)dB_array, int(lddb),
(BackendFloat_t*)&beta, (BackendFloat_t**)dC_array, int(lddc), int(batchCount) );
}
else{
double** dAarray = (double**)queue->get_dAarray();
double** dBarray = (double**)queue->get_dBarray();
double** dCarray = (double**)queue->get_dCarray();
magma_int_t max_batchCount = queue->get_maxBatch();
for(magma_int_t i = 0; i < batchCount; i+=max_batchCount){
magma_int_t batch = min(max_batchCount, batchCount-i);
magma_ddisplace_pointers(dAarray, (double**)dA_array + i, ldda, Ai, Aj, batch, queue);
magma_ddisplace_pointers(dBarray, (double**)dB_array + i, lddb, Bi, Bj, batch, queue);
magma_ddisplace_pointers(dCarray, (double**)dC_array + i, lddc, Ci, Cj, batch, queue);
cublasDgemmBatched(
queue->cublas_handle(), cublas_trans_const(transA), cublas_trans_const(transB),
int(m), int(n), int(k),
(BackendFloat_t*)&alpha, (const BackendFloat_t**)dAarray, int(ldda),
(const BackendFloat_t**)dBarray, int(lddb),
(BackendFloat_t*)&beta, (BackendFloat_t**)dCarray, int(lddc), int(batch) );
}
}
}
else{
magmablas_dgemm_batched_core(
transA, transB,
m, n, k,
alpha, dA_array, Ai, Aj, ldda,
dB_array, Bi, Bj, lddb,
beta, dC_array, Ci, Cj, lddc,
batchCount, queue);
}
}
/***************************************************************************//**
Purpose
-------
DGEMM performs one of the matrix-matrix operations
C = alpha*op( A )*op( B ) + beta*C,
where op( X ) is one of
op( X ) = X or
op( X ) = X**T or
op( X ) = X**H,
alpha and beta are scalars, and A, B and C are matrices, with
op( A ) an m by k matrix, op( B ) a k by n matrix and C an m by n matrix.
Parameters
----------
@param[in]
transA magma_trans_t.
On entry, transA specifies the form of op( A ) to be used in
the matrix multiplication as follows:
- = MagmaNoTrans: op( A ) = A.
- = MagmaTrans: op( A ) = A**T.
- = MagmaConjTrans: op( A ) = A**H.
@param[in]
transB magma_trans_t.
On entry, transB specifies the form of op( B ) to be used in
the matrix multiplication as follows:
- = MagmaNoTrans: op( B ) = B.
- = MagmaTrans: op( B ) = B**T.
- = MagmaConjTrans: op( B ) = B**H.
@param[in]
m INTEGER.
On entry, M specifies the number of rows of the matrix
op( A ) and of the matrix C. M must be at least zero.
@param[in]
n INTEGER.
On entry, N specifies the number of columns of the matrix
op( B ) and the number of columns of the matrix C. N must be
at least zero.
@param[in]
k INTEGER.
On entry, K specifies the number of columns of the matrix
op( A ) and the number of rows of the matrix op( B ). K must
be at least zero.
@param[in]
alpha DOUBLE PRECISION
On entry, ALPHA specifies the scalar alpha.
@param[in]
dA_array Array of pointers, dimension (batchCount).
Each is a DOUBLE PRECISION array A of DIMENSION ( ldda, ka ), where ka is
k when transA = MagmaNoTrans, and is m otherwise.
Before entry with transA = MagmaNoTrans, the leading m by k
part of the array A must contain the matrix A, otherwise
the leading k by m part of the array A must contain the
matrix A.
@param[in]
ldda INTEGER.
On entry, ldda specifies the first dimension of each array A as declared
in the calling (sub) program. When transA = MagmaNoTrans then
ldda must be at least max( 1, m ), otherwise ldda must be at
least max( 1, k ).
@param[in]
dB_array Array of pointers, dimension (batchCount).
Each is a DOUBLE PRECISION array B of DIMENSION ( lddb, kb ), where kb is
n when transB = MagmaNoTrans, and is k otherwise.
Before entry with transB = MagmaNoTrans, the leading k by n
part of the array B must contain the matrix B, otherwise
the leading n by k part of the array B must contain the
matrix B.
@param[in]
lddb INTEGER.
On entry, lddb specifies the first dimension of each array B as declared
in the calling (sub) program. When transB = MagmaNoTrans then
lddb must be at least max( 1, k ), otherwise lddb must be at
least max( 1, n ).
@param[in]
beta DOUBLE PRECISION.
On entry, BETA specifies the scalar beta. When BETA is
supplied as zero then C need not be set on input.
@param[in,out]
dC_array Array of pointers, dimension (batchCount).
Each is a DOUBLE PRECISION array C of DIMENSION ( lddc, n ).
Before entry, the leading m by n part of the array C must
contain the matrix C, except when beta is zero, in which
case C need not be set on entry.
On exit, the array C is overwritten by the m by n matrix
( alpha*op( A )*op( B ) + beta*C ).
@param[in]
lddc INTEGER.
On entry, lddc specifies the first dimension of each array C as declared
in the calling (sub) program. lddc must be at least
max( 1, m ).
@param[in]
batchCount INTEGER
The number of matrices to operate on.
@param[in]
queue magma_queue_t
Queue to execute in.
@ingroup magma_gemm_batched
*******************************************************************************/
extern "C" void
magmablas_dgemm_batched( magma_trans_t transA, magma_trans_t transB,
magma_int_t m, magma_int_t n, magma_int_t k,
double alpha,
double const * const * dA_array, magma_int_t ldda,
double const * const * dB_array, magma_int_t lddb,
double beta,
double **dC_array, magma_int_t lddc,
magma_int_t batchCount, magma_queue_t queue )
{
magmablas_dgemm_batched_core(
transA, transB, m, n, k,
alpha, dA_array, 0, 0, ldda,
dB_array, 0, 0, lddb,
beta, dC_array, 0, 0, lddc,
batchCount, queue );
}
/******************************************************************************/
extern "C" void
magmablas_dgemm_batched_strided( magma_trans_t transA, magma_trans_t transB,
magma_int_t m, magma_int_t n, magma_int_t k,
double alpha,
double const * dA, magma_int_t ldda, magma_int_t strideA,
double const * dB, magma_int_t lddb, magma_int_t strideB,
double beta,
double * dC, magma_int_t lddc, magma_int_t strideC,
magma_int_t batchCount, magma_queue_t queue )
{
double** dAarray = (double**)queue->get_dAarray();
double** dBarray = (double**)queue->get_dBarray();
double** dCarray = (double**)queue->get_dCarray();
magma_int_t max_batchCount = queue->get_maxBatch();
for(magma_int_t i = 0; i < batchCount; i+=max_batchCount){
magma_int_t batch = min(max_batchCount, batchCount-i);
magma_dset_pointer(dAarray, (double*)(dA + i * strideA), ldda, 0, 0, strideA, batch, queue);
magma_dset_pointer(dBarray, (double*)(dB + i * strideB), lddb, 0, 0, strideB, batch, queue);
magma_dset_pointer(dCarray, dC + i * strideC, lddc, 0, 0, strideC, batch, queue);
magmablas_dgemm_batched_core(
transA, transB,
m, n, k,
alpha, dAarray, 0, 0, ldda,
dBarray, 0, 0, lddb,
beta, dCarray, 0, 0, lddc,
batch, queue);
}
}
/******************************************************************************/
extern "C" void
magma_dgemm_batched( magma_trans_t transA, magma_trans_t transB,
magma_int_t m, magma_int_t n, magma_int_t k,
double alpha,
double const * const * dA_array, magma_int_t ldda,
double const * const * dB_array, magma_int_t lddb,
double beta,
double **dC_array, magma_int_t lddc,
magma_int_t batchCount, magma_queue_t queue )
{
magma_dgemm_batched_core(
transA, transB, m, n, k,
alpha, dA_array, 0, 0, ldda,
dB_array, 0, 0, lddb,
beta, dC_array, 0, 0, lddc,
batchCount, queue );
}
|