File: parallel_case.py

package info (click to toggle)
python3.14 3.14.0~rc1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 126,824 kB
  • sloc: python: 745,274; ansic: 713,752; xml: 31,250; sh: 5,822; cpp: 4,063; makefile: 1,988; objc: 787; lisp: 502; javascript: 136; asm: 75; csh: 12
file content (78 lines) | stat: -rw-r--r-- 2,778 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Run a test case multiple times in parallel threads."""

import copy
import threading
import unittest

from unittest import TestCase


class ParallelTestCase(TestCase):
    def __init__(self, test_case: TestCase, num_threads: int):
        self.test_case = test_case
        self.num_threads = num_threads
        self._testMethodName = test_case._testMethodName
        self._testMethodDoc = test_case._testMethodDoc

    def __str__(self):
        return f"{str(self.test_case)} [threads={self.num_threads}]"

    def run_worker(self, test_case: TestCase, result: unittest.TestResult,
                   barrier: threading.Barrier):
        barrier.wait()
        test_case.run(result)

    def run(self, result=None):
        if result is None:
            result = test_case.defaultTestResult()
            startTestRun = getattr(result, 'startTestRun', None)
            stopTestRun = getattr(result, 'stopTestRun', None)
            if startTestRun is not None:
                startTestRun()
        else:
            stopTestRun = None

        # Called at the beginning of each test. See TestCase.run.
        result.startTest(self)

        cases = [copy.copy(self.test_case) for _ in range(self.num_threads)]
        results = [unittest.TestResult() for _ in range(self.num_threads)]

        barrier = threading.Barrier(self.num_threads)
        threads = []
        for i, (case, r) in enumerate(zip(cases, results)):
            thread = threading.Thread(target=self.run_worker,
                                      args=(case, r, barrier),
                                      name=f"{str(self.test_case)}-{i}",
                                      daemon=True)
            threads.append(thread)

        for thread in threads:
            thread.start()

        for threads in threads:
            threads.join()

        # Aggregate test results
        if all(r.wasSuccessful() for r in results):
            result.addSuccess(self)

        # Note: We can't call result.addError, result.addFailure, etc. because
        # we no longer have the original exception, just the string format.
        for r in results:
            if len(r.errors) > 0 or len(r.failures) > 0:
                result._mirrorOutput = True
            result.errors.extend(r.errors)
            result.failures.extend(r.failures)
            result.skipped.extend(r.skipped)
            result.expectedFailures.extend(r.expectedFailures)
            result.unexpectedSuccesses.extend(r.unexpectedSuccesses)
            result.collectedDurations.extend(r.collectedDurations)

        if any(r.shouldStop for r in results):
            result.stop()

        # Test has finished running
        result.stopTest(self)
        if stopTestRun is not None:
            stopTestRun()