1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6from itertools import product 7 8import torch 9import torch.nn as nn 10import torch.nn.functional as F 11from torch.testing import FileCheck 12 13 14try: 15 import torchvision 16 17 HAS_TORCHVISION = True 18except ImportError: 19 HAS_TORCHVISION = False 20skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 21 22# Make the helper files in test/ importable 23pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 24sys.path.append(pytorch_test_dir) 25from torch.testing._internal.jit_utils import JitTestCase 26 27 28if __name__ == "__main__": 29 raise RuntimeError( 30 "This test file is not meant to be run directly, use:\n\n" 31 "\tpython test/test_jit.py TESTNAME\n\n" 32 "instead." 33 ) 34 35activations = [ 36 F.celu, 37 F.elu, 38 F.hardsigmoid, 39 F.hardswish, 40 F.hardtanh, 41 F.leaky_relu, 42 F.relu, 43 F.relu6, 44 F.rrelu, 45 F.selu, 46 F.silu, 47] 48 49 50class TestFunctionalToInplaceActivation(JitTestCase): 51 def test_check_no_type_promotion(self): 52 dtypes = [ 53 torch.bool, 54 torch.int8, 55 torch.int16, 56 torch.int32, 57 torch.int64, 58 torch.float32, 59 torch.float64, 60 ] 61 # restore_mutation.h contains a mapping from activation operators 62 # to whether they allow type conversion. Use this checking to 63 # guard the mapping, and if any later change breaks the assumption 64 # we need to update the mapping correspondingly. 65 for activation, dtype in product(activations, dtypes): 66 inp = torch.normal(0, 5, size=(4, 4)).to(dtype) 67 try: 68 out = activation(inp) 69 self.assertEqual(dtype, out.dtype) 70 except RuntimeError: 71 # Skip the not implemented error 72 pass 73 74 def test_functional_to_inplace_activation(self): 75 for activation in activations: 76 77 def test_basic(x): 78 y = x + 1 79 z = activation(y) 80 return z 81 82 fn = torch.jit.script(test_basic) 83 self.run_pass("inline", fn.graph) 84 self.run_pass("constant_propagation", fn.graph) 85 FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph) 86 self.run_pass("functional_to_inplace_activation", fn.graph) 87 FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph) 88 FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph) 89 inp = torch.rand([2, 2]) 90 self.assertEqual(fn(inp), test_basic(inp)) 91 92 def test_no_functional_to_inplace(self): 93 # inplace conversion should not happen because sigmoid may 94 # perform type conversion 95 def test1(): 96 y = torch.ones([2, 2]) 97 z = torch.sigmoid(y) 98 return z 99 100 fn = torch.jit.script(test1) 101 self.run_pass("functional_to_inplace_activation", fn.graph) 102 FileCheck().check_not("aten::sigmoid_").run(fn.graph) 103 104 # inplace conversion should not happen because y is alias 105 # the input x 106 def test2(x): 107 y = x[0] 108 z = torch.relu(y) 109 return z 110 111 fn = torch.jit.script(test2) 112 self.run_pass("functional_to_inplace_activation", fn.graph) 113 FileCheck().check_not("aten::relu_").run(fn.graph) 114 115 # inplace conversion should not happen because self.x is 116 # at the global scope 117 class Test3(nn.Module): 118 def __init__(self, x): 119 super().__init__() 120 self.x = x 121 122 def forward(self): 123 y = torch.relu(self.x) 124 return y 125 126 fn = torch.jit.script(Test3(torch.rand([2, 2])).eval()) 127 self.run_pass("functional_to_inplace_activation", fn.graph) 128 FileCheck().check_not("aten::relu_").run(fn.graph) 129 130 @skipIfNoTorchVision 131 def test_resnet18_correctness(self): 132 model = torchvision.models.resnet18() 133 frozen_model = torch.jit.freeze(torch.jit.script(model.eval())) 134 ( 135 N, 136 C, 137 H, 138 W, 139 ) = ( 140 10, 141 3, 142 224, 143 224, 144 ) 145 inp = torch.randn(N, C, H, W) 146 self.run_pass("functional_to_inplace_activation", frozen_model.graph) 147 self.assertEqual(model(inp), frozen_model(inp)) 148 149 150class TestInplaceToFunctionalActivation(JitTestCase): 151 def test_inplace_to_functional_activation(self): 152 for activation in activations: 153 154 def test_basic(x): 155 y = x + 1 156 activation(y, inplace=True) 157 return y 158 159 fn = torch.jit.script(test_basic) 160 self.run_pass("inline", fn.graph) 161 self.run_pass("constant_propagation", fn.graph) 162 FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph) 163 self.run_pass("inplace_to_functional_activation", fn.graph) 164 FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph) 165 FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph) 166 167 for activation in [ 168 torch.relu_, 169 torch.sigmoid_, 170 torch.tanh_, 171 ]: 172 173 def test_basic(x): 174 y = x + 1 175 activation(y) 176 return y 177 178 fn = torch.jit.script(test_basic) 179 self.run_pass("inline", fn.graph) 180 self.run_pass("constant_propagation", fn.graph) 181 FileCheck().check(f"aten::{activation.__name__}").run(fn.graph) 182 self.run_pass("inplace_to_functional_activation", fn.graph) 183 FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph) 184 FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph) 185 186 inp = torch.rand([2, 2]) 187 self.assertEqual(fn(inp), test_basic(inp)) 188 189 @skipIfNoTorchVision 190 def test_resnet18_correctness(self): 191 model = torchvision.models.resnet18() 192 frozen_model = torch.jit.freeze(torch.jit.script(model.eval())) 193 ( 194 N, 195 C, 196 H, 197 W, 198 ) = ( 199 10, 200 3, 201 224, 202 224, 203 ) 204 inp = torch.randn(N, C, H, W) 205 self.run_pass("inplace_to_functional_activation", frozen_model.graph) 206 self.assertEqual(model(inp), frozen_model(inp)) 207