File: memory_snapshot.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (412 lines) | stat: -rw-r--r-- 14,211 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
#include <ATen/Context.h>
#include <ATen/record_function.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <torch/csrc/cuda/memory_snapshot.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/profiler/combined_traceback.h>

namespace torch::cuda {

using c10::Dict;
using c10::IValue;
using torch::jit::Pickler;

using c10::cuda::CUDACachingAllocator::SegmentInfo;

namespace {
std::string write_pickle(const IValue& v) {
  std::vector<char> result;
  {
    auto writer = [&](const char* data, size_t size) {
      result.insert(result.end(), data, data + size);
    };
    Pickler pickler(writer, nullptr, nullptr, nullptr, nullptr, false);
    pickler.protocol();
    pickler.pushIValue(v);
    pickler.stop();
  }
  return std::string(result.begin(), result.end());
}
Dict<IValue, IValue> new_dict() {
  return Dict<IValue, IValue>(c10::AnyType::get(), c10::AnyType::get());
}
c10::List<IValue> new_list() {
  return List<IValue>(c10::AnyType::get());
}

std::vector<IValue> ivalue_symbolize(
    std::vector<CapturedTraceback*>& to_symbolize) {
  // we dedup repeated to_symbolize objects to prevent
  // creating a bunch of duplicated frame objects
  std::unordered_map<CapturedTraceback*, uint64_t> cached_frames;
  std::vector<CapturedTraceback*> unique_frames;
  for (const auto& sc : to_symbolize) {
    auto it = cached_frames.find(sc);
    if (it == cached_frames.end()) {
      cached_frames.insert({sc, unique_frames.size()});
      unique_frames.push_back(sc);
    }
  }
  auto s = symbolize(unique_frames);

  IValue line_s = "line";
  IValue name_s = "name";
  IValue filename_s = "filename";
  std::vector<IValue> all_frames;
  for (const auto& f : s.all_frames) {
    auto d = new_dict();
    d.insert(name_s, f.funcname);
    d.insert(filename_s, f.filename);
    d.insert(line_s, int64_t(f.lineno));
    all_frames.emplace_back(std::move(d));
  }

  std::vector<IValue> py_unique_frames;
  for (const auto& t : s.tracebacks) {
    auto l = new_list();
    for (const auto& e : t) {
      l.push_back(all_frames.at(e));
    }
    py_unique_frames.emplace_back(std::move(l));
  }

  std::vector<IValue> result;
  result.reserve(to_symbolize.size());
  for (const auto& sc : to_symbolize) {
    result.push_back(py_unique_frames.at(cached_frames.at(sc)));
  }
  return result;
}

std::shared_ptr<c10::GatheredContext> gather() {
  return CapturedTraceback::gather(true, true, false);
}

std::shared_ptr<c10::GatheredContext> gather_with_cpp() {
  return CapturedTraceback::gather(true, true, true);
}

CapturedTraceback* getFromContext(
    const std::shared_ptr<c10::GatheredContext>& x) {
  if (CapturedTraceback* sc = dynamic_cast<CapturedTraceback*>(x.get())) {
    return sc;
  }
  TORCH_CHECK(
      false,
      "attempting to gather stack context from the wrong StackContext type.");
}

void _initRecordAnnotations() {
  static c10::once_flag ra_init;
  c10::call_once(ra_init, [&] {
    // Save user annotations to CCA memory snapshot tool
    at::addThreadLocalCallback(
        at::RecordFunctionCallback(
            [](const at::RecordFunction& fn)
                -> std::unique_ptr<at::ObserverContext> {
              c10::cuda::CUDACachingAllocator::recordAnnotation(
                  {{"name", fn.name()}, {"stage", "START"}});
              return nullptr;
            },
            [](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) {
              c10::cuda::CUDACachingAllocator::recordAnnotation(
                  {{"name", fn.name()}, {"stage", "END"}});
            })
            .scopes({at::RecordScope::USER_SCOPE}));
  });
}

} // namespace

void _record_memory_history(
    bool enabled,
    bool record_context,
    int64_t trace_alloc_max_entries,
    bool trace_alloc_record_context,
    bool record_cpp_context) {
  c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather;
  if (enabled && record_cpp_context &&
      (trace_alloc_record_context || record_context)) {
    recorder = gather_with_cpp;
    // warm up C++ stack unwinding
    unwind::unwind();
  }
  auto when = c10::cuda::CUDACachingAllocator::RecordContext::NEVER;
  if (trace_alloc_record_context) {
    when = c10::cuda::CUDACachingAllocator::RecordContext::ALLOC;
  } else if (record_context) {
    when = c10::cuda::CUDACachingAllocator::RecordContext::STATE;
  }
  at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
  _initRecordAnnotations();
  c10::cuda::CUDACachingAllocator::recordHistory(
      enabled, recorder, trace_alloc_max_entries, when);
}

static void checkOptionIn(
    const std::string& option,
    std::initializer_list<std::string> valid,
    const char* error) {
  TORCH_CHECK(
      valid.end() != std::find(valid.begin(), valid.end(), option), error);
}

void _record_memory_history(
    std::optional<std::string> enabled,
    std::optional<std::string> context,
    const std::string& stacks,
    size_t max_entries) {
  if (enabled) {
    checkOptionIn(
        *enabled,
        {"state", "all"},
        "expected state to be 'state', 'all', or None");
  }
  if (context) {
    checkOptionIn(
        *context,
        {"state", "alloc", "all"},
        "expected context to be 'state', 'alloc', 'all', or None");
  }
  checkOptionIn(
      stacks, {"python", "all"}, "expected stacks to be 'python', or 'all'");

  c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather;
  if (enabled && context && stacks == "all") {
    recorder = gather_with_cpp;
    // warm up C++ stack unwinding
    unwind::unwind();
  }
  max_entries = (enabled && *enabled == "all") ? max_entries : 1;
  auto when = c10::cuda::CUDACachingAllocator::RecordContext::NEVER;
  if (context) {
    if (context == "all") {
      when = c10::cuda::CUDACachingAllocator::RecordContext::ALL;
    } else if (context == "alloc") {
      when = c10::cuda::CUDACachingAllocator::RecordContext::ALLOC;
    } else if (context == "state") {
      when = c10::cuda::CUDACachingAllocator::RecordContext::STATE;
    }
  }
  at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
  _initRecordAnnotations();
  c10::cuda::CUDACachingAllocator::recordHistory(
      enabled.has_value(), recorder, max_entries, when);
}

std::string _memory_snapshot_pickled() {
  IValue device_s = "device";
  IValue address_s = "address";
  IValue total_size_s = "total_size";
  IValue allocated_size_s = "allocated_size";
  IValue active_size_s = "active_size";
  IValue requested_size_s = "requested_size";
  IValue stream_s = "stream";
  IValue segment_type_s = "segment_type";
  IValue segment_pool_id = "segment_pool_id";
  IValue large_s = "large";
  IValue small_s = "small";
  IValue size_s = "size";
  IValue state_s = "state";
  IValue active_allocated_s = "active_allocated";
  IValue active_pending_free_s = "active_pending_free";
  IValue inactive_s = "inactive";
  IValue addr_s = "addr";
  IValue filename_s = "filename";
  IValue name_s = "name";
  IValue line_s = "line";
  IValue frames_s = "frames";
  IValue blocks_s = "blocks";
  IValue is_expandable_s = "is_expandable";
  IValue time_us_s = "time_us";

  auto empty_frames = new_list();

  std::vector<CapturedTraceback*> frame_tracebacks;
  std::vector<Dict<IValue, IValue>> frame_dict;

  auto add_frame_key = [&](const c10::Dict<IValue, IValue>& d,
                           const std::shared_ptr<c10::GatheredContext>& ctx) {
    if (ctx) {
      frame_tracebacks.push_back(getFromContext(ctx));
      frame_dict.push_back(d);
    } else {
      d.insert(frames_s, empty_frames);
    }
  };

  const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) {
    auto segmentDict = new_dict();
    segmentDict.insert(device_s, segmentInfo.device);
    segmentDict.insert(address_s, static_cast<int64_t>(segmentInfo.address));
    segmentDict.insert(
        total_size_s, static_cast<int64_t>(segmentInfo.total_size));
    segmentDict.insert(
        allocated_size_s, static_cast<int64_t>(segmentInfo.allocated_size));
    segmentDict.insert(
        active_size_s, static_cast<int64_t>(segmentInfo.active_size));
    segmentDict.insert(
        requested_size_s, static_cast<int64_t>(segmentInfo.requested_size));
    segmentDict.insert(stream_s, int64_t(segmentInfo.stream));
    segmentDict.insert(
        segment_type_s, (segmentInfo.is_large ? large_s : small_s));
    segmentDict.insert(
        segment_pool_id,
        std::tuple<int64_t, int64_t>(segmentInfo.owner_private_pool_id));
    segmentDict.insert(is_expandable_s, segmentInfo.is_expandable);

    add_frame_key(segmentDict, segmentInfo.context_when_allocated);

    auto address = segmentInfo.address;
    auto blocks = new_list();
    for (const auto& blockInfo : segmentInfo.blocks) {
      auto blockDict = new_dict();
      blockDict.insert(address_s, static_cast<int64_t>(address));
      blockDict.insert(size_s, static_cast<int64_t>(blockInfo.size));
      blockDict.insert(
          requested_size_s, static_cast<int64_t>(blockInfo.requested_size));
      blockDict.insert(
          state_s,
          (blockInfo.allocated
               ? active_allocated_s
               : (blockInfo.active ? active_pending_free_s : inactive_s)));
      add_frame_key(blockDict, blockInfo.context_when_allocated);
      address += blockInfo.size;
      blocks.push_back(blockDict);
    }
    segmentDict.insert(blocks_s, blocks);

    return segmentDict;
  };

  auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();

  auto segments = new_list();
  for (const auto& segmentInfo : snapshot.segments) {
    segments.push_back(segmentInfoToDict(segmentInfo));
  }

  auto traces = new_list();
  IValue action_s = "action";
  IValue alloc_s = "alloc";
  IValue free_requested_s = "free_requested";
  IValue free_completed_s = "free_completed";
  IValue segment_alloc_s = "segment_alloc";
  IValue segment_free_s = "segment_free";
  IValue segment_map_s = "segment_map";
  IValue segment_unmap_s = "segment_unmap";
  IValue snapshot_s = "snapshot";
  IValue oom_s = "oom";
  IValue device_free_s = "device_free";

  using namespace c10::cuda::CUDACachingAllocator;

  auto action_to_str = [&](TraceEntry::Action action) {
    switch (action) {
      case TraceEntry::ALLOC:
        return alloc_s;
      case TraceEntry::FREE_REQUESTED:
        return free_requested_s;
      case TraceEntry::FREE_COMPLETED:
        return free_completed_s;
      case TraceEntry::SEGMENT_ALLOC:
        return segment_alloc_s;
      case TraceEntry::SEGMENT_FREE:
        return segment_free_s;
      case TraceEntry::OOM:
        return oom_s;
      case TraceEntry::SNAPSHOT:
        return snapshot_s;
      case TraceEntry::SEGMENT_UNMAP:
        return segment_unmap_s;
      case TraceEntry::SEGMENT_MAP:
        return segment_map_s;
    }
    throw std::runtime_error("unreachable");
  };

