File: FCImpl.hpp

package info (click to toggle)
libxsmm 1.17-4
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 14,976 kB
  • sloc: ansic: 119,587; cpp: 27,680; fortran: 9,179; sh: 5,765; makefile: 5,040; pascal: 2,312; python: 1,812; f90: 1,773
file content (89 lines) | stat: -rw-r--r-- 2,952 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
******************************************************************************/


#pragma once

#include <omp.h>
#include <assert.h>
#include <sys/time.h>
#include "common.hpp"
#include "check.hpp"
#include "Tensor.hpp"

typedef struct {
  string node_name;
  int nInput, nOutput;
  int batch_size;
  int iHeight, iWidth;
  int oHeight, oWidth;
  int kh, kw;
  bool bias_term;
  int in_data_type, out_data_type;
  int algType;
  int num_numa_nodes;
  int num_threads;
} FCImplParams;

class FCImpl
{
  protected:
    FCImplParams* gp;
    int engine;
    TensorLayoutType bot_layout_type, top_layout_type, gbot_layout_type;
    void *bot_layout=NULL, *top_layout=NULL, *gbot_layout=NULL;
    int top_compute_engine=-1;
    int bot_compute_engine=-1;
    string nname;
    TensorBuf* scratchp;

  public:
    FCImpl(FCImplParams* gp_, int engine_): gp(gp_), engine(engine_) {}

    void set_top_compute_engine(int e) { top_compute_engine = e;}
    void set_bot_compute_engine(int e) { bot_compute_engine = e;}
    void set_node_name(string s) { nname = s; }
    void set_scratch_buffer(TensorBuf* sb) { scratchp = sb; }

    virtual void forwardPropagate(TensorBuf *inp, TensorBuf* weightp, TensorBuf *hweightp, TensorBuf* biasp, TensorBuf *outp, int tid) = 0;
    virtual void backPropagate(TensorBuf *deloutp, TensorBuf* weightp, TensorBuf *delinp, int tid) = 0;
    virtual void weightUpdate(TensorBuf *deloutp, TensorBuf *inp, TensorBuf *delweightp, TensorBuf *delbiasp, int tid) = 0;

    virtual void forwardPropagate(TensorBuf *inp, TensorBuf* weightp, TensorBuf *hweightp, TensorBuf* biasp, TensorBuf *outp)
    {
      switch(engine)
      {
        case XSMM:
          forwardPropagate(inp, weightp, hweightp, biasp, outp, 0);
          break;
      }
    }

    virtual void backPropagate(TensorBuf *deloutp, TensorBuf *weightp, TensorBuf *delinp)
    {
      switch(engine)
      {
        case XSMM:
          backPropagate(deloutp, weightp, delinp, 0);
          break;
      }
    }

    virtual void weightUpdate(TensorBuf *deloutp, TensorBuf *inp, TensorBuf *delweightp, TensorBuf *delbiasp)
    {
      switch(engine)
      {
        case XSMM:
          weightUpdate(deloutp, inp, delweightp, delbiasp, 0);
          break;
      }
    }
};