1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
from collections import namedtuple
from typing import Callable, Optional
import lightning.pytorch as pl
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
Batch = namedtuple("Batch", ["inputs", "labels"])
class SSLPretrainModule(pl.LightningModule):
def __init__(
self,
model: nn.Module,
loss_fn: Callable,
optimizer: Optimizer,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
):
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def log_metrics(self, batch: Batch, output, loss, step_type):
"""Log useful information to TensorBoard. Users are expected to
write their customized `log_metrics` method to log information
such as loss values, metric scores, etc.
Args:
batch (Batch): Batch tuple from the dataloader.
output: Output generated by the model.
loss (Tensor): Generated class
step_type (str): Type of step. Choices are "train", "val", and "test".
"""
pass
def training_step(self, batch: Batch, batch_idx):
out = self.model(*batch.inputs)
loss, num_frame = self.loss_fn(*out, *batch.labels)
self.log_metric(batch, out, loss, "train")
# normalize the loss based on the sum of num_frame across all GPUs
num_frames = self.all_gather(num_frame)
self.log(
"Gathered number of frames",
num_frames.float().sum(),
on_step=True,
on_epoch=True,
)
loss *= num_frames.size(0) / num_frames.sum() # world size / num_frames
return loss
def validation_step(self, batch, batch_idx):
out = self.model(*batch.inputs)
loss, _ = self.loss_fn(*out, *batch.labels)
self.log_metric(batch, out, loss, "val")
return loss
|