1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
|
# Owner(s): ["module: onnx"]
import torch
import torch.nn as nn
class DummyNet(nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.features = nn.Sequential(
nn.LeakyReLU(0.02),
nn.BatchNorm2d(3),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False),
)
def forward(self, x):
output = self.features(x)
return output.view(-1, 1).squeeze(1)
class ConcatNet(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs):
return torch.cat(inputs, 1)
class PermuteNet(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input.permute(2, 3, 0, 1)
class PReluNet(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.PReLU(3),
)
def forward(self, x):
output = self.features(x)
return output
class FakeQuantNet(nn.Module):
def __init__(self):
super().__init__()
self.fake_quant = torch.ao.quantization.FakeQuantize()
self.fake_quant.disable_observer()
def forward(self, x):
output = self.fake_quant(x)
return output
|