1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
|
# Owner(s): ["module: inductor"]
import copy
import importlib
import itertools
import os
import sys
import torch
from torch import nn
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
importlib.import_module("functorch")
importlib.import_module("filelock")
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
copy_tests,
)
class ConvOp(nn.Module):
expected_optimization_count = 1
def __init__(
self,
conv_class,
bn_class,
use_bias,
in_channels,
out_channels,
device,
**kwargs,
):
super().__init__()
self.conv = conv_class(in_channels, out_channels, bias=use_bias, **kwargs).to(
device
)
self.bn = bn_class(out_channels).to(device)
def forward(self, x):
x = self.conv(x)
return self.bn(x)
class MultiUserConvOp(nn.Module):
expected_optimization_count = 3
def __init__(
self,
conv_class,
bn_class,
use_bias,
in_channels,
out_channels,
device,
**kwargs,
):
super().__init__()
self.conv1 = conv_class(in_channels, out_channels, bias=use_bias, **kwargs).to(
device
)
self.bn1 = bn_class(out_channels).to(device)
self.conv2 = conv_class(out_channels, out_channels, bias=use_bias, **kwargs).to(
device
)
self.bn2 = bn_class(out_channels).to(device)
self.conv3 = conv_class(out_channels, out_channels, bias=use_bias, **kwargs).to(
device
)
self.bn3 = bn_class(out_channels).to(device)
def forward(self, x):
# this conv-bn pair can use efficient_conv_bn_eval
x = self.bn1(self.conv1(input=x))
# this conv-bn pair cannot use efficient_conv_bn_eval feature
# just for the second forward of the `self.conv2`
x = self.bn2(input=self.conv2(self.conv2(x)))
# this conv-bn pair can use efficient_conv_bn_eval feature
# just for the first forward of the `self.bn3`
# test for multiple users of one computation node
x = self.bn3(input=self.conv3(input=x))
x = self.bn3(x) + x
return x
class EfficientConvBNEvalTemplate(TestCase):
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
def test_basic(self):
def test_conv_bn_eval(
test_class, use_bias, module, sync_bn, decompose_nn_module
):
from functorch import make_fx
from torch._dispatch.python import enable_python_dispatcher
kwargs = {"kernel_size": 3, "stride": 2} if module[0] != nn.Linear else {}
mod_eager = test_class(
module[0],
module[1],
use_bias,
3,
32,
self.device,
**kwargs,
).eval()
# Copy module to test backward
mod_optimized = copy.deepcopy(mod_eager)
if sync_bn:
mod_eager = nn.SyncBatchNorm.convert_sync_batchnorm(mod_eager).eval()
mod_optimized = nn.SyncBatchNorm.convert_sync_batchnorm(
mod_optimized
).eval()
torch._dynamo.reset()
inps = [4, 3]
# Conv shape goes from big to small, and ConvTranspose shape goes from small to big
spatial_d = (
4 if issubclass(module[0], nn.modules.conv._ConvTransposeNd) else 96
)
if module[0] == nn.Conv1d or module[0] == nn.ConvTranspose1d:
inps += [spatial_d] * 1
if module[0] == nn.Conv2d or module[0] == nn.ConvTranspose2d:
inps += [spatial_d] * 2
if module[0] == nn.Conv3d or module[0] == nn.ConvTranspose3d:
inps += [spatial_d] * 3
inp = torch.rand(inps).to(self.device)
if decompose_nn_module:
with enable_python_dispatcher():
mod_optimized = make_fx(mod_optimized, pre_dispatch=True)(inp)
mod_optimized = torch.compile(mod_optimized)
original_value = counters["inductor"]["efficient_conv_bn_eval"]
optim_eager = torch.optim.SGD(mod_eager.parameters(), lr=1e-3)
optim_optimized = torch.optim.SGD(mod_optimized.parameters(), lr=1e-3)
optim_eager.zero_grad()
optim_optimized.zero_grad()
# test forward
out_eager = mod_eager(inp)
out_optimized = mod_optimized(inp)
self.assertEqual(out_optimized, out_eager, atol=3e-04, rtol=1e-5)
out_eager.mean().backward()
out_optimized.mean().backward()
optim_eager.step()
optim_optimized.step()
# test forward (by testing forward again after one training iteration)
inp_bw = torch.rand_like(inp)
out_eager_bw = mod_eager(inp_bw)
out_optimized_bw = mod_optimized(inp_bw)
self.assertEqual(out_eager_bw, out_optimized_bw, atol=3e-04, rtol=1e-5)
current_value = counters["inductor"]["efficient_conv_bn_eval"]
self.assertEqual(
current_value - original_value, test_class.expected_optimization_count
)
conv_bias = [True, False]
modules = [
(nn.Linear, nn.BatchNorm1d),
(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d),
(nn.ConvTranspose1d, nn.BatchNorm1d),
(nn.ConvTranspose2d, nn.BatchNorm2d),
(nn.ConvTranspose3d, nn.BatchNorm3d),
]
test_classes = [ConvOp, MultiUserConvOp]
sync_bns = [False, True]
decompose_nn_modules = [False, True]
for (
test_class,
use_bias,
module,
sync_bn,
decompose_nn_module,
) in itertools.product(
test_classes,
conv_bias,
modules,
sync_bns,
decompose_nn_modules,
):
test_conv_bn_eval(
test_class, use_bias, module, sync_bn, decompose_nn_module
)
if HAS_CPU and not torch.backends.mps.is_available():
class EfficientConvBNEvalCpuTests(TestCase):
device = "cpu"
copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalCpuTests, "cpu")
if HAS_GPU and not TEST_WITH_ASAN:
class EfficientConvBNEvalGpuTests(TestCase):
device = GPU_TYPE
copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalGpuTests, GPU_TYPE)
del EfficientConvBNEvalTemplate
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU or HAS_GPU:
run_tests(needs="filelock")
|