File: fbgemm_pack_matrix_cache.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (78 lines) | stat: -rw-r--r-- 2,152 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include "fbgemm_pack_matrix_cache.h"

#include <map>
#include <memory>
#include <mutex>

using namespace std;

namespace caffe2 {

template <typename ACC_T>
shared_ptr<fbgemm::PackBMatrix<int8_t, ACC_T>> GetOrCreateFbgemmPackBMatrix(
    fbgemm::matrix_op_t trans,
    int32_t m,
    int32_t n,
    const void* orig_data,
    const int8_t* quantized_data,
    int32_t ld) {
  static std::map<
      std::tuple<int, int, const void*>,
      weak_ptr<fbgemm::PackBMatrix<int8_t, ACC_T>>>
      cache;
  static mutex cache_mutex;

  // Create a new packed matrix and compare with cached one if there's any.
  // Note that a cache miss is as expensive as a cache hit here, the purpose of
  // this cache is only to deduplicate the quantized tensors for improved
  // memory bandwidth if different nets share copies of the same operator.
  // TODO: make this cheaper by computing hash of fdata.
  auto new_packed = make_shared<fbgemm::PackBMatrix<int8_t, ACC_T>>(
      trans,
      m,
      n,
      quantized_data,
      ld,
      nullptr, // pmat
      1); // groups

  std::tuple<int, int, const void*> key(m, n, orig_data);
  std::shared_ptr<fbgemm::PackBMatrix<int8_t, ACC_T>> cache_entry;
  {
    lock_guard<mutex> lock(cache_mutex);
    auto itr = cache.find(key);
    if (itr != cache.end()) {
      cache_entry = itr->second.lock();
    }
  } // release lock here during expensive equals()

  if (!cache_entry || !cache_entry->metaEquals(*new_packed) ||
      !cache_entry->equals(*new_packed)) {
    // cache miss
    lock_guard<mutex> lock(cache_mutex);
    cache[key] = new_packed;
    return new_packed;
  } else {
    return cache_entry;
  }
}

template shared_ptr<fbgemm::PackBMatrix<int8_t, int16_t>>
GetOrCreateFbgemmPackBMatrix<int16_t>(
    fbgemm::matrix_op_t trans,
    int32_t m,
    int32_t n,
    const void* orig_data,
    const int8_t* quantized_data,
    int32_t ld);

template shared_ptr<fbgemm::PackBMatrix<int8_t, int32_t>>
GetOrCreateFbgemmPackBMatrix<int32_t>(
    fbgemm::matrix_op_t trans,
    int32_t m,
    int32_t n,
    const void* orig_data,
    const int8_t* quantized_data,
    int32_t ld);

} // namespace caffe2