File: test_tree_utils.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (49 lines) | stat: -rw-r--r-- 1,865 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Owner(s): ["oncall: export"]
from collections import OrderedDict

import torch
from torch._dynamo.test_case import TestCase
from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.testing._internal.common_utils import run_tests
from torch.utils._pytree import tree_structure


class TestTreeUtils(TestCase):
    def test_reorder_kwargs(self):
        original_kwargs = {"a": torch.tensor(0), "b": torch.tensor(1)}
        user_kwargs = {"b": torch.tensor(2), "a": torch.tensor(3)}
        orig_spec = tree_structure(((), original_kwargs))

        reordered_kwargs = reorder_kwargs(user_kwargs, orig_spec)

        # Key ordering should be the same
        self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
        self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),

    def test_equivalence_check(self):
        tree1 = {"a": torch.tensor(0), "b": torch.tensor(1), "c": None}
        tree2 = OrderedDict(a=torch.tensor(0), b=torch.tensor(1), c=None)
        spec1 = tree_structure(tree1)
        spec2 = tree_structure(tree2)

        def dict_ordered_dict_eq(type1, context1, type2, context2):
            if type1 is None or type2 is None:
                return type1 is type2 and context1 == context2

            if issubclass(type1, (dict, OrderedDict)) and issubclass(
                type2, (dict, OrderedDict)
            ):
                return context1 == context2

            return type1 is type2 and context1 == context2

        self.assertTrue(is_equivalent(spec1, spec2, dict_ordered_dict_eq))

        # Wrong ordering should still fail
        tree3 = OrderedDict(b=torch.tensor(1), a=torch.tensor(0))
        spec3 = tree_structure(tree3)
        self.assertFalse(is_equivalent(spec1, spec3, dict_ordered_dict_eq))


if __name__ == "__main__":
    run_tests()