1# Owner(s): ["oncall: mobile"] 2 3import torch 4import torch.ao.nn.quantized as nnq 5import torch.nn as nn 6import torch.utils.bundled_inputs 7from torch.ao.quantization import default_qconfig, float_qparams_weight_only_qconfig 8 9# graph mode quantization based on fx 10from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx 11from torch.testing._internal.common_quantization import ( 12 LinearModelWithSubmodule, 13 NodeSpec as ns, 14 QuantizationLiteTestCase, 15) 16 17 18class TestLiteFuseFx(QuantizationLiteTestCase): 19 # Tests from: 20 # ./caffe2/test/quantization/fx/test_quantize_fx.py 21 22 def test_embedding(self): 23 class M(torch.nn.Module): 24 def __init__(self) -> None: 25 super().__init__() 26 self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) 27 28 def forward(self, indices): 29 return self.emb(indices) 30 31 model = M().eval() 32 indices = torch.randint(low=0, high=10, size=(20,)) 33 34 quantized_node = ns.call_module(nnq.Embedding) 35 configs = [ 36 (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)), 37 (None, ns.call_module(nn.Embedding)), 38 (default_qconfig, ns.call_module(nn.Embedding)), 39 ] 40 41 for qconfig, node in configs: 42 qconfig_dict = {"": qconfig} 43 m = prepare_fx( 44 model, 45 qconfig_dict, 46 example_inputs=torch.randint(low=0, high=10, size=(20,)), 47 ) 48 m = convert_fx(m) 49 self._compare_script_and_mobile(m, input=indices) 50 51 def test_conv2d(self): 52 class M(torch.nn.Module): 53 def __init__(self) -> None: 54 super().__init__() 55 self.conv1 = nn.Conv2d(1, 1, 1) 56 self.conv2 = nn.Conv2d(1, 1, 1) 57 58 def forward(self, x): 59 x = self.conv1(x) 60 x = self.conv2(x) 61 return x 62 63 m = M().eval() 64 qconfig_dict = {"": default_qconfig, "module_name": [("conv1", None)]} 65 m = prepare_fx(m, qconfig_dict, example_inputs=torch.randn(1, 1, 1, 1)) 66 data = torch.randn(1, 1, 1, 1) 67 m = convert_fx(m) 68 # first conv is quantized, second conv is not quantized 69 self._compare_script_and_mobile(m, input=data) 70 71 def test_submodule(self): 72 # test quantizing complete module, submodule and linear layer 73 configs = [ 74 {}, 75 {"module_name": [("subm", None)]}, 76 {"module_name": [("fc", None)]}, 77 ] 78 for config in configs: 79 model = LinearModelWithSubmodule().eval() 80 qconfig_dict = { 81 "": torch.ao.quantization.get_default_qconfig("qnnpack"), 82 **config, 83 } 84 model = prepare_fx( 85 model, 86 qconfig_dict, 87 example_inputs=torch.randn(5, 5), 88 ) 89 quant = convert_fx(model) 90 91 x = torch.randn(5, 5) 92 self._compare_script_and_mobile(quant, input=x) 93 94 95if __name__ == "__main__": 96 run_tests() # noqa: F821 97