1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
|
/*
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#if defined(TARGET_LINUX_POWER)
#error "Source cannot be compiled for POWER architectures"
#include "xmm2altivec.h"
#else
#include <immintrin.h>
#include "mth_avx512helper.h"
#endif
#include "fdlog_defs.h"
#include <stdio.h>
extern "C" __m512d FCN_AVX512(__fvd_log_fma3)(__m512d);
// casts int to double
inline
__m512d __internal_fast_int2dbl(__m512i a)
{
__m512i const INT2DBL_HI = _mm512_set1_epi64(INT2DBL_HI_D);
__m512i const INT2DBL_LO = _mm512_set1_epi64(INT2DBL_LO_D);
__m512d const INT2DBL = (__m512d)_mm512_set1_epi64(INT2DBL_D);
__m512i t = _mm512_xor_si512(INT2DBL_LO, a);
t = _mm512_mask_blend_epi32(0x5555, INT2DBL_HI, t);
return _mm512_sub_pd((__m512d)t, INT2DBL);
}
// special cases for log
static __m512d __attribute__ ((noinline)) __pgm_log_d_vec512_special_cases(__m512d const a, __m512d z)
{
__m512d const ZERO = _mm512_set1_pd(ZERO_D);
__m512i const ALL_ONES_EXPONENT = _mm512_set1_epi64(ALL_ONES_EXPONENT_D);
__m512d const NAN_VAL = (__m512d)_mm512_set1_epi64(NAN_VAL_D);
__m512d const NEG_INF = (__m512d)_mm512_set1_epi64(NEG_INF_D);
__m512i detect_inf_nan = (__m512i)_mm512_sub_pd(a, a);
__m512d inf_nan_mask = (__m512d)_MM512_CMPEQ_EPI64(_mm512_and_si512(detect_inf_nan, ALL_ONES_EXPONENT), ALL_ONES_EXPONENT);
// inf + inf = inf = log(inf). nan + nan = nan = log(nan).
__m512i inf_nan = (__m512i)_mm512_add_pd(a, a);
z = _MM512_BLENDV_PD(z, (__m512d)inf_nan, inf_nan_mask);
__m512d non_positive_mask = _MM512_CMP_PD(a, ZERO, _CMP_LT_OQ);
// log(negative number) = NaN
z = _MM512_BLENDV_PD(z, NAN_VAL, non_positive_mask);
// log(0) = -inf
__m512d zero_mask = _MM512_CMP_PD(a, ZERO, _CMP_EQ_OQ);
z = _MM512_BLENDV_PD(z, NEG_INF, zero_mask);
return z;
}
__m512d FCN_AVX512(__fvd_log_fma3)(__m512d const a)
{
__m512d const HI_CONST_1 = (__m512d)_mm512_set1_epi64(HI_CONST_1_D);
__m512d const HI_CONST_2 = (__m512d)_mm512_set1_epi64(HI_CONST_2_D);
__m512i const HALFIFIER = _mm512_set1_epi64(HALFIFIER_D);
__m512i const HI_THRESH = _mm512_set1_epi64(HI_THRESH_D);
__m512d const ONE_F = _mm512_set1_pd(ONE_F_D);
__m512d const ZERO = _mm512_set1_pd(ZERO_D);
__m512d const LN2_HI = _mm512_set1_pd(LN2_HI_D);
__m512d const LN2_LO = _mm512_set1_pd(LN2_LO_D);
__m512i const HI_MASK = _mm512_set1_epi64(HI_MASK_D);
__m512i const ONE = _mm512_set1_epi64(ONE_D);
__m512i const TEN_23 = _mm512_set1_epi64(TEN_23_D);
__m512i const ALL_ONES_EXPONENT = _mm512_set1_epi64(ALL_ONES_EXPONENT_D);
__m512d const LOG_C1_VEC = _mm512_set1_pd( LOG_C1_VEC_D );
__m512d const LOG_C2_VEC = _mm512_set1_pd( LOG_C2_VEC_D );
__m512d const LOG_C3_VEC = _mm512_set1_pd( LOG_C3_VEC_D );
__m512d const LOG_C4_VEC = _mm512_set1_pd( LOG_C4_VEC_D );
__m512d const LOG_C5_VEC = _mm512_set1_pd( LOG_C5_VEC_D );
__m512d const LOG_C6_VEC = _mm512_set1_pd( LOG_C6_VEC_D );
__m512d const LOG_C7_VEC = _mm512_set1_pd( LOG_C7_VEC_D );
__m512d const LOG_C8_VEC = _mm512_set1_pd( LOG_C8_VEC_D );
__m512d const LOG_C9_VEC = _mm512_set1_pd( LOG_C9_VEC_D );
__m512d const LOG_C10_VEC = _mm512_set1_pd( LOG_C10_VEC_D );
__m512d const LOG_C11_VEC = _mm512_set1_pd( LOG_C11_VEC_D );
__m512d const LOG_C12_VEC = _mm512_set1_pd( LOG_C12_VEC_D );
__m512d const LOG_C13_VEC = _mm512_set1_pd( LOG_C13_VEC_D );
__m512d const LOG_C14_VEC = _mm512_set1_pd( LOG_C14_VEC_D );
__m512d const LOG_C15_VEC = _mm512_set1_pd( LOG_C15_VEC_D );
__m512d const LOG_C16_VEC = _mm512_set1_pd( LOG_C16_VEC_D );
__m512d const LOG_C17_VEC = _mm512_set1_pd( LOG_C17_VEC_D );
__m512d const LOG_C18_VEC = _mm512_set1_pd( LOG_C18_VEC_D );
__m512d const LOG_C19_VEC = _mm512_set1_pd( LOG_C19_VEC_D );
__m512d const LOG_C20_VEC = _mm512_set1_pd( LOG_C20_VEC_D );
__m512d const LOG_C21_VEC = _mm512_set1_pd( LOG_C21_VEC_D );
__m512d const LOG_C22_VEC = _mm512_set1_pd( LOG_C22_VEC_D );
__m512d const LOG_C23_VEC = _mm512_set1_pd( LOG_C23_VEC_D );
__m512d const LOG_C24_VEC = _mm512_set1_pd( LOG_C24_VEC_D );
__m512d a_mut, m, f;
__m512i expo, expo_plus1;
__m512d thresh_mask;
// isolate mantissa
a_mut = _MM512_AND_PD(a, HI_CONST_1);
a_mut = _MM512_OR_PD(a_mut, HI_CONST_2);
// magic trick to improve accuracy (divide mantissa by 2 and increase exponent by 1)
thresh_mask = _MM512_CMP_PD(a_mut, (__m512d)HI_THRESH, _CMP_GT_OS);
m = (__m512d)_mm512_sub_epi32((__m512i)a_mut, HALFIFIER);
m = _MM512_BLENDV_PD(a_mut, m, thresh_mask);
// compute exponent
expo = _mm512_srli_epi64((__m512i)a, D52_D);
expo = _mm512_sub_epi64(expo, TEN_23);
expo_plus1 = _mm512_add_epi64(expo, ONE);
expo = (__m512i)_MM512_BLENDV_PD((__m512d)expo, (__m512d)expo_plus1, thresh_mask);
// computing polynomial for log(1+m)
m = _mm512_sub_pd(m, ONE_F);
// estrin scheme for highest 16 terms, then estrin again for the next 8. Finally finish off with horner.
__m512d z9 = _mm512_fmadd_pd(LOG_C10_VEC, m, LOG_C9_VEC);
__m512d z11 = _mm512_fmadd_pd(LOG_C12_VEC, m, LOG_C11_VEC);
__m512d z13 = _mm512_fmadd_pd(LOG_C14_VEC, m, LOG_C13_VEC);
__m512d z15 = _mm512_fmadd_pd(LOG_C16_VEC, m, LOG_C15_VEC);
__m512d z17 = _mm512_fmadd_pd(LOG_C18_VEC, m, LOG_C17_VEC);
__m512d z19 = _mm512_fmadd_pd(LOG_C20_VEC, m, LOG_C19_VEC);
__m512d z21 = _mm512_fmadd_pd(LOG_C22_VEC, m, LOG_C21_VEC);
__m512d z23 = _mm512_fmadd_pd(LOG_C24_VEC, m, LOG_C23_VEC);
__m512d m2 = _mm512_mul_pd(m, m);
z9 = _mm512_fmadd_pd(z11, m2, z9);
z13 = _mm512_fmadd_pd(z15, m2, z13);
z17 = _mm512_fmadd_pd(z19, m2, z17);
z21 = _mm512_fmadd_pd(z23, m2, z21);
__m512d m4 = _mm512_mul_pd(m2, m2);
z9 = _mm512_fmadd_pd(z13, m4, z9);
z17 = _mm512_fmadd_pd(z21, m4, z17);
__m512d m8 = _mm512_mul_pd(m4, m4);
z9 = _mm512_fmadd_pd(z17, m8, z9);
// estrin for the next 8 terms
__m512d z8 = _mm512_fmadd_pd(z9, m, LOG_C8_VEC);
__m512d z6 = _mm512_fmadd_pd(LOG_C7_VEC, m, LOG_C6_VEC);
__m512d z4 = _mm512_fmadd_pd(LOG_C5_VEC, m, LOG_C4_VEC);
__m512d z2 = _mm512_fmadd_pd(LOG_C3_VEC, m, LOG_C2_VEC);
z6 = _mm512_fmadd_pd(z8, m2, z6);
z2 = _mm512_fmadd_pd(z4, m2, z2);
__m512d z = _mm512_fmadd_pd(z6, m4, z2);
// finish computation with horner
z = _mm512_fmadd_pd(z, m, LOG_C1_VEC);
z = _mm512_mul_pd(z, m);
f = __internal_fast_int2dbl(expo);
z = _mm512_fmadd_pd(f, LN2_HI, z);
// compute special cases (inf, NaN, negative, 0)
__m512i detect_inf_nan = (__m512i)_mm512_sub_pd(a, a);
__m512i detect_non_positive = (__m512i)_MM512_CMP_PD(a, ZERO, _CMP_LE_OQ);
__m512i inf_nan_mask = _MM512_CMPEQ_EPI64(_mm512_and_si512(detect_inf_nan, ALL_ONES_EXPONENT), ALL_ONES_EXPONENT);
int specMask = _MM512_MOVEMASK_PD((__m512d)_mm512_or_si512(detect_non_positive, inf_nan_mask));
if(__builtin_expect(specMask, 0)) {
return __pgm_log_d_vec512_special_cases(a, z);
}
return z;
}
|