1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
|
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
#include <torch/csrc/lazy/core/cache.h>
// SHAPE CACHINHG CODE
namespace torch {
namespace jit {
namespace {
using CanonicalArg = c10::variant<CanonicalizedSymbolicShape, IValue>;
using CanonicalArgVec = std::vector<CanonicalArg>;
using CanonicalRet = std::vector<CanonicalizedSymbolicShape>;
using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>;
CanonicalArgVec cannonicalizeVec(
const std::vector<SSAInput>& arg_vec,
std::unordered_map<int64_t, int64_t>& ss_map,
bool deep_copy = true) {
CanonicalArgVec canonical_args;
canonical_args.reserve(arg_vec.size());
for (auto& arg : arg_vec) {
if (const IValue* iv = c10::get_if<IValue>(&arg)) {
if (deep_copy) {
canonical_args.push_back(iv->deepcopy());
} else {
canonical_args.push_back(*iv);
}
} else {
auto& ss = c10::get<at::SymbolicShape>(arg);
canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
}
}
return canonical_args;
}
std::vector<CanonicalizedSymbolicShape> cannonicalizeVec(
const std::vector<at::SymbolicShape>& ret_vec,
std::unordered_map<int64_t, int64_t>& ss_map) {
std::vector<CanonicalizedSymbolicShape> canonical_rets;
canonical_rets.reserve(ret_vec.size());
for (auto& ss : ret_vec) {
canonical_rets.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
}
return canonical_rets;
}
struct ArgumentsHasher {
size_t operator()(const ShapeCacheKey& cacheKey) const {
// TODO: ignore arguments that are not used in shape function (not needed
// initially)
auto& op_name = std::get<0>(cacheKey);
auto& arg_vec = std::get<1>(cacheKey);
size_t hash_val = c10::hash<c10::OperatorName>()(op_name);
hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val);
for (const CanonicalArg& arg : arg_vec) {
size_t cur_arg = 0;
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
// IValue doesn't hash List (as Python doesn't), so we will do a custom
// list hash
if (ival->isList()) {
TORCH_INTERNAL_ASSERT(ival->isIntList(), "Unexpected Args in List");
cur_arg = ival->toListRef().size();
for (const IValue& elem_ival : ival->toListRef()) {
cur_arg = at::hash_combine(cur_arg, IValue::hash(elem_ival));
}
} else {
cur_arg = IValue::hash(ival);
}
} else {
cur_arg = c10::get<CanonicalizedSymbolicShape>(arg).hash();
}
hash_val = at::hash_combine(hash_val, cur_arg);
}
return hash_val;
}
};
using ShapeCache = lazy::Cache<
ShapeCacheKey,
std::vector<CanonicalizedSymbolicShape>,
ArgumentsHasher>;
constexpr size_t kShapeCacheSize = 1024;
ShapeCache shapeCache(kShapeCacheSize);
ShapeCacheKey get_cache_key(
const FunctionSchema* schema,
const std::vector<SSAInput>& arg_vec,
std::unordered_map<int64_t, int64_t>& ss_map,
bool deep_copy = true) {
CanonicalArgVec canonical_args = cannonicalizeVec(arg_vec, ss_map, deep_copy);
return std::make_tuple(schema->operator_name(), canonical_args);
}
} // namespace
TORCH_API void cache_shape_function(
const FunctionSchema* schema,
const std::vector<SSAInput>& arg_vec,
const std::vector<at::SymbolicShape>& ret_vec) {
// TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>>
auto ss_map = std::unordered_map<int64_t, int64_t>();
auto cache_key = get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ true);
auto can_ret_vec = std::make_shared<std::vector<CanonicalizedSymbolicShape>>(
cannonicalizeVec(ret_vec, ss_map));
shapeCache.Add(cache_key, can_ret_vec);
}
TORCH_API c10::optional<std::vector<at::SymbolicShape>>
get_cached_shape_function(
const FunctionSchema* schema,
const std::vector<SSAInput>& arg_vec) {
// TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>> for both
// ss_map and inverse_ss_map
auto ss_map = std::unordered_map<int64_t, int64_t>();
auto cache_key =
get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false);
auto cached_ret_vec = shapeCache.Get(cache_key);
if (cached_ret_vec == nullptr) {
return c10::nullopt;
}
// Decanonicalize the return values
auto inverse_ss_map = std::unordered_map<int64_t, int64_t>();
for (auto& ss_val : ss_map) {
inverse_ss_map[ss_val.second] = ss_val.first;
}
std::vector<at::SymbolicShape> ret_vec;
for (auto& css : *cached_ret_vec) {
ret_vec.emplace_back(css.toSymbolicShape(inverse_ss_map));
}
return ret_vec;
}
// Function only to access the cache, used for testing
TORCH_API void clear_shape_cache() {
shapeCache.Clear();
}
TORCH_API size_t get_shape_cache_size() {
return shapeCache.Numel();
}
void CanonicalizedSymbolicShape::init(
const c10::SymbolicShape& orig_shape,
std::unordered_map<int64_t, int64_t>& ss_map) {
auto sizes = orig_shape.sizes();
if (!sizes) {
values_ = c10::nullopt;
return;
}
values_ = std::vector<int64_t>();
int64_t cur_symbolic_index = -static_cast<int64_t>(ss_map.size()) - 1;
for (auto& cur_shape : *sizes) {
if (cur_shape.is_static()) {
values_->push_back(cur_shape.static_size());
} else {
// Check for aliasing
auto it = ss_map.find(cur_shape.value());
if (it == ss_map.end()) {
values_->push_back(cur_symbolic_index);
ss_map.insert({cur_shape.value(), cur_symbolic_index});
cur_symbolic_index--;
} else {
values_->push_back(it->second);
}
}
}
}
c10::SymbolicShape CanonicalizedSymbolicShape::toSymbolicShape(
std::unordered_map<int64_t, int64_t>& inverse_ss_map) const {
if (!values_.has_value()) {
return c10::SymbolicShape();
}
std::vector<at::ShapeSymbol> sizes;
for (long long cur_val : *values_) {
if (cur_val >= 0) {
sizes.push_back(at::ShapeSymbol::fromStaticSize(cur_val));
continue;
}
auto res = inverse_ss_map.find(cur_val);
if (res != inverse_ss_map.end()) {
sizes.push_back(at::ShapeSymbol::fromStaticSize(res->second));
} else {
auto new_symbol = at::ShapeSymbol::newSymbol();
inverse_ss_map.insert({cur_val, new_symbol.value()});
sizes.push_back(new_symbol);
}
}
return c10::SymbolicShape(std::move(sizes));
}
size_t CanonicalizedSymbolicShape::hash() const {
if (!values_.has_value()) {
return 0x8cc80c80; // random value to prevent hash collisions
}
return c10::hash<std::vector<int64_t>>()(values_.value());
}
bool operator==(
const CanonicalizedSymbolicShape& a,
const CanonicalizedSymbolicShape& b) {
return a.values_ == b.values_;
};
} // namespace jit
} // namespace torch
|