File: benchmark_base.py

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 161,672 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (259 lines) | stat: -rw-r--r-- 8,937 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import csv
import gc
import json
import os
from abc import ABC, abstractmethod

from fbscribelogger import make_scribe_logger

import torch._C._instruction_counter as i_counter
import torch._dynamo.config as config
from torch._dynamo.utils import CompileTimeInstructionCounter


scribe_log_torch_benchmark_compile_time = make_scribe_logger(
    "TorchBenchmarkCompileTime",
    """
struct TorchBenchmarkCompileTimeLogEntry {

  # The commit SHA that triggered the workflow, e.g., 02a6b1d30f338206a71d0b75bfa09d85fac0028a. Derived from GITHUB_SHA.
  4: optional string commit_sha;

  # The unit timestamp in second for the Scuba Time Column override
  6: optional i64 time;
  7: optional i64 instruction_count; # Instruction count of compilation step
  8: optional string name; # Benchmark name

  # Commit date (not author date) of the commit in commit_sha as timestamp, e.g., 1724208105.  Increasing if merge bot is used, though not monotonic; duplicates occur when stack is landed.
  16: optional i64 commit_date;

  # A unique number for each workflow run within a repository, e.g., 19471190684. Derived from GITHUB_RUN_ID.
  17: optional string github_run_id;

  # A unique number for each attempt of a particular workflow run in a repository, e.g., 1. Derived from GITHUB_RUN_ATTEMPT.
  18: optional string github_run_attempt;

  # Indicates if branch protections or rulesets are configured for the ref that triggered the workflow run. Derived from GITHUB_REF_PROTECTED.
  20: optional bool github_ref_protected;

  # The fully-formed ref of the branch or tag that triggered the workflow run, e.g., refs/pull/133891/merge or refs/heads/main. Derived from GITHUB_REF.
  21: optional string github_ref;

  # The weight of the record according to current sampling rate
  25: optional i64 weight;

  # The name of the current job. Derived from JOB_NAME, e.g., linux-jammy-py3.8-gcc11 / test (default, 3, 4, linux.2xlarge).
  26: optional string github_job;

  # The GitHub user who triggered the job.  Derived from GITHUB_TRIGGERING_ACTOR.
  27: optional string github_triggering_actor;

  # A unique number for each run of a particular workflow in a repository, e.g., 238742. Derived from GITHUB_RUN_NUMBER.
  28: optional string github_run_number_str;
}
""",  # noqa: B950
)


class BenchmarkBase(ABC):
    # Measure total number of instruction spent in _work.
    # Garbage collection is NOT disabled during _work().
    _enable_instruction_count = False

    # Measure total number of instruction spent in convert_frame.compile_inner
    # Garbage collection is disabled during _work() to avoid noise.
    _enable_compile_time_instruction_count = False

    # number of iterations used to run when collecting instruction_count or compile_time_instruction_count.
    _num_iterations = 5

    def __init__(
        self,
        category: str,
        device: str,
        backend: str = "",
        mode: str = "",
        dynamic=None,
    ):
        # These individual attributes are used to support different filters on the
        # dashboard later
        self._category = category
        self._device = device
        self._backend = backend
        self._mode = mode  # Training or inference
        self._dynamic = dynamic

    def with_iterations(self, value):
        self._num_iterations = value
        return self

    def enable_instruction_count(self):
        self._enable_instruction_count = True
        return self

    def enable_compile_time_instruction_count(self):
        self._enable_compile_time_instruction_count = True
        return self

    def name(self):
        return ""

    def backend(self):
        return self._backend

    def mode(self):
        return self._mode

    def category(self):
        return self._category

    def device(self):
        return self._device

    def is_dynamic(self):
        return self._dynamic

    def description(self):
        return ""

    @abstractmethod
    def _prepare(self):
        pass

    @abstractmethod
    def _work(self):
        pass

    def _prepare_once(self):  # noqa: B027
        pass

    def _count_instructions(self):
        print(f"collecting instruction count for {self.name()}")
        results = []
        for i in range(self._num_iterations):
            self._prepare()
            id = i_counter.start()
            self._work()
            count = i_counter.end(id)
            print(f"instruction count for iteration {i} is {count}")
            results.append(count)
        return min(results)

    def _count_compile_time_instructions(self):
        gc.disable()

        try:
            print(f"collecting compile time instruction count for {self.name()}")
            config.record_compile_time_instruction_count = True

            results = []
            for i in range(self._num_iterations):
                self._prepare()
                gc.collect()
                # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner
                # hence this will only count instruction count spent in compile_inner.
                CompileTimeInstructionCounter.clear()
                self._work()
                count = CompileTimeInstructionCounter.value()
                if count == 0:
                    raise RuntimeError(
                        "compile time instruction count is 0, please check your benchmarks"
                    )
                print(f"compile time instruction count for iteration {i} is {count}")
                results.append(count)

            config.record_compile_time_instruction_count = False
            return min(results)
        finally:
            gc.enable()

    def _write_to_json(self, output_dir: str):
        """
        Write the result into JSON format, so that it can be uploaded to the benchmark database
        to be displayed on OSS dashboard. The JSON format is defined at
        https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
        """
        records = []
        for entry in self.results:
            metric_name = entry[1]
            value = entry[2]

            if not metric_name or value is None:
                continue

            records.append(
                {
                    "benchmark": {
                        "name": "pr_time_benchmarks",
                        "mode": self.mode(),
                        "extra_info": {
                            "is_dynamic": self.is_dynamic(),
                            "device": self.device(),
                            "description": self.description(),
                        },
                    },
                    "model": {
                        "name": self.name(),
                        "type": self.category(),
                        "backend": self.backend(),
                    },
                    "metric": {
                        "name": metric_name,
                        "benchmark_values": [value],
                    },
                }
            )

        with open(os.path.join(output_dir, f"{self.name()}.json"), "w") as f:
            json.dump(records, f)

    def append_results(self, path):
        with open(path, "a", newline="") as csvfile:
            # Create a writer object
            writer = csv.writer(csvfile)
            # Write the data to the CSV file
            for entry in self.results:
                writer.writerow(entry)

        # TODO (huydhn) This requires the path to write to, so it needs to be in the same place
        # as the CSV writer for now
        self._write_to_json(os.path.dirname(os.path.abspath(path)))

    def print(self):
        for entry in self.results:
            print(f"{entry[0]},{entry[1]},{entry[2]}")

    def collect_all(self):
        self._prepare_once()
        self.results = []
        if (
            self._enable_instruction_count
            and self._enable_compile_time_instruction_count
        ):
            raise RuntimeError(
                "not supported until we update the logger, both logs to the same field now"
            )

        if self._enable_instruction_count:
            r = self._count_instructions()
            self.results.append((self.name(), "instruction_count", r))
            scribe_log_torch_benchmark_compile_time(
                name=self.name(),
                instruction_count=r,
            )
        if self._enable_compile_time_instruction_count:
            r = self._count_compile_time_instructions()

            self.results.append(
                (
                    self.name(),
                    "compile_time_instruction_count",
                    r,
                )
            )
            # TODO add a new field compile_time_instruction_count to the logger.
            scribe_log_torch_benchmark_compile_time(
                name=self.name(),
                instruction_count=r,
            )
        return self