1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
|
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
******************************************************************************/
#pragma once
#include <string>
#include <stdio.h>
#include "assert.h"
#include "Node.hpp"
#include "Engine.hpp"
#include "Params.hpp"
#include "Tensor.hpp"
#include "proto/gxm.pb.h"
#include "SplitImpl.hpp"
#include "SplitLoop.hpp"
using namespace std;
using namespace gxm;
class SplitParams : public NNParams
{
public:
SplitParams(void) {}
virtual ~SplitParams(void) {}
void set_data_type(int t) { data_type_ = t; }
int get_data_type() { return data_type_; }
void set_compute_engine(int ce) { compute_engine_ = ce; }
int get_compute_engine() { return compute_engine_; }
protected:
int compute_engine_, data_type_;
};
static MLParams* parseSplitParams(NodeParameter* np)
{
SplitParams* sp = new SplitParams();
// Set name of node
string str = np->name();
assert(!str.empty());
sp->set_node_name(str);
//Set node type (Convolution, FullyConnected, etc)
str = np->type();
assert(!str.empty());
sp->set_node_type(str);
//Set tensor names
assert(np->bottom_size() == 1);
assert(!np->bottom(0).empty());
sp->set_bottom_names(np->bottom(0));
for(int i=0; i<np->top_size(); i++)
sp->set_top_names(np->top(i));
//Set Mode for the node
assert((np->mode() == TRAIN) || (np->mode() == TEST));
sp->set_mode(np->mode());
//Set backprop needed/not needed flag for this node
sp->set_bprop_flag(np->propagate_down());
SplitParameter psp = np->split_param();
sp->set_data_type(psp.data_type());
sp->set_compute_engine(psp.engine());
return sp;
}
class SplitNode : public NNNode
{
public:
SplitNode(SplitParams* p, MLEngine* e);
virtual ~SplitNode(void) {}
protected:
void forwardPropagate();
void backPropagate();
void configure(int engine);
void convert_bf16_f32(libxsmm_bfloat16*, float*, int);
void shape_setzero(Shape* s)
{
for(int i=0; i<MAX_DIMS; i++)
s->dims[i] = 0;
}
vector<Tensor *>tenTop_;
Tensor *tenBot_;
vector<TensorBuf *> tenTopData_, tenTopDiff_;
TensorBuf *tenBotData_, *tenBotDiff_;
int bot_cengine_;
int count_, in_dtype, out_dtype;
float *stptr=NULL, cbptr[16];
SplitImplParams gparams_;
SplitImpl *impl=NULL;
MLEngine* eptr_;
};
|