File: mpscnn_context.mm

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (102 lines) | stat: -rw-r--r-- 3,302 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

#include "caffe2/core/common.h"

#ifdef C10_MOBILE

#include "mpscnn_context.h"
#include "mpscnn_kernels.h"

#include "caffe2/core/logging.h"
#include "caffe2/core/timer.h"

#include <array>
#include <mutex>
#include <thread>

#import <Metal/MTLFunctionConstantValues.h>

namespace caffe2 {

MPSCNNContext& getMPSCNNContext() {
  static std::once_flag once;
  static MPSCNNContext ctx;
  std::call_once(once, []() {
    NSError* compileError = nil;
    ctx.device = MTLCreateSystemDefaultDevice();
    ctx.library = [ctx.device newLibraryWithSource:[NSString stringWithUTF8String:MPSCNN_KERNELS]
                                           options:nil
                                             error:&compileError];
    if (compileError != nil || ctx.library == nil) {
      CAFFE_THROW("Failed to load kernels: ", [[compileError localizedDescription] UTF8String]);
    }
    ctx.commandQueue = [ctx.device newCommandQueue];
  });
  return ctx;
}

id<MTLComputePipelineState> MPSCNNContext::getPipelineState(NSString* kernel) {
  std::string kernelStr = std::string([kernel UTF8String]);
  std::lock_guard<std::mutex> g(pipelineCacheMutex_);
  if (pipelineCache_.find(kernelStr) != pipelineCache_.end()) {
    VLOG(1) << "Hit in pipeline cache for: " << kernelStr;
    return pipelineCache_[kernelStr];
  }
  LOG(INFO) << "Miss in pipeline cache for: " << kernelStr;
  id<MTLFunction> func = [library newFunctionWithName:kernel];
  if (!func) {
    CAFFE_THROW("Couldn't get function: ", kernelStr);
    return nullptr;
  }
  NSError* errors;
  id<MTLComputePipelineState> state =
      [device newComputePipelineStateWithFunction:func error:&errors];
  if (!state) {
    CAFFE_THROW("Couldn't get state: ", kernelStr);
    return nullptr;
  }
  pipelineCache_[kernelStr] = state;
  return state;
}

id<MTLComputePipelineState> MPSCNNContext::getSpecializedPipelineState(
    NSString* kernel, const std::vector<ushort>& constants) {
  std::string kernelStr = std::string([kernel UTF8String]);
  for (auto i = 0; i < constants.size(); ++i) {
    kernelStr += "_" + std::to_string(constants[i]);
  }
  std::lock_guard<std::mutex> g(pipelineCacheMutex_);
  if (pipelineCache_.find(kernelStr) != pipelineCache_.end()) {
    VLOG(1) << "Hit in pipeline cache for: " << kernelStr;
    return pipelineCache_[kernelStr];
  }
  MTLFunctionConstantValues* constantValues = [MTLFunctionConstantValues new];
  for (auto i = 0; i < constants.size(); ++i) {
    [constantValues setConstantValue:&constants[i] type:MTLDataTypeUShort atIndex:i];
  }
  NSError* errors;

  LOG(INFO) << "Miss in pipeline cache for: " << kernelStr;
  id<MTLFunction> func =
      [library newFunctionWithName:kernel constantValues:constantValues error:&errors];
  if (!func) {
    CAFFE_THROW("Couldn't get function: ",
                kernelStr,
                " error: ",
                [[errors localizedDescription] UTF8String]);
    return nullptr;
  }
  id<MTLComputePipelineState> state =
      [device newComputePipelineStateWithFunction:func error:&errors];
  if (!state) {
    CAFFE_THROW("Couldn't get function: ",
                kernelStr,
                " error: ",
                [[errors localizedDescription] UTF8String]);
    return nullptr;
  }
  pipelineCache_[kernelStr] = state;
  return state;
}
}

#endif