File: coordinator.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (143 lines) | stat: -rw-r--r-- 5,377 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import time

import numpy as np
from agent import AgentBase
from observer import ObserverBase

import torch
import torch.distributed.rpc as rpc


COORDINATOR_NAME = "coordinator"
AGENT_NAME = "agent"
OBSERVER_NAME = "observer{}"

EPISODE_STEPS = 100


class CoordinatorBase:
    def __init__(self, batch_size, batch, state_size, nlayers, out_features):
        r"""
        Coordinator object to run on worker.  Only one coordinator exists.  Responsible
        for facilitating communication between agent and observers and recording benchmark
        throughput and latency data.
        Args:
            batch_size (int): Number of observer requests to process in a batch
            batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
            state_size (list): List of ints dictating the dimensions of the state
            nlayers (int): Number of layers in the model
            out_features (int): Number of out features in the model
        """
        self.batch_size = batch_size
        self.batch = batch

        self.agent_rref = None  # Agent RRef
        self.ob_rrefs = []  # Observer RRef

        agent_info = rpc.get_worker_info(AGENT_NAME)
        self.agent_rref = rpc.remote(agent_info, AgentBase)

        for rank in range(batch_size):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2))
            ob_ref = rpc.remote(ob_info, ObserverBase)
            self.ob_rrefs.append(ob_ref)

            ob_ref.rpc_sync().set_state(state_size, batch)

        self.agent_rref.rpc_sync().set_world(
            batch_size, state_size, nlayers, out_features, self.batch
        )

    def run_coordinator(self, episodes, episode_steps, queue):
        r"""
        Runs n benchmark episodes.  Each episode is started by coordinator telling each
        observer to contact the agent.  Each episode is concluded by coordinator telling agent
        to finish the episode, and then the coordinator records benchmark data
        Args:
            episodes (int): Number of episodes to run
            episode_steps (int): Number steps to be run in each episdoe by each observer
            queue (SimpleQueue): SimpleQueue from torch.multiprocessing.get_context() for
                                 saving benchmark run results to
        """

        agent_latency_final = []
        agent_throughput_final = []

        observer_latency_final = []
        observer_throughput_final = []

        for ep in range(episodes):
            ep_start_time = time.time()

            print(f"Episode {ep} - ", end="")

            n_steps = episode_steps

            futs = []
            for ob_rref in self.ob_rrefs:
                futs.append(
                    ob_rref.rpc_async().run_ob_episode(self.agent_rref, n_steps)
                )

            rets = torch.futures.wait_all(futs)
            agent_latency, agent_throughput = self.agent_rref.rpc_sync().finish_episode(
                rets
            )

            self.agent_rref.rpc_sync().reset_metrics()

            agent_latency_final += agent_latency
            agent_throughput_final += agent_throughput

            observer_latency_final += [ret[2] for ret in rets]
            observer_throughput_final += [ret[3] for ret in rets]

            ep_end_time = time.time()
            episode_time = ep_end_time - ep_start_time
            print(round(episode_time, 3))

        observer_latency_final = [t for s in observer_latency_final for t in s]
        observer_throughput_final = [t for s in observer_throughput_final for t in s]

        benchmark_metrics = {
            "agent latency (seconds)": {},
            "agent throughput": {},
            "observer latency (seconds)": {},
            "observer throughput": {},
        }

        print(f"For batch size {self.batch_size}")
        print("\nAgent Latency - ", len(agent_latency_final))
        agent_latency_final = sorted(agent_latency_final)
        for p in [50, 75, 90, 95]:
            v = np.percentile(agent_latency_final, p)
            print("p" + str(p) + ":", round(v, 3))
            p = f"p{p}"
            benchmark_metrics["agent latency (seconds)"][p] = round(v, 3)

        print("\nAgent Throughput - ", len(agent_throughput_final))
        agent_throughput_final = sorted(agent_throughput_final)
        for p in [50, 75, 90, 95]:
            v = np.percentile(agent_throughput_final, p)
            print("p" + str(p) + ":", int(v))
            p = f"p{p}"
            benchmark_metrics["agent throughput"][p] = int(v)

        print("\nObserver Latency - ", len(observer_latency_final))
        observer_latency_final = sorted(observer_latency_final)
        for p in [50, 75, 90, 95]:
            v = np.percentile(observer_latency_final, p)
            print("p" + str(p) + ":", round(v, 3))
            p = f"p{p}"
            benchmark_metrics["observer latency (seconds)"][p] = round(v, 3)

        print("\nObserver Throughput - ", len(observer_throughput_final))
        observer_throughput_final = sorted(observer_throughput_final)
        for p in [50, 75, 90, 95]:
            v = np.percentile(observer_throughput_final, p)
            print("p" + str(p) + ":", int(v))
            p = f"p{p}"
            benchmark_metrics["observer throughput"][p] = int(v)

        if queue:
            queue.put(benchmark_metrics)