File: __init__.py

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (97 lines) | stat: -rw-r--r-- 3,442 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import sys
import torch


def is_built():
    r"""Returns whether PyTorch is built with CUDA support.  Note that this
    doesn't necessarily mean CUDA is available; just that if this PyTorch
    binary were run a machine with working CUDA drivers and devices, we
    would be able to use it."""
    return torch._C.has_cuda


class cuFFTPlanCacheAttrContextProp(object):
    # Like regular ContextProp, but uses the `.device_index` attribute from the
    # calling object as the first argument to the getter and setter.
    def __init__(self, getter, setter):
        self.getter = getter
        self.setter = setter

    def __get__(self, obj, objtype):
        return self.getter(obj.device_index)

    def __set__(self, obj, val):
        if isinstance(self.setter, str):
            raise RuntimeError(self.setter)
        self.setter(obj.device_index, val)


class cuFFTPlanCache(object):
    r"""
    Represents a specific plan cache for a specific `device_index`. The
    attributes `size` and `max_size`, and method `clear`, can fetch and/ or
    change properties of the C++ cuFFT plan cache.
    """
    def __init__(self, device_index):
        self.device_index = device_index

    size = cuFFTPlanCacheAttrContextProp(
        torch._cufft_get_plan_cache_size,
        '.size is a read-only property showing the number of plans currently in the '
        'cache. To change the cache capacity, set cufft_plan_cache.max_size.')

    max_size = cuFFTPlanCacheAttrContextProp(torch._cufft_get_plan_cache_max_size,
                                             torch._cufft_set_plan_cache_max_size)

    def clear(self):
        return torch._cufft_clear_plan_cache(self.device_index)


class cuFFTPlanCacheManager(object):
    r"""
    Represents all cuFFT plan caches. When indexed with a device object/index,
    this object returns the `cuFFTPlanCache` corresponding to that device.

    Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g.,
    setting the `.max_size`) attribute, the current device's cuFFT plan cache is
    used.
    """

    __initialized = False

    def __init__(self):
        self.caches = []
        self.__initialized = True

    def __getitem__(self, device):
        index = torch.cuda._utils._get_device_index(device)
        if index < 0 or index >= torch.cuda.device_count():
            raise RuntimeError(
                ("cufft_plan_cache: expected 0 <= device index < {}, but got "
                 "device with index {}").format(torch.cuda.device_count(), index))
        if len(self.caches) == 0:
            self.caches.extend(cuFFTPlanCache(index) for index in range(torch.cuda.device_count()))
        return self.caches[index]

    def __getattr__(self, name):
        return getattr(self[torch.cuda.current_device()], name)

    def __setattr__(self, name, value):
        if self.__initialized:
            return setattr(self[torch.cuda.current_device()], name, value)
        else:
            return super(cuFFTPlanCacheManager, self).__setattr__(name, value)


class cuBLASModule:
    def __getattr__(self, name):
        assert name == "allow_tf32", "Unknown attribute " + name
        return torch._C._get_cublas_allow_tf32()

    def __setattr__(self, name, value):
        assert name == "allow_tf32", "Unknown attribute " + name
        return torch._C._set_cublas_allow_tf32(value)


cufft_plan_cache = cuFFTPlanCacheManager()
matmul = cuBLASModule()