File: test_aten_pow.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (92 lines) | stat: -rw-r--r-- 4,381 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Owner(s): ["oncall: jit"]

import torch
from torch.testing._internal.common_utils import TestCase

class TestAtenPow(TestCase):
    def test_aten_pow_zero_negative_exponent(self):
        '''
        1. Testing a = int, b = int
        '''
        @torch.jit.script
        def fn_int_int(a: int, b: int):
            return a ** b
        # Existing correct behaviors of aten::pow
        self.assertEqual(fn_int_int(2, 1), 2 ** 1)
        self.assertEqual(fn_int_int(2, 0), 2 ** 0)
        self.assertEqual(fn_int_int(2, -2), 2 ** (-2))
        self.assertEqual(fn_int_int(-2, 2), (-2) ** 2)
        self.assertEqual(fn_int_int(-2, 0), (-2) ** 0)
        self.assertEqual(fn_int_int(-2, -2), (-2) ** (-2))
        self.assertEqual(fn_int_int(-2, -1), (-2) ** (-1))
        self.assertEqual(fn_int_int(0, 2), 0 ** 1)
        self.assertEqual(fn_int_int(0, 0), 0 ** 0)
        # zero base and negative exponent case that should trigger RunTimeError
        self.assertRaises(RuntimeError, fn_int_int, 0, -2)

        '''
        2. Testing a = int, b = float
        '''
        @torch.jit.script
        def fn_int_float(a: int, b: float):
            return a ** b
        # Existing correct behaviors of aten::pow
        self.assertEqual(fn_int_float(2, 2.5), 2 ** 2.5)
        self.assertEqual(fn_int_float(2, -2.5), 2 ** (-2.5))
        self.assertEqual(fn_int_float(2, -0.0), 2 ** (-0.0))
        self.assertEqual(fn_int_float(2, 0.0), 2 ** (0.0))
        self.assertEqual(fn_int_float(-2, 2.0), (-2) ** 2.0)
        self.assertEqual(fn_int_float(-2, -2.0), (-2) ** (-2.0))
        self.assertEqual(fn_int_float(-2, -3.0), (-2) ** (-3.0))
        self.assertEqual(fn_int_float(-2, -0.0), (-2) ** (-0.0))
        self.assertEqual(fn_int_float(-2, 0.0), (-2) ** (0.0))
        self.assertEqual(fn_int_float(0, 2.0), 0 ** 2.0)
        self.assertEqual(fn_int_float(0, 0.5), 0 ** 0.5)
        self.assertEqual(fn_int_float(0, 0.0), 0 ** 0.0)
        self.assertEqual(fn_int_float(0, -0.0), 0 ** (-0.0))
        # zero base and negative exponent case that should trigger RunTimeError
        self.assertRaises(RuntimeError, fn_int_float, 0, -2.5)

        '''
        3. Testing a = float, b = int
        '''
        @torch.jit.script
        def fn_float_int(a: float, b: int):
            return a ** b
        # Existing correct behaviors of aten::pow
        self.assertEqual(fn_float_int(2.5, 2), 2.5 ** 2)
        self.assertEqual(fn_float_int(2.5, -2), 2.5 ** (-2))
        self.assertEqual(fn_float_int(2.5, -0), 2.5 ** (-0))
        self.assertEqual(fn_float_int(2.5, 0), 2.5 ** 0)
        self.assertEqual(fn_float_int(-2.5, 2), 2.5 ** 2)
        self.assertEqual(fn_float_int(-2.5, -2), (-2.5) ** (-2))
        self.assertEqual(fn_float_int(-2.5, -3), (-2.5) ** (-3))
        self.assertEqual(fn_float_int(-2.5, -0), (-2.5) ** (-0))
        self.assertEqual(fn_float_int(-2.5, 0), (-2.5) ** 0)
        self.assertEqual(fn_float_int(0.0, 2), 0 ** 2)
        self.assertEqual(fn_float_int(0.0, 0), 0 ** 0)
        self.assertEqual(fn_float_int(0.0, -0), 0 ** (-0))
        # zero base and negative exponent case that should trigger RunTimeError
        self.assertRaises(RuntimeError, fn_float_int, 0.0, -2)

        '''
        4. Testing a = float, b = float
        '''
        @torch.jit.script
        def fn_float_float(a: float, b: float):
            return a ** b
        # Existing correct behaviors of aten::pow
        self.assertEqual(fn_float_float(2.5, 2.0), 2.5 ** 2.0)
        self.assertEqual(fn_float_float(2.5, -2.0), 2.5 ** (-2.0))
        self.assertEqual(fn_float_float(2.5, -0.0), 2.5 ** (-0.0))
        self.assertEqual(fn_float_float(2.5, 0.0), 2.5 ** 0.0)
        self.assertEqual(fn_float_float(-2.5, 2.0), 2.5 ** 2.0)
        self.assertEqual(fn_float_float(-2.5, -2.0), (-2.5) ** (-2.0))
        self.assertEqual(fn_float_float(-2.5, -3.0), (-2.5) ** (-3.0))
        self.assertEqual(fn_float_float(-2.5, -0.0), (-2.5) ** (-0.0))
        self.assertEqual(fn_float_float(-2.5, 0.0), (-2.5) ** 0.0)
        self.assertEqual(fn_float_float(0.0, 2.0), 0.0 ** 2.0)
        self.assertEqual(fn_float_float(0.0, 0.0), 0.0 ** 0.0)
        self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
        # zero base and negative exponent case that should trigger RunTimeError
        self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)