File: fused_kernel.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (103 lines) | stat: -rw-r--r-- 3,435 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#pragma once

#include <ATen/ATen.h>
#include <torch/csrc/jit/codegen/fuser/partition_desc.h>
#include <torch/csrc/jit/codegen/fuser/tensor_desc.h>
#include <torch/csrc/utils/disallow_copy.h>

#include <cstdint>
#include <string>
#include <vector>

namespace torch {
namespace jit {
namespace fuser {

struct FusedKernel {
  TH_DISALLOW_COPY_AND_ASSIGN(FusedKernel);

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  FusedKernel(
      std::string name,
      std::string code,
      std::vector<TensorDesc> input_desc,
      std::vector<TensorDesc> output_desc,
      std::vector<PartitionDesc> chunk_desc,
      std::vector<PartitionDesc> concat_desc,
      bool has_random)
      : name_(std::move(name)),
        code_(std::move(code)),
        input_desc_(std::move(input_desc)),
        output_desc_(std::move(output_desc)),
        chunk_desc_(std::move(chunk_desc)),
        concat_desc_(std::move(concat_desc)),
        has_random_(has_random) {}

  virtual ~FusedKernel() = default;

  // arguments is a list of pointers to the arguments for the compiled CUDA/CPU
  // code.
  // The format of arguments is suitable for directly passing to a call to
  // cuLaunchKernel as the kernel arguments.
  // Currently the first argument is a pointer to numel (for passing to
  // CUDA code), and the remainder are pointers to the TensorInfo<T> structs
  // that compiled code uses to load Tensor data.
  // launch_with_tensors handles packing at::Tensors into this arguments array.
  // CPU code uses the same convension so that launch_with_tensors can be
  // shared.
  virtual void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
      const = 0;
  virtual at::Backend backend() const = 0;

  // Getters
  const std::string& name() const {
    return name_;
  }
  const std::string& code() const {
    return code_;
  }
  const std::vector<TensorDesc>& inputDesc() const {
    return input_desc_;
  }
  const std::vector<TensorDesc>& outputDesc() const {
    return output_desc_;
  }
  const std::vector<PartitionDesc>& chunkDesc() const {
    return chunk_desc_;
  }
  const std::vector<PartitionDesc>& concatDesc() const {
    return concat_desc_;
  }
  bool hasRandom() const {
    return has_random_;
  }

 protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const std::string name_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const std::string code_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const std::vector<TensorDesc> input_desc_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const std::vector<TensorDesc> output_desc_;

  // same size as input_desc, describes whether an
  // input should be broken into subtensors (chunks)
  // to be consumed by the fusion group
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const std::vector<PartitionDesc> chunk_desc_;

  // same size as output_desc, describes whether
  // an output is actually a concatenation of
  // many subtensors that the fusion group produces
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const std::vector<PartitionDesc> concat_desc_;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const bool has_random_;
};

} // namespace fuser
} // namespace jit
} // namespace torch