• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: mobile"]
2
3import torch
4import torch.ao.nn.quantized as nnq
5import torch.nn as nn
6import torch.utils.bundled_inputs
7from torch.ao.quantization import default_qconfig, float_qparams_weight_only_qconfig
8
9# graph mode quantization based on fx
10from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
11from torch.testing._internal.common_quantization import (
12    LinearModelWithSubmodule,
13    NodeSpec as ns,
14    QuantizationLiteTestCase,
15)
16
17
18class TestLiteFuseFx(QuantizationLiteTestCase):
19    # Tests from:
20    # ./caffe2/test/quantization/fx/test_quantize_fx.py
21
22    def test_embedding(self):
23        class M(torch.nn.Module):
24            def __init__(self) -> None:
25                super().__init__()
26                self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
27
28            def forward(self, indices):
29                return self.emb(indices)
30
31        model = M().eval()
32        indices = torch.randint(low=0, high=10, size=(20,))
33
34        quantized_node = ns.call_module(nnq.Embedding)
35        configs = [
36            (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)),
37            (None, ns.call_module(nn.Embedding)),
38            (default_qconfig, ns.call_module(nn.Embedding)),
39        ]
40
41        for qconfig, node in configs:
42            qconfig_dict = {"": qconfig}
43            m = prepare_fx(
44                model,
45                qconfig_dict,
46                example_inputs=torch.randint(low=0, high=10, size=(20,)),
47            )
48            m = convert_fx(m)
49            self._compare_script_and_mobile(m, input=indices)
50
51    def test_conv2d(self):
52        class M(torch.nn.Module):
53            def __init__(self) -> None:
54                super().__init__()
55                self.conv1 = nn.Conv2d(1, 1, 1)
56                self.conv2 = nn.Conv2d(1, 1, 1)
57
58            def forward(self, x):
59                x = self.conv1(x)
60                x = self.conv2(x)
61                return x
62
63        m = M().eval()
64        qconfig_dict = {"": default_qconfig, "module_name": [("conv1", None)]}
65        m = prepare_fx(m, qconfig_dict, example_inputs=torch.randn(1, 1, 1, 1))
66        data = torch.randn(1, 1, 1, 1)
67        m = convert_fx(m)
68        # first conv is quantized, second conv is not quantized
69        self._compare_script_and_mobile(m, input=data)
70
71    def test_submodule(self):
72        # test quantizing complete module, submodule and linear layer
73        configs = [
74            {},
75            {"module_name": [("subm", None)]},
76            {"module_name": [("fc", None)]},
77        ]
78        for config in configs:
79            model = LinearModelWithSubmodule().eval()
80            qconfig_dict = {
81                "": torch.ao.quantization.get_default_qconfig("qnnpack"),
82                **config,
83            }
84            model = prepare_fx(
85                model,
86                qconfig_dict,
87                example_inputs=torch.randn(5, 5),
88            )
89            quant = convert_fx(model)
90
91            x = torch.randn(5, 5)
92            self._compare_script_and_mobile(quant, input=x)
93
94
95if __name__ == "__main__":
96    run_tests()  # noqa: F821
97