1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
# Sparse benchmarks
# This benchmark is for sparse matmul performance test.
# They exist for comparing the performance of sparse matrix routines
# `sparse @ vector`, `sparse @ sparse` and `sparse @ dense` with different backends (CPU/CUDA)
# and with other frameworks such as scipy.
import sys
import argparse
import torch
import torch.utils.benchmark as benchmark_utils
from .utils import load_dlmc_dataset
from scipy.sparse import isspmatrix
import os
def scipy_matmul(mat1, mat2):
if isspmatrix(mat1) and isspmatrix(mat2):
return mat1.dot(mat2).tocoo()
return mat1.dot(mat2)
def matmul_backward(a_dense, b_dense, grad_output):
r1 = a_dense.matmul(b_dense)
r1.backward(grad_output)
def sparse_matmul_backward(a, b, grad_output):
c = torch.sparse.mm(a, b)
c.backward(grad_output)
OPS_MAP = {
"sparse@sparse": "torch.sparse.mm",
"sparse@dense": "torch.matmul",
"sparse@vector": "torch.matmul",
}
# also get the arguments as input from the user using `argparse`
def parse_args():
parser = argparse.ArgumentParser(description='matmul benchmark')
parser.add_argument('--path', type=str, help='DLMC dataset path')
parser.add_argument('--dataset', type=str, default='magnitude_pruning')
parser.add_argument('--hidden_size', default=2048, type=int)
parser.add_argument('--backward_test', action="store_true")
parser.add_argument('--operation', type=str, help="|".join(OPS_MAP.keys()), default=next(iter(OPS_MAP)))
parser.add_argument('--with_cuda', action='store_true')
parser.add_argument('--timer_min_run_time', default=1, type=float)
return parser
def get_tasks(op, backward_test, device):
def filter_ops(operation):
if backward_test:
test_name = device + ":matmul-backward"
return [
(test_name, device, "torch:" + operation.replace("sparse", "dense"),
"matmul_backward(dx, dy, grad_output)"),
(test_name, device, "torch:" + operation, "sparse_matmul_backward(x, y, sparse_grad_output)")
]
else:
test_name = device + ":matmul-forward"
return list(filter(None, [
(test_name, device, "torch:" + operation.replace("sparse", "dense"),
"{}(dx, dy)".format(OPS_MAP[operation])),
(test_name, device, "torch:" + operation, "{}(x, y)".format(OPS_MAP[operation])),
(test_name, device, "scipy:" + operation, "scipy_matmul(sx, sy)") if device == "cpu" else None
]))
all_operations = {
"sparse@sparse": filter_ops("sparse@sparse"),
"sparse@dense": filter_ops("sparse@dense"),
"sparse@vector": filter_ops("sparse@vector"),
}
return all_operations[op]
if __name__ == '__main__':
parser = parse_args()
args = parser.parse_args()
if args.with_cuda and not torch.cuda.is_available():
raise RuntimeError("No CUDA available")
dataset_path = args.path
dataset_name = args.dataset
dataset_path = os.path.join(dataset_path, dataset_name)
device = 'cuda' if args.with_cuda else 'cpu'
tasks = get_tasks(args.operation, args.backward_test, device)
repeats = 3
timers = [
benchmark_utils.Timer(
stmt=stmt,
globals={
"scipy_matmul": scipy_matmul,
"matmul_backward": matmul_backward,
"sparse_matmul_backward": sparse_matmul_backward,
**variables
},
label=label,
sub_label=sub_label,
description=f"{sparsity}",
env=device,
)
for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98]
for label, device, sub_label, stmt in tasks
for variables in
load_dlmc_dataset(dataset_path, args.operation, args.hidden_size, sparsity, device, args.backward_test)
]
measurements = []
for i, timer in enumerate(timers * repeats):
m = timer.blocked_autorange(min_run_time=args.timer_min_run_time)
m.metadata = {
"device": 'cuda' if m.task_spec.env.find("cuda") >= 0 else 'cpu'
}
measurements.append(m)
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
print()
comparison = benchmark_utils.Compare(measurements)
print("== Results " + "=" * 80 + "\n" + "/" * 95 + "\n")
comparison.print()
|