• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: unknown"]
2import unittest
3from dataclasses import dataclass
4from typing import Any, Callable, cast, Tuple, Union
5
6import torch
7from torch import nn, optim
8from torch._subclasses.fake_tensor import FakeTensorMode
9from torch.distributed._tools.runtime_estimator import RuntimeEstimator
10from torch.testing._internal.common_cuda import TEST_CUDA
11from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
12from torch.testing._internal.distributed._tensor.common_dtensor import (
13    ModelArgs,
14    Transformer,
15)
16
17
18@dataclass
19class ConvArgs:
20    image_size: int
21    num_classes: int
22
23
24class SimpleCNN(nn.Module):
25    def __init__(self, conv_args: ConvArgs):
26        super().__init__()
27        image_size = conv_args.image_size
28        num_classes = conv_args.num_classes
29        self.image_size = image_size
30        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
31        self.pool = nn.MaxPool2d(2, 2)
32        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
33        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
34        self.conv4 = nn.Conv2d(128, 256, kernel_size=3)
35        self.fc1_size = self._calculate_fc1_size()
36        self.fc1 = nn.Linear(self.fc1_size, 512)
37        self.fc2 = nn.Linear(512, 256)
38        self.fc3 = nn.Linear(256, num_classes)
39
40    def _calculate_fc1_size(self):
41        size = self.image_size
42        size = (size - 5 + 1) // 2  # conv1 and pool
43        size = (size - 5 + 1) // 2  # conv2 and pool
44        size = size - 3 + 1  # conv3
45        size = (size - 3 + 1) // 2  # conv4 and pool
46        return 512 * size * size
47
48    def forward(self, x):
49        x = self.pool(nn.functional.relu(self.conv1(x)))
50        x = self.pool(nn.functional.relu(self.conv2(x)))
51        x = nn.functional.relu(self.conv3(x))
52        x = self.pool(nn.functional.relu(self.conv4(x)))
53        x = x.view(-1, self.fc1_size)
54        x = nn.functional.relu(self.fc1(x))
55        x = nn.functional.relu(self.fc2(x))
56        x = self.fc3(x)
57        return x
58
59
60class TestRuntimeEstimator(TestCase):
61    def _train_step(
62        self,
63        model: nn.Module,
64        optimizer: optim.Optimizer,
65        inp: torch.Tensor,
66    ):
67        out = model(inp)
68        loss = out.sum()
69        loss.backward()
70        optimizer.step()
71        optimizer.zero_grad()
72
73    def _measure_actual_cuda_time(
74        self,
75        func: Callable,
76        args: Tuple[Any, ...],
77    ) -> float:
78        warmup_iters, actual_iters = 2, 5
79        start_event = torch.cuda.Event(enable_timing=True)
80        end_event = torch.cuda.Event(enable_timing=True)
81        for _ in range(warmup_iters):
82            func(*args)
83        start_event.record()
84        for _ in range(actual_iters):
85            func(*args)
86        end_event.record()
87        torch.cuda.synchronize()
88        measured_time = start_event.elapsed_time(end_event) / actual_iters
89        return measured_time
90
91    def _runtime_estimate(
92        self,
93        estimate_mode: str,
94        func: Callable,
95        args: Tuple[Any, ...],
96    ) -> float:
97        # Optimizer init step
98        func(*args)
99        runtime_estimator = RuntimeEstimator()
100        with runtime_estimator(estimate_mode_type=estimate_mode):
101            func(*args)
102        return runtime_estimator.total_runtime
103
104    def _init_model_and_args(
105        self,
106        model_type: str,
107        model_args: Union[ConvArgs, ModelArgs],
108        bsz: int,
109    ) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]:
110        dev = torch.cuda.current_device()
111        if model_type == "Transformer":
112            model_args = cast(ModelArgs, model_args)
113            with torch.device(dev):
114                model = Transformer(model_args)
115            optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True)
116            inp = torch.randint(
117                0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev
118            )
119        elif model_type == "CNN":
120            model_args = cast(ConvArgs, model_args)
121            with torch.device(dev):
122                model = SimpleCNN(model_args)
123            optimizer = optim.SGD(model.parameters(), lr=1e-2, foreach=True)
124            inp = torch.randn(
125                bsz, 3, model_args.image_size, model_args.image_size, device=dev
126            )
127        else:
128            raise NotImplementedError("Only Transformer and CNN is supported")
129        return (model, optimizer, inp)
130
131    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
132    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
133    def test_transformer_runtime(
134        self,
135    ):
136        """Runs a basic GPT-2 model"""
137        vocab_size = 8192
138        bsz, seq_len = 8, 1024
139        model_args = ModelArgs(
140            n_layers=4,
141            n_heads=12,
142            vocab_size=vocab_size,
143            max_seq_len=seq_len,
144            dim=768,
145            dropout_p=0.1,
146        )
147
148        args = self._init_model_and_args("Transformer", model_args, bsz)
149        actual_runtime = self._measure_actual_cuda_time(self._train_step, args)
150        with FakeTensorMode():
151            fake_args = self._init_model_and_args("Transformer", model_args, bsz)
152            benchmark_estimate = self._runtime_estimate(
153                "operator-level-benchmark", self._train_step, fake_args
154            )
155            roofline_estimate = self._runtime_estimate(
156                "operator-level-cost-model", self._train_step, fake_args
157            )
158        benchmark_accuracy = actual_runtime / benchmark_estimate
159        roofline_accuracy = actual_runtime / roofline_estimate
160        print(
161            f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}"
162            f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}"
163        )
164        self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2)
165        self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3)
166
167    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
168    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
169    def test_conv_model_runtime(
170        self,
171    ):
172        """Runs a simple CNN model"""
173        num_classes = 100
174        bsz, img_sz = 256, 128
175        model_args = ConvArgs(img_sz, num_classes)
176        args = self._init_model_and_args("CNN", model_args, bsz)
177        actual_runtime = self._measure_actual_cuda_time(self._train_step, args)
178        with FakeTensorMode():
179            fake_args = self._init_model_and_args("CNN", model_args, bsz)
180            benchmark_estimate = self._runtime_estimate(
181                "operator-level-benchmark", self._train_step, fake_args
182            )
183            roofline_estimate = self._runtime_estimate(
184                "operator-level-cost-model", self._train_step, fake_args
185            )
186        benchmark_accuracy = actual_runtime / benchmark_estimate
187        roofline_accuracy = actual_runtime / roofline_estimate
188        print(
189            f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n"
190            f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}"
191        )
192        self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2)
193        self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4)
194
195
196if __name__ == "__main__":
197    run_tests()
198