1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
|
import torch
import argparse
from common import SubTensor, WithTorchFunction, SubWithTorchFunction # noqa: F401
Tensor = torch.tensor
NUM_REPEATS = 1000000
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the torch.add for a given class a given number of times."
)
parser.add_argument(
"tensor_class", metavar="TensorClass", type=str, help="The class to benchmark."
)
parser.add_argument(
"--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats."
)
args = parser.parse_args()
TensorClass = globals()[args.tensor_class]
NUM_REPEATS = args.nreps
t1 = TensorClass([1.])
t2 = TensorClass([2.])
for _ in range(NUM_REPEATS):
torch.add(t1, t2)
|