1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
|
import functools
import time
from abc import ABC, abstractmethod
from metrics.MetricsLogger import MetricsLogger
import torch
class TrainerBase(ABC):
BATCH_LEVEL_METRIC = "batch_level_metric"
BATCH_ALL = "batch_all"
FORWARD_METRIC = "forward_metric"
FORWARD_PASS = "forward_pass"
BACKWARD_METRIC = "backward_metric"
BACKWARD = "backward"
def __init__(self, rank):
r"""
Inits TrainerBase class.
Args:
rank (int): worker rank
"""
self.__metrics_logger = MetricsLogger(rank)
@abstractmethod
def train(self):
r"""
A method to be implemented by child class that will train a neural network.
"""
return
def record_start(self, type, key, name, cuda=True):
r"""
A method that records the start event for a metric.
Args:
type (str): group id for metric
key (str): unique id for metric within a group
name (str): description of the metric
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
type,
key,
name,
cuda
)
def record_end(self, type, key):
r"""
A method that records the end event for a metric.
Args:
type (str): group id for metric
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
type,
key
)
def record_batch_start(self, key, cuda=True):
r"""
A helper method that records a batch metric for the
given key. A user should call this at the start of an
iteration step during training.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.BATCH_LEVEL_METRIC,
key,
self.BATCH_ALL,
cuda
)
def record_batch_end(self, key):
r"""
A helper method that records a batch metric for the
given key. A user should call this at the end of an
iteration step during training.
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
self.BATCH_LEVEL_METRIC,
key
)
def record_forward_start(self, key, cuda=True):
r"""
A helper method that records a forward metric
for the given key. A user should call this before
their neural network forward.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.FORWARD_METRIC,
key,
self.FORWARD_PASS,
cuda
)
def record_forward_end(self, key):
r"""
A helper method that records a forward metric
for the given key. A user should call this after their
neural network forward.
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
self.FORWARD_METRIC,
key
)
def record_backward_start(self, key, cuda=True):
r"""
A helper method that records a backward metric
for the given key. A user should call this before
their .backward() call.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.BACKWARD_METRIC,
key,
self.BACKWARD,
cuda
)
def record_backward_end(self, key):
r"""
A helper method that records a backward metric
for the given key. A user should call this after
.backward().
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
self.BACKWARD_METRIC,
key
)
@staticmethod
def methodmetric(name, type="method_metric", cuda=True):
r"""
A decorator that records a metric for the decorated method.
Args:
name (str): description of the metric
type (str): group id for metric
cuda (bool): indicator to determine if this is a CUDA metric
"""
def decorator(function):
@functools.wraps(function)
def wrapper(self, *args):
key = time.time()
self.__metrics_logger.record_start(type, key, name, cuda)
result = function(self, *args)
self.__metrics_logger.record_end(type, key)
return result
return wrapper
return decorator
def get_metrics(self):
r"""
A method that returns metrics captured by the __metrics_logger.
"""
return self.__metrics_logger.get_processed_metrics()
def clear_metrics(self):
r"""
A method that clears __metrics_logger recorded metrics.
"""
return self.__metrics_logger.clear_metrics()
class DdpTrainer(TrainerBase):
def __init__(
self,
process_group,
use_cuda_rpc,
server_rref,
backend,
epochs,
preprocess_data,
create_criterion,
create_ddp_model,
hook_state_class,
hook,
iteration_step
):
r"""
A trainer that implements a DDP training algorithm using a simple hook that performs allreduce
using the process_group implementation.
Args:
process_group (ProcessGroup): distributed process group
use_cuda_rpc (bool): indicator for CUDA RPC
server_rref (RRef): remote reference to the server
backend (str): distributed communication backend
epochs (int): epoch count for training
preprocess_data (function): preprocesses data passed
to the trainer before starting training
create_criterion (function): creates a criterion to calculate loss
create_ddp_model (function): creates a ddp model for the trainer
hook_state_class (class): class that will be used to keep tracking of state
during training.
hook (function): ddp communication hook
iteration_step (function): will perform 1 step of training
"""
super().__init__(process_group.rank())
self.process_group = process_group
self.use_cuda_rpc = use_cuda_rpc
self.server_rref = server_rref
self.backend = backend
self.epochs = epochs
self.preprocess_data = preprocess_data
self.create_criterion = create_criterion
self.create_ddp_model = create_ddp_model
self.hook_state_class = hook_state_class
self.hook = hook
self.iteration_step = iteration_step
self.rank = process_group.rank()
self.trainer_count = process_group.size()
def epoch_key(self, epoch, index):
r"""
A method that returns an encoded key that represents the current epoch and
iteration index.
Args:
epoch (int): epoch index
index (int): iteration index
"""
return f"{epoch},{index}"
def train(self, model, data):
r"""
A method that implements the training algorithm.
Args:
model (nn.Module): neural network model
data (list): training examples
"""
model = model.cuda(self.rank)
data = self.preprocess_data(self.rank, data)
criterion = self.create_criterion(self.rank)
ddp_model, hook_state = self.create_ddp_model(
self, self.rank, model, self.process_group, self.hook_state_class, self.hook
)
optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
for epoch in range(self.epochs):
if epoch % 5 == 0 and self.rank == 0:
print(f"train epoch={epoch}")
for index, batch in enumerate(data):
self.iteration_step(
self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch
)
torch.cuda.synchronize(self.rank)
|