# Owner(s): ["oncall: quantization"] import copy import math import torch import torch.ao.nn.intrinsic.qat as nniqat import torch.ao.nn.qat as nnqat import torch.ao.nn.qat.dynamic as nnqatd import torch.ao.nn.quantized as nnq import torch.ao.nn.quantized.dynamic as nnqd import torch.backends.mkldnn import torch.nn as nn import torch.testing._internal.hypothesis_utils as hu from hypothesis import given, strategies as st from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d from torch.ao.quantization import ( convert, default_embedding_qat_qconfig, default_qat_qconfig, default_qconfig, default_symmetric_qnnpack_qat_qconfig, DeQuantStub, FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize, get_default_qat_qconfig, get_embedding_qat_module_mappings, get_embedding_static_quant_module_mappings, NoopObserver, prepare, prepare_qat, quantize_qat, QuantStub, ) from torch.ao.quantization.qconfig import qconfig_equals from torch.nn import BatchNorm2d, Conv2d, init, ReLU from torch.nn.modules.utils import _pair from torch.testing._internal.common_quantization import ( DeFusedEmbeddingBagLinear, ManualConvLinearQATModel, ManualConvLinearSymmQATModel, ManualDropoutQATModel, ManualEmbeddingBagLinear, ManualLinearDynamicQATModel, ManualLinearQATModel, QuantizationTestCase, QuantStubModel, test_only_eval_fn, test_only_train_fn, TwoLayerLinearModel, ) from torch.testing._internal.common_quantized import ( override_qengines, override_quantized_engine, supported_qengines, ) from torch.testing._internal.common_utils import skipIfNoXNNPACK hu.assert_deadline_disabled() from functools import reduce class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd): """ Conv-BN fusion implemented with explicit folding. Useful to verify numerical equivalency with non-folded version. """ def __init__(self, # ConvNd args in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, # BatchNormNd args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None): nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.eps = eps self.momentum = momentum self.freeze_bn = freeze_bn if self.training else True self.num_features = out_channels self.gamma = nn.Parameter(torch.empty(out_channels)) self.beta = nn.Parameter(torch.empty(out_channels)) self.affine = True self.track_running_stats = True self.running_mean = nn.Buffer(torch.zeros(out_channels)) self.running_var = nn.Buffer(torch.ones(out_channels)) self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long)) self.activation_post_process = self.qconfig.activation() self.weight_fake_quant = self.qconfig.weight() if bias: self.bias = nn.Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_bn_parameters() def reset_running_stats(self): self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_bn_parameters(self): self.reset_running_stats() init.uniform_(self.gamma) init.zeros_(self.beta) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def reset_parameters(self): super().reset_parameters() # A hack to avoid resetting on undefined parameters if hasattr(self, 'gamma'): self.reset_bn_parameters() def update_bn_stats(self): self.freeze_bn = False return self def freeze_bn_stats(self): self.freeze_bn = True return self def _forward(self, input): # exponential_average_factor is self.momentum set to # (when it is available) only so that if gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and not self.freeze_bn and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum # we use running statistics from the previous batch, so this is an # approximation of the approach mentioned in the whitepaper, but we only # need to do one convolution in this case instead of two running_std = torch.sqrt(self.running_var + self.eps) scale_factor = self.gamma / running_std scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1]) if self.bias is not None: zero_bias = torch.zeros_like(self.bias, dtype=input.dtype) else: zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype) conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias) if self.training and not self.freeze_bn: # recovering original conv to get original batch_mean and batch_var if self.bias is not None: conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1]) else: conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) batch_mean = torch.mean(conv_orig, dim=[0, 2, 3]) batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False) n = float(conv_orig.numel() / conv_orig.size()[1]) unbiased_batch_var = batch_var * (n / (n - 1)) batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps) conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \ (self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1]) self.running_mean = exponential_average_factor * batch_mean.detach() + \ (1 - exponential_average_factor) * self.running_mean self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \ (1 - exponential_average_factor) * self.running_var else: if self.bias is None: conv = conv + (self.beta - self.gamma * self.running_mean / running_std).reshape([1, -1, 1, 1]) else: conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1]) return conv def extra_repr(self): # TODO(jerryzh): extend return super().extra_repr() def forward(self, input): return self.activation_post_process(self._forward(input)) @classmethod def from_float(cls, mod, qconfig=None): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user """ assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ if not qconfig: assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' qconfig = mod.qconfig conv, bn = mod[0], mod[1] qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, bn.eps, bn.momentum, False, qconfig) qat_convbn.weight = conv.weight qat_convbn.bias = conv.bias qat_convbn.gamma = bn.weight qat_convbn.beta = bn.bias qat_convbn.running_mean = bn.running_mean qat_convbn.running_var = bn.running_var qat_convbn.num_batches_tracked = bn.num_batches_tracked return qat_convbn class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d): _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d def __init__(self, # ConvNd args in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', # BatchNorm2d args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) _ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode, eps, momentum, freeze_bn, qconfig) class TestQuantizeEagerQAT(QuantizationTestCase): def setUp(self): super().setUp() self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long), torch.randn((12, 1), dtype=torch.float)] for _ in range(2)] self.embed_data = [[torch.randint(0, 10, (12, 1))]] def test_manual(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualLinearQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) checkQuantized(model) model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn, [self.train_data]) checkQuantized(model) def test_dropout(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualDropoutQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.dropout), nnq.Dropout) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) checkQuantized(model) model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn, [self.train_data]) checkQuantized(model) def test_eval_only_fake_quant(self): r"""Using FakeQuant in evaluation only mode, this is useful for estimating accuracy loss when we quantize the network """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualLinearQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) model.eval() test_only_eval_fn(model, self.calib_data) def test_conv_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualConvLinearQATModel() model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.img_data_2d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv), nnq.Conv2d) self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.img_data_2d) self.checkScriptable(model, self.img_data_2d) self.checkNoQconfig(model) checkQuantized(model) model = ManualConvLinearQATModel() model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) checkQuantized(model) @skipIfNoXNNPACK def test_conv_linear_symm(self): r"""Same as test_conv_linear but with Symmetric quantization. Supported only with qengine=qnnpack, which uses symmetric kernels from xnnpack library.""" for qengine in supported_qengines: if qengine != 'qnnpack': continue with override_quantized_engine(qengine): model = ManualConvLinearSymmQATModel() model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.img_data_2d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv), nnq.Conv2d) self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.img_data_2d) self.checkScriptable(model, self.img_data_2d) self.checkNoQconfig(model) checkQuantized(model) model = ManualConvLinearSymmQATModel() model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) checkQuantized(model) def test_dynamic_qat_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): # Dynamic QAT without memoryless observers should fail with self.assertRaisesRegex(ValueError, "Dynamic QAT requires a memoryless observer." + "This means a MovingAverage observer with averaging constant equal to 1" ): model = ManualLinearDynamicQATModel(default_qat_qconfig) model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear}) model = ManualLinearDynamicQATModel() model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear}) self.assertEqual(type(model.fc1), nnqatd.Linear) self.assertEqual(type(model.fc2), nnqatd.Linear) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model, mapping={nnqatd.Linear: nnqd.Linear}) self.assertEqual(type(model.fc1), nnqd.Linear) self.assertEqual(type(model.fc2), nnqd.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) def test_defused_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = DeFusedEmbeddingBagLinear().train() model = prepare_qat(model, mapping=get_embedding_qat_module_mappings()) self.checkObservers(model) test_only_train_fn(model, self.embed_linear_data_train) # make sure activation_post_process is inserted after Linear. self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize) # make sure that Embedding has a noop for activation. self.assertEqual(type(model.emb.activation_post_process), NoopObserver) # make sure that FakeQuant zero_points are correct dtype self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32) model = convert(model, mapping=get_embedding_static_quant_module_mappings()) def checkQuantized(model): # make sure Embedding is now a QuantizedEmbedding self.assertEqual(type(model.emb), nn.quantized.Embedding) # make sure Linear is now a QuantizedLinear self.assertEqual(type(model.linear), nn.quantized.Linear) test_only_eval_fn(model, self.embed_data) self.checkScriptable(model, self.embed_data) self.checkNoQconfig(model) checkQuantized(model) def test_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualEmbeddingBagLinear().train() model = prepare_qat(model, mapping=get_embedding_qat_module_mappings()) self.checkObservers(model) test_only_train_fn(model, self.embed_linear_data_train) # make sure not activation_post_process is inserted for EmbeddingBag self.assertFalse(hasattr(model, "activation_post_process")) # make sure that FakeQuant zero_points are correct dtype self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32) model = convert(model, mapping=get_embedding_static_quant_module_mappings()) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, self.embed_data) self.checkScriptable(model, self.embed_data) self.checkNoQconfig(model) checkQuantized(model) model = ManualEmbeddingBagLinear() def test_train_save_load_eval(self): r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict During eval, we first call prepare_qat and conver on the model and then load the state_dict and compare results against original model """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = TwoLayerLinearModel() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) model = prepare_qat(model) fq_state_dict = model.state_dict() test_only_train_fn(model, self.train_data) model = convert(model) quant_state_dict = model.state_dict() x = torch.rand(2, 5, dtype=torch.float) ref = model(x) # Create model again for eval. Check result using quantized state_dict model = TwoLayerLinearModel() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) torch.ao.quantization.prepare_qat(model, inplace=True) new_state_dict = model.state_dict() # Check to make sure the model after prepare_qat has the same state_dict as original. self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) torch.ao.quantization.convert(model, inplace=True) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out) # Check model created using prepare has same state dict as quantized state_dict model = TwoLayerLinearModel() model.eval() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.convert(model, inplace=True) self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out) @override_qengines def test_forward_hooks_preserved(self): r"""Test QAT on preserving pre forward and post forward hooks of original model """ qengine = torch.backends.quantized.engine model = QuantStubModel() counter = { 'pre_forwards': 0, 'forwards': 0, } def fw_pre_hook(h_module, input): counter['pre_forwards'] += 1 def fw_hook(h_module, input, output): counter['forwards'] += 1 model.fc.register_forward_pre_hook(fw_pre_hook) model.fc.register_forward_hook(fw_hook) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) model = prepare_qat(model) def checkHooksIsPresent(model, before_convert=True): forward_hooks = 1 if before_convert: self.assertEqual(len(model.quant._forward_hooks.values()), 1, "Quantization observer hook has disappeared") forward_hooks = 2 self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, "Extra pre forward hooks have appeared on a layer") self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks, "Extra post forward hooks have appeared on a layer") checkHooksIsPresent(model, True) x = torch.rand(2, 5, dtype=torch.float) model(x) torch.ao.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False) def test_add_scalar_uses_input_qparams(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.quant = torch.ao.quantization.QuantStub() self.ff = torch.ao.nn.quantized.FloatFunctional() def forward(self, x): x = self.quant(x) x = self.ff.add_scalar(x, 1.0) return x m = M() m.qconfig = torch.ao.quantization.default_qconfig mp = torch.ao.quantization.prepare_qat(m) mp(torch.randn(4, 4)) mq = torch.ao.quantization.convert(mp) res = mq(torch.randn(4, 4)) eps = 1e-5 self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps) def test_mul_scalar_uses_input_qparams(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.quant = torch.ao.quantization.QuantStub() self.ff = torch.ao.nn.quantized.FloatFunctional() def forward(self, x): x = self.quant(x) x = self.ff.mul_scalar(x, 2.0) return x m = M() m.qconfig = torch.ao.quantization.default_qconfig mp = torch.ao.quantization.prepare_qat(m) mp(torch.randn(4, 4)) mq = torch.ao.quantization.convert(mp) res = mq(torch.randn(4, 4)) eps = 1e-5 self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps) @override_qengines def test_qat_embedding_bag_errors(self): default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine) # Test constructor parameters checks here. with self.assertRaisesRegex(AssertionError, "qconfig must be provided for QAT module"): nnqat.EmbeddingBag(10, 5, qconfig=None) with self.assertRaisesRegex(AssertionError, "Embedding Bag weights requires a qscheme of " + "torch.per_channel_affine_float_qparams"): nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig) # Test from_float checks here. embed = nn.Embedding(10, 5) with self.assertRaisesRegex(AssertionError, "qat.EmbeddingBag.from_float only works for EmbeddingBag"): nnqat.EmbeddingBag.from_float(embed) embed_bag = nn.EmbeddingBag(10, 5) with self.assertRaisesRegex(AssertionError, "Input float module must have qconfig defined"): nnqat.EmbeddingBag.from_float(embed_bag) embed_bag.qconfig = None with self.assertRaisesRegex(AssertionError, "Input float module must have a valid qconfig"): nnqat.EmbeddingBag.from_float(embed_bag) embed_bag.qconfig = default_qat_qconfig with self.assertRaisesRegex(AssertionError, "Embedding Bag weights requires a qscheme of " + "torch.per_channel_affine_float_qparams"): nnqat.EmbeddingBag.from_float(embed_bag) def test_embedding_qat_qconfig_equal(self): # Embedding QAT uses a NoopObserver class for activation, # and a FakeQuant for weight, make sure that qconfig comparison # functions properly for a mix of partial function and class in # qconfig. model = ManualEmbeddingBagLinear().train() model = prepare_qat(model) self.assertTrue(qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig)) class TestQuantizeEagerQATNumerics(QuantizationTestCase): def _test_activation_convert_numerics_impl(self, Act, data): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.act = Act() self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.act(x) x = self.dequant(x) return x m = M().train() m.qconfig = default_qat_qconfig m = prepare_qat(m) before_convert = m(data) m = convert(m) after_convert = m(data) self.assertEqual(before_convert, after_convert) def test_fixed_qparam_ops(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sigmoid = torch.nn.Sigmoid() self.hardsigmoid = torch.nn.Hardsigmoid() self.tanh = torch.nn.Tanh() self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.sigmoid(x) x = self.hardsigmoid(x) x = self.tanh(x) x = self.dequant(x) return x m = M().train() m.qconfig = default_qat_qconfig m = prepare_qat(m) for attr in ['sigmoid', 'hardsigmoid', 'tanh']: self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) data = torch.randn(1, 3, 2, 4) before_convert = m(data) m = convert(m) after_convert = m(data) self.assertEqual(before_convert, after_convert) # make sure activation post process is removed for attr in ['sigmoid', 'hardsigmoid', 'tanh']: # verify fake quant module is removd self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process')) # verify that hooks are removed self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) # make sure no fake quantize module is inserted for eval mode def checkNoFQModule(m): for attr in ['sigmoid', 'hardsigmoid', 'tanh']: self.assertFalse(hasattr(getattr(m, attr), "activation_post_process")) self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) m = M().eval() m.qconfig = default_qconfig m = prepare(m) checkNoFQModule(m) m = convert(m) checkNoFQModule(m) def test_leaky_relu(self): data = torch.randn(1, 3, 2, 4) self._test_activation_convert_numerics_impl(nn.LeakyReLU, data) def test_relu(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.relu = nn.ReLU() def forward(self, x): x = self.relu(x) return x m = M().train() m.qconfig = default_qconfig m = prepare_qat(m) # make sure no activation_post_process is inserted for relu self.assertFalse(hasattr(m, "activation_post_process")) m = convert(m) # make sure ReLU module is not changed self.assertTrue(type(m.relu), nn.ReLU) @given(batch_size=st.integers(2, 4), input_channels_per_group=st.sampled_from([2, 3, 4]), height=st.integers(5, 10), width=st.integers(5, 10), output_channels_per_group=st.sampled_from([2, 3]), groups=st.integers(1, 3), kernel_h=st.integers(1, 3), kernel_w=st.integers(1, 3), stride_h=st.integers(1, 2), stride_w=st.integers(1, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), dilation=st.integers(1, 1), padding_mode=st.sampled_from(['zeros', 'circular']), use_relu=st.booleans(), eps=st.sampled_from([1e-5, 1e-4, 1e-3]), momentum=st.sampled_from([0.1, 0.2, 0.3]), freeze_bn=st.booleans(), zero_gamma=st.booleans(), has_bias=st.booleans(), use_slow_fusion=st.booleans()) def test_conv_bn_relu( self, batch_size, input_channels_per_group, height, width, output_channels_per_group, groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, padding_mode, use_relu, eps, momentum, freeze_bn, zero_gamma, has_bias, use_slow_fusion, ): input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups dilation_h = dilation_w = dilation conv_op = Conv2d( input_channels, output_channels, (kernel_h, kernel_w), (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w), groups, has_bias, padding_mode ).to(dtype=torch.double) bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double) relu_op = ReLU() cls = ConvBnReLU2d if use_relu else ConvBn2d qat_op = cls( input_channels, output_channels, (kernel_h, kernel_w), (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w), groups, has_bias, padding_mode, eps, momentum, freeze_bn=True, qconfig=default_qat_qconfig ).to(dtype=torch.double) qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion # the approximate fusion will not work if bn.weight has 0 if zero_gamma and use_slow_fusion: torch.nn.init.zeros_(qat_op.bn.weight) qat_op.apply(torch.ao.quantization.disable_fake_quant) if freeze_bn: qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats) else: qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats) # align inputs and internal parameters input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True) conv_op.weight = torch.nn.Parameter(qat_op.weight.detach()) if has_bias: conv_op.bias = torch.nn.Parameter(qat_op.bias.detach()) bn_op.running_mean = qat_op.bn.running_mean.clone() bn_op.running_var = qat_op.bn.running_var.clone() bn_op.weight = torch.nn.Parameter(qat_op.bn.weight.detach()) bn_op.bias = torch.nn.Parameter(qat_op.bn.bias.detach()) def compose(functions): # functions are reversed for natural reading order return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x) if not use_relu: def relu_op(x): # noqa: F811 return x if freeze_bn: def ref_op(x): x = conv_op(x) x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \ (bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \ .reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1]) x = relu_op(x) return x else: ref_op = compose([conv_op, bn_op, relu_op]) input_clone = input.clone().detach().requires_grad_() for i in range(2): result_ref = ref_op(input) result_actual = qat_op(input_clone) self.assertEqual(result_ref, result_actual) # backward dout = torch.randn(result_ref.size(), dtype=torch.double) loss = (result_ref - dout).sum() loss.backward() input_grad_ref = input.grad.cpu() weight_grad_ref = conv_op.weight.grad.cpu() gamma_grad_ref = bn_op.weight.grad.cpu() beta_grad_ref = bn_op.bias.grad.cpu() running_mean_ref = bn_op.running_mean running_var_ref = bn_op.running_var num_batches_tracked_ref = bn_op.num_batches_tracked loss = (result_actual - dout).sum() loss.backward() input_grad_actual = input_clone.grad.cpu() weight_grad_actual = qat_op.weight.grad.cpu() gamma_grad_actual = qat_op.bn.weight.grad.cpu() beta_grad_actual = qat_op.bn.bias.grad.cpu() running_mean_actual = qat_op.bn.running_mean running_var_actual = qat_op.bn.running_var num_batches_tracked_actual = qat_op.bn.num_batches_tracked precision = 1e-10 self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0) self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0) self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0) self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0) self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0) self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0) self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0) @given(batch_size=st.integers(2, 4), input_channels_per_group=st.sampled_from([2, 3, 4]), height=st.integers(5, 10), width=st.integers(5, 10), output_channels_per_group=st.sampled_from([2, 3]), groups=st.integers(1, 3), kernel_h=st.integers(1, 3), kernel_w=st.integers(1, 3), stride_h=st.integers(1, 2), stride_w=st.integers(1, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), dilation=st.integers(1, 1), padding_mode=st.sampled_from(['zeros', 'circular']), eps=st.sampled_from([1e-5, 1e-4, 1e-3]), momentum=st.sampled_from([0.1, 0.2, 0.3]), freeze_bn=st.booleans(), bias=st.booleans()) def test_conv_bn_folded_vs_unfolded( self, batch_size, input_channels_per_group, height, width, output_channels_per_group, groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, padding_mode, eps, momentum, freeze_bn, bias, ): input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups dilation_h = dilation_w = dilation qat_op = ConvBn2d( input_channels, output_channels, (kernel_h, kernel_w), (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w), groups, bias, # bias padding_mode, eps, momentum, freeze_bn=freeze_bn, qconfig=default_qat_qconfig ).to(dtype=torch.double) qat_ref_op = _ReferenceConvBn2d( input_channels, output_channels, (kernel_h, kernel_w), (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w), groups, bias, # bias padding_mode, eps, momentum, freeze_bn=freeze_bn, qconfig=default_qat_qconfig ).to(dtype=torch.double) qat_op.apply(torch.ao.quantization.disable_fake_quant) qat_ref_op.apply(torch.ao.quantization.disable_fake_quant) # align inputs and internal parameters qat_ref_op.weight = torch.nn.Parameter(qat_op.weight.detach().clone()) qat_ref_op.running_mean = qat_op.bn.running_mean.clone() qat_ref_op.running_var = qat_op.bn.running_var.clone() qat_ref_op.gamma = torch.nn.Parameter(qat_op.bn.weight.detach().clone()) qat_ref_op.beta = torch.nn.Parameter(qat_op.bn.bias.detach().clone()) if qat_op.bias is not None: qat_ref_op.bias = torch.nn.Parameter(qat_op.bias.detach().clone()) lr = 0.01 qat_op_optim = torch.optim.SGD(qat_op.parameters(), lr=lr) qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr) for i in range(5): # make sure that calling model.train() does not override the # bn freeze setting qat_op.train() qat_ref_op.train() qat_op_optim.zero_grad() qat_ref_op_optim.zero_grad() input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True) input_clone = input.clone().detach().requires_grad_() if i > 2: qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats) qat_ref_op.freeze_bn_stats() if i > 3: qat_op.apply(torch.ao.quantization.disable_observer) qat_ref_op.apply(torch.ao.quantization.disable_observer) result_ref = qat_ref_op(input) result_actual = qat_op(input_clone) self.assertEqual(result_ref, result_actual) # backward dout = torch.randn(result_ref.size(), dtype=torch.double) + 10.0 loss = (result_ref - dout).sum() loss.backward() input_grad_ref = input.grad.cpu() weight_grad_ref = qat_ref_op.weight.grad.cpu() gamma_grad_ref = qat_ref_op.gamma.grad.cpu() beta_grad_ref = qat_ref_op.beta.grad.cpu() running_mean_ref = qat_ref_op.running_mean running_var_ref = qat_ref_op.running_var num_batches_tracked_ref = qat_ref_op.num_batches_tracked loss = (result_actual - dout).sum() loss.backward() input_grad_actual = input_clone.grad.cpu() weight_grad_actual = qat_op.weight.grad.cpu() gamma_grad_actual = qat_op.bn.weight.grad.cpu() beta_grad_actual = qat_op.bn.bias.grad.cpu() running_mean_actual = qat_op.bn.running_mean running_var_actual = qat_op.bn.running_var num_batches_tracked_actual = qat_op.bn.num_batches_tracked precision = 1e-5 self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0) self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0) self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0) self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0) self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0) self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0) self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0) qat_op_optim.step() qat_ref_op_optim.step() @override_qengines def test_linear_bn_numerics(self): qengine = torch.backends.quantized.engine m_ref = nn.Sequential( nn.Linear(4, 4), nn.BatchNorm1d(4), ) m_ref_copy = copy.deepcopy(m_ref) m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']]) qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) m_ref_copy[0].qconfig = qconfig m = nniqat.LinearBn1d.from_float(m_ref_copy[0]) # without fake_quants, fused QAT module should match fp32 module m.apply(torch.ao.quantization.disable_fake_quant) data = torch.randn(4, 4) r1 = m_ref(data) r2 = m(data) self.assertTrue(torch.allclose(r1, r2)) @skipIfNoXNNPACK @override_qengines def test_linear_bn_symm_numerics(self): qengine = torch.backends.quantized.engine if qengine != "qnnpack": return # Only qnnpack support symmetric quantization m_ref = nn.Sequential( nn.Linear(4, 4), nn.BatchNorm1d(4), ) m_ref_copy = copy.deepcopy(m_ref) m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']]) qconfig = default_symmetric_qnnpack_qat_qconfig m_ref_copy[0].qconfig = qconfig m = nniqat.LinearBn1d.from_float(m_ref_copy[0]) # without fake_quants, fused QAT module should match fp32 module m.apply(torch.ao.quantization.disable_fake_quant) data = torch.randn(4, 4) r1 = m_ref(data) r2 = m(data) self.assertTrue(torch.allclose(r1, r2)) @override_qengines def test_linear_bn_workflow(self): qengine = torch.backends.quantized.engine m = nn.Sequential( QuantStub(), nn.Linear(4, 4), nn.BatchNorm1d(4), ) data = torch.randn(4, 4) m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']]) mp = prepare_qat(m) mp(data) mq = convert(mp) self.assertTrue(type(mq[1]) == nnq.Linear) self.assertTrue(type(mq[2]) == nn.Identity) @skipIfNoXNNPACK @override_qengines def test_linear_precomputed_fake_quant(self): qengine = torch.backends.quantized.engine if qengine != "qnnpack": return # Only qnnpack support symmetric quantization m_ref = nn.Linear(4, 4) m_ref_copy = copy.deepcopy(m_ref) qconfig = default_qconfig m_ref_copy.qconfig = qconfig weight_post_process = copy.deepcopy(qconfig.weight()) activation = copy.deepcopy(qconfig.activation()) activation(torch.randn(4, 4)) m_ref_copy.activation_post_process = activation m_ref_copy = nnq.Linear.from_float(m_ref_copy) weight_post_process = qconfig.weight() weight_post_process.min_val = torch.tensor(-1) weight_post_process.max_val = torch.tensor(1) m_ref.weight_post_process = weight_post_process m_ref.activation_post_process = activation m_ref.qconfig = qconfig m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True) self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale) if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_quantization.py TESTNAME\n\n" "instead.")