File: net_async_task_graph.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (78 lines) | stat: -rw-r--r-- 2,253 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
#define CAFFE2_NET_ASYNC_TASK_GRAPH_H

#include "caffe2/core/net_async_base.h"
#include "caffe2/core/net_async_task.h"
#include "caffe2/core/net_async_task_future.h"
#include "caffe2/core/operator.h"

namespace caffe2 {

// AsyncTaskGraph represents an execution of a net, it owns the tasks and
// associated futures, sets up future callbacks and propagates errors.
// Usage steps:
// - Adding graph nodes and edges through CreateNode/AddDependency;
// - Freezing the graph (FreezeGraph), after the freezing a future
//   can be obtained using GetFuture;
// - Execution of the graph is scheduled through ExecuteGraph, after each
//   execution Reset must be called to prepare the graph for the next run

class AsyncTaskGraphBase {
 public:
  virtual bool CreateNode(
      int node_id,
      const std::vector<OperatorBase*>& ops) = 0;

  virtual bool AddDependency(
      int child_node_id,
      const std::vector<int>& parent_node_ids) = 0;

  virtual void FreezeGraph() = 0;

  virtual AsyncTaskFuture* ExecuteGraph() = 0;

  virtual AsyncTaskFuture* GetFuture() = 0;

  virtual void Reset() = 0;

  virtual ~AsyncTaskGraphBase() noexcept {}
};

class AsyncTaskGraph : public AsyncTaskGraphBase {
 public:
  AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);

  bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;

  bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
      override;

  void FreezeGraph() override;

  AsyncTaskFuture* ExecuteGraph() override;

  AsyncTaskFuture* GetFuture() override;

  void Reset() override;

 private:
  // used to, e.g., get access to executor's thread pools
  // TODO: pass tracer and counters through ExecutorHelper
  ExecutorHelper* helper_;
  ExecutionOptions options_;

  bool frozen_;

  std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
  std::unordered_map<int, std::unordered_set<int>> parents_;
  std::unordered_map<int, std::unordered_set<int>> children_;
  std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;

  std::vector<AsyncTask*> root_tasks_;

  std::unique_ptr<AsyncTaskFuture> run_future_;
};

} // namespace caffe2

#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H