• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6from itertools import product
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.testing import FileCheck
12
13
14try:
15    import torchvision
16
17    HAS_TORCHVISION = True
18except ImportError:
19    HAS_TORCHVISION = False
20skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
21
22# Make the helper files in test/ importable
23pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
24sys.path.append(pytorch_test_dir)
25from torch.testing._internal.jit_utils import JitTestCase
26
27
28if __name__ == "__main__":
29    raise RuntimeError(
30        "This test file is not meant to be run directly, use:\n\n"
31        "\tpython test/test_jit.py TESTNAME\n\n"
32        "instead."
33    )
34
35activations = [
36    F.celu,
37    F.elu,
38    F.hardsigmoid,
39    F.hardswish,
40    F.hardtanh,
41    F.leaky_relu,
42    F.relu,
43    F.relu6,
44    F.rrelu,
45    F.selu,
46    F.silu,
47]
48
49
50class TestFunctionalToInplaceActivation(JitTestCase):
51    def test_check_no_type_promotion(self):
52        dtypes = [
53            torch.bool,
54            torch.int8,
55            torch.int16,
56            torch.int32,
57            torch.int64,
58            torch.float32,
59            torch.float64,
60        ]
61        # restore_mutation.h contains a mapping from activation operators
62        # to whether they allow type conversion. Use this checking to
63        # guard the mapping, and if any later change breaks the assumption
64        # we need to update the mapping correspondingly.
65        for activation, dtype in product(activations, dtypes):
66            inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
67            try:
68                out = activation(inp)
69                self.assertEqual(dtype, out.dtype)
70            except RuntimeError:
71                # Skip the not implemented error
72                pass
73
74    def test_functional_to_inplace_activation(self):
75        for activation in activations:
76
77            def test_basic(x):
78                y = x + 1
79                z = activation(y)
80                return z
81
82            fn = torch.jit.script(test_basic)
83            self.run_pass("inline", fn.graph)
84            self.run_pass("constant_propagation", fn.graph)
85            FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
86            self.run_pass("functional_to_inplace_activation", fn.graph)
87            FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
88            FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
89            inp = torch.rand([2, 2])
90            self.assertEqual(fn(inp), test_basic(inp))
91
92    def test_no_functional_to_inplace(self):
93        # inplace conversion should not happen because sigmoid may
94        # perform type conversion
95        def test1():
96            y = torch.ones([2, 2])
97            z = torch.sigmoid(y)
98            return z
99
100        fn = torch.jit.script(test1)
101        self.run_pass("functional_to_inplace_activation", fn.graph)
102        FileCheck().check_not("aten::sigmoid_").run(fn.graph)
103
104        # inplace conversion should not happen because y is alias
105        # the input x
106        def test2(x):
107            y = x[0]
108            z = torch.relu(y)
109            return z
110
111        fn = torch.jit.script(test2)
112        self.run_pass("functional_to_inplace_activation", fn.graph)
113        FileCheck().check_not("aten::relu_").run(fn.graph)
114
115        # inplace conversion should not happen because self.x is
116        # at the global scope
117        class Test3(nn.Module):
118            def __init__(self, x):
119                super().__init__()
120                self.x = x
121
122            def forward(self):
123                y = torch.relu(self.x)
124                return y
125
126        fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
127        self.run_pass("functional_to_inplace_activation", fn.graph)
128        FileCheck().check_not("aten::relu_").run(fn.graph)
129
130    @skipIfNoTorchVision
131    def test_resnet18_correctness(self):
132        model = torchvision.models.resnet18()
133        frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
134        (
135            N,
136            C,
137            H,
138            W,
139        ) = (
140            10,
141            3,
142            224,
143            224,
144        )
145        inp = torch.randn(N, C, H, W)
146        self.run_pass("functional_to_inplace_activation", frozen_model.graph)
147        self.assertEqual(model(inp), frozen_model(inp))
148
149
150class TestInplaceToFunctionalActivation(JitTestCase):
151    def test_inplace_to_functional_activation(self):
152        for activation in activations:
153
154            def test_basic(x):
155                y = x + 1
156                activation(y, inplace=True)
157                return y
158
159            fn = torch.jit.script(test_basic)
160            self.run_pass("inline", fn.graph)
161            self.run_pass("constant_propagation", fn.graph)
162            FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
163            self.run_pass("inplace_to_functional_activation", fn.graph)
164            FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
165            FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
166
167        for activation in [
168            torch.relu_,
169            torch.sigmoid_,
170            torch.tanh_,
171        ]:
172
173            def test_basic(x):
174                y = x + 1
175                activation(y)
176                return y
177
178            fn = torch.jit.script(test_basic)
179            self.run_pass("inline", fn.graph)
180            self.run_pass("constant_propagation", fn.graph)
181            FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
182            self.run_pass("inplace_to_functional_activation", fn.graph)
183            FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
184            FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
185
186            inp = torch.rand([2, 2])
187            self.assertEqual(fn(inp), test_basic(inp))
188
189    @skipIfNoTorchVision
190    def test_resnet18_correctness(self):
191        model = torchvision.models.resnet18()
192        frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
193        (
194            N,
195            C,
196            H,
197            W,
198        ) = (
199            10,
200            3,
201            224,
202            224,
203        )
204        inp = torch.randn(N, C, H, W)
205        self.run_pass("inplace_to_functional_activation", frozen_model.graph)
206        self.assertEqual(model(inp), frozen_model(inp))
207