File: memory_dag.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (234 lines) | stat: -rw-r--r-- 7,631 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#include <torch/csrc/jit/passes/utils/memory_dag.h>

#include <c10/util/flat_hash_map.h>
#include <algorithm>
#include <queue>

namespace torch::jit {
namespace {

void makePointerToImpl(Element* from, Element* to) {
  from->pointsTo.set(to->index);
  to->pointedFrom.set(from->index);
}

Element* makeFreshValueImpl(
    const Value* v,
    std::vector<std::unique_ptr<Element>>& indexToElementMap_) {
  if (v == nullptr) {
    // Create a wildcard element, with no corresponding value
    indexToElementMap_.emplace_back(
        std::make_unique<Element>(indexToElementMap_.size()));
    return indexToElementMap_.back().get();
  }
  indexToElementMap_.emplace_back(
      std::make_unique<Element>(v, indexToElementMap_.size()));
  return indexToElementMap_.back().get();
}
} // namespace

Element::Element(const Value* value_, unsigned index_)
    : index(index_), values({value_}) {}
Element::Element(unsigned index_) : index(index_), values({}) {}

const Element* MemoryDAG::fromIndex(unsigned x) const {
  TORCH_INTERNAL_ASSERT(x < indexToElementMap_.size());
  return indexToElementMap_[x].get();
}

Element* MemoryDAG::fromIndex(unsigned x) {
  TORCH_INTERNAL_ASSERT(x < indexToElementMap_.size());
  return indexToElementMap_[x].get();
}

bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
  const auto& aMemLoc = getMemoryLocations(a);
  const auto& bMemLoc = getMemoryLocations(b);

  return aMemLoc.intersects(bMemLoc);
}

bool MemoryDAG::mayContainAlias(const Element* a, const Element* b) const {
  return getAllContainedMemoryLocations(a).intersects(
      getAllContainedMemoryLocations(b));
}

const MemoryLocations& MemoryDAG::getAllContainedMemoryLocations(
    const Element* elem) const {
  if (C10_UNLIKELY(!elem->cachedAllContainedMemoryLocations_.has_value())) {
    MemoryLocations cache;
    elem->cachedAllContainedMemoryLocations_ = MemoryLocations();
    collectAllContainedMemoryLocationsImpl(
        elem, *elem->cachedAllContainedMemoryLocations_);
  }
  return *elem->cachedAllContainedMemoryLocations_;
}

void MemoryDAG::collectAllContainedMemoryLocations(
    const Element* elem,
    MemoryLocations& cont) const {
  // we have already recursed on this element
  unsigned compIdx = elem->index;
  if (cont.test(compIdx)) {
    return;
  }

  if (C10_UNLIKELY(!elem->cachedAllContainedMemoryLocations_.has_value())) {
    MemoryLocations cache;
    collectAllContainedMemoryLocationsImpl(elem, cache);
    elem->cachedAllContainedMemoryLocations_ = std::move(cache);
  }
  cont |= *elem->cachedAllContainedMemoryLocations_;
}

void MemoryDAG::collectAllContainedMemoryLocationsImpl(
    const Element* elem,
    MemoryLocations& cont) const {
  unsigned compIdx = elem->index;
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!cont.test(compIdx));
  cont.set(compIdx);

  for (const auto& mem_loc : getMemoryLocations(elem)) {
    collectAllContainedMemoryLocations(fromIndex(mem_loc), cont);
  }

  for (const auto& contained : elem->containedElements) {
    collectAllContainedMemoryLocations(fromIndex(contained), cont);
  }
}

bool MemoryDAG::mayContainAlias(
    const Element* a,
    const at::ArrayRef<Element*> b) const {
  if (b.empty()) {
    return false;
  }

  const auto& a_contained = getAllContainedMemoryLocations(a);
  return std::any_of(b.begin(), b.end(), [this, &a_contained](Element* b_elem) {
    return a_contained.intersects(this->getAllContainedMemoryLocations(b_elem));
  });
}

