File: test_fr_analysis.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (116 lines) | stat: -rw-r--r-- 3,426 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Owner(s): ["oncall: distributed"]

import pathlib
import sys


REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent

sys.path.insert(0, str(REPO_ROOT))
from tools.flight_recorder.components.types import MatchState
from tools.flight_recorder.components.utils import match_one_event


# Make sure to remove REPO_ROOT after import is done
sys.path.remove(str(REPO_ROOT))

from torch.testing._internal.common_utils import run_tests, TestCase


def create_one_event(
    collectcive_name,
    pg_info,
    input_sizes,
    output_sizes,
    state="scheduled",
    collective_seq_id=0,
    p2p_seq_id=0,
    output_dtypes="float32",
):
    return {
        "profiling_name": f"nccl:{collectcive_name}",
        "state": state,
        "process_group": pg_info,
        "input_sizes": input_sizes,
        "output_sizes": output_sizes,
        "input_dtypes": "float32",
        "output_dtypes": output_dtypes,
        "collective_seq_id": str(collective_seq_id),
        "p2p_seq_id": str(p2p_seq_id),
        "time_created_ns": 0,
        "frames": [],
    }


class FlightRecorderEventTest(TestCase):
    def test_match_one_event(self):
        e1 = create_one_event(
            "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
        )
        membership = {"0": {0, 1}}
        self.assertEqual(
            match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED
        )

        e2 = create_one_event(
            "all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
        )
        self.assertEqual(
            match_one_event(e1, e2, membership, "0"),
            MatchState.COLLECTIVE_TYPE_MISMATCH,
        )

        e3 = create_one_event(
            "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
        )
        e4 = create_one_event(
            "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
        )
        self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED)

        e5 = create_one_event(
            "all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1
        )
        self.assertEqual(
            match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
        )

        e6 = create_one_event(
            "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2
        )
        self.assertEqual(
            match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
        )

        e7 = create_one_event(
            "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2
        )
        self.assertEqual(
            match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
        )

        e9 = create_one_event(
            "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1
        )
        self.assertEqual(
            match_one_event(e1, e9, membership, "0"),
            MatchState.COLLECTIVE_STATE_MISMATCH,
        )

        e10 = create_one_event(
            "all_reduce",
            ("0", "default"),
            [[4, 4]],
            [[4, 4]],
            "completed",
            1,
            output_dtypes="float16",
        )
        self.assertEqual(
            match_one_event(e10, e9, membership, "0"),
            MatchState.COLLECTIVE_DTYPE_MISMATCH,
        )


if __name__ == "__main__":
    run_tests()