1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
#ifndef _COMEX_COMMON_ACC_H_
#define _COMEX_COMMON_ACC_H_
#include "comex.h"
/* needed for complex accumulate */
typedef struct {
double real;
double imag;
} DoubleComplex;
typedef struct {
float real;
float imag;
} SingleComplex;
#if SIZEOF_INT == BLAS_SIZE
#define BLAS_INT int
#elif SIZEOF_LONG == BLAS_SIZE
#define BLAS_INT long
#elif SIZEOF_LONG_LONG == BLAS_SIZE
#define BLAS_INT long long
#endif
#if HAVE_BLAS
void sscal(const BLAS_INT *n, const float *a, const float *x, const BLAS_INT *incx);
void dscal(const BLAS_INT *n, const double *a, const double *x, const BLAS_INT *incx);
void BLAS_SAXPY(const BLAS_INT *n, const float *a, const float *x,
const BLAS_INT *incx, float *y, const BLAS_INT *incy);
void BLAS_DAXPY(const BLAS_INT *n, const double *a, const double *x,
const BLAS_INT *incX, double *Y, const BLAS_INT *incy);
void BLAS_CAXPY(const BLAS_INT *n, const void *a, const void *x,
const BLAS_INT *incX, void *Y, const BLAS_INT *incy);
void BLAS_ZAXPY(const BLAS_INT *n, const void *a, const void *x,
const BLAS_INT *incX, void *Y, const BLAS_INT *incy);
void BLAS_SCOPY(const BLAS_INT *n, const float *x,
const BLAS_INT *incx, float *y, const BLAS_INT *incy);
void BLAS_DCOPY(const BLAS_INT *n, const double *x,
const BLAS_INT *incx, double *y, const BLAS_INT *incy);
void BLAS_CCOPY(const BLAS_INT *n, const void *x,
const BLAS_INT *incx, void *y, const BLAS_INT *incy);
void BLAS_ZCOPY(const BLAS_INT *n, const void *x,
const BLAS_INT *incx, void *y, const BLAS_INT *incy);
#endif
#define IADD_SCALE_REG(A,B,C) (A) += (B) * (C)
#define IADD_SCALE_CPL(A,B,C) \
(A).real += ((B).real*(C).real) - ((B).imag*(C).imag);\
(A).imag += ((B).real*(C).imag) + ((B).imag*(C).real);
#define MUL_REG(A,B,C) (A) = (B) * (C)
#define MUL_CPL(A,B,C) \
(A).real = ((B).real*(C).real) - ((B).imag*(C).imag);\
(A).imag = ((B).real*(C).imag) + ((B).imag*(C).real);
static inline void _scale(
const int op,
const int bytes,
void * const restrict dst,
const void * const restrict src,
const void * const restrict scale)
{
#define SCALE_BLAS(COMEX_TYPE, C_TYPE, LETTER) \
if (op == COMEX_TYPE) { \
const BLAS_INT ONE = 1; \
const BLAS_INT N = bytes/sizeof(C_TYPE); \
BLAS_##LETTER##COPY(&N, src, &ONE, dst, &ONE); \
BLAS_##LETTER##AXPY(&N, scale, src, &ONE, dst, &ONE); \
} else
#define SCALE(WHICH, COMEX_TYPE, C_TYPE) \
if (op == COMEX_TYPE) { \
int m; \
const int m_lim = bytes/sizeof(C_TYPE); \
C_TYPE * const restrict iterator = (C_TYPE * const restrict )dst; \
const C_TYPE * const restrict value = (const C_TYPE * const restrict)src;\
const C_TYPE calc_scale = *(const C_TYPE * const restrict )scale; \
for (m = 0 ; m < m_lim; ++m) { \
MUL_##WHICH(iterator[m], value[m], calc_scale); \
} \
} else
#if 0 // HAVE_BLAS
SCALE_BLAS(COMEX_ACC_DBL, double, D)
SCALE_BLAS(COMEX_ACC_FLT, float, S)
SCALE(REG, COMEX_ACC_INT, int)
SCALE(REG, COMEX_ACC_LNG, long)
SCALE_BLAS(COMEX_ACC_DCP, DoubleComplex, Z)
SCALE_BLAS(COMEX_ACC_CPL, SingleComplex, C)
#else
SCALE(REG, COMEX_ACC_DBL, double)
SCALE(REG, COMEX_ACC_FLT, float)
SCALE(REG, COMEX_ACC_INT, int)
SCALE(REG, COMEX_ACC_LNG, long)
SCALE(CPL, COMEX_ACC_DCP, DoubleComplex)
SCALE(CPL, COMEX_ACC_CPL, SingleComplex)
#endif
{
#ifdef COMEX_ASSERT
COMEX_ASSERT(0);
#else
assert(0);
#endif
}
#undef SCALE_BLAS
#undef SCALE
}
static inline void _acc(
const int op,
const int bytes,
void * const restrict dst,
const void * const restrict src,
const void * const restrict scale)
{
#define ACC_BLAS(COMEX_TYPE, C_TYPE, LETTER) \
if (op == COMEX_TYPE) { \
const BLAS_INT ONE = 1; \
const BLAS_INT N = bytes/sizeof(C_TYPE); \
BLAS_##LETTER##AXPY(&N, scale, src, &ONE, dst, &ONE); \
} else
#define ACC(WHICH, COMEX_TYPE, C_TYPE) \
if (op == COMEX_TYPE) { \
int m; \
const int m_lim = bytes/sizeof(C_TYPE); \
C_TYPE * const restrict iterator = (C_TYPE * const restrict)dst; \
const C_TYPE * const restrict value = (const C_TYPE * const restrict)src;\
const C_TYPE calc_scale = *(const C_TYPE * const restrict)scale; \
for (m = 0 ; m < m_lim; ++m) { \
IADD_SCALE_##WHICH(iterator[m], value[m], calc_scale); \
} \
} else
#if HAVE_BLAS
ACC_BLAS(COMEX_ACC_DBL, double, D)
ACC_BLAS(COMEX_ACC_FLT, float, S)
ACC(REG, COMEX_ACC_INT, int)
ACC(REG, COMEX_ACC_LNG, long)
ACC_BLAS(COMEX_ACC_DCP, DoubleComplex, Z)
ACC_BLAS(COMEX_ACC_CPL, SingleComplex, C)
#else
ACC(REG, COMEX_ACC_DBL, double)
ACC(REG, COMEX_ACC_FLT, float)
ACC(REG, COMEX_ACC_INT, int)
ACC(REG, COMEX_ACC_LNG, long)
ACC(CPL, COMEX_ACC_DCP, DoubleComplex)
ACC(CPL, COMEX_ACC_CPL, SingleComplex)
#endif
{
#ifdef COMEX_ASSERT
COMEX_ASSERT(0);
#else
assert(0);
#endif
}
#undef ACC_BLAS
#undef ACC
}
#undef IADD_SCALE_REG
#undef IADD_SCALE_CPL
#undef MUL_REG
#undef MUL_CPL
#undef BLAS_INT
#endif /* _COMEX_COMMON_ACC_H_ */
|