1# Owner(s): ["module: dynamo"] 2 3import unittest 4 5import torch 6import torch._dynamo as torchdynamo 7from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TestCase 8 9 10try: 11 import tabulate # noqa: F401 # type: ignore[import] 12 13 from torch.utils.benchmark.utils.compile import bench_all 14 15 HAS_TABULATE = True 16except ImportError: 17 HAS_TABULATE = False 18 19 20@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 21@unittest.skipIf(not HAS_TABULATE, "tabulate not available") 22class TestCompileBenchmarkUtil(TestCase): 23 def test_training_and_inference(self): 24 class ToyModel(torch.nn.Module): 25 def __init__(self) -> None: 26 super().__init__() 27 self.weight = torch.nn.Parameter(torch.Tensor(2, 2)) 28 29 def forward(self, x): 30 return x * self.weight 31 32 torchdynamo.reset() 33 model = ToyModel().cuda() 34 35 inference_table = bench_all(model, torch.ones(1024, 2, 2).cuda(), 5) 36 self.assertTrue( 37 "Inference" in inference_table 38 and "Eager" in inference_table 39 and "-" in inference_table 40 ) 41 42 training_table = bench_all( 43 model, 44 torch.ones(1024, 2, 2).cuda(), 45 5, 46 optimizer=torch.optim.SGD(model.parameters(), lr=0.01), 47 ) 48 self.assertTrue( 49 "Train" in training_table 50 and "Eager" in training_table 51 and "-" in training_table 52 ) 53 54 55if __name__ == "__main__": 56 run_tests() 57