File: onnxfns_test.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (69 lines) | stat: -rw-r--r-- 2,031 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import unittest

from tests.common import onnx_script_test_case
from tests.models import onnxfns1


class TestOnnxFns(onnx_script_test_case.OnnxScriptTestCase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.rtol = 1e-05

    def test_onnxfns_relu(self):
        self.run_onnx_test(onnxfns1.Relu)

    def test_onnxfns_selu(self):
        default_alpha = 1.67326319217681884765625
        default_gamma = 1.05070102214813232421875

        self.run_onnx_test(onnxfns1.Selu, alpha=default_alpha, gamma=default_gamma)

    def test_onnxfns_elu(self):
        default_alpha = 1.0
        self.run_onnx_test(onnxfns1.Elu, alpha=default_alpha)

    def test_onnxfns_thresholded_relu(self):
        default_alpha = 1.0
        self.run_onnx_test(onnxfns1.ThresholdedRelu, alpha=default_alpha)

    def test_onnxfns_leaky_relu(self):
        default_alpha = 0.01
        self.run_onnx_test(onnxfns1.LeakyRelu, alpha=default_alpha)

    def test_onnxfns_prelu(self):
        self.run_onnx_test(onnxfns1.PRelu)

    def test_onnxfns_hard_sigmoid(self):
        default_alpha = 0.2
        default_beta = 0.5
        self.run_onnx_test(onnxfns1.HardSigmoid, alpha=default_alpha, beta=default_beta)

    def test_onnxfns_hard_shrink(self):
        default_bias = 0.0
        default_lambd = 0.5
        self.run_onnx_test(onnxfns1.Shrink, bias=default_bias, lambd=default_lambd)

    def test_onnxfns_hard_softplus(self):
        self.run_onnx_test(onnxfns1.Softplus)

    def test_onnxfns_hard_softsign(self):
        self.run_onnx_test(onnxfns1.Softsign)

    def test_onnxfns_hard_clip(self):
        self.run_onnx_test(
            onnxfns1.Clip,
            skip_eager_test=True,
            skip_test_names=[
                "test_clip_default_int8_min",
                "test_clip_default_int8_max",
                "test_clip_default_int8_inbounds",
            ],
        )


if __name__ == "__main__":
    unittest.main()