File: test_pad_mm.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (137 lines) | stat: -rw-r--r-- 5,632 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import sys
import unittest


sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from expecttest import TestCase

from test_utils import read_file_to_string, run_bash  # type: ignore[import-not-found]


class TestPadMM(TestCase):
    def test_padmm_a100(self) -> None:
        run_bash("get_padmm_dataset.sh")
        run_bash("gen_pad_mm_a100.sh")
        file_path = "../../../torch/_inductor/autoheuristic/artifacts/_PadMMA100.py"
        a100_heuristic_generated_code = read_file_to_string(file_path)

        self.assertExpectedInline(
            a100_heuristic_generated_code,
            """\
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
from torch._inductor.autoheuristic.learnedheuristic_interface import (
    LearnedHeuristicRegression,
)


class PadMMA100(LearnedHeuristicRegression):

    def __init__(self) -> None:
        pass

    def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
        return (
            metadata.name == self.get_name()
            and metadata.shared_memory == 166912
            and str(metadata.device_capa) == "(8, 0)"
        )

    def get_feedback(self, context: AHContext, choice: Choice) -> float:
        context.context_dict[CHOICE_COL] = choice
        return self.predict(context)

    def get_confidence_threshold(self) -> float:
        return 1.7025303314066

    def get_name(self) -> str:
        return 'pad_mm'

    def predict(self, context: AHContext) -> float:
        if str(context.get_value('choice')) != 'pad':
            if str(context.get_value('using_tf32')) != 'False':
                if context.get_value('m*n') <= 4171264.0:
                    if context.get_value('m*k') <= 3999308.0:
                        return 1.8751469764071178
                    else:
                        if str(context.get_value('n_multiple_32')) != 'True':
                            return 0.9117231355626345
                        else:
                            return 1.1607689608873861
                else:
                    if str(context.get_value('n_multiple_2')) != 'True':
                        if str(context.get_value('using_tf32')) != 'True':
                            return 0.7430382200435992
                        else:
                            return 0.8531269794448678
                    else:
                        if str(context.get_value('k_multiple_2')) != 'True':
                            return 0.7577181972719917
                        else:
                            return 0.8977349440424219
            else:
                if context.get_value('m*n') <= 1299712.0:
                    return 1.1669723418995592
                else:
                    if context.get_value('mat2_stride_1') <= 45217.5:
                        if context.get_value('m*n') <= 55884158.0:
                            return 1.0262769936909601
                        else:
                            return 1.0022677428470845
                    else:
                        if context.get_value('m') <= 18478.0:
                            return 1.1127066261894312
                        else:
                            return 1.0337740659894263
        else:
            if str(context.get_value('mat1_dtype')) != 'torch.float32':
                if str(context.get_value('n_multiple_2')) != 'False':
                    if str(context.get_value('k_multiple_2')) != 'True':
                        if context.get_value('mat1_stride_0') <= 561.0:
                            return 1.2900382135142956
                        else:
                            return 1.5761737616057887
                    else:
                        if context.get_value('num_dims_needs_padding') <= 1.5:
                            return 1.0472263310239422
                        else:
                            return 1.1727673465762514
                else:
                    if context.get_value('k') <= 28238.5:
                        if context.get_value('k/(m*n)') <= 0.00026227018679492176:
                            return 1.6770542505397175
                        else:
                            return 1.3974785435105923
                    else:
                        if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
                            return 1.3952699800111992
                        else:
                            return 1.5759286511628336
            else:
                if str(context.get_value('using_tf32')) != 'False':
                    if context.get_value('m*n') <= 14119424.0:
                        return 0.8875772670422478
                    else:
                        if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
                            return 1.1467728924377265
                        else:
                            return 1.215842963532998
                else:
                    if context.get_value('arith_intensity') <= 396.8774871826172:
                        return 0.89940161869551
                    else:
                        if context.get_value('mat2_stride_1') <= 45217.5:
                            return 0.9964328169353532
                        else:
                            return 0.9493479238294826
""",
        )


if __name__ == "__main__":
    unittest.main()