  for (const auto& traceInfo : snapshot.device_traces) {
    auto trace = new_list();
    for (const auto& te : traceInfo) {
      auto trace_entry = new_dict();
      trace_entry.insert(action_s, action_to_str(te.action_));
      trace_entry.insert(
          TraceEntry::OOM == te.action_ ? device_free_s : addr_s,
          static_cast<int64_t>(te.addr_));
      trace_entry.insert(size_s, (int64_t)te.size_);
      trace_entry.insert(stream_s, int64_t(te.stream_));
      if (te.context_) {
        auto sc = getFromContext(te.context_);
        frame_tracebacks.push_back(sc);
        frame_dict.push_back(trace_entry);
      }
      trace_entry.insert(time_us_s, te.time_.t_);
      trace.push_back(trace_entry);
    }
    traces.push_back(trace);
  }

  auto external_annotations = new_list();
  for (const auto& ae : snapshot.external_annotations) {
    auto annotation_entry = new_dict();
    for (const auto& md : ae.metadata_) {
      annotation_entry.insert((IValue)md.first, md.second);
    }
    annotation_entry.insert(device_s, ae.device_);
    annotation_entry.insert(time_us_s, ae.time_.t_);
    external_annotations.push_back(annotation_entry);
  }

  auto allocator_settings = new_dict();
  IValue last_allocator_settings_s = "PYTORCH_CUDA_ALLOC_CONF";
  IValue max_split_size_s = "max_split_size";
  IValue garbage_collection_threshold_s = "garbage_collection_threshold";
  IValue expandable_segments_s = "expandable_segments";
  IValue pinned_num_register_threads_s = "pinned_num_register_threads";
  IValue release_lock_on_malloc_s = "release_lock_on_cudamalloc";
  IValue pinned_use_host_register_s = "pinned_use_cuda_host_register";
  IValue roundup_power2_divisions_s = "roundup_power2_divisions";

  allocator_settings.insert(
      last_allocator_settings_s,
      snapshot.config_metadata.last_allocator_settings);
  allocator_settings.insert(
      max_split_size_s, int64_t(snapshot.config_metadata.max_split_size));
  allocator_settings.insert(
      garbage_collection_threshold_s,
      snapshot.config_metadata.garbage_collection_threshold);
  allocator_settings.insert(
      expandable_segments_s, snapshot.config_metadata.expandable_segments);
  allocator_settings.insert(
      pinned_num_register_threads_s,
      int64_t(snapshot.config_metadata.pinned_num_register_threads));
  allocator_settings.insert(
      release_lock_on_malloc_s,
      snapshot.config_metadata.release_lock_on_malloc);
  allocator_settings.insert(
      pinned_use_host_register_s,
      snapshot.config_metadata.pinned_use_host_register);
  unsigned int roundup_key = 1;
  auto roundup_settings = new_dict();
  for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) {
    IValue roundup_key_s = std::to_string(roundup_key);
    roundup_settings.insert(roundup_key_s, int64_t(v));
    roundup_key *= 2;
  }
  allocator_settings.insert(roundup_power2_divisions_s, roundup_settings);

  auto result = new_dict();
  result.insert("segments", segments);
  result.insert("device_traces", traces);
  result.insert("allocator_settings", allocator_settings);
  result.insert("external_annotations", external_annotations);

  auto frames = ivalue_symbolize(frame_tracebacks);
  for (auto i : c10::irange(frames.size())) {
    frame_dict.at(i).insert(frames_s, frames.at(i));
  }

  return write_pickle(result);
}
} // namespace torch::cuda