File: gemmgelu_test.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (65 lines) | stat: -rw-r--r-- 1,667 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import unittest

import numpy as np

from tests.common import onnx_script_test_case
from tests.functions import gemmgelu


class TestGemmGelu(onnx_script_test_case.OnnxScriptTestCase):
    def test_gemmgelu(self):
        np.random.seed(0)
        m = 2
        k = 4
        n = 8
        a = np.random.rand(k, m).astype("float32").T
        w = np.random.rand(n, k).astype("float32").T
        b = (
            np.random.rand(
                n,
            )
            .astype("float32")
            .T
        )

        # FIXME(liqunfu): expected are from ort evaluation.
        # needs numpy oxs to provide expected instead.
        expected = np.array(
            [
                [
                    1.6088762,
                    1.2583977,
                    1.868434,
                    1.530172,
                    1.5025945,
                    1.5770031,
                    0.93028706,
                    1.4389044,
                ],
                [
                    2.2128997,
                    1.3670988,
                    2.4269097,
                    2.1586964,
                    1.9926084,
                    2.0960782,
                    1.2971772,
                    2.0846245,
                ],
            ],
            dtype=np.float32,
        )

        cases = [
            onnx_script_test_case.FunctionTestParams(gemmgelu.gemmgelu, [a, w, b], [expected])
        ]
        for case in cases:
            self.run_converter_test(case, rtol=1e-6)
            self.run_eager_test(case)


if __name__ == "__main__":
    unittest.main()