1# mypy: allow-untyped-defs 2import torchvision 3 4import torch 5from torch.distributed._tools import MemoryTracker 6 7 8def run_one_model(net: torch.nn.Module, input: torch.Tensor): 9 net.cuda() 10 input = input.cuda() 11 12 # Create the memory Tracker 13 mem_tracker = MemoryTracker() 14 # start_monitor before the training iteration starts 15 mem_tracker.start_monitor(net) 16 17 # run one training iteration 18 net.zero_grad(True) 19 loss = net(input) 20 if isinstance(loss, dict): 21 loss = loss["out"] 22 loss.sum().backward() 23 net.zero_grad(set_to_none=True) 24 25 # stop monitoring after the training iteration ends 26 mem_tracker.stop() 27 # print the memory stats summary 28 mem_tracker.summary() 29 # plot the memory traces at operator level 30 mem_tracker.show_traces() 31 32 33run_one_model(torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda")) 34