1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
import os
import sys
import unittest
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from expecttest import TestCase
from test_utils import read_file_to_string, run_bash # type: ignore[import-not-found]
class TestPadMM(TestCase):
def test_padmm_a100(self) -> None:
run_bash("get_padmm_dataset.sh")
run_bash("gen_pad_mm_a100.sh")
file_path = "../../../torch/_inductor/autoheuristic/artifacts/_PadMMA100.py"
a100_heuristic_generated_code = read_file_to_string(file_path)
self.assertExpectedInline(
a100_heuristic_generated_code,
"""\
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicRegression,
)
class PadMMA100(LearnedHeuristicRegression):
def __init__(self) -> None:
pass
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
return (
metadata.name == self.get_name()
and metadata.shared_memory == 166912
and str(metadata.device_capa) == "(8, 0)"
)
def get_feedback(self, context: AHContext, choice: Choice) -> float:
context.context_dict[CHOICE_COL] = choice
return self.predict(context)
def get_confidence_threshold(self) -> float:
return 1.7025303314066
def get_name(self) -> str:
return 'pad_mm'
def predict(self, context: AHContext) -> float:
if str(context.get_value('choice')) != 'pad':
if str(context.get_value('using_tf32')) != 'False':
if context.get_value('m*n') <= 4171264.0:
if context.get_value('m*k') <= 3999308.0:
return 1.8751469764071178
else:
if str(context.get_value('n_multiple_32')) != 'True':
return 0.9117231355626345
else:
return 1.1607689608873861
else:
if str(context.get_value('n_multiple_2')) != 'True':
if str(context.get_value('using_tf32')) != 'True':
return 0.7430382200435992
else:
return 0.8531269794448678
else:
if str(context.get_value('k_multiple_2')) != 'True':
return 0.7577181972719917
else:
return 0.8977349440424219
else:
if context.get_value('m*n') <= 1299712.0:
return 1.1669723418995592
else:
if context.get_value('mat2_stride_1') <= 45217.5:
if context.get_value('m*n') <= 55884158.0:
return 1.0262769936909601
else:
return 1.0022677428470845
else:
if context.get_value('m') <= 18478.0:
return 1.1127066261894312
else:
return 1.0337740659894263
else:
if str(context.get_value('mat1_dtype')) != 'torch.float32':
if str(context.get_value('n_multiple_2')) != 'False':
if str(context.get_value('k_multiple_2')) != 'True':
if context.get_value('mat1_stride_0') <= 561.0:
return 1.2900382135142956
else:
return 1.5761737616057887
else:
if context.get_value('num_dims_needs_padding') <= 1.5:
return 1.0472263310239422
else:
return 1.1727673465762514
else:
if context.get_value('k') <= 28238.5:
if context.get_value('k/(m*n)') <= 0.00026227018679492176:
return 1.6770542505397175
else:
return 1.3974785435105923
else:
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
return 1.3952699800111992
else:
return 1.5759286511628336
else:
if str(context.get_value('using_tf32')) != 'False':
if context.get_value('m*n') <= 14119424.0:
return 0.8875772670422478
else:
if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
return 1.1467728924377265
else:
return 1.215842963532998
else:
if context.get_value('arith_intensity') <= 396.8774871826172:
return 0.89940161869551
else:
if context.get_value('mat2_stride_1') <= 45217.5:
return 0.9964328169353532
else:
return 0.9493479238294826
""",
)
if __name__ == "__main__":
unittest.main()
|