File: CompileTimeFunctionPointer.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (56 lines) | stat: -rw-r--r-- 1,677 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#pragma once

#include <c10/util/TypeTraits.h>

namespace c10 {

/**
 * Represent a function pointer as a C++ type.
 * This allows using the function pointer as a type
 * in a template and calling it from inside the template
 * allows the compiler to inline the call because it
 * knows the function pointer at compile time.
 *
 * Example 1:
 *  int add(int a, int b) {return a + b;}
 *  using Add = TORCH_FN_TYPE(add);
 *  template<class Func> struct Executor {
 *    int execute(int a, int b) {
 *      return Func::func_ptr()(a, b);
 *    }
 *  };
 *  Executor<Add> executor;
 *  EXPECT_EQ(3, executor.execute(1, 2));
 *
 * Example 2:
 *  int add(int a, int b) {return a + b;}
 *  template<class Func> int execute(Func, int a, int b) {
 *    return Func::func_ptr()(a, b);
 *  }
 *  EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
 */
template <class FuncType_, FuncType_* func_ptr_>
struct CompileTimeFunctionPointer final {
  static_assert(
      guts::is_function_type<FuncType_>::value,
      "TORCH_FN can only wrap function types.");
  using FuncType = FuncType_;

  static constexpr FuncType* func_ptr() {
    return func_ptr_;
  }
};

template <class T>
struct is_compile_time_function_pointer : std::false_type {};
template <class FuncType, FuncType* func_ptr>
struct is_compile_time_function_pointer<
    CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};

} // namespace c10

#define TORCH_FN_TYPE(func)                                           \
  ::c10::CompileTimeFunctionPointer<                                  \
      std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
      func>
#define TORCH_FN(func) TORCH_FN_TYPE(func)()