bool MemoryDAG::mayContainAlias(
    const at::ArrayRef<Element*> a,
    const at::ArrayRef<Element*> b) const {
  if (a.empty() || b.empty()) {
    return false;
  }

  MemoryLocations all_a_mlocs;
  for (const auto& elem : a) {
    collectAllContainedMemoryLocations(elem, all_a_mlocs);
  }

  MemoryLocations all_b_mlocs;
  for (const auto& elem : b) {
    collectAllContainedMemoryLocations(elem, all_b_mlocs);
  }

  return all_a_mlocs.intersects(all_b_mlocs);
}

void MemoryDAGBuilder::makePointerTo(Element* from, Element* to) {
  makePointerToImpl(from, to);
}

void MemoryDAGBuilder::addToContainedElements(
    Element* elem,
    Element* container) {
  TORCH_INTERNAL_ASSERT(
      elem != container, "Elements cannot contain themselves");
  container->containedElements.set(elem->index);
}

// Give `v` a fresh alias (i.e. it does not point to any value)
Element* MemoryDAGBuilder::makeFreshValue(const Value* v) {
  return makeFreshValueImpl(v, indexToElementMap_);
}

// This function builds up a bitset representing the "alias set" for
// `e` (`MemoryLocations` is just a typedef'd c10::SparseBitVector).
const MemoryLocations& MemoryDAG::getMemoryLocations(const Element* e) const {
  // Note on cache invalidation: all mutation should occur through
  // MemoryDAGBuilder. Thus, once we consume the builder to create an
  // immutable MemoryDAG, we can cache here without worrying that we
  // might potentially get invalidated.
  if (e->cachedMemoryLocations_) {
    return *e->cachedMemoryLocations_;
  }

  MemoryLocations ret;
  if (e->pointsTo.empty()) {
    // Base case: if we don't point to anything, this element is a memory
    // location. Return itself.
    ret.set(e->index);
  } else {
    for (auto el : e->pointsTo) {
      ret |= getMemoryLocations(fromIndex(el));
    }
  }

  e->cachedMemoryLocations_ = std::move(ret);
  return *e->cachedMemoryLocations_;
}

void MemoryDAG::setWildcards(
    const std::unordered_set<const Value*>& wildcards,
    const ska::flat_hash_map<const Value*, Element*>& elementMap,
    const std::function<Element*(const Value*)>& getWildcardElement) {
  std::unordered_map<Element*, MemoryLocations> cacheUpdates;
  // If an element is set as a wildcard, that means that all its memory
  // locations must point to the wildcard element.
  for (const Value* v : wildcards) {
    auto wildcardElement = getWildcardElement(v);
    TORCH_INTERNAL_ASSERT(wildcardElement);

    const MemoryLocations& pointeeSet = getMemoryLocations(elementMap.at(v));
    for (const auto& pointee : pointeeSet) {
      auto from = this->fromIndex(pointee);
      // avoid cycles where the wildcard points to itself
      if (from != wildcardElement) {
        makePointerToImpl(from, wildcardElement);
      }
    }
    // Track which memory locations we edited with a new pointer to the wildcard
    // element.
    cacheUpdates[wildcardElement] |= pointeeSet;
  }

  // Update caches in-place.
  // We take advantage of the fact that we only edited memory locations.
  //
  // Say we added a pointer from `MemoryLocationFoo -> WildcardBar`.
  // For every element, if the cache contains `MemoryLocationFoo`, then we must
  // add `WildcardBar` to it.
  for (const std::unique_ptr<Element>& e : this->indexToElementMap_) {
    e->cachedAllContainedMemoryLocations_.reset();
    if (e->values.empty()) {
      // This element is a wildcard element, we can skip it.
      continue;
    }

    auto wildcardElement = getWildcardElement(*(e->values.begin()));
    if (!wildcardElement) {
      // This value is not a wildcard.
      continue;
    }
    auto it = cacheUpdates.find(wildcardElement);
    if (it == cacheUpdates.end()) {
      // We didn't rewrite any MemoryLocations to point to this element.
      continue;
    }
    // If this element contains an edited memory location, update the cache to
    // contain the pointed-to wildcard element as well.
    if (getMemoryLocations(e.get()).intersects(it->second)) {
      e->cachedMemoryLocations_->set(wildcardElement->index);
    }
  }
}

Element* MemoryDAG::unsafeMakeFreshValue(const Value* v) {
  return makeFreshValueImpl(v, indexToElementMap_);
}
} // namespace torch::jit