File: PoolingXSMM.hpp

package info (click to toggle)
libxsmm 1.17-4
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 14,976 kB
  • sloc: ansic: 119,587; cpp: 27,680; fortran: 9,179; sh: 5,765; makefile: 5,040; pascal: 2,312; python: 1,812; f90: 1,773
file content (48 lines) | stat: -rw-r--r-- 2,084 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
******************************************************************************/


#pragma once
#include "PoolingImpl.hpp"
#include "libxsmm.h"
#include "check.hpp"

#define CHKERR_LIBXSMM_DNN(A) if ( A != LIBXSMM_DNN_SUCCESS )\
{\
  fprintf(stdout, "%s, %s\n", gp->node_name.c_str(), libxsmm_dnn_get_error(A) );\
  fflush(stdout);\
}

class PoolXSMM : public PoolImpl
{
  protected:
    PoolImpl *gp_;
    libxsmm_dnn_pooling_desc pooling_desc;
    libxsmm_dnn_pooling* libxsmm_handle[NUM_NUMA_NODES];
    libxsmm_dnn_tensor*  libxsmm_input[NUM_NUMA_NODES] = {NULL};
    libxsmm_dnn_tensor*  libxsmm_delinput[NUM_NUMA_NODES]={NULL};
    libxsmm_dnn_tensor*  libxsmm_output[NUM_NUMA_NODES]={NULL};
    libxsmm_dnn_tensor*  libxsmm_deloutput[NUM_NUMA_NODES]={NULL};
    libxsmm_dnn_tensor*  libxsmm_mask[NUM_NUMA_NODES]={NULL};
    libxsmm_dnn_tensor_datalayout* libxsmm_layout;
    libxsmm_dnn_err_t status;
    libxsmm_dnn_err_t global_status = LIBXSMM_DNN_SUCCESS;
    bool updated_scratch_fwd=false, updated_scratch_bwd=false;
    void *scratch=NULL;
    int prev_scratch_size = 0;
  public:
    PoolXSMM(PoolImplParams* gp, int engine);
    virtual ~PoolXSMM(void) {}

    // Assume external threading, e.g., #pragma omp
    void forwardPropagate(TensorBuf *inp, TensorBuf *outp, int *maskp, int tid);
    void backPropagate(TensorBuf *deloutp, int *maskp, TensorBuf *delinp, int tid);
};