File: test_pytorch_export_helpers.py

package info (click to toggle)
onnxruntime 1.23.2%2Bdfsg-6
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 340,756 kB
  • sloc: cpp: 3,222,136; python: 188,267; ansic: 114,318; asm: 37,927; cs: 36,849; java: 10,962; javascript: 6,811; pascal: 4,126; sh: 2,996; xml: 705; objc: 281; makefile: 67
file content (42 lines) | stat: -rw-r--r-- 1,520 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import unittest

import torch

from ..pytorch_export_helpers import infer_input_info

# example usage from <ort root>/tools/python
# python -m unittest util/test/test_pytorch_export_helpers.py
# NOTE: at least on Windows you must use that as the working directory for all the imports to be happy


class TestModel(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super().__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x, min=0, max=1):
        step1 = self.linear1(x).clamp(min=min, max=max)
        step2 = self.linear2(step1)
        return step2


class TestInferInputs(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._model = TestModel(1000, 100, 10)
        cls._input = torch.randn(1, 1000)

    def test_positional(self):
        # test we can infer the input names from the forward method when positional args are used
        input_names, inputs_as_tuple = infer_input_info(self._model, self._input, 0, 1)
        self.assertEqual(input_names, ["x", "min", "max"])

    def test_keywords(self):
        # test that we sort keyword args and the inputs to match the module
        input_names, inputs_as_tuple = infer_input_info(self._model, self._input, max=1, min=0)
        self.assertEqual(input_names, ["x", "min", "max"])
        self.assertEqual(inputs_as_tuple, (self._input, 0, 1))