• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Owner(s): ["oncall: mobile"]
3
4import tempfile
5
6import torch
7from torch.ao.nn.sparse.quantized.dynamic.linear import Linear
8from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK
9from torch.testing._internal.common_quantized import (
10    override_cpu_allocator_for_qnnpack,
11    override_quantized_engine,
12    qengine_is_qnnpack,
13)
14from torch.testing._internal.common_utils import TestCase
15
16
17class TestQlinearPackedParams(TestCase):
18    def qlinear_packed_params_test(self, allow_non_zero_zero_points=False):
19        # copied from https://pytorch.org/docs/stable/sparse.html#csr-tensor-operations,
20        # so row/col block indices match that example, but with blocks and
21        # scaled rows
22        weight_fp32 = torch.Tensor(
23            [
24                [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
25                [6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
26                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
27            ]
28        )
29
30        row_block_size = 1
31        col_block_size = 4
32        out_features = weight_fp32.shape[0]
33        in_features = weight_fp32.shape[1]
34
35        scales = [2.0, 6.0, 12.0]
36        zero_points = [
37            ((i + 1) if allow_non_zero_zero_points else 0) for i in range(out_features)
38        ]
39        dtype = torch.qint8
40
41        wide_weight_fp32 = torch.zeros((3, 4008))  # 4000 is tile width for Fbgemm
42        wide_weight_fp32[0][0] = 4
43        wide_weight_fp32[0][4004] = 6
44        wide_weight_fp32[1][0] = 8
45
46        per_tensor_small = (
47            torch.quantize_per_tensor(weight_fp32, scales[0], zero_points[0], dtype),
48            True,
49            [0, 1, 3, 3],
50            [2, 0, 1],
51            [
52                x + (1 if allow_non_zero_zero_points else 0)
53                for x in [1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6]
54            ],
55        )
56
57        per_channel_small = (
58            torch.quantize_per_channel(
59                weight_fp32,
60                torch.Tensor(scales),
61                torch.Tensor(zero_points).to(torch.int),
62                0,  # axis = 0
63                dtype,
64            ),
65            False,
66            [0, 1, 3, 3],
67            [2, 0, 1],
68            [
69                x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0)
70                for (i, x) in enumerate([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2])
71            ],
72        )
73
74        per_tensor_large = (
75            torch.quantize_per_tensor(
76                wide_weight_fp32,
77                scales[0],
78                zero_points[0],
79                dtype,
80            ),
81            True,
82            [0, 2, 3, 3],
83            [0, 1001, 0],
84            [
85                x + (1 if allow_non_zero_zero_points else 0)
86                for x in [2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]
87            ],
88        )
89
90        for (
91            weight,
92            is_per_tensor_quantized,
93            expected_row_block_indices,
94            expected_col_block_indices,
95            expected_weights,
96        ) in [per_tensor_small, per_channel_small, per_tensor_large]:
97            lin = Linear(
98                out_features=weight.shape[0],
99                in_features=weight.shape[1],
100                row_block_size=row_block_size,
101                col_block_size=col_block_size,
102                bias=True,
103                dtype=dtype,
104            )
105
106            bias = torch.ones(size=(weight.shape[0],))
107
108            lin.set_weight_bias(weight, bias, row_block_size, col_block_size)
109
110            serialized = lin._packed_params._packed_params.__getstate__()
111
112            (
113                _,  # version
114                bias_,
115                out_features_block_size_,
116                in_features_block_size_,
117                weight_scales_,
118                weight_zero_points_,
119                quantization_scheme_,
120                row_block_indices_,
121                col_block_indices_,
122                weights_,
123                output_channels_,
124                input_channels_,
125            ) = serialized[0]
126
127            # Test Serialization
128            self.assertEqual(bias_, bias)
129            self.assertEqual(out_features_block_size_, row_block_size)
130            self.assertEqual(in_features_block_size_, col_block_size)
131            self.assertEqual(
132                weight_scales_, [scales[0]] if is_per_tensor_quantized else scales
133            )
134            self.assertEqual(
135                weight_zero_points_,
136                [zero_points[0]] if is_per_tensor_quantized else zero_points,
137            )
138            self.assertEqual(quantization_scheme_, is_per_tensor_quantized)
139            self.assertEqual(row_block_indices_, expected_row_block_indices)
140            self.assertEqual(col_block_indices_, expected_col_block_indices)
141            self.assertEqual(
142                weights_.tolist(), [v + 128 for v in expected_weights]
143            )  # weights are serialized as +128
144            self.assertEqual(output_channels_, weight.shape[0])
145            self.assertEqual(input_channels_, weight.shape[1])
146
147            # Test Unpacking
148            (
149                weights_,
150                bias_,
151                out_features_block_size_,
152                in_features_block_size_,
153            ) = lin._weight_bias()
154            self.assertEqual(torch.dequantize(weights_), torch.dequantize(weight))
155            self.assertEqual(bias_, bias)
156            self.assertEqual(out_features_block_size_, row_block_size)
157            self.assertEqual(in_features_block_size_, col_block_size)
158
159            # Test Deserialization
160            with tempfile.TemporaryFile() as file_buff:
161                torch.save(lin, file_buff)
162                file_buff.seek(0)
163                lin2 = torch.load(file_buff)
164                self.assertEqual(lin._weight_bias(), lin2._weight_bias())
165                # Serialize -> Deserialize -> Serialize should match Serialize
166                self.assertEqual(
167                    serialized, lin2._packed_params._packed_params.__getstate__()
168                )
169
170                # Test that op output is preserved by serialize -> deserialize
171                if qengine_is_qnnpack():
172                    x = torch.rand(size=(1, weight.shape[1]))
173                    y1 = lin(x)
174                    y2 = lin2(x)
175                    self.assertEqual(y1, y2)
176
177    @skipIfNoFBGEMM
178    def test_qlinear_packed_params_fbgemm(self):
179        torch.manual_seed(0)
180        with override_quantized_engine("fbgemm"):
181            self.qlinear_packed_params_test(allow_non_zero_zero_points=False)
182
183    @skipIfNoQNNPACK
184    def test_qlinear_packed_params_qnnpack(self):
185        torch.manual_seed(0)
186        with override_quantized_engine("qnnpack"):
187            with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
188                self.qlinear_packed_params_test(allow_non_zero_zero_points=True)
189
190    def test_qlinear_packed_params_fbgemm_qnnpack_cross_compatibility(self):
191        torch.manual_seed(0)
192
193        weight_fp32 = torch.Tensor(
194            [
195                [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
196                [6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
197                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
198            ]
199        )
200
201        row_block_size = 1
202        col_block_size = 4
203        out_features = weight_fp32.shape[0]
204        in_features = weight_fp32.shape[1]
205
206        scales = [2.0, 3.0, 7.0]
207        zero_points = [0 for _ in range(out_features)]
208        dtype = torch.qint8
209
210        x = torch.rand(size=(1, weight_fp32.shape[1]))
211
212        def make_lin_get_state_weight_bias_and_save():
213            weight = torch.quantize_per_tensor(
214                weight_fp32,
215                scales[0],
216                zero_points[0],
217                dtype,
218            )
219            lin = Linear(
220                out_features=weight.shape[0],
221                in_features=weight.shape[1],
222                row_block_size=row_block_size,
223                col_block_size=col_block_size,
224                bias=True,
225                dtype=dtype,
226            )
227            bias = torch.ones(size=(weight.shape[0],))
228            lin.set_weight_bias(weight, bias, row_block_size, col_block_size)
229
230            state = lin._packed_params._packed_params.__getstate__()
231            weight_bias = lin._weight_bias()
232
233            file_buff = tempfile.TemporaryFile()
234            torch.save(lin, file_buff)
235            file_buff.seek(0)
236
237            return ((state, weight_bias), file_buff)
238
239        def load_get_state_weight_bias(f_b):
240            lin2 = torch.load(f_b)
241            state = lin2._packed_params._packed_params.__getstate__()
242            weight_bias = lin2._weight_bias()
243            f_b.close()
244            return (state, weight_bias)
245
246        def packed_params_data_with_int32_indices(data_as_state_and_weight_bias):
247            (st, weight_bias) = data_as_state_and_weight_bias
248            (s0, s1) = st
249            s0_updated = tuple(
250                [
251                    # 7 and 8 are row and col block indices respectively
252                    v if (i != 7 and i != 8) else v.to(torch.int32)
253                    for (i, v) in enumerate(list(s0))
254                ]
255            )
256            return ((s0_updated, s1), weight_bias)
257
258        # Test Fbgemm -> Qnnpack
259        with override_quantized_engine("fbgemm"):
260            (
261                packed_params_data_1a,
262                file_buff_1,
263            ) = make_lin_get_state_weight_bias_and_save()
264
265        with override_quantized_engine("qnnpack"):
266            with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
267                packed_params_data_1b = load_get_state_weight_bias(file_buff_1)
268
269        self.assertEqual(
270            packed_params_data_with_int32_indices(packed_params_data_1a),
271            packed_params_data_with_int32_indices(packed_params_data_1b),
272        )
273
274        # Test Qnnpack -> Fbgemm
275        with override_quantized_engine("qnnpack"):
276            with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
277                (
278                    packed_params_data_2a,
279                    file_buff_2,
280                ) = make_lin_get_state_weight_bias_and_save()
281
282        with override_quantized_engine("fbgemm"):
283            packed_params_data_2b = load_get_state_weight_bias(file_buff_2)
284
285        self.assertEqual(
286            packed_params_data_with_int32_indices(packed_params_data_2a),
287            packed_params_data_with_int32_indices(packed_params_data_2b),
288        )
289