1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
#define CAFFE2_NET_ASYNC_TASK_GRAPH_H
#include "caffe2/core/net_async_base.h"
#include "caffe2/core/net_async_task.h"
#include "caffe2/core/net_async_task_future.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
// AsyncTaskGraph represents an execution of a net, it owns the tasks and
// associated futures, sets up future callbacks and propagates errors.
// Usage steps:
// - Adding graph nodes and edges through CreateNode/AddDependency;
// - Freezing the graph (FreezeGraph), after the freezing a future
// can be obtained using GetFuture;
// - Execution of the graph is scheduled through ExecuteGraph, after each
// execution Reset must be called to prepare the graph for the next run
class AsyncTaskGraphBase {
public:
virtual bool CreateNode(
int node_id,
const std::vector<OperatorBase*>& ops) = 0;
virtual bool AddDependency(
int child_node_id,
const std::vector<int>& parent_node_ids) = 0;
virtual void FreezeGraph() = 0;
virtual AsyncTaskFuture* ExecuteGraph() = 0;
virtual AsyncTaskFuture* GetFuture() = 0;
virtual void Reset() = 0;
virtual ~AsyncTaskGraphBase() noexcept {}
};
class AsyncTaskGraph : public AsyncTaskGraphBase {
public:
AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);
bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;
bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
override;
void FreezeGraph() override;
AsyncTaskFuture* ExecuteGraph() override;
AsyncTaskFuture* GetFuture() override;
void Reset() override;
private:
// used to, e.g., get access to executor's thread pools
// TODO: pass tracer and counters through ExecutorHelper
ExecutorHelper* helper_;
ExecutionOptions options_;
bool frozen_;
std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
std::unordered_map<int, std::unordered_set<int>> parents_;
std::unordered_map<int, std::unordered_set<int>> children_;
std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;
std::vector<AsyncTask*> root_tasks_;
std::unique_ptr<AsyncTaskFuture> run_future_;
};
} // namespace caffe2
#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H
|