1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
|
#pragma once
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch,
1,
SparseLengthsSumFused8BitRowwise);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchMatMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, ExpandDims);
// Iterate over schema from ai.onnx.pytorch domain opset 1
class OpSet_PyTorch_ver1 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, SparseLengthsSumFused8BitRowwise)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, SparseLengthsSum)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, SparseLengthsWeightedSum)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, BatchGather)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, DotProduct)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, FCTransposed)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, BatchMatMul)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, ExpandDims)>());
}
};
inline void RegisterPyTorchOperatorSetSchema() {
RegisterOpSetSchema<OpSet_PyTorch_ver1>();
}
} // namespace ONNX_NAMESPACE
|