File: Split.hpp

package info (click to toggle)
libxsmm 1.17-4
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 14,976 kB
  • sloc: ansic: 119,587; cpp: 27,680; fortran: 9,179; sh: 5,765; makefile: 5,040; pascal: 2,312; python: 1,812; f90: 1,773
file content (110 lines) | stat: -rw-r--r-- 3,116 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
******************************************************************************/


#pragma once
#include <string>
#include <stdio.h>
#include "assert.h"
#include "Node.hpp"
#include "Engine.hpp"
#include "Params.hpp"
#include "Tensor.hpp"
#include "proto/gxm.pb.h"
#include "SplitImpl.hpp"
#include "SplitLoop.hpp"

using namespace std;
using namespace gxm;

class SplitParams : public NNParams
{
    public:
      SplitParams(void) {}
      virtual ~SplitParams(void) {}

      void set_data_type(int t) { data_type_ = t; }
      int get_data_type() { return data_type_; }

      void set_compute_engine(int ce) { compute_engine_ = ce; }
      int get_compute_engine() { return compute_engine_; }

    protected:
      int compute_engine_, data_type_;
};

static MLParams* parseSplitParams(NodeParameter* np)
{
    SplitParams* sp = new SplitParams();

    // Set name of node
    string str = np->name();
    assert(!str.empty());
    sp->set_node_name(str);

    //Set node type (Convolution, FullyConnected, etc)
    str = np->type();
    assert(!str.empty());
    sp->set_node_type(str);

    //Set tensor names
    assert(np->bottom_size() == 1);
    assert(!np->bottom(0).empty());
    sp->set_bottom_names(np->bottom(0));

    for(int i=0; i<np->top_size(); i++)
      sp->set_top_names(np->top(i));

    //Set Mode for the node
    assert((np->mode() == TRAIN) || (np->mode() == TEST));
    sp->set_mode(np->mode());

    //Set backprop needed/not needed flag for this node
    sp->set_bprop_flag(np->propagate_down());

    SplitParameter psp = np->split_param();

    sp->set_data_type(psp.data_type());
    sp->set_compute_engine(psp.engine());

    return sp;
}

class SplitNode : public NNNode
{
    public:
      SplitNode(SplitParams* p, MLEngine* e);
      virtual ~SplitNode(void) {}

    protected:
      void forwardPropagate();
      void backPropagate();
      void configure(int engine);
      void convert_bf16_f32(libxsmm_bfloat16*, float*, int);

      void shape_setzero(Shape* s)
      {
        for(int i=0; i<MAX_DIMS; i++)
          s->dims[i] = 0;
      }

      vector<Tensor *>tenTop_;
      Tensor *tenBot_;
      vector<TensorBuf *> tenTopData_, tenTopDiff_;
      TensorBuf *tenBotData_, *tenBotDiff_;
      int bot_cengine_;
      int count_, in_dtype, out_dtype;
      float *stptr=NULL, cbptr[16];

      SplitImplParams gparams_;
      SplitImpl *impl=NULL;
      MLEngine* eptr_;
};