1# Owner(s): ["module: unknown"] 2import unittest 3from dataclasses import dataclass 4from typing import Any, Callable, cast, Tuple, Union 5 6import torch 7from torch import nn, optim 8from torch._subclasses.fake_tensor import FakeTensorMode 9from torch.distributed._tools.runtime_estimator import RuntimeEstimator 10from torch.testing._internal.common_cuda import TEST_CUDA 11from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 12from torch.testing._internal.distributed._tensor.common_dtensor import ( 13 ModelArgs, 14 Transformer, 15) 16 17 18@dataclass 19class ConvArgs: 20 image_size: int 21 num_classes: int 22 23 24class SimpleCNN(nn.Module): 25 def __init__(self, conv_args: ConvArgs): 26 super().__init__() 27 image_size = conv_args.image_size 28 num_classes = conv_args.num_classes 29 self.image_size = image_size 30 self.conv1 = nn.Conv2d(3, 32, kernel_size=5) 31 self.pool = nn.MaxPool2d(2, 2) 32 self.conv2 = nn.Conv2d(32, 64, kernel_size=5) 33 self.conv3 = nn.Conv2d(64, 128, kernel_size=3) 34 self.conv4 = nn.Conv2d(128, 256, kernel_size=3) 35 self.fc1_size = self._calculate_fc1_size() 36 self.fc1 = nn.Linear(self.fc1_size, 512) 37 self.fc2 = nn.Linear(512, 256) 38 self.fc3 = nn.Linear(256, num_classes) 39 40 def _calculate_fc1_size(self): 41 size = self.image_size 42 size = (size - 5 + 1) // 2 # conv1 and pool 43 size = (size - 5 + 1) // 2 # conv2 and pool 44 size = size - 3 + 1 # conv3 45 size = (size - 3 + 1) // 2 # conv4 and pool 46 return 512 * size * size 47 48 def forward(self, x): 49 x = self.pool(nn.functional.relu(self.conv1(x))) 50 x = self.pool(nn.functional.relu(self.conv2(x))) 51 x = nn.functional.relu(self.conv3(x)) 52 x = self.pool(nn.functional.relu(self.conv4(x))) 53 x = x.view(-1, self.fc1_size) 54 x = nn.functional.relu(self.fc1(x)) 55 x = nn.functional.relu(self.fc2(x)) 56 x = self.fc3(x) 57 return x 58 59 60class TestRuntimeEstimator(TestCase): 61 def _train_step( 62 self, 63 model: nn.Module, 64 optimizer: optim.Optimizer, 65 inp: torch.Tensor, 66 ): 67 out = model(inp) 68 loss = out.sum() 69 loss.backward() 70 optimizer.step() 71 optimizer.zero_grad() 72 73 def _measure_actual_cuda_time( 74 self, 75 func: Callable, 76 args: Tuple[Any, ...], 77 ) -> float: 78 warmup_iters, actual_iters = 2, 5 79 start_event = torch.cuda.Event(enable_timing=True) 80 end_event = torch.cuda.Event(enable_timing=True) 81 for _ in range(warmup_iters): 82 func(*args) 83 start_event.record() 84 for _ in range(actual_iters): 85 func(*args) 86 end_event.record() 87 torch.cuda.synchronize() 88 measured_time = start_event.elapsed_time(end_event) / actual_iters 89 return measured_time 90 91 def _runtime_estimate( 92 self, 93 estimate_mode: str, 94 func: Callable, 95 args: Tuple[Any, ...], 96 ) -> float: 97 # Optimizer init step 98 func(*args) 99 runtime_estimator = RuntimeEstimator() 100 with runtime_estimator(estimate_mode_type=estimate_mode): 101 func(*args) 102 return runtime_estimator.total_runtime 103 104 def _init_model_and_args( 105 self, 106 model_type: str, 107 model_args: Union[ConvArgs, ModelArgs], 108 bsz: int, 109 ) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]: 110 dev = torch.cuda.current_device() 111 if model_type == "Transformer": 112 model_args = cast(ModelArgs, model_args) 113 with torch.device(dev): 114 model = Transformer(model_args) 115 optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True) 116 inp = torch.randint( 117 0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev 118 ) 119 elif model_type == "CNN": 120 model_args = cast(ConvArgs, model_args) 121 with torch.device(dev): 122 model = SimpleCNN(model_args) 123 optimizer = optim.SGD(model.parameters(), lr=1e-2, foreach=True) 124 inp = torch.randn( 125 bsz, 3, model_args.image_size, model_args.image_size, device=dev 126 ) 127 else: 128 raise NotImplementedError("Only Transformer and CNN is supported") 129 return (model, optimizer, inp) 130 131 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") 132 @unittest.skipIf(not TEST_CUDA, "CUDA not available") 133 def test_transformer_runtime( 134 self, 135 ): 136 """Runs a basic GPT-2 model""" 137 vocab_size = 8192 138 bsz, seq_len = 8, 1024 139 model_args = ModelArgs( 140 n_layers=4, 141 n_heads=12, 142 vocab_size=vocab_size, 143 max_seq_len=seq_len, 144 dim=768, 145 dropout_p=0.1, 146 ) 147 148 args = self._init_model_and_args("Transformer", model_args, bsz) 149 actual_runtime = self._measure_actual_cuda_time(self._train_step, args) 150 with FakeTensorMode(): 151 fake_args = self._init_model_and_args("Transformer", model_args, bsz) 152 benchmark_estimate = self._runtime_estimate( 153 "operator-level-benchmark", self._train_step, fake_args 154 ) 155 roofline_estimate = self._runtime_estimate( 156 "operator-level-cost-model", self._train_step, fake_args 157 ) 158 benchmark_accuracy = actual_runtime / benchmark_estimate 159 roofline_accuracy = actual_runtime / roofline_estimate 160 print( 161 f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}" 162 f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" 163 ) 164 self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) 165 self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3) 166 167 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") 168 @unittest.skipIf(not TEST_CUDA, "CUDA not available") 169 def test_conv_model_runtime( 170 self, 171 ): 172 """Runs a simple CNN model""" 173 num_classes = 100 174 bsz, img_sz = 256, 128 175 model_args = ConvArgs(img_sz, num_classes) 176 args = self._init_model_and_args("CNN", model_args, bsz) 177 actual_runtime = self._measure_actual_cuda_time(self._train_step, args) 178 with FakeTensorMode(): 179 fake_args = self._init_model_and_args("CNN", model_args, bsz) 180 benchmark_estimate = self._runtime_estimate( 181 "operator-level-benchmark", self._train_step, fake_args 182 ) 183 roofline_estimate = self._runtime_estimate( 184 "operator-level-cost-model", self._train_step, fake_args 185 ) 186 benchmark_accuracy = actual_runtime / benchmark_estimate 187 roofline_accuracy = actual_runtime / roofline_estimate 188 print( 189 f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n" 190 f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" 191 ) 192 self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) 193 self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4) 194 195 196if __name__ == "__main__": 197 run_tests() 198