1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
def basic_iteration_step(self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch):
r"""
A function that performs an iteration of training.
Args:
ddp_model (nn.Module): distributed data parallel model
criterion (nn.Module): loss function to measure model
optimizer (optim.Optimizer): updates model parameters
hook_state (object): ddp communication hook state object
epoch (int): index of pass through the data
index (int): iteration number - 1 in current batch
batch (list): training examples
"""
hook_state.next_batch()
self.record_batch_start(self.epoch_key(epoch, index))
optimizer.zero_grad()
self.record_forward_start(self.epoch_key(epoch, index))
loss = criterion(ddp_model(batch[0]), batch[1])
self.record_forward_end(self.epoch_key(epoch, index))
self.record_backward_start(self.epoch_key(epoch, index))
loss.backward()
self.record_backward_end(self.epoch_key(epoch, index))
optimizer.step()
self.record_batch_end(self.epoch_key(epoch, index))
|