1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
"""Define some common setup blocks which benchmarks can reuse."""
import enum
from core.api import GroupedSetup
from core.utils import parse_stmts
_TRIVIAL_2D = GroupedSetup(
r"x = torch.ones((4, 4))",
r"auto x = torch::ones({4, 4});"
)
_TRIVIAL_3D = GroupedSetup(
r"x = torch.ones((4, 4, 4))",
r"auto x = torch::ones({4, 4, 4});"
)
_TRIVIAL_4D = GroupedSetup(
r"x = torch.ones((4, 4, 4, 4))",
r"auto x = torch::ones({4, 4, 4, 4});"
)
_TRAINING = GroupedSetup(*parse_stmts(
r"""
Python | C++
---------------------------------------- | ----------------------------------------
# Inputs | // Inputs
x = torch.ones((1,)) | auto x = torch::ones({1});
y = torch.ones((1,)) | auto y = torch::ones({1});
|
# Weights | // Weights
w0 = torch.ones( | auto w0 = torch::ones({1});
(1,), requires_grad=True) | w0.set_requires_grad(true);
w1 = torch.ones( | auto w1 = torch::ones({1});
(1,), requires_grad=True) | w1.set_requires_grad(true);
w2 = torch.ones( | auto w2 = torch::ones({2});
(2,), requires_grad=True) | w2.set_requires_grad(true);
"""
))
class Setup(enum.Enum):
TRIVIAL_2D = _TRIVIAL_2D
TRIVIAL_3D = _TRIVIAL_3D
TRIVIAL_4D = _TRIVIAL_4D
TRAINING = _TRAINING
|