1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
|
# mypy: allow-untyped-defs
import torchvision
import torch
from torch.distributed._tools import MemoryTracker
def run_one_model(net: torch.nn.Module, input: torch.Tensor):
net.cuda()
input = input.cuda()
# Create the memory Tracker
mem_tracker = MemoryTracker()
# start_monitor before the training iteration starts
mem_tracker.start_monitor(net)
# run one training iteration
net.zero_grad(True)
loss = net(input)
if isinstance(loss, dict):
loss = loss["out"]
loss.sum().backward()
net.zero_grad(set_to_none=True)
# stop monitoring after the training iteration ends
mem_tracker.stop()
# print the memory stats summary
mem_tracker.summary()
# plot the memory traces at operator level
mem_tracker.show_traces()
run_one_model(torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda"))
|