1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
|
#pragma once
#include <ATen/core/jit_type.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/sparse_bitset.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/type_hashing.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <torch/csrc/Export.h>
// Uses a compressed index representation for faster comparisons
typedef c10::SparseBitVector<256> MemoryLocations;
namespace torch {
namespace jit {
struct Element;
struct Value;
class MemoryDAG;
using AliasTypeSet = std::vector<TypePtr>;
/**
* Helper to build up the points-to graph.
*
* We separate the "building" into a different class because it allows us to
* cache internally to MemoryDAG without worrying about how the DAG structure
* is mutated.
*/
class TORCH_API MemoryDAGBuilder {
public:
MemoryDAGBuilder() = default;
MemoryDAGBuilder(const MemoryDAGBuilder&) = delete;
MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete;
// Make `from` point at `to`.
void makePointerTo(Element* from, Element* to);
void addToContainedElements(Element* contained, Element* container);
// Make a fresh Element (i.e. an Element that doesn't point to anything) and
// return it.
Element* makeFreshValue(const Value* v);
friend MemoryDAG;
private:
// `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses
// the map to construct the `MemoryDAG`
std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
// class MemoryDAG
//
// This class tracks the "A points to B" graph for all values. It is used by
// AliasDb to provide a higher-level API.
//
// We maintain a DAG where:
// - Vertices (called "Elements") represent Values and
// other aliasing entities (e.g. the stuff inside a list)
// - Edges represent a "points-to" relationship.
//
// Leaves in this DAG are entities that don't point to anything, and thus
// correspond to unique "memory locations".
//
// So, by traversing the "points-to" graph to the leaves, you can determine
// which memory locations an element may point to.
class TORCH_API MemoryDAG {
public:
explicit MemoryDAG(std::unique_ptr<MemoryDAGBuilder> builder)
: indexToElementMap_(std::move(builder->indexToElementMap_)) {}
// explicitly delete copy constructor because otherwise windows build is
// confused for an exported class see
// https://stackoverflow.com/a/51033485/105137
MemoryDAG(const MemoryDAG&) = delete;
MemoryDAG& operator=(const MemoryDAG&) = delete;
// Return the unique memory locations that `Element` might represent.
const MemoryLocations& getMemoryLocations(const Element* e) const;
// Do `a` and `b` potentially share a memory location?
bool mayAlias(const Element* a, const Element* b) const;
// Does `a` hold reference to any memory that is stored in `b`, or vice versa?
bool mayContainAlias(const Element* a, const Element* b) const;
bool mayContainAlias(const Element* a, const at::ArrayRef<Element*> b) const;
bool mayContainAlias(
const at::ArrayRef<Element*> a,
const at::ArrayRef<Element*> b) const;
// Converts from the compressed index representation
const Element* fromIndex(unsigned x) const;
Element* fromIndex(unsigned x);
void collectAllContainedMemoryLocations(
const Element* elem,
MemoryLocations& cont) const;
/**
* The following methods are special cases where we need to mutate the
* internals of MemoryDAG for efficiency reasons. Don't call them unless you
* know what you're doing! In particular, don't add new mutating methods
* without ensuring that you are maintaining cache consistency for memory
* locations.
*/
// Adding wildcards can trigger extremely expensive cache invalidations. This
// method adds them in a more efficient cache-aware way.
void setWildcards(
const std::unordered_set<const Value*>& wildcards,
const ska::flat_hash_map<const Value*, Element*>& elementMap,
const std::function<Element*(const Value*)>& getWildcardElement);
Element* unsafeMakeFreshValue(const Value* v);
private:
const MemoryLocations& getAllContainedMemoryLocations(
const Element* elem) const;
void collectAllContainedMemoryLocationsImpl(
const Element* elem,
MemoryLocations& cont) const;
std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
// `Element` represents a vertex in the points-to graph. It represents
// anything that could have an aliasing relationship--mostly IR
// `Value`s, but also wildcards or the type inside a container (e.g. `T`
// in `List[T]`)
struct Element {
Element(const Value* value_, unsigned index_);
// wildcard constructor
explicit Element(unsigned index_);
// Index into the owning DAG's bit vector that represents this element.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
unsigned index;
// All elements that this element *may* point to. It's possible to have
// multiple elements that you might point to due to control flow/complex ops
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations pointsTo;
// Backreference for points-to.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations pointedFrom;
// Elements can contain other elements (e.g. List[Tensor])
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations containedElements;
// The values that this element corresponds to. May be empty if this element
// doesn't represent a first-class value.
// This is for debug information only.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unordered_set<const Value*> values;
private:
// Make `from` point at `to`.
void makePointerTo(Element* from, Element* to);
friend class MemoryDAG;
// We memoize the results of `getMemoryLocations` to speed up queries.
// A nullopt means that this cache is not yet populated. Since `MemoryDAG` is
// immutable, this cache should never need to be invalidated.
mutable c10::optional<MemoryLocations> cachedMemoryLocations_;
mutable c10::optional<MemoryLocations> cachedAllContainedMemoryLocations_;
};
} // namespace jit
} // namespace torch
|