1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import numpy as np
from tests.common import onnx_script_test_case
from tests.functions import gemmgelu
class TestGemmGelu(onnx_script_test_case.OnnxScriptTestCase):
def test_gemmgelu(self):
np.random.seed(0)
m = 2
k = 4
n = 8
a = np.random.rand(k, m).astype("float32").T
w = np.random.rand(n, k).astype("float32").T
b = (
np.random.rand(
n,
)
.astype("float32")
.T
)
# FIXME(liqunfu): expected are from ort evaluation.
# needs numpy oxs to provide expected instead.
expected = np.array(
[
[
1.6088762,
1.2583977,
1.868434,
1.530172,
1.5025945,
1.5770031,
0.93028706,
1.4389044,
],
[
2.2128997,
1.3670988,
2.4269097,
2.1586964,
1.9926084,
2.0960782,
1.2971772,
2.0846245,
],
],
dtype=np.float32,
)
cases = [
onnx_script_test_case.FunctionTestParams(gemmgelu.gemmgelu, [a, w, b], [expected])
]
for case in cases:
self.run_converter_test(case, rtol=1e-6)
self.run_eager_test(case)
if __name__ == "__main__":
unittest.main()
|