• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2import copy
3import sys
4import unittest
5
6import torch
7from torch._inductor import config
8from torch._inductor.package import load_package
9from torch._inductor.test_case import TestCase
10from torch.testing._internal import common_utils
11from torch.testing._internal.common_utils import IS_FBCODE
12from torch.testing._internal.triton_utils import HAS_CUDA
13
14
15try:
16    try:
17        from .test_torchinductor import copy_tests
18    except ImportError:
19        from test_torchinductor import copy_tests
20except (unittest.SkipTest, ImportError) as e:
21    if __name__ == "__main__":
22        sys.exit(0)
23    raise
24
25
26def compile(model, example_inputs, dynamic_shapes, options, device):
27    ep = torch.export.export(
28        model,
29        example_inputs,
30        dynamic_shapes=dynamic_shapes,
31        strict=False,
32    )
33    gm = ep.module()
34    package_path = torch._inductor.aot_compile(gm, example_inputs, options=options)  # type: ignore[arg-type]
35    compiled_model = load_package(package_path, device)
36    return compiled_model
37
38
39def check_model(
40    self: TestCase,
41    model,
42    example_inputs,
43    options=None,
44    dynamic_shapes=None,
45    disable_constraint_solver=False,
46    atol=None,
47    rtol=None,
48):
49    with torch.no_grad(), config.patch(
50        {
51            "aot_inductor.package": True,
52            # TODO: "aot_inductor.force_mmap_weights": True,
53        }
54    ):
55        torch.manual_seed(0)
56        model = model.to(self.device)
57        ref_model = copy.deepcopy(model)
58        ref_inputs = copy.deepcopy(example_inputs)
59        expected = ref_model(*ref_inputs)
60
61        torch.manual_seed(0)
62        compiled_model = compile(
63            model,
64            example_inputs,
65            dynamic_shapes,
66            options,
67            self.device,
68        )
69
70        actual = compiled_model(*example_inputs)
71
72    self.assertEqual(actual, expected, atol=atol, rtol=rtol)
73
74
75class AOTInductorTestsTemplate:
76    def test_add(self):
77        class Model(torch.nn.Module):
78            def forward(self, x, y):
79                return x + y
80
81        example_inputs = (
82            torch.randn(10, 10, device=self.device),
83            torch.randn(10, 10, device=self.device),
84        )
85        self.check_model(Model(), example_inputs)
86
87    def test_linear(self):
88        class Model(torch.nn.Module):
89            def __init__(self) -> None:
90                super().__init__()
91                self.linear = torch.nn.Linear(10, 10)
92
93            def forward(self, x, y):
94                return x + self.linear(y)
95
96        example_inputs = (
97            torch.randn(10, 10, device=self.device),
98            torch.randn(10, 10, device=self.device),
99        )
100        self.check_model(Model(), example_inputs)
101
102
103common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
104
105
106@unittest.skipIf(sys.platform == "darwin" or IS_FBCODE, "No CUDA on MacOS")
107class AOTInductorTestPackagedABICompatibleCuda(TestCase):
108    device = "cuda"
109    check_model = check_model
110
111
112copy_tests(
113    AOTInductorTestsTemplate,
114    AOTInductorTestPackagedABICompatibleCuda,
115    "packaged_abi_compatible_cuda",
116)
117
118
119@unittest.skipIf(IS_FBCODE, "This is for OSS only")
120class AOTInductorTestPackagedABICompatibleCpu(TestCase):
121    device = "cpu"
122    check_model = check_model
123
124
125copy_tests(
126    AOTInductorTestsTemplate,
127    AOTInductorTestPackagedABICompatibleCpu,
128    "packaged_abi_compatible_cpu",
129)
130
131if __name__ == "__main__":
132    from torch._inductor.test_case import run_tests
133
134    # cpp_extension N/A in fbcode
135    if HAS_CUDA or sys.platform == "darwin":
136        run_tests(needs="filelock")
137