File: Node.hpp

package info (click to toggle)
libxsmm 1.17-4
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 14,976 kB
  • sloc: ansic: 119,587; cpp: 27,680; fortran: 9,179; sh: 5,765; makefile: 5,040; pascal: 2,312; python: 1,812; f90: 1,773
file content (180 lines) | stat: -rw-r--r-- 5,037 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
******************************************************************************/


#pragma once
#include <string>
#include <vector>
#include <algorithm>
#include <list>
#include <algorithm>
#include "Params.hpp"
#include "MLNode.hpp"
#include "Engine.fwd.hpp"
#include "Task.hpp"
#include "proto/gxm.pb.h"

using namespace std;
using namespace gxm;
#ifdef USE_MLSL
#include "mlsl.hpp"
#endif

class NNParams : public MLParams
{
  protected:
    vector<string> top_;
    vector<string> bottom_;
    string nname_;
    string type_;
    int mode_;
    bool bp_flag_;

  public:
    NNParams(void) {}
    virtual ~NNParams(void) {}

    void set_top_names(string name) { top_.push_back(name); }
    void set_bottom_names(string name) { bottom_.push_back(name); }
    void set_node_name(string nname) { nname_ = nname; }
    void set_node_type(string type) {type_ = type; }
    void set_mode(int mode) { mode_ = mode; }
    void set_bprop_flag(bool flag) { bp_flag_ = flag; }

    string get_node_name() { return nname_; }
    vector<string>& get_top_names() { return top_; }
    vector<string>& get_bottom_names() { return bottom_; }
    string get_node_type() { return type_; }
    int get_mode() { return mode_; }
    bool get_bprop_flag() { return bp_flag_; }
};

class NNNode : public MLNode
{
  public:
    NNNode(NNParams* p, MLEngine* e) : MLNode(p, e)
    {
      for(int i = 0; i < 4; i++) tBasic_[i] = NULL;
    }

    virtual ~NNNode(void)
    {
      for(int i = 0; i < 4; i++) if(tBasic_[i] != NULL) { delete tBasic_[i]; tBasic_[i] = NULL; }
    }

    void createTasks(list<Task*>, int) {}
    virtual void createStrategy(int) {}

    virtual void forwardPropagate() {}
    virtual void backPropagate() {}
    virtual void weightUpdate() {}
    virtual void solverStep() {}

    int executeTask(int taskId)
    {
      if(taskId == 0)
      {
        forwardPropagate();
      }
      else if(taskId == 1)
      {
        backPropagate();
      }
      else if(taskId == 2)
      {
        weightUpdate();
      }
      else if(taskId == 3)
      {
        solverStep();
      }
      return 0;
    }

    void enqueTask(int pos) {}

    virtual void createPersistentTask() {}

    void setNextNode(NNNode* next)
    {
      //check if next is already in the nextNodes list
      if(std::find(nextNodes_.begin(), nextNodes_.end(), next) == nextNodes_.end())
      {
        nextNodes_.push_back(next);
        next->prevNodes_.push_back(this);
      }
    }

    void setPrevNode(NNNode* prev)
    {
      //check if prev is already in the prevNodes list
      if(std::find(prevNodes_.begin(), prevNodes_.end(), prev) == prevNodes_.end())
      {
        prevNodes_.push_back(prev);
        prev->nextNodes_.push_back(this);
      }
    }

    Task *getBasicTask(int type)
    {
      int index = -1;
      if(type == 0 || (type == 1 && bp_flag_) || (type > 1 && has_weights_))
        index = type;
      if(index != -1) {
        if(tBasic_[index] == NULL) tBasic_[index] = new Task(this, -1, type);
        return tBasic_[index];
      }
      return NULL;
    }

    void createNNGraph(int mode);

    void setNodeType(string type) { ntype_ = type; }

    string getNodeType() { return ntype_; }
    string getNodeName() { return nname_; }
    int getMode() { return mode_; }

    int getNumPrevNodes() { return prevNodes_.size(); }
    int getNumNextNodes() { return nextNodes_.size(); }

    NNNode* getPrevNode(int i) { if(prevNodes_.size() > 0) return prevNodes_[i]; else return NULL; }
    NNNode* getNextNode(int i) { if(nextNodes_.size() > 0) return nextNodes_[i]; else return NULL; }

    int get_num_tops() { return top_.size(); }
    void set_top_compute_engine(int e) { top_compute_engine_ = e; }
    int get_bot_compute_engine() { return bot_compute_engine_; }
    void set_next_node_type(string s) {next_ntype_ = s;}

    void refineTask(){}

    virtual void createCheckPoint() {}
    virtual void restoreCheckPoint() {}

  protected:
    string nname_, ntype_, next_ntype_;
    vector<string> top_;
    vector<string> bottom_;
    int mode_;
    bool bp_flag_;
    bool has_weights_;
    vector<NNNode*> prevNodes_;
    vector<NNNode*> nextNodes_;
    int top_compute_engine_, bot_compute_engine_;
#ifdef USE_MLSL
    MLSL::Operation* op_;
#endif


    // 0-Forw, 1-Back, 2-WGrad, 3-Solver
    Task *tBasic_[4];
};