• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import torchvision
3
4import torch
5from torch.distributed._tools import MemoryTracker
6
7
8def run_one_model(net: torch.nn.Module, input: torch.Tensor):
9    net.cuda()
10    input = input.cuda()
11
12    # Create the memory Tracker
13    mem_tracker = MemoryTracker()
14    # start_monitor before the training iteration starts
15    mem_tracker.start_monitor(net)
16
17    # run one training iteration
18    net.zero_grad(True)
19    loss = net(input)
20    if isinstance(loss, dict):
21        loss = loss["out"]
22    loss.sum().backward()
23    net.zero_grad(set_to_none=True)
24
25    # stop monitoring after the training iteration ends
26    mem_tracker.stop()
27    # print the memory stats summary
28    mem_tracker.summary()
29    # plot the memory traces at operator level
30    mem_tracker.show_traces()
31
32
33run_one_model(torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda"))
34