File: fwd_decls.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (129 lines) | stat: -rw-r--r-- 3,034 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#pragma once
#include <c10/core/ScalarType.h>
#include <memory>

namespace torch {
namespace jit {
namespace tensorexpr {

template <typename Node>
using NodePtr = std::shared_ptr<Node>;

template <typename To, typename From>
NodePtr<To> to(NodePtr<From> x) {
  return std::dynamic_pointer_cast<To>(x);
}

template <typename To, typename From>
NodePtr<To> static_to(NodePtr<From> x) {
  return std::static_pointer_cast<To>(x);
}

template <typename Node, typename... Args>
NodePtr<Node> alloc(Args&&... args) {
  return std::make_shared<Node>(std::forward<Args>(args)...);
}

class Buf;
class Expr;
class Stmt;
class Var;

using BufPtr = NodePtr<Buf>;
using ExprPtr = NodePtr<Expr>;
using StmtPtr = NodePtr<Stmt>;
using VarPtr = NodePtr<Var>;

class ExprHandle;
class VarHandle;
class BufHandle;

class Add;
class And;
class BitCast;
class Broadcast;
class Cast;
class CompareSelect;
class Div;
class IfThenElse;
class Intrinsics;
class Let;
class Load;
class Lshift;
class Max;
class MaxTerm;
class Min;
class MinTerm;
class Mod;
class Mul;
class Or;
class Polynomial;
class Ramp;
class ReduceOp;
class RoundOff;
class Rshift;
class Store;
class Sub;
class Term;
class Xor;
using AddPtr = NodePtr<Add>;
using AndPtr = NodePtr<And>;
using BitCastPtr = NodePtr<BitCast>;
using BroadcastPtr = NodePtr<Broadcast>;
using CastPtr = NodePtr<Cast>;
using CompareSelectPtr = NodePtr<CompareSelect>;
using DivPtr = NodePtr<Div>;
using IfThenElsePtr = NodePtr<IfThenElse>;
using IntrinsicsPtr = NodePtr<Intrinsics>;
using LetPtr = NodePtr<Let>;
using LoadPtr = NodePtr<Load>;
using LshiftPtr = NodePtr<Lshift>;
using MaxPtr = NodePtr<Max>;
using MaxTermPtr = NodePtr<MaxTerm>;
using MinPtr = NodePtr<Min>;
using MinTermPtr = NodePtr<MinTerm>;
using ModPtr = NodePtr<Mod>;
using MulPtr = NodePtr<Mul>;
using OrPtr = NodePtr<Or>;
using PolynomialPtr = NodePtr<Polynomial>;
using RampPtr = NodePtr<Ramp>;
using ReduceOpPtr = NodePtr<ReduceOp>;
using RoundOffPtr = NodePtr<RoundOff>;
using RshiftPtr = NodePtr<Rshift>;
using StorePtr = NodePtr<Store>;
using SubPtr = NodePtr<Sub>;
using TermPtr = NodePtr<Term>;
using XorPtr = NodePtr<Xor>;

class Allocate;
class AtomicAdd;
class Block;
class Cond;
class ExternalCall;
class ExternalCallWithAlloc;
class For;
class Free;
class FreeExt;
class PlacementAllocate;
class SyncThreads;
using AllocatePtr = NodePtr<Allocate>;
using AtomicAddPtr = NodePtr<AtomicAdd>;
using BlockPtr = NodePtr<Block>;
using CondPtr = NodePtr<Cond>;
using ExternalCallPtr = NodePtr<ExternalCall>;
using ExternalCallWithAllocPtr = NodePtr<ExternalCallWithAlloc>;
using ForPtr = NodePtr<For>;
using FreePtr = NodePtr<Free>;
using FreeExtPtr = NodePtr<FreeExt>;
using PlacementAllocatePtr = NodePtr<PlacementAllocate>;
using SyncThreadsPtr = NodePtr<SyncThreads>;

#define IMM_DECLARE(Type, Name) \
  class Name##Imm;              \
  using Name##ImmPtr = NodePtr<Name##Imm>;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
#undef IMM_DECLARE

} // namespace tensorexpr
} // namespace jit
} // namespace torch