1#!/usr/bin/env python3 2# Owner(s): ["oncall: mobile"] 3 4import tempfile 5 6import torch 7from torch.ao.nn.sparse.quantized.dynamic.linear import Linear 8from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK 9from torch.testing._internal.common_quantized import ( 10 override_cpu_allocator_for_qnnpack, 11 override_quantized_engine, 12 qengine_is_qnnpack, 13) 14from torch.testing._internal.common_utils import TestCase 15 16 17class TestQlinearPackedParams(TestCase): 18 def qlinear_packed_params_test(self, allow_non_zero_zero_points=False): 19 # copied from https://pytorch.org/docs/stable/sparse.html#csr-tensor-operations, 20 # so row/col block indices match that example, but with blocks and 21 # scaled rows 22 weight_fp32 = torch.Tensor( 23 [ 24 [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0], 25 [6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0], 26 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 27 ] 28 ) 29 30 row_block_size = 1 31 col_block_size = 4 32 out_features = weight_fp32.shape[0] 33 in_features = weight_fp32.shape[1] 34 35 scales = [2.0, 6.0, 12.0] 36 zero_points = [ 37 ((i + 1) if allow_non_zero_zero_points else 0) for i in range(out_features) 38 ] 39 dtype = torch.qint8 40 41 wide_weight_fp32 = torch.zeros((3, 4008)) # 4000 is tile width for Fbgemm 42 wide_weight_fp32[0][0] = 4 43 wide_weight_fp32[0][4004] = 6 44 wide_weight_fp32[1][0] = 8 45 46 per_tensor_small = ( 47 torch.quantize_per_tensor(weight_fp32, scales[0], zero_points[0], dtype), 48 True, 49 [0, 1, 3, 3], 50 [2, 0, 1], 51 [ 52 x + (1 if allow_non_zero_zero_points else 0) 53 for x in [1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6] 54 ], 55 ) 56 57 per_channel_small = ( 58 torch.quantize_per_channel( 59 weight_fp32, 60 torch.Tensor(scales), 61 torch.Tensor(zero_points).to(torch.int), 62 0, # axis = 0 63 dtype, 64 ), 65 False, 66 [0, 1, 3, 3], 67 [2, 0, 1], 68 [ 69 x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0) 70 for (i, x) in enumerate([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2]) 71 ], 72 ) 73 74 per_tensor_large = ( 75 torch.quantize_per_tensor( 76 wide_weight_fp32, 77 scales[0], 78 zero_points[0], 79 dtype, 80 ), 81 True, 82 [0, 2, 3, 3], 83 [0, 1001, 0], 84 [ 85 x + (1 if allow_non_zero_zero_points else 0) 86 for x in [2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0] 87 ], 88 ) 89 90 for ( 91 weight, 92 is_per_tensor_quantized, 93 expected_row_block_indices, 94 expected_col_block_indices, 95 expected_weights, 96 ) in [per_tensor_small, per_channel_small, per_tensor_large]: 97 lin = Linear( 98 out_features=weight.shape[0], 99 in_features=weight.shape[1], 100 row_block_size=row_block_size, 101 col_block_size=col_block_size, 102 bias=True, 103 dtype=dtype, 104 ) 105 106 bias = torch.ones(size=(weight.shape[0],)) 107 108 lin.set_weight_bias(weight, bias, row_block_size, col_block_size) 109 110 serialized = lin._packed_params._packed_params.__getstate__() 111 112 ( 113 _, # version 114 bias_, 115 out_features_block_size_, 116 in_features_block_size_, 117 weight_scales_, 118 weight_zero_points_, 119 quantization_scheme_, 120 row_block_indices_, 121 col_block_indices_, 122 weights_, 123 output_channels_, 124 input_channels_, 125 ) = serialized[0] 126 127 # Test Serialization 128 self.assertEqual(bias_, bias) 129 self.assertEqual(out_features_block_size_, row_block_size) 130 self.assertEqual(in_features_block_size_, col_block_size) 131 self.assertEqual( 132 weight_scales_, [scales[0]] if is_per_tensor_quantized else scales 133 ) 134 self.assertEqual( 135 weight_zero_points_, 136 [zero_points[0]] if is_per_tensor_quantized else zero_points, 137 ) 138 self.assertEqual(quantization_scheme_, is_per_tensor_quantized) 139 self.assertEqual(row_block_indices_, expected_row_block_indices) 140 self.assertEqual(col_block_indices_, expected_col_block_indices) 141 self.assertEqual( 142 weights_.tolist(), [v + 128 for v in expected_weights] 143 ) # weights are serialized as +128 144 self.assertEqual(output_channels_, weight.shape[0]) 145 self.assertEqual(input_channels_, weight.shape[1]) 146 147 # Test Unpacking 148 ( 149 weights_, 150 bias_, 151 out_features_block_size_, 152 in_features_block_size_, 153 ) = lin._weight_bias() 154 self.assertEqual(torch.dequantize(weights_), torch.dequantize(weight)) 155 self.assertEqual(bias_, bias) 156 self.assertEqual(out_features_block_size_, row_block_size) 157 self.assertEqual(in_features_block_size_, col_block_size) 158 159 # Test Deserialization 160 with tempfile.TemporaryFile() as file_buff: 161 torch.save(lin, file_buff) 162 file_buff.seek(0) 163 lin2 = torch.load(file_buff) 164 self.assertEqual(lin._weight_bias(), lin2._weight_bias()) 165 # Serialize -> Deserialize -> Serialize should match Serialize 166 self.assertEqual( 167 serialized, lin2._packed_params._packed_params.__getstate__() 168 ) 169 170 # Test that op output is preserved by serialize -> deserialize 171 if qengine_is_qnnpack(): 172 x = torch.rand(size=(1, weight.shape[1])) 173 y1 = lin(x) 174 y2 = lin2(x) 175 self.assertEqual(y1, y2) 176 177 @skipIfNoFBGEMM 178 def test_qlinear_packed_params_fbgemm(self): 179 torch.manual_seed(0) 180 with override_quantized_engine("fbgemm"): 181 self.qlinear_packed_params_test(allow_non_zero_zero_points=False) 182 183 @skipIfNoQNNPACK 184 def test_qlinear_packed_params_qnnpack(self): 185 torch.manual_seed(0) 186 with override_quantized_engine("qnnpack"): 187 with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): 188 self.qlinear_packed_params_test(allow_non_zero_zero_points=True) 189 190 def test_qlinear_packed_params_fbgemm_qnnpack_cross_compatibility(self): 191 torch.manual_seed(0) 192 193 weight_fp32 = torch.Tensor( 194 [ 195 [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0], 196 [6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0], 197 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 198 ] 199 ) 200 201 row_block_size = 1 202 col_block_size = 4 203 out_features = weight_fp32.shape[0] 204 in_features = weight_fp32.shape[1] 205 206 scales = [2.0, 3.0, 7.0] 207 zero_points = [0 for _ in range(out_features)] 208 dtype = torch.qint8 209 210 x = torch.rand(size=(1, weight_fp32.shape[1])) 211 212 def make_lin_get_state_weight_bias_and_save(): 213 weight = torch.quantize_per_tensor( 214 weight_fp32, 215 scales[0], 216 zero_points[0], 217 dtype, 218 ) 219 lin = Linear( 220 out_features=weight.shape[0], 221 in_features=weight.shape[1], 222 row_block_size=row_block_size, 223 col_block_size=col_block_size, 224 bias=True, 225 dtype=dtype, 226 ) 227 bias = torch.ones(size=(weight.shape[0],)) 228 lin.set_weight_bias(weight, bias, row_block_size, col_block_size) 229 230 state = lin._packed_params._packed_params.__getstate__() 231 weight_bias = lin._weight_bias() 232 233 file_buff = tempfile.TemporaryFile() 234 torch.save(lin, file_buff) 235 file_buff.seek(0) 236 237 return ((state, weight_bias), file_buff) 238 239 def load_get_state_weight_bias(f_b): 240 lin2 = torch.load(f_b) 241 state = lin2._packed_params._packed_params.__getstate__() 242 weight_bias = lin2._weight_bias() 243 f_b.close() 244 return (state, weight_bias) 245 246 def packed_params_data_with_int32_indices(data_as_state_and_weight_bias): 247 (st, weight_bias) = data_as_state_and_weight_bias 248 (s0, s1) = st 249 s0_updated = tuple( 250 [ 251 # 7 and 8 are row and col block indices respectively 252 v if (i != 7 and i != 8) else v.to(torch.int32) 253 for (i, v) in enumerate(list(s0)) 254 ] 255 ) 256 return ((s0_updated, s1), weight_bias) 257 258 # Test Fbgemm -> Qnnpack 259 with override_quantized_engine("fbgemm"): 260 ( 261 packed_params_data_1a, 262 file_buff_1, 263 ) = make_lin_get_state_weight_bias_and_save() 264 265 with override_quantized_engine("qnnpack"): 266 with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): 267 packed_params_data_1b = load_get_state_weight_bias(file_buff_1) 268 269 self.assertEqual( 270 packed_params_data_with_int32_indices(packed_params_data_1a), 271 packed_params_data_with_int32_indices(packed_params_data_1b), 272 ) 273 274 # Test Qnnpack -> Fbgemm 275 with override_quantized_engine("qnnpack"): 276 with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): 277 ( 278 packed_params_data_2a, 279 file_buff_2, 280 ) = make_lin_get_state_weight_bias_and_save() 281 282 with override_quantized_engine("fbgemm"): 283 packed_params_data_2b = load_get_state_weight_bias(file_buff_2) 284 285 self.assertEqual( 286 packed_params_data_with_int32_indices(packed_params_data_2a), 287 packed_params_data_with_int32_indices(packed_params_data_2b), 288 ) 289