File: test_hash.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (112 lines) | stat: -rw-r--r-- 3,797 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Owner(s): ["oncall: jit"]

import os
import sys

import torch

from typing import Tuple, List

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == "__main__":
    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
                       "\tpython test/test_jit.py TESTNAME\n\n"
                       "instead.")

class TestHash(JitTestCase):
    def test_hash_tuple(self):
        def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
            return hash(t1) == hash(t2)

        self.checkScript(fn, ((1, 2), (1, 2)))
        self.checkScript(fn, ((1, 2), (3, 4)))
        self.checkScript(fn, ((1, 2), (2, 1)))

    def test_hash_tuple_nested_unhashable_type(self):
        # Tuples may contain unhashable types like `list`, check that we error
        # properly in that case.
        @torch.jit.script
        def fn_unhashable(t1: Tuple[int, List[int]]):
            return hash(t1)

        with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"):
            fn_unhashable((1, [1]))

    def test_hash_tensor(self):
        """Tensors should hash by identity"""
        def fn(t1, t2):
            return hash(t1) == hash(t2)

        tensor1 = torch.tensor(1)
        tensor1_clone = torch.tensor(1)
        tensor2 = torch.tensor(2)

        self.checkScript(fn, (tensor1, tensor1))
        self.checkScript(fn, (tensor1, tensor1_clone))
        self.checkScript(fn, (tensor1, tensor2))

    def test_hash_none(self):
        def fn():
            n1 = None
            n2 = None
            return hash(n1) == hash(n2)

        self.checkScript(fn, ())

    def test_hash_bool(self):
        def fn(b1: bool, b2: bool):
            return hash(b1) == hash(b2)

        self.checkScript(fn, (True, False))
        self.checkScript(fn, (True, True))
        self.checkScript(fn, (False, True))
        self.checkScript(fn, (False, False))

    def test_hash_float(self):
        def fn(f1: float, f2: float):
            return hash(f1) == hash(f2)

        self.checkScript(fn, (1.2345, 1.2345))
        self.checkScript(fn, (1.2345, 6.789))
        self.checkScript(fn, (1.2345, float("inf")))
        self.checkScript(fn, (float("inf"), float("inf")))
        self.checkScript(fn, (1.2345, float('nan')))
        if sys.version_info < (3, 10):
            # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html :
            # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity.
            self.checkScript(fn, (float("nan"), float("nan")))
        self.checkScript(fn, (float("nan"), float("inf")))

    def test_hash_int(self):
        def fn(i1: int, i2: int):
            return hash(i1) == hash(i2)

        self.checkScript(fn, (123, 456))
        self.checkScript(fn, (123, 123))
        self.checkScript(fn, (123, -123))
        self.checkScript(fn, (-123, -123))
        self.checkScript(fn, (123, 0))

    def test_hash_string(self):
        def fn(s1: str, s2: str):
            return hash(s1) == hash(s2)

        self.checkScript(fn, ("foo", "foo"))
        self.checkScript(fn, ("foo", "bar"))
        self.checkScript(fn, ("foo", ""))

    def test_hash_device(self):
        def fn(d1: torch.device, d2: torch.device):
            return hash(d1) == hash(d2)

        gpu0 = torch.device('cuda:0')
        gpu1 = torch.device('cuda:1')
        cpu = torch.device('cpu')
        self.checkScript(fn, (gpu0, gpu0))
        self.checkScript(fn, (gpu0, gpu1))
        self.checkScript(fn, (gpu0, cpu))
        self.checkScript(fn, (cpu, cpu))