File: arith.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (188 lines) | stat: -rw-r--r-- 7,385 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#pragma once

#include <torch/csrc/WindowsTorchApiMacro.h>

#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/type.h>

class Val;

/*
 * The operations defined in this header is intended as user facing functions.
 * Generally users should not directly instantiate temporary TensorViews they
 * should instead use the functions below which will automatically create IR
 * nodes, and return a resulting TensorView of correctly tracked shapes.
 */

namespace torch {
namespace jit {
namespace fuser {

// Insertion of casting op to dtype, returns new resulting val
TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1);
TORCH_CUDA_API TensorView* castOp(DataType dtype, TensorView* v1);

// Perform unary op type and return the output
TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1);
TORCH_CUDA_API TensorView* unaryOp(UnaryOpType type, TensorView* v1);

// Perform binary op type on v1 and v2 and return a type promoted output.
// Mod, CeilDiv, and LT are considered Int only output operations for now.
TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2);
TORCH_CUDA_API TensorView* binaryOp(BinaryOpType type, TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* binaryOp(BinaryOpType type, Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* binaryOp(
    BinaryOpType type,
    TensorView* v1,
    TensorView* v2);

// Perform a reduction operation on v1, initial value for reduction is init,
// reduces across axes, and reduction operation defined by BinaryOp.
TORCH_CUDA_API TensorView* reductionOp(
    BinaryOpType reduction_op_type,
    const std::vector<int>& axes,
    Val* init,
    TensorView* v1);

// UNARY OPERATIONS
TORCH_CUDA_API Val* neg(Val* v);
TORCH_CUDA_API TensorView* neg(TensorView* v);

// Broadcasts v1 based on bool vector. Size of broadcast bool vector should be
// the number of dims desired in the broadcasted tensor. This vector should be
// true if output dim should be a broadcasted dim, and false if it is not a
// broadcasted dim. Number of false entires must match the number of input dims.
TORCH_CUDA_API TensorView* broadcast(
    TensorView* inp,
    const std::vector<bool>& is_broadcast_dim);

// BINARY OPERATIONS
// add
TORCH_CUDA_API Val* add(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* add(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* add(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* add(TensorView* v1, TensorView* v2);
// sub
TORCH_CUDA_API Val* sub(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* sub(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* sub(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* sub(TensorView* v1, TensorView* v2);
// mul
TORCH_CUDA_API Val* mul(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* mul(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* mul(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* mul(TensorView* v1, TensorView* v2);
// div
TORCH_CUDA_API Val* div(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* div(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* div(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* div(TensorView* v1, TensorView* v2);
// mod
TORCH_CUDA_API Val* mod(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* mod(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* mod(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* mod(TensorView* v1, TensorView* v2);
// lt
TORCH_CUDA_API Val* lt(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* lt(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* lt(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* lt(TensorView* v1, TensorView* v2);
// eq
TORCH_CUDA_API Val* eq(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* eq(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* eq(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* eq(TensorView* v1, TensorView* v2);
// ceilDiv
TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* ceilDiv(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* ceilDiv(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* ceilDiv(TensorView* v1, TensorView* v2);
// andOp
TORCH_CUDA_API Val* andOp(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* andOp(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* andOp(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* andOp(TensorView* v1, TensorView* v2);

// REDUCTION OPERATIONS
TORCH_CUDA_API TensorView* sum(
    TensorView* v1,
    const std::vector<int>& reduction_axes);

// COMPOUND OPERATIONS
// add_alpha
TORCH_CUDA_API Val* add_alpha(Val* v1, Val* v2, Val* s);
TORCH_CUDA_API TensorView* add_alpha(TensorView* v1, Val* v2, Val* s);
TORCH_CUDA_API TensorView* add_alpha(Val* v1, TensorView* v2, Val* s);
TORCH_CUDA_API TensorView* add_alpha(TensorView* v1, TensorView* v2, Val* s);
// sub_alpha
TORCH_CUDA_API Val* sub_alpha(Val* v1, Val* v2, Val* s);
TORCH_CUDA_API TensorView* sub_alpha(TensorView* v1, Val* v2, Val* s);
TORCH_CUDA_API TensorView* sub_alpha(Val* v1, TensorView* v2, Val* s);
TORCH_CUDA_API TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* s);
// lerp
TORCH_CUDA_API Val* lerp(Val* start, Val* end, Val* weight);
TORCH_CUDA_API TensorView* lerp(TensorView* start, Val* end, Val* weight);
TORCH_CUDA_API TensorView* lerp(Val* start, TensorView* end, Val* weight);
TORCH_CUDA_API TensorView* lerp(Val* start, Val* end, TensorView* weight);
TORCH_CUDA_API TensorView* lerp(
    TensorView* start,
    TensorView* end,
    Val* weight);
TORCH_CUDA_API TensorView* lerp(
    TensorView* start,
    Val* end,
    TensorView* weight);
TORCH_CUDA_API TensorView* lerp(
    Val* start,
    TensorView* end,
    TensorView* weight);
TORCH_CUDA_API TensorView* lerp(
    TensorView* start,
    TensorView* end,
    TensorView* weight);
// addcmul
TORCH_CUDA_API Val* addcmul(Val* v1, Val* v2, Val* v3, Val* s);
TORCH_CUDA_API TensorView* addcmul(TensorView* v1, Val* v2, Val* v3, Val* s);
TORCH_CUDA_API TensorView* addcmul(Val* v1, TensorView* v2, Val* v3, Val* s);
TORCH_CUDA_API TensorView* addcmul(Val* v1, Val* v2, TensorView* v3, Val* s);
TORCH_CUDA_API TensorView* addcmul(
    TensorView* v1,
    TensorView* v2,
    Val* v3,
    Val* s);
TORCH_CUDA_API TensorView* addcmul(
    TensorView* v1,
    Val* v2,
    TensorView* v3,
    Val* s);
TORCH_CUDA_API TensorView* addcmul(
    Val* v1,
    TensorView* v2,
    TensorView* v3,
    Val* s);
TORCH_CUDA_API TensorView* addcmul(
    TensorView* v1,
    TensorView* v2,
    TensorView* v3,
    Val* s);

// TERNARY OPERATIONS
// where
TORCH_CUDA_API Val* where(Val* c, Val* v1, Val* v2);
TORCH_CUDA_API TensorView* where(TensorView* c, Val* v1, Val* v2);
TORCH_CUDA_API TensorView* where(Val* c, TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* where(Val* c, Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* where(TensorView* c, TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* where(TensorView* c, Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* where(Val* c, TensorView* v1, TensorView* v2);
TORCH_CUDA_API TensorView* where(TensorView* c, TensorView* v1, TensorView* v2);
// threshold
TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value);
TORCH_CUDA_API TensorView* threshold(TensorView* in, Val* thresh, Val* value);
// clamp
TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val);
TORCH_CUDA_API TensorView* clamp(TensorView* in, Val* min_val, Val* max_val);

} // namespace fuser
} // namespace jit
} // namespace torch