1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
|
// in memory description of all ATen Ops similar to Caffe2 schema
// once C10 exists this can be removed, or stubbed out, but we need
// it now to implement correct semantic checking for script
#pragma once
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/op_registration/op_allowlist.h>
#include <ATen/core/stack.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/runtime/operator_options.h>
#include <torch/library.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/symbol.h>
#include <functional>
#include <initializer_list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
struct Node;
using ::c10::Argument;
using ::c10::FunctionSchema;
using ::c10::Symbol;
using OperationCreator = Operation (*)(const Node*);
/*
* Note: JIT relies on Operator instances having static lifetime, because
* it for example stores a non-owning FunctionSchema* pointer in the Node class,
* which points to the function schema stored in the Operator instance.
* Also, jit::Operator is meant to store more operator related information like
* symbolic derivatives, which also requires them to have static lifetime
* so that changes to symbolic derivatives are remembered.
*
* Currently, the JIT operator library contains a jit::Operator instance
* with a wrapper for each c10 operator. The c10 operator library registers
* those wrappers using listeners in register_c10_ops.cpp.
* TODO Instead of doing it this way, we should only have pure-jit ops in
* the jit library but have the JIT operator lookup look into the c10 library
* too.
*/
// An Operator is a thin wrapper around either a pure JIT operator (e.g. prim
// ops) or a c10 operator, allowing some common operations and abstracting away
// the concrete operator nature.
struct TORCH_API Operator {
private:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct C10Operator final {
c10::OperatorHandle handle_;
Operation op_;
};
struct UnparsedFunctionSchema final {
std::string schema_string_;
mutable c10::optional<c10::AliasAnalysisKind> alias_analysis_;
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct JitOnlyOperator final {
// The only valid transition for schema_ is from right->left, i.e.
// when the schema gets parsed.
mutable c10::either<FunctionSchema, UnparsedFunctionSchema> schema_;
c10::either<Operation, OperationCreator> op_;
};
public:
Operator(c10::OperatorHandle opHandle, Operation operation)
: op_(c10::make_left<C10Operator, JitOnlyOperator>(
C10Operator{opHandle, std::move(operation)})) {}
Operator(
std::string schema,
Operation op,
c10::AliasAnalysisKind alias_analysis)
: op_(c10::make_right<C10Operator, JitOnlyOperator>(JitOnlyOperator{
c10::make_right<FunctionSchema, UnparsedFunctionSchema>(
UnparsedFunctionSchema{std::move(schema), alias_analysis}),
c10::make_left<Operation, OperationCreator>(std::move(op))})) {}
Operator(
std::string name,
std::string overload_name,
std::vector<Argument> arguments,
std::vector<Argument> returns,
Operation op,
c10::AliasAnalysisKind alias_analysis)
: op_(c10::make_right<C10Operator, JitOnlyOperator>(JitOnlyOperator{
c10::make_left<FunctionSchema, UnparsedFunctionSchema>(
varArgSchemaWithName(
name,
overload_name,
arguments,
returns,
alias_analysis)),
c10::make_left<Operation, OperationCreator>(std::move(op))})) {}
Operator(
std::string schema,
OperationCreator op_creator,
c10::AliasAnalysisKind alias_analysis)
: op_(c10::make_right<C10Operator, JitOnlyOperator>(JitOnlyOperator{
c10::make_right<FunctionSchema, UnparsedFunctionSchema>(
UnparsedFunctionSchema{std::move(schema), alias_analysis}),
c10::make_right<Operation, OperationCreator>(op_creator)})) {}
// Helper constructor to register `op` to run
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
// This is accomplished by marking the schema varargs and having no required
// arguments.
Operator(
Symbol name,
OperationCreator op_creator,
c10::AliasAnalysisKind alias_analysis)
: op_(c10::make_right<C10Operator, JitOnlyOperator>(JitOnlyOperator{
c10::make_left<FunctionSchema, UnparsedFunctionSchema>(
varArgSchemaWithName(name, alias_analysis)),
c10::make_right<Operation, OperationCreator>(op_creator)})) {}
Operation getOperation(const Node* node = nullptr) const {
return op_.fold<Operation>(
[](const C10Operator& op) { return op.op_; },
[node](const JitOnlyOperator& op) {
return op.op_.fold<Operation>(
[](const Operation& op) { return op; },
[node](const OperationCreator& op_creator) {
return op_creator(node);
});
});
}
Operation getOperationForDispatchKey(c10::DispatchKey dk) const {
// TODO: some sort of caching mechanism?
return op_.fold<Operation>(
[dk](const C10Operator& op) {
return [op, dk](Stack& stack) {
op.handle_.callBoxedForDispatchKey(dk, stack);
};
},
[](const JitOnlyOperator& op) {
TORCH_CHECK(
false,
"calling a JIT operator for dispatch key is not supported");
return nullptr;
});
}
const FunctionSchema& schema() const {
return op_.fold<const FunctionSchema&>(
[](const C10Operator& op) -> const FunctionSchema& {
return op.handle_.schema();
},
[](const JitOnlyOperator& op) -> const FunctionSchema& {
// we lazily parse schema initialized from strings so that
// we do less work during static operator registration
if (op.schema_.is_right()) {
auto& unmaterializedSchema = op.schema_.right();
FunctionSchema schema =
parseSchema(unmaterializedSchema.schema_string_);
if (unmaterializedSchema.alias_analysis_.has_value()) {
// TODO What if it gets set later?
schema.setAliasAnalysis(*unmaterializedSchema.alias_analysis_);
}
op.schema_ = c10::make_left<FunctionSchema, UnparsedFunctionSchema>(
std::move(schema));
}
return op.schema_.left();
});
}
c10::ArrayRef<at::Tag> getTags() const {
return op_.fold<c10::ArrayRef<at::Tag>>(
[](const C10Operator& op) { return op.handle_.getTags(); },
[](const JitOnlyOperator& op) {
// Returns empty list of tags for JitOnlyOperators since it
// doesn't save c10::OperatorHandle
return c10::ArrayRef<at::Tag>();
});
}
bool isC10Op() const {
return op_.is_left();
}
c10::AliasAnalysisKind aliasAnalysisKind() const {
const FunctionSchema& schemaRef = schema();
c10::AliasAnalysisKind alias_analysis = schemaRef.aliasAnalysis();
TORCH_CHECK(
alias_analysis == AliasAnalysisKind::FROM_SCHEMA ||
!schemaRef.hasAnyAliasInfo(),
"In operator registration: Tried to register operator ",
schemaRef,
" with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");
return alias_analysis;
}
bool hasOperation() const {
return op_.fold<bool>(
[](const C10Operator&) { return true; },
[](const JitOnlyOperator& op) { return op.op_.is_left(); });
}
private:
static FunctionSchema varArgSchemaWithName(
Symbol name,
AliasAnalysisKind alias_analysis) {
auto result = FunctionSchema(
name,
"",
{},
{},
/*is_vararg*/ true,
/*is_varret*/ true);
result.setAliasAnalysis(alias_analysis);
return result;
}
static FunctionSchema varArgSchemaWithName(
std::string name,
std::string overload_name,
std::vector<Argument> arguments,
std::vector<Argument> returns,
AliasAnalysisKind alias_analysis) {
auto result = FunctionSchema(
name,
overload_name,
arguments,
returns,
/*is_vararg*/ false,
/*is_varret*/ false);
result.setAliasAnalysis(alias_analysis);
return result;
}
c10::either<C10Operator, JitOnlyOperator> op_;
};
TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
TORCH_API const std::vector<std::shared_ptr<Operator>> getAllOperators();
TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
Symbol name);
// given a operator with an overload name, find the specific operator related to
// it, may return nullptr if no operator exists.
TORCH_API std::shared_ptr<Operator> findOperatorFor(
const c10::OperatorName& full_name);
TORCH_API std::vector<Symbol> findSimilarOperators(Symbol input_op);
TORCH_API void registerOperator(Operator&& op);
TORCH_API void deregisterOperator(const FunctionSchema& schema);
// XXX: this function is meant to be used with string literals only!
TORCH_API std::shared_ptr<Operator> getOperatorForLiteral(
const char* signature);
// Ensure the thing that registers c10 ops is defined.
// Otherwise, our registry will not have c10 ops. You can run into this
// scenario if you're querying registered ops during static init.
//
// This fn is defined in register_c10_ops.cpp
TORCH_API void ensure_c10_registerer_defined();
// Used to assert that unschematized operators have an analysis method written
TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym);
// A factory function to generate an optional operator. It has two
// instantiations depending on the template bool arg value. The arg can be a
// compile-time function for the selective op registration based on schema
// string.
template <typename Func>
c10::optional<Operator> OperatorGenerator(
const char* schema_str,
Func&& op,
AliasAnalysisKind alias_analysis) {
return c10::optional<Operator>(Operator(
std::string(schema_str), std::forward<Func>(op), alias_analysis));
}
template <typename Func>
c10::optional<Operator> OperatorGenerator(
torch::detail::SelectiveStr<true> schema_str,
Func&& op,
AliasAnalysisKind alias_analysis) {
return OperatorGenerator(
static_cast<const char*>(schema_str),
std::forward<Func>(op),
alias_analysis);
}
template <typename Func>
c10::optional<Operator> OperatorGenerator(
torch::detail::SelectiveStr<false> schema_str,
Func&& op,
AliasAnalysisKind alias_analysis) {
return c10::nullopt;
}
template <typename Func>
c10::optional<Operator> OperatorGenerator(
const std::string name,
const std::string overload_name,
const std::vector<c10::Argument> arguments,
const std::vector<c10::Argument> returns,
Func&& op,
AliasAnalysisKind alias_analysis) {
return c10::optional<Operator>(Operator(
name,
overload_name,
arguments,
returns,
std::forward<Func>(op),
alias_analysis));
}
} // namespace jit
} // namespace torch
|