1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h>
#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h>
#include <mutex>
namespace nvfuser {
static std::mutex fusion_cache_lock;
FusionCache* FusionCache::singleton_ = nullptr;
FusionCacheEntry::FusionCacheEntry(RecordFunctor* rec, size_t _fusion_id)
: record(rec), record_hash_map(), fusion_id(_fusion_id), visits(0) {}
bool FusionCacheEntry::isTerminal() const {
return (record.get()->recordType() == RecordType::End);
}
FusionCache* FusionCache::get(size_t max_fusions) {
std::lock_guard<std::mutex> guard(fusion_cache_lock);
if (singleton_ == nullptr) {
singleton_ = new FusionCache(max_fusions);
}
TORCH_CHECK(
max_fusions >= singleton_->fusions_.size(),
"The max fusions is set less than the number of fusions in the cache.");
singleton_->max_fusions_ = max_fusions;
return singleton_;
}
size_t FusionCache::numFusions() const {
return fusions_.size();
}
void FusionCache::print(std::ostream& os) {
os << "Total Fusions: " << fusions_.size() << "\n";
// Does not make sense to print stats if the cache is disabled.
if (fusions_.size() > 0) {
os << "Cache Hits by Fusion Id:\n";
auto total_cache_hits = 0;
for (size_t i = 0; i < terminal_cache_entries_.size(); ++i) {
// The first visit is a miss!
auto visits = terminal_cache_entries_[i]->visits - 1;
total_cache_hits += visits;
os << "\t" << i << " -> " << visits << " hits\n";
}
auto hit_rate = static_cast<float>(total_cache_hits) /
static_cast<float>(fusion_cache_start_->visits) * 100.0;
os << "Cache Lookups: " << fusion_cache_start_->visits;
os << " Cache Hits: " << total_cache_hits;
os << " Hit Rate: " << hit_rate << "%\n";
}
}
void FusionCache::reset() {
std::lock_guard<std::mutex> guard(fusion_cache_lock);
if (singleton_ != nullptr) {
auto max_fusions = singleton_->max_fusions_;
delete singleton_;
singleton_ = new FusionCache(max_fusions);
}
}
FusionCache::FusionCache(size_t max_fusions)
: max_fusions_(max_fusions),
fusion_cache_start_(nullptr),
fusion_cache_ptr_(nullptr),
fusions_() {
RecordFunctor* start = new StartRecord();
fusion_cache_start_ = std::make_unique<FusionCacheEntry>(start);
fusion_cache_ptr_ = fusion_cache_start_.get();
}
c10::optional<FusionCacheEntry*> FusionCache::lookupFusionCacheEntry(
RecordFunctor* rec) const {
TORCH_CHECK(
!fusionCachePtr()->isTerminal(),
"There should be no children from a Terminal Cache Entry!");
TORCH_CHECK(rec, "Record is null!");
auto cache_entry = fusionCachePtr()->record_hash_map.find(rec);
if (cache_entry == std::end(fusionCachePtr()->record_hash_map)) {
return c10::nullopt;
} else {
return c10::optional<FusionCacheEntry*>(cache_entry->second.get());
}
}
c10::optional<size_t> FusionCache::createFusionCacheEntry(RecordFunctor* rec) {
c10::optional<size_t> result = c10::nullopt;
TORCH_CHECK(
!fusionCachePtr()->isTerminal(),
"Cannot create a cache entry from a terminal entry!");
TORCH_CHECK(rec, "Record is null!");
size_t fusion_id = 0;
if (rec->recordType() == RecordType::End) {
TORCH_CHECK(
(fusions_.size() + 1) <= max_fusions_,
"The number of fusions in nvfuser has exceeded ",
max_fusions_,
"fusions. The max_fusions for the FusionCache might need to be ",
"increased if the max number is not being exceeded due to an error.");
fusions_.push_back(std::make_unique<Nvf::FusionExecutorCache>(
std::make_unique<Nvf::Fusion>()));
fusion_id = fusions_.size() - 1;
result = c10::optional<size_t>(fusion_id);
}
// Copying the record owned by the FusionDefinition that calls this function
// so the cache owns a copy when the FusionDefinition gets destroyed rather
// than managing a shared pointer that would only share with
// FusionDefinition that creates a cache entry but not cache lookups
RecordFunctor* new_rec = rec->clone();
fusionCachePtr()->record_hash_map[new_rec] =
std::make_unique<FusionCacheEntry>(new_rec, fusion_id);
if (rec->recordType() == RecordType::End) {
terminal_cache_entries_.push_back(
fusionCachePtr()->record_hash_map[new_rec].get());
}
if (Nvf::isDebugDumpEnabled(Nvf::DebugDumpOption::PythonFrontendDebug)) {
std::stringstream ss;
new_rec->print(ss);
std::cout << "\nFusionDefinition: Create new cache entry for: " << ss.str()
<< "\n";
}
return result;
}
void FusionCache::resetFusionCachePtr() {
fusion_cache_ptr_ = fusion_cache_start_.get();
TORCH_CHECK(fusionCachePtr()->record->recordType() == RecordType::Start);
++(fusionCachePtr()->visits);
}
void FusionCache::traverseFusionCache(RecordFunctor* rec) {
TORCH_CHECK(
!fusionCachePtr()->isTerminal(),
"Cannot traverse cache from a terminal entry!");
auto cache_entry = fusionCachePtr()->record_hash_map.find(rec);
TORCH_CHECK(
cache_entry != std::end(fusionCachePtr()->record_hash_map),
"Cache Entry for Cache Traverse is not found!");
TORCH_CHECK(cache_entry->second, "Record in Cache Entry is null!");
fusion_cache_ptr_ = cache_entry->second.get();
++(fusionCachePtr()->visits);
}
FusionCacheEntry* FusionCache::fusionCachePtr() const {
TORCH_INTERNAL_ASSERT(
fusion_cache_ptr_ != nullptr,
"The fusion cache entry is unexpectedly null.");
return fusion_cache_ptr_;
}
} // namespace nvfuser
|