File: device_interface.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (126 lines) | stat: -rw-r--r-- 3,414 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations

import time

import torch
from torch._dynamo import device_interface  # noqa: PLC2701 import-private-name


class DeviceProperties:
    def __init__(self) -> None:
        self.major = 8  # TODO: bypass check for H100 in triton_heuristics.py
        self.max_threads_per_multi_processor = 1
        self.multi_processor_count = 80


class DeviceInterface(device_interface.DeviceInterface):
    class Event(torch.Event):
        def __init__(
            self,
            enable_timing: bool = False,
            blocking: bool = False,
            interprocess: bool = False,
        ) -> None:
            self.enable_timing = enable_timing
            self.recorded_time: int | None = None

        def record(self, stream) -> None:
            if not self.enable_timing:
                return
            assert self.recorded_time is None
            self.recorded_time = time.perf_counter_ns()

        def elapsed_time(self, end_event: DeviceInterface.Event) -> float:
            assert self.recorded_time
            assert end_event.recorded_time
            # convert to ms
            return (end_event.recorded_time - self.recorded_time) / 1000000

        def wait(self, stream) -> None:
            pass

        def query(self) -> None:
            pass

        def synchronize(self) -> None:
            pass

    class device:  # noqa: N801 invalid-class-name # pyright: ignore [reportIncompatibleVariableOverride]
        def __init__(self, device) -> None:
            self.device = device

    class Worker(device_interface.DeviceInterface.Worker):
        @staticmethod
        def set_device(device: int) -> None:
            # No device index for our backend
            pass

        @staticmethod
        def current_device() -> int:
            # No device index for our backend
            return 0

        @staticmethod
        def get_device_properties(
            device=None,
        ) -> DeviceProperties:
            return DeviceProperties()

    @staticmethod
    def current_device() -> int:
        return 0

    @staticmethod
    def set_device(device) -> None:
        pass

    @staticmethod
    def device_count() -> int:
        raise NotImplementedError

    @staticmethod
    def maybe_exchange_device(device: int) -> int:
        assert (
            device == 0
        ), f"Only device index 0 is supported, tried to set index to {device}"
        return 0  # previous device is always 0

    @staticmethod
    def exchange_device(device: int) -> int:
        assert (
            device == 0
        ), f"Only device index 0 is supported, tried to set index to {device}"
        return 0  # previous device is always 0

    @staticmethod
    def current_stream():
        raise NotImplementedError

    @staticmethod
    def set_stream(stream) -> None:
        raise NotImplementedError

    @staticmethod
    def get_raw_stream(device_index: int):
        return None

    @staticmethod
    def synchronize(device) -> None:
        pass

    @staticmethod
    def get_device_properties(device) -> DeviceProperties:
        raise NotImplementedError

    # Can be mock patched by @patch decorator.
    @staticmethod
    def is_available() -> bool:
        return True

    @staticmethod
    def get_compute_capability(device) -> int:
        return 0

    @staticmethod
    def triton_supported() -> bool:
        return True