File: fusion_cache.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (111 lines) | stat: -rw-r--r-- 4,000 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#pragma once
#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h>

#include <memory>

//! nvFuser Fusion IR namespace abbreviation
namespace Nvf = torch::jit::fuser::cuda;

namespace nvfuser {

struct RecordFunctor;

//! \struct FusionCacheEntry
//! \brief Is the container for a Node in the cache contained in the
//! FusionCache that is organized as a prefix tree.

struct TORCH_CUDA_CU_API FusionCacheEntry {
  FusionCacheEntry(RecordFunctor* rec, size_t _fusion_id = 0);

  // Queries whether the entry denotes a leaf node which also represents
  // a the end of Fusion entry in the cache.
  bool isTerminal() const;

  //! An entry's primary data is the record it holds
  std::unique_ptr<RecordFunctor> record;
  //! A hash map of the children for the current node.
  //! The hash map hashs a pointer to a RecordFunctor because
  //! the hash function is virtual.
  std::unordered_map<RecordFunctor*, std::unique_ptr<FusionCacheEntry>>
      record_hash_map;
  //! An index into FusionCache's vector of nvFuser object that holds an
  //! unscheduled Fusion.  The id is only valid if the entry is terminal.
  size_t fusion_id;
  //! Count of times the Entry is traversed
  size_t visits;
};

//! \class FusionCache
//! \brief A singleton class used in the nvFuser python interface
//! to manage the caching of fusions.
//!
//! The fusion cache implements a prefix tree of records in order to cache
//! fusions.  A leaf of the tree with a terminal node contains an nvFuser
//! Fusion IR container for a cached instance.
//!
//! \todo Add the ability to evict a fusion.  There is currently a max number
//! of fusions that is checked to prevent a runaway case.

class TORCH_CUDA_CU_API FusionCache {
  //! The constructor is private given the FusionCache is only constructed
  //! as a singleton.
  FusionCache(size_t max_fusions);

  //! Copy and Assignment of the FusionCache is not supported
  FusionCache(const FusionCache&) = delete;
  FusionCache& operator=(const FusionCache&) = delete;

 public:
  //! The next 2 pubic methods are the python interface methods

  //! Gets a pointer to the singleton and creates a new one if necessary
  static FusionCache* get(size_t max_fusions = 8192);
  //! Number of fusions cached
  size_t numFusions() const;
  //! print cache stats
  void print(std::ostream& os);
  //! Reset Cache to an empty state
  static void reset();

  //! The rest of the public methods are only used in C++

  //! Queries the current cache entry to see if a record matches one of its
  //! children
  c10::optional<FusionCacheEntry*> lookupFusionCacheEntry(
      RecordFunctor* rec) const;
  //! Creates a child node for the current cache entry and an optional
  //! fusion_id is returned if the new entry is terminal
  c10::optional<size_t> createFusionCacheEntry(RecordFunctor* rec);
  //! Resets the current cache pointer to the top of the tree
  void resetFusionCachePtr();
  //! Traverses the cache from the current entry to the child associated
  //! with the record given.
  void traverseFusionCache(RecordFunctor* rec);

  friend class FusionInterface;

 private:
  //! Returns the pointer to the current cache entry
  FusionCacheEntry* fusionCachePtr() const;

  //! The static pointer to the FusionCache
  static FusionCache* singleton_;

  //! The max allowed number of fusions in the cache
  size_t max_fusions_;
  //! The top of the prefix tree used to start a cache look up of a given
  //! fusion definition.
  std::unique_ptr<FusionCacheEntry> fusion_cache_start_;
  //! A pointer to the current cache entry in a cache lookup of a fusion
  //! definition.
  FusionCacheEntry* fusion_cache_ptr_;
  //! A vector of nvFuser Fusion IR fusions.
  std::vector<std::unique_ptr<Nvf::FusionExecutorCache>> fusions_;
  //! A vector of Terminal Cache Entries for Stats collection
  std::vector<FusionCacheEntry*> terminal_cache_entries_;
};

} // namespace nvfuser