# Owner(s): ["oncall: quantization"] # torch import torch from torch.testing import FileCheck from torch.testing._internal.common_quantization import QuantizationTestCase class TestFusionPasses(QuantizationTestCase): def test_quantized_add_relu_fusion(self): class MAdd(torch.nn.Module): def forward(self, x, y): a = torch.ops.quantized.add(x, y, 1.0, 0) relu_out = torch.relu(a) return relu_out A = torch.arange(-128, 130, dtype=torch.float) B = torch.arange(-128, 130, dtype=torch.float) scale = 2.0 zero_point = 127 qA = torch.quantize_per_tensor( A, scale=scale, zero_point=zero_point, dtype=torch.quint8 ) qB = torch.quantize_per_tensor( B, scale=scale, zero_point=zero_point, dtype=torch.quint8 ) # Check quantized add + relu fusion m = MAdd() scripted_m = torch.jit.script(m) ref_output = scripted_m(qA, qB) # Must inline the graph. # In this test case since we are directly calling ops # it does not matter, however if we are calling nn # modules we have to inline graph. torch._C._jit_pass_inline(scripted_m.graph) torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph) FileCheck().check_not("aten::relu").check("quantized::add_relu").run( scripted_m.graph ) output = scripted_m(qA, qB) self.assertEqual(ref_output, output) class MAddOut(torch.nn.Module): def forward(self, x, y, z): a = torch.ops.quantized.add_out(x, y, z) relu_out = torch.relu(a) return relu_out qC = torch._empty_affine_quantized( qA.shape, scale=scale, zero_point=zero_point, dtype=torch.quint8 ) # Check quantized add + relu fusion m = MAddOut() scripted_m = torch.jit.script(m) ref_output = scripted_m(qA, qB, qC) # Must inline the graph. # In this test case since we are directly calling ops # it does not matter, however if we are calling nn # modules we have to inline graph. torch._C._jit_pass_inline(scripted_m.graph) torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph) FileCheck().check_not("aten::relu").check_not("quantized::add_out").check( "quantized::add_relu_out" ).run(scripted_m.graph) output = scripted_m(qA, qB, qC) self.assertEqual(ref_output, output) class MAddScalar(torch.nn.Module): def forward(self, x, y: float): a = torch.ops.quantized.add_scalar(x, y) relu_out = torch.relu(a) return relu_out # Check quantized add + relu fusion m = MAddScalar() scripted_m = torch.jit.script(m) ref_output = scripted_m(qA, 3.0) torch._C._jit_pass_inline(scripted_m.graph) torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph) FileCheck().check_not("aten::relu").check_not("quantized::add_scalar(").check( "quantized::add_scalar_relu" ).run(scripted_m.graph) output = scripted_m(qA, 3.0) self.assertEqual(ref_output, output) class MAddScalarOut(torch.nn.Module): def forward(self, x, y: float, z): a = torch.ops.quantized.add_scalar_out(x, y, z) relu_out = torch.relu(a) return relu_out qC = torch._empty_affine_quantized( qA.shape, scale=scale, zero_point=zero_point, dtype=torch.quint8 ) m = MAddScalarOut() scripted_m = torch.jit.script(m) ref_output = scripted_m(qA, 3.0, qC) torch._C._jit_pass_inline(scripted_m.graph) torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph) FileCheck().check_not("aten::relu").check_not( "quantized::add_scalar_out" ).check("quantized::add_scalar_relu_out").run(scripted_m.graph) output = scripted_m(qA, 3.0, qC) self.assertEqual(ref_output, output)