File: net_async_task_graph.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (78 lines) | stat: -rw-r--r-- 2,253 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
#define CAFFE2_NET_ASYNC_TASK_GRAPH_H

#include "caffe2/core/net_async_base.h"
#include "caffe2/core/net_async_task.h"
#include "caffe2/core/net_async_task_future.h"
#include "caffe2/core/operator.h"

namespace caffe2 {

// AsyncTaskGraph represents an execution of a net, it owns the tasks and
// associated futures, sets up future callbacks and propagates errors.
// Usage steps:
// - Adding graph nodes and edges through CreateNode/AddDependency;
// - Freezing the graph (FreezeGraph), after the freezing a future
//   can be obtained using GetFuture;
// - Execution of the graph is scheduled through ExecuteGraph, after each
//   execution Reset must be called to prepare the graph for the next run

class AsyncTaskGraphBase {
 public:
  virtual bool CreateNode(
      int node_id,
      const std::vector<OperatorBase*>& ops) = 0;

  virtual bool AddDependency(
      int child_node_id,
      const std::vector<int>& parent_node_ids) = 0;

  virtual void FreezeGraph() = 0;

  virtual AsyncTaskFuture* ExecuteGraph() = 0;

  virtual AsyncTaskFuture* GetFuture() = 0;

  virtual void Reset() = 0;

  virtual ~AsyncTaskGraphBase() noexcept {}
};

class AsyncTaskGraph : public AsyncTaskGraphBase {
 public:
  AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);

  bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;

  bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
      override;

  void FreezeGraph() override;

  AsyncTaskFuture* ExecuteGraph() override;

  AsyncTaskFuture* GetFuture() override;

  void Reset() override;

 private:
  // used to, e.g., get access to executor's thread pools
  // TODO: pass tracer and counters through ExecutorHelper
  ExecutorHelper* helper_;
  ExecutionOptions options_;

  bool frozen_;

  std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
  std::unordered_map<int, std::unordered_set<int>> parents_;
  std::unordered_map<int, std::unordered_set<int>> children_;
  std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;

  std::vector<AsyncTask*> root_tasks_;

  std::unique_ptr<AsyncTaskFuture> run_future_;
};

} // namespace caffe2

#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H