1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
******************************************************************************/
#pragma once
#include <vector>
#include "MLNode.hpp"
#include <algorithm>
#include "Tensor.hpp"
#define BASIC_TASK_FORW 0
#define BASIC_TASK_BACK 1
#define BASIC_TASK_WGRAD 2
#define BASIC_TASK_SOLVE 3
#define CUSTOM_TASK_START 100
using namespace std;
using namespace gxm;
class Task
{
protected:
MLNode *node_;
int taskId_;
int basicTaskId_;
int minBin_, maxBin_;
vector<Task*> inputs_;
vector<Task*> outputs_;
vector<Task*> subTasks_;
Task *parent_;
public:
Task(MLNode* n, int taskId, int basicTaskId)
{
this->node_ = n;
this->taskId_ = taskId;
this->basicTaskId_ = basicTaskId;
this->minBin_ = 0;
this->maxBin_ = 0;
parent_ = NULL;
}
virtual ~Task(void) {}
Task *createSubTask(int taskId) {
Task *subTask = new Task(this->node_, taskId, basicTaskId_);
this->subTasks_.push_back(subTask);
subTask->parent_ = this;
return subTask;
}
bool addForwDep(Task *dest) {
if(dest == NULL) return false;
// add only if task is not in the list
if(std::find(outputs_.begin(), outputs_.end(), dest) == outputs_.end())
{
this->outputs_.push_back(dest);
if(std::find(dest->inputs_.begin(), dest->inputs_.end(), this) == dest->inputs_.end())
dest->inputs_.push_back(this);
return true;
}
else
return false;
}
bool addBackDep(Task *src) {
if(src == NULL) return false;
// add only if task is not in the list
if(std::find(inputs_.begin(), inputs_.end(), src) == inputs_.end())
{
this->inputs_.push_back(src);
if(std::find(src->outputs_.begin(), src->outputs_.end(), this) == src->outputs_.end())
src->outputs_.push_back(this);
return true;
}
else
return false;
}
vector<Task*>& getForwDepTasks() { return this->outputs_; }
vector<Task*>& getBackDepTasks() { return this->inputs_; }
void setMinBin(int bin) { minBin_ = bin; }
void setMaxBin(int bin) { maxBin_ = bin; }
int getMinBin() { return minBin_; }
int getMaxBin() { return maxBin_; }
int getBasicTaskId() {return basicTaskId_; }
int getTaskId() {return taskId_; }
MLNode* getNode() { return node_; }
void invoke() { node_->executeTask(basicTaskId_); }
inline int numInputs() { return inputs_.size(); }
inline int numOutputs() { return outputs_.size(); }
};
|