File: executor_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (103 lines) | stat: -rw-r--r-- 3,039 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103




from caffe2.python import core, workspace
from caffe2.python.test.executor_test_util import (
    build_conv_model,
    build_resnet50_dataparallel_model,
    run_resnet50_epoch,
    ExecutorTestBase,
    executor_test_settings,
    executor_test_model_names)

from caffe2.python.test_util import TestCase

from hypothesis import given
import hypothesis.strategies as st

import unittest


EXECUTORS = ["parallel", "async_scheduling"]
ITERATIONS = 1


class ExecutorCPUConvNetTest(ExecutorTestBase):
    @given(executor=st.sampled_from(EXECUTORS),
           model_name=st.sampled_from(executor_test_model_names()),
           batch_size=st.sampled_from([1]),
           num_workers=st.sampled_from([8]))
    @executor_test_settings
    def test_executor(self, executor, model_name, batch_size, num_workers):
        model = build_conv_model(model_name, batch_size)
        model.Proto().num_workers = num_workers

        def run_model():
            iterations = ITERATIONS
            if model_name == "MLP":
                iterations = 1  # avoid numeric instability with MLP gradients
            workspace.RunNet(model.net, iterations)

        self.compare_executors(
            model,
            ref_executor="simple",
            test_executor=executor,
            model_run_func=run_model,
        )


@unittest.skipIf(not workspace.has_gpu_support, "no gpu")
class ExecutorGPUResNetTest(ExecutorTestBase):
    @given(executor=st.sampled_from(EXECUTORS),
           num_workers=st.sampled_from([8]))
    @executor_test_settings
    def test_executor(self, executor, num_workers):
        model = build_resnet50_dataparallel_model(
            num_gpus=workspace.NumGpuDevices(), batch_size=8, epoch_size=8)
        model.Proto().num_workers = num_workers

        def run_model():
            run_resnet50_epoch(model, batch_size=8, epoch_size=8)

        self.compare_executors(
            model,
            ref_executor="simple",
            test_executor=executor,
            model_run_func=run_model,
        )


class ExecutorFailingOpTest(TestCase):
    def test_failing_op(self):
        def create_failing_net(throw_exception):
            net = core.Net("failing_net")
            if throw_exception:
                net.ThrowException([], [])
            else:
                net.Fail([], [])
            net.Proto().type = "async_scheduling"
            return net

        workspace.ResetWorkspace()
        net = create_failing_net(throw_exception=True)
        workspace.CreateNet(net)
        with self.assertRaises(RuntimeError):
            workspace.RunNet(net)

        with self.assertRaises(RuntimeError):
            workspace.RunNet(net, allow_fail=True)

        workspace.ResetWorkspace()
        net = create_failing_net(throw_exception=False)
        workspace.CreateNet(net)

        with self.assertRaises(RuntimeError):
            workspace.RunNet(net)

        res = workspace.RunNet(net, allow_fail=True)
        self.assertFalse(res)


if __name__ == '__main__':
    unittest.main()