1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Anand Venkat (Intel Corp.)
******************************************************************************/
#include <libxsmm.h>
#include <libxsmm_macros.h>
extern "C" int batch_reduce_kernel_update(const float *weight, const float *input, float *output, int blocks, int ofmblock, int ifmblock, int ofw, int stride_w, int r, int s, int ifh, int ifw){
int ld_b = stride_w*ifmblock;
libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr(ofmblock,ofw, ifmblock,NULL,&ld_b,NULL,NULL,NULL, NULL, NULL);
const unsigned long long cblocks = blocks;
const float * A[cblocks];
const float * B[cblocks];
int weight_stride = ofmblock*ifmblock*r*s;
int input_stride = ifw*ifh*ifmblock;
if(r == 1 && s == 1){
for (int icb = 0; icb < cblocks; icb ++) {
A[icb] = &weight[icb*weight_stride];
B[icb] = &input[icb*input_stride];
}
}else{/*Eg.if( r == 3 && s == 3){*/
for( int k = 0 ; k < blocks/(r*s); k++){
for(int i=0; i < r; i++){
for(int j =0; j < s; j++){
A[k*r*s + i*s + j] = &weight[k*r*s*ofmblock*ifmblock + (i*s + j)*ofmblock*ifmblock];
B[k*r*s + i*s + j] = &input[k*ifw*ifh*ifmblock + i*ifw*ifmblock + j*ifmblock];
}
}
}
}
/* Reduce batch gemm call */
batchreduce_kernela(A, B, output, &cblocks);
return 0;
}
extern "C" int batch_reduce_kernel_init_update(const float *weight, const float *input, float *output, int blocks, int ofmblock, int ifmblock,int r, int s, int ifh, int ifw,int ofw, int stride_w ){
float beta = 0.0;
int lda = ofmblock;
int ldx = ofmblock;
int ld_b = stride_w*ifmblock;
int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') );
libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr(ofmblock,ofw, ifmblock,&lda,&ld_b,&ldx,NULL,&beta, &l_flags, NULL);
const unsigned long long cblocks = blocks;
const float * A[cblocks];
const float * B[cblocks];
int weight_stride = ofmblock*ifmblock*r*s;
int input_stride = ifw*ifh*ifmblock;
if(r == 1 && s == 1){
for (int icb = 0; icb < cblocks; icb ++) {
A[icb] = &weight[icb*weight_stride];
B[icb] = &input[icb*input_stride];
}
}else{ /*if( r == 3 && s == 3){*/
for( int k = 0 ; k < blocks/(r*s); k++)
for(int i=0; i < r; i++)
for(int j =0; j < s; j++){
A[k*r*s + i*s + j] = &weight[k*r*s*ofmblock*ifmblock + (i*s + j)*ofmblock*ifmblock];
B[k*r*s + i*s + j] = &input[k*ifw*ifh*ifmblock + i*ifw*ifmblock + j*ifmblock];
}
}
/* Reduce batch gemm call */
batchreduce_kernela(A, B, output, &cblocks);
return 0;
}
extern "C" int batch_reduce_kernel_init(float *output, int ofmblock, int ofw){
int num_elements = ofw*ofmblock;
LIBXSMM_PRAGMA_SIMD
for(int i=0; i < num_elements; i++)
output[i] = 0.0;
return 0;
}
|