# Owner(s): ["oncall: quantization"] import copy import operator import unittest from typing import Any, Optional, Tuple, Type import torch from torch._export import capture_pre_autograd_graph from torch._utils_internal import capture_pre_autograd_graph_using_training_ir from torch.ao.quantization import ( default_fake_quant, FusedMovingAvgObsFakeQuantize, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, QConfigMapping, ) from torch.ao.quantization.backend_config import get_qnnpack_backend_config from torch.ao.quantization.qconfig import ( default_per_channel_symmetric_qnnpack_qat_qconfig, default_symmetric_qnnpack_qat_qconfig, ) from torch.ao.quantization.quantize_fx import prepare_qat_fx from torch.ao.quantization.quantize_pt2e import ( _convert_to_reference_decomposed_fx, convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) from torch.ao.quantization.quantizer import ( DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, Quantizer, ) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, skip_if_no_torchvision, skipIfNoQNNPACK, ) from torch.testing._internal.common_quantized import override_quantized_engine class PT2EQATTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2E QAT with some helper methods. """ class _BaseConvBnModel(torch.nn.Module): def __init__( self, conv_class: Type[torch.nn.Module], bn_class: Type[torch.nn.Module], has_conv_bias: bool, has_bn: bool, has_relu: bool, **conv_kwargs, ): super().__init__() conv_kwargs.setdefault("in_channels", 3) conv_kwargs.setdefault("out_channels", 3) conv_kwargs.setdefault("kernel_size", 3) conv_kwargs.setdefault("bias", has_conv_bias) self.conv = conv_class(**conv_kwargs) self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None self.relu = torch.nn.ReLU() if has_relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x def _get_conv_bn_model( self, has_conv_bias: bool = True, has_bn: bool = True, has_relu: bool = False, transpose: bool = False, **conv_kwargs, ): """ Return an instance of a simple test model containing the conv[-bn][-relu] pattern. By default, this returns a conv-bn model with conv bias. """ return self._BaseConvBnModel( self.conv_transpose_class if transpose else self.conv_class, self.bn_class, has_conv_bias, has_bn, has_relu, **conv_kwargs, ) def _verify_symmetric_xnnpack_qat_numerics( self, model: torch.nn.Module, example_inputs: Tuple[Any, ...], ): self._verify_symmetric_xnnpack_qat_numerics_helper( model, example_inputs, is_per_channel=True, ) self._verify_symmetric_xnnpack_qat_numerics_helper( model, example_inputs, is_per_channel=False, ) def _verify_symmetric_xnnpack_qat_numerics_helper( self, model: torch.nn.Module, example_inputs: Tuple[Any, ...], is_per_channel: bool, verify_convert: bool = True, ): """ Helper method to verify that the QAT numerics for PT2E quantization match those of FX graph mode quantization for symmetric qnnpack. """ # resetting dynamo cache torch._dynamo.reset() MANUAL_SEED = 100 # PT2 export model_pt2e = copy.deepcopy(model) quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config( is_per_channel=is_per_channel, is_qat=True ) ) model_pt2e = capture_pre_autograd_graph( model_pt2e, example_inputs, ) model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer) torch.manual_seed(MANUAL_SEED) after_prepare_result_pt2e = model_pt2e(*example_inputs) model_fx = copy.deepcopy(model) if is_per_channel: default_qconfig = default_per_channel_symmetric_qnnpack_qat_qconfig else: default_qconfig = default_symmetric_qnnpack_qat_qconfig qconfig_mapping = QConfigMapping().set_global(default_qconfig) backend_config = get_qnnpack_backend_config() model_fx = prepare_qat_fx( model_fx, qconfig_mapping, example_inputs, backend_config=backend_config ) torch.manual_seed(MANUAL_SEED) after_prepare_result_fx = model_fx(*example_inputs) # Verify that numerics match self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx) if verify_convert: # We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e torch.ao.quantization.move_exported_model_to_eval(model_pt2e) model_pt2e = convert_pt2e(model_pt2e) quant_result_pt2e = model_pt2e(*example_inputs) model_fx.eval() model_fx = _convert_to_reference_decomposed_fx( model_fx, backend_config=backend_config, ) quant_result_fx = model_fx(*example_inputs) self.assertEqual(quant_result_pt2e, quant_result_fx) def _verify_symmetric_xnnpack_qat_graph( self, m: torch.fx.GraphModule, example_inputs: Tuple[Any, ...], has_relu: bool, has_bias: bool = True, is_cuda: bool = False, expected_conv_literal_args: Optional[Tuple[Any, ...]] = None, # TODO: set this to true by default verify_convert: bool = False, ): self._verify_symmetric_xnnpack_qat_graph_helper( m, example_inputs, is_per_channel=True, has_relu=has_relu, has_bias=has_bias, is_cuda=is_cuda, expected_conv_literal_args=expected_conv_literal_args, verify_convert=verify_convert, ) self._verify_symmetric_xnnpack_qat_graph_helper( m, example_inputs, is_per_channel=False, has_relu=has_relu, has_bias=has_bias, is_cuda=is_cuda, expected_conv_literal_args=expected_conv_literal_args, verify_convert=verify_convert, ) def _verify_symmetric_xnnpack_qat_graph_helper( self, m: torch.fx.GraphModule, example_inputs: Tuple[Any, ...], is_per_channel: bool, has_relu: bool, has_bias: bool = True, is_cuda: bool = False, expected_conv_literal_args: Optional[Tuple[Any, ...]] = None, verify_convert: bool = False, ): """ Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern with fake quantizes inserted into the correct places. # TODO: also verify that metadata is copied over to the new nodes. """ m = copy.deepcopy(m) quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config(is_per_channel, is_qat=True) ) m = capture_pre_autograd_graph( m, example_inputs, ) m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) # Verify: getitem output activation fake quantize output_node = list(m.graph.nodes)[-1] output_fq_node = output_node.args[0][0] self.assertTrue(output_fq_node.target.startswith("activation_post_process_")) output_fq_mod = getattr(m, output_fq_node.target) self.assertEqual(type(output_fq_mod), FusedMovingAvgObsFakeQuantize) self.assertEqual( type(output_fq_mod.activation_post_process), MovingAverageMinMaxObserver ) self.assertEqual(output_fq_mod.dtype, torch.int8) self.assertEqual(output_fq_mod.quant_min, -128) self.assertEqual(output_fq_mod.quant_max, 127) # Verify: getitem(bn, 0) or relu(getitem(bn, 0)) if has_relu: relu_node = output_fq_node.args[0] getitem_node = relu_node.args[0] self.assertEqual(relu_node.target, torch.ops.aten.relu.default) else: relu_node = None getitem_node = output_fq_node.args[0] is_training_ir_flag = capture_pre_autograd_graph_using_training_ir() if is_training_ir_flag: # The relu node takes in the output of bn. # See NOTE [training ir has no getitem for bn node]. bn_node = getitem_node self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default) else: # TODO: This branch is going through a deprecated branch and should be deleted soon, # after capture_pre_autograd_graph fully migrate to training IR # T199018392 self.assertEqual(getitem_node.target, operator.getitem) bn_node = getitem_node.args[0] expected_bn_op = None if is_cuda: if torch.version.cuda is not None: expected_bn_op = torch.ops.aten.cudnn_batch_norm.default elif torch.version.hip is not None: expected_bn_op = torch.ops.aten.miopen_batch_norm.default else: expected_bn_op = torch.ops.aten._native_batch_norm_legit.default self.assertEqual(bn_node.target, expected_bn_op) # Verify: conv / scale_factor.reshape [+ bias.reshape] if has_bias: add_bias_node = bn_node.args[0] (div_scale_factor_node, bias_reshape_node) = add_bias_node.args self.assertEqual(add_bias_node.target, torch.ops.aten.add.Tensor) self.assertEqual(bias_reshape_node.target, torch.ops.aten.reshape.default) else: div_scale_factor_node = bn_node.args[0] (conv_node, scale_factor_reshape_node) = div_scale_factor_node.args conv_op = conv_node.target self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor) self.assertTrue(_is_conv_node(conv_node)) self.assertEqual( scale_factor_reshape_node.target, torch.ops.aten.reshape.default ) # Verify: conv literal args if expected_conv_literal_args is not None: assert ( len(expected_conv_literal_args) == 6 ), "wrong num conv args, bad test setup" for i in range(6): if i + 3 < len(conv_node.args): self.assertEqual( conv_node.args[i + 3], expected_conv_literal_args[i] ) # Verify: conv input activation fake quantize conv_input_fq_node = conv_node.args[0] conv_input_node = conv_input_fq_node.args[0] self.assertTrue( conv_input_fq_node.target.startswith("activation_post_process_") ) conv_input_fq_mod = getattr(m, conv_input_fq_node.target) self.assertEqual(type(conv_input_fq_mod), FusedMovingAvgObsFakeQuantize) self.assertEqual( type(conv_input_fq_mod.activation_post_process), MovingAverageMinMaxObserver ) self.assertEqual(conv_input_fq_mod.dtype, torch.int8) self.assertEqual(conv_input_fq_mod.quant_min, -128) self.assertEqual(conv_input_fq_mod.quant_max, 127) self.assertTrue(conv_input_node.op, "placeholder") # Verify: conv weight fake quantize conv_weight_fq_node = conv_node.args[1] self.assertTrue( conv_weight_fq_node.target.startswith("activation_post_process_") ) conv_weight_fq_mod = getattr(m, conv_weight_fq_node.target) if is_per_channel: expected_weight_observer_type = MovingAveragePerChannelMinMaxObserver else: expected_weight_observer_type = MovingAverageMinMaxObserver self.assertEqual(type(conv_weight_fq_mod), FusedMovingAvgObsFakeQuantize) self.assertEqual( type(conv_weight_fq_mod.activation_post_process), expected_weight_observer_type, ) self.assertEqual(conv_weight_fq_mod.dtype, torch.int8) self.assertEqual(conv_weight_fq_mod.quant_min, -127) self.assertEqual(conv_weight_fq_mod.quant_max, 127) # Verify: conv(fq(input), fq(weight * scale_factor.reshape), zero_bias) zero_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None mul_weight_scale_factor_node = conv_weight_fq_node.args[0] ( conv_weight_fq_node, scale_factor_reshape_node, ) = mul_weight_scale_factor_node.args if has_bias: self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default) else: self.assertTrue(zero_bias_node is None) self.assertEqual(mul_weight_scale_factor_node.target, torch.ops.aten.mul.Tensor) self.assertEqual( scale_factor_reshape_node.target, torch.ops.aten.reshape.default ) # Verify: scale_factor = bn_weight / sqrt(bn_running_var + eps) scale_factor_node = scale_factor_reshape_node.args[0] (bn_weight_node, sqrt_node) = scale_factor_node.args bn_running_var_add_node = sqrt_node.args[0] (bn_running_var_node, eps) = bn_running_var_add_node.args self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor) if is_training_ir_flag: self.assertTrue("bn.weight" in bn_weight_node.target) self.assertTrue("bn.running_var" in bn_running_var_node.target) else: self.assertTrue("bn_weight" in bn_weight_node.target) self.assertTrue("bn_running_var" in bn_running_var_node.target) self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default) self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor) self.assertEqual(eps, 1e-5) # Optionally check the converted graph if verify_convert: m = convert_pt2e(m) m(*example_inputs) if is_per_channel: conv_weight_dq_op = ( torch.ops.quantized_decomposed.dequantize_per_channel.default ) node_occurrence = { ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default ): 2, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 2, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_channel.default ): 1, } else: conv_weight_dq_op = ( torch.ops.quantized_decomposed.dequantize_per_tensor.default ) node_occurrence = { ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default ): 2, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 3, } node_list = [ ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default ), ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ), ns.call_function(conv_weight_dq_op), ns.call_function(conv_op), ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default ), ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ), ] self.checkGraphModuleNodes( m, expected_node_list=node_list, expected_node_occurrence=node_occurrence, ) class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): """ Base TestCase to be used for all conv-bn[-relu] fusion patterns. """ # TODO: how can we avoid adding every new test to dynamo/expected_test_failures? # Otherwise it fails with the following error: # torch._dynamo.exc.InternalTorchDynamoError: # 'QuantizationConfig' object has no attribute '__bool__' def setUp(self): # NB: Skip the test if this is a base class, this is to handle the test # discovery logic in buck which finds and runs all tests here including # the base class which we don't want to run if self.id() and "_Base" in self.id(): self.skipTest("Skipping test running from base class") def test_qat_conv_no_bias(self): m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True) m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False) self._verify_symmetric_xnnpack_qat_numerics(m1, self.example_inputs) self._verify_symmetric_xnnpack_qat_numerics(m2, self.example_inputs) def test_qat_conv_bn_fusion(self): m = self._get_conv_bn_model() self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_qat_conv_bn_fusion_cuda(self): m = self._get_conv_bn_model().cuda() example_inputs = (self.example_inputs[0].cuda(),) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs, has_relu=False, is_cuda=True, ) self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) def test_qat_conv_bn_fusion_literal_args(self): class M(torch.nn.Module): def __init__(self, conv_class, bn_class): super().__init__() self.conv = conv_class(3, 3, 3, stride=2, padding=4) self.bn = bn_class(3) def forward(self, x): x = self.conv(x) x = self.bn(x) return x assert self.dim in [1, 2] if self.dim == 1: # stride, padding, dilation, transposed, output_padding, groups conv_args = ((2,), (4,), (1,), False, (0,), 1) example_inputs = (torch.randn(1, 3, 5),) else: # stride, padding, dilation, transposed, output_padding, groups conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1) example_inputs = (torch.randn(1, 3, 5, 5),) m = M(self.conv_class, self.bn_class) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs, has_relu=False, expected_conv_literal_args=conv_args, ) self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) def test_qat_conv_bn_fusion_no_conv_bias(self): class M2(torch.nn.Module): """ Mixed conv + BN with and without conv bias. """ def __init__(self, conv_class, bn_class): super().__init__() self.conv1 = conv_class(3, 3, 3, bias=False) self.bn1 = bn_class(3) self.conv2 = conv_class(3, 3, 3, bias=True) self.bn2 = bn_class(3) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.conv2(x) x = self.bn2(x) return x m1 = self._get_conv_bn_model(has_conv_bias=False) m2 = M2(self.conv_class, self.bn_class) assert self.dim in [1, 2] if self.dim == 1: example_inputs = (torch.randn(3, 3, 5),) else: example_inputs = (torch.randn(3, 3, 5, 5),) self._verify_symmetric_xnnpack_qat_graph( m1, example_inputs, has_relu=False, has_bias=False, ) self._verify_symmetric_xnnpack_qat_numerics(m1, example_inputs) self._verify_symmetric_xnnpack_qat_numerics(m2, example_inputs) def test_qat_conv_bn_relu_fusion(self): m = self._get_conv_bn_model(has_relu=True) self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_qat_conv_bn_relu_fusion_cuda(self): m = self._get_conv_bn_model(has_relu=True).cuda() example_inputs = (self.example_inputs[0].cuda(),) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs, has_relu=True, is_cuda=True, ) self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) def test_qat_conv_bn_relu_fusion_no_conv_bias(self): m = self._get_conv_bn_model(has_conv_bias=False, has_relu=True) self._verify_symmetric_xnnpack_qat_graph( m, self.example_inputs, has_relu=True, has_bias=False, ) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) def test_qat_inplace_add_relu(self): class M(torch.nn.Module): def __init__(self, conv_class): super().__init__() self.conv = conv_class(1, 1, 1) self.relu = torch.nn.ReLU(inplace=True) def forward(self, x): x0 = x x = self.conv(x) x += x0 x = self.relu(x) return x assert self.dim in [1, 2] if self.dim == 1: example_inputs = (torch.randn(1, 1, 3),) else: example_inputs = (torch.randn(1, 1, 3, 3),) m = M(self.conv_class) self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self): """ Test the case where the placeholder node for the [conv - bn - getitem] pattern is also a getitem node: some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem We want the metadata to be copied from the `conv_bn_getitem` node, not from the `unrelated_getitem` node, which is not part of the conv-bn pattern but is returned as part of the match anyway (as a placeholder). """ class M(torch.nn.Module): def __init__(self, conv_class, bn_class): super().__init__() self.bn1 = bn_class(3) self.conv = conv_class(3, 3, 3) self.bn2 = bn_class(3) def forward(self, x): x = self.bn1(x) x = self.conv(x) x = self.bn2(x) return x def _get_getitem_nodes(m: torch.fx.GraphModule): """ Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph. """ unrelated_getitem_node, conv_bn_getitem_node = None, None for node in m.graph.nodes: if ( node.target != operator.getitem or node.args[0].target != torch.ops.aten._native_batch_norm_legit.default ): continue if node.args[0].args[0].op == "placeholder": unrelated_getitem_node = node else: conv_bn_getitem_node = node assert ( unrelated_getitem_node is not None ), "did not find unrelated getitem node, bad test setup" assert ( conv_bn_getitem_node is not None ), "did not find conv bn getitem node, bad test setup" return (unrelated_getitem_node, conv_bn_getitem_node) # Program capture m = M(self.conv_class, self.bn_class) m = capture_pre_autograd_graph(m, self.example_inputs) m.graph.eliminate_dead_code() m.recompile() (_, original_conv_bn_getitem_node) = _get_getitem_nodes(m) # Prepare QAT quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config(is_per_channel=False, is_qat=True) ) m = prepare_qat_pt2e(m, quantizer) (unrelated_getitem_node, conv_bn_getitem_node) = _get_getitem_nodes(m) # Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem` original_conv_bn_getitem_meta = original_conv_bn_getitem_node.meta[ "quantization_annotation" ] conv_bn_getitem_meta = conv_bn_getitem_node.meta["quantization_annotation"] self.assertEqual(conv_bn_getitem_meta, original_conv_bn_getitem_meta) self.assertTrue("quantization_annotation" not in unrelated_getitem_node.meta) def test_qat_update_shared_qspec(self): """ Test the case where nodes used in SharedQuantizationSpec were replaced during QAT subgraph rewriting. """ class M(torch.nn.Module): def __init__(self, conv_class, bn_class): super().__init__() self.conv = conv_class(3, 3, 3) self.bn = bn_class(3) self.hardtanh = torch.nn.Hardtanh() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.hardtanh(x) return x m = M(self.conv_class, self.bn_class) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) def test_qat_preserve_source_fn_stack(self): """ Test whether `source_fn_stack` is preserved after QAT fusion. """ class M(torch.nn.Module): def __init__(self, conv_class, bn_class, backbone): super().__init__() self.conv = conv_class(5, 3, 3) self.bn = bn_class(3) self.relu = torch.nn.ReLU() self.backbone = backbone def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.backbone(x) return x assert self.dim in [1, 2] if self.dim == 1: example_inputs = (torch.randn(1, 5, 10),) else: example_inputs = (torch.randn(1, 5, 10, 10),) # QAT prepare + convert backbone = self._get_conv_bn_model(has_relu=True) m = M(self.conv_class, self.bn_class, backbone) quantizer = XNNPACKQuantizer() quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) m = capture_pre_autograd_graph(m, example_inputs) m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) # Extract the conv and relu nodes (bn was folded into conv) first_conv, first_relu, second_conv, second_relu = None, None, None, None for n in m.graph.nodes: if n.target == torch.ops.aten.relu.default: if first_relu is None: assert first_conv is None, "bad test setup" first_relu = n first_conv = n.args[0] else: assert second_conv is None, "bad test setup" second_relu = n second_conv = n.args[0] # Extract the conv weight and bias nodes def get_conv_weight_and_bias(conv_node: torch.fx.Node): weight_dq_node = conv_node.args[1] qweight_node = weight_dq_node.args[0] bias_node = conv_node.args[2] assert isinstance(qweight_node, torch.fx.Node) assert isinstance(bias_node, torch.fx.Node) return (qweight_node, bias_node) first_conv_qweight, first_conv_bias = get_conv_weight_and_bias(first_conv) second_conv_qweight, second_conv_bias = get_conv_weight_and_bias(second_conv) # Assert that each set of conv, conv weight, and conv bias are in the same partition def get_source_fn(node: torch.fx.Node): # E.g. [('l__self___backbone1_conv', )] return node.meta["source_fn_stack"][0][0] # we don't preserve this is quantized weight currently since it's folded # but user can attach "quantization_tag" to the node and it will be preserved # self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_qweight)) # self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_qweight)) self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_bias)) self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_bias)) # Assert that different sets of convs and relus have different partitions self.assertNotEqual(get_source_fn(first_conv), get_source_fn(first_relu)) self.assertNotEqual(get_source_fn(first_conv), get_source_fn(second_conv)) self.assertNotEqual(get_source_fn(second_conv), get_source_fn(second_relu)) self.assertNotEqual(get_source_fn(first_relu), get_source_fn(second_relu)) # Assert that "backbone" exists only in the second set of conv and relu's partition self.assertTrue("backbone" not in get_source_fn(first_conv)) self.assertTrue("backbone" not in get_source_fn(first_relu)) self.assertTrue("backbone" in get_source_fn(second_conv)) self.assertTrue("backbone" in get_source_fn(second_relu)) def test_qat_conv_bn_bias_derived_qspec(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs m = capture_pre_autograd_graph(m, example_inputs) quantizer = ConvBnDerivedBiasQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) m(*example_inputs) # Assert that both weight and bias are quantized (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) weight_dq = conv_node.args[1] bias_dq = conv_node.args[2] self.assertEqual( weight_dq.target, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ) self.assertEqual( bias_dq.target, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ) weight_getattr = weight_dq.args[0] bias_getattr = bias_dq.args[0] self.assertEqual( weight_getattr.op, "get_attr", ) self.assertEqual( bias_getattr.op, "get_attr", ) # Assert that bias scale = weight scale * input scale input_dq = conv_node.args[0] input_scale = input_dq.args[1] bias_scale = bias_dq.args[1] weight_scale = weight_dq.args[1] self.assertEqual(bias_scale, input_scale * weight_scale) # Assert that args for the bias' quantize and dequantize ops # are copied correctly after subgraph rewriting (bias_qmin, bias_qmax, bias_dtype) = bias_dq.args[3:] self.assertEqual(bias_qmin, -(2**31)) self.assertEqual(bias_qmax, 2**31 - 1) self.assertEqual(bias_dtype, torch.int32) def test_qat_per_channel_weight_custom_dtype(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs m = capture_pre_autograd_graph(m, example_inputs) quantizer = ConvBnInt32WeightQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) m(*example_inputs) # Assert that conv weight is quantized per channel (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) weight_dq = conv_node.args[1] self.assertEqual( weight_dq.target, torch.ops.quantized_decomposed.dequantize_per_channel.default, ) weight_getattr = weight_dq.args[0] self.assertEqual( weight_getattr.op, "get_attr", ) # Assert that args for the weight's dequantize ops # are copied correctly after subgraph rewriting (dq_axis, dq_qmin, dq_qmax, dq_dtype) = weight_dq.args[3:] self.assertEqual(dq_axis, 0) self.assertEqual(dq_qmin, 0) self.assertEqual(dq_qmax, 2**31 - 1) self.assertEqual(dq_dtype, torch.int32) def _do_test_qat_conv_transpose_bn(self, has_relu: bool): # Use different in/out channel sizes to test if conv weight is # properly transposed in QAT pattern m = self._get_conv_bn_model( has_relu=has_relu, transpose=True, in_channels=3, out_channels=5, kernel_size=3, ) self._verify_symmetric_xnnpack_qat_graph( m, self.example_inputs, has_relu=has_relu, verify_convert=True, ) def test_qat_conv_transpose_bn(self): self._do_test_qat_conv_transpose_bn(has_relu=False) def test_qat_conv_transpose_bn_relu(self): self._do_test_qat_conv_transpose_bn(has_relu=True) def test_qat_conv_bn_per_channel_weight_bias(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs m = capture_pre_autograd_graph(m, example_inputs) quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True) m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) m(*example_inputs) # Expected graph: # x -> q_tensor -> dq_tensor -> conv -> q_tensor -> dq_tensor -> output # weight -> q_channel -> dq_channel / # bias -> q_channel -> dq_channel / (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) conv_op = conv_node.target conv_weight_dq_op = ( torch.ops.quantized_decomposed.dequantize_per_channel.default ) node_occurrence = { ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default ): 2, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 2, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_channel.default ): 2, } node_list = [ ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default ), ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ), ns.call_function(conv_weight_dq_op), ns.call_function(conv_weight_dq_op), ns.call_function(conv_op), ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default ), ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ), ] self.checkGraphModuleNodes( m, expected_node_list=node_list, expected_node_occurrence=node_occurrence, ) def test_fold_bn_erases_bn_node(self): """ Ensure the BN node is erased from the graph after folding it into conv in `convert_pt2e` even in train mode. """ m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) m = capture_pre_autograd_graph(m, self.example_inputs) quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config(is_per_channel=False, is_qat=True), ) m = prepare_qat_pt2e(m, quantizer) m = convert_pt2e(m) (conv_node, bn_node, _) = _get_conv_bn_getitem_nodes(m) self.assertTrue(conv_node is not None) self.assertTrue(bn_node is None) @skipIfNoQNNPACK class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): dim = 1 example_inputs = (torch.randn(1, 3, 5),) conv_class = torch.nn.Conv1d conv_transpose_class = torch.nn.ConvTranspose1d bn_class = torch.nn.BatchNorm1d @skipIfNoQNNPACK class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base): dim = 2 example_inputs = (torch.randn(1, 3, 5, 5),) conv_class = torch.nn.Conv2d conv_transpose_class = torch.nn.ConvTranspose2d bn_class = torch.nn.BatchNorm2d def _is_conv_node(n: torch.fx.Node): return n.op == "call_function" and n.target in [ torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.conv_transpose1d, torch.ops.aten.conv_transpose1d.default, torch.ops.aten.conv_transpose2d, torch.ops.aten.conv_transpose2d.input, ] def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule): """ Return a 3-tuple of (conv, bn, getitem) nodes from the graph. """ model.graph.eliminate_dead_code() model.recompile() conv_node = None bn_node = None getitem_node = None for n in model.graph.nodes: if _is_conv_node(n): conv_node = n if n.target in ( torch.ops.aten._native_batch_norm_legit.default, torch.ops.aten.batch_norm.default, ): bn_node = n if n.target == operator.getitem: getitem_node = n assert conv_node is not None, "bad test setup" return (conv_node, bn_node, getitem_node) class ConvBnInt32WeightQuantizer(Quantizer): """ Dummy quantizer that annotates conv bn in such a way that the weights are quantized per channel to int32. """ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model) act_qspec = QuantizationSpec( dtype=torch.uint8, quant_min=0, quant_max=255, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=default_fake_quant, ) weight_qspec = QuantizationSpec( dtype=torch.int32, quant_min=0, quant_max=2**31 - 1, qscheme=torch.per_channel_affine, observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( observer=MovingAveragePerChannelMinMaxObserver, ), ) conv_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ conv_node.args[0]: act_qspec, conv_node.args[1]: weight_qspec, }, _annotated=True, ) if getitem_node is not None: # TODO: This branch is going through a deprecated branch and should be deleted soon, # after capture_pre_autograd_graph fully migrate to training IR # T199018392 getitem_node.meta["quantization_annotation"] = QuantizationAnnotation( output_qspec=act_qspec, _annotated=True, ) else: # See NOTE [training ir has no getitem for bn node]. assert capture_pre_autograd_graph_using_training_ir() bn_node.meta["quantization_annotation"] = QuantizationAnnotation( output_qspec=act_qspec, _annotated=True, ) return model def validate(self, model: torch.fx.GraphModule): pass class ConvBnDerivedBiasQuantizer(Quantizer): """ Dummy quantizer that annotates conv bn in such a way that the bias qparams are derived from the conv input activation and weight qparams. """ def __init__(self, is_per_channel: bool = False): super().__init__() self.is_per_channel = is_per_channel def _derive_bias_qparams_from_act_and_weight_qparams(self, obs_or_fqs): act_scale, _ = obs_or_fqs[0].calculate_qparams() weight_scale, _ = obs_or_fqs[1].calculate_qparams() if self.is_per_channel: bias_scale = act_scale * weight_scale bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) else: bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32) bias_zero_point = torch.tensor([0], dtype=torch.int32) return bias_scale, bias_zero_point def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if self.is_per_channel: weight_qscheme = torch.per_channel_symmetric weight_fq = FusedMovingAvgObsFakeQuantize.with_args( observer=MovingAveragePerChannelMinMaxObserver, ) else: weight_qscheme = torch.per_tensor_affine weight_fq = default_fake_quant conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model) act_qspec = QuantizationSpec( dtype=torch.uint8, quant_min=0, quant_max=255, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=default_fake_quant, ) weight_qspec = QuantizationSpec( dtype=torch.uint8, quant_min=0, quant_max=255, qscheme=weight_qscheme, observer_or_fake_quant_ctr=weight_fq, ) bias_qspec = DerivedQuantizationSpec( derived_from=[ (conv_node.args[0], conv_node), (conv_node.args[1], conv_node), ], derive_qparams_fn=self._derive_bias_qparams_from_act_and_weight_qparams, dtype=torch.int32, quant_min=-(2**31), quant_max=2**31 - 1, qscheme=weight_qscheme, ch_axis=0 if self.is_per_channel else None, ) conv_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ conv_node.args[0]: act_qspec, conv_node.args[1]: weight_qspec, conv_node.args[2]: bias_qspec, }, _annotated=True, ) if getitem_node is not None: # TODO: This branch is going through a deprecated branch and should be deleted soon, # after capture_pre_autograd_graph fully migrate to training IR # T199018392 getitem_node.meta["quantization_annotation"] = QuantizationAnnotation( output_qspec=act_qspec, _annotated=True, ) else: # NOTE [training ir has no getitem for bn node]. # getitem is None when we use the training IR. It outputs # aten.batch_norm.default, which do not need any getitem node. # In this case, we need to annotate on the batch norm node. # geteitem node should only be None if we are using training IR. assert capture_pre_autograd_graph_using_training_ir() bn_node.meta["quantization_annotation"] = QuantizationAnnotation( output_qspec=act_qspec, _annotated=True, ) return model def validate(self, model: torch.fx.GraphModule): pass @skipIfNoQNNPACK class TestQuantizePT2EQATModels(PT2EQATTestCase): @skip_if_no_torchvision @skipIfNoQNNPACK def test_qat_resnet18(self): import torchvision with override_quantized_engine("qnnpack"): example_inputs = (torch.randn(1, 3, 224, 224),) m = torchvision.models.resnet18() self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) @skip_if_no_torchvision @skipIfNoQNNPACK def test_qat_mobilenet_v2(self): import torchvision with override_quantized_engine("qnnpack"): example_inputs = (torch.randn(1, 3, 224, 224),) m = torchvision.models.mobilenet_v2() self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) class TestQuantizeMixQATAndPTQ(QuantizationTestCase): class TwoLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(16, 8, bias=False) self.linear2 = torch.nn.Linear(8, 8) def forward(self, x): return self.linear2(self.linear1(x)) class QATPTQTestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 16, 3) self.linears = TestQuantizeMixQATAndPTQ.TwoLinear() self.my_linear = torch.nn.Linear(8, 8) def forward(self, x): conv_out = self.conv(x) permute_out = torch.permute(conv_out, (0, 2, 3, 1)) linear_out = self.linears(permute_out) my_linear_out = self.my_linear(linear_out) # Hardtanh doesnt get quantized via xnnpack quantizer in this test # because it relies on the propagation rules # Need to fix this return torch.nn.functional.hardtanh(my_linear_out) def _prepare_qat_linears(self, model): for name, child in model.named_children(): if isinstance(child, (torch.nn.Linear, TestQuantizeMixQATAndPTQ.TwoLinear)): if isinstance(child, torch.nn.Linear): in_channels = child.weight.size(1) else: in_channels = child.linear1.weight.size(1) example_input = (torch.rand((1, in_channels)),) traced_child = capture_pre_autograd_graph(child, example_input) quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=True, is_qat=True ) quantizer.set_global(quantization_config) traced_child_prepared = prepare_qat_pt2e(traced_child, quantizer) setattr(model, name, traced_child_prepared) else: self._prepare_qat_linears(child) def _convert_qat_linears(self, model): for name, child in model.named_children(): if isinstance(child, torch.fx.GraphModule): torch.ao.quantization.move_exported_model_to_eval(child) converted_child = convert_pt2e(child) setattr(model, name, converted_child) else: self._convert_qat_linears(child) def test_mixing_qat_ptq(self): example_inputs = (torch.randn(2, 3, 4, 4),) model = TestQuantizeMixQATAndPTQ.QATPTQTestModule() self._prepare_qat_linears(model) after_prepare_result_pt2e = model(*example_inputs) # must be fixed model.eval() self._convert_qat_linears(model) quant_result_pt2e = model(*example_inputs) model_pt2e = capture_pre_autograd_graph( model, example_inputs, ) quantizer = XNNPACKQuantizer() quantizer.set_module_type(torch.nn.Linear, None) quantization_config = get_symmetric_quantization_config() quantizer.set_global(quantization_config) model_pt2e = prepare_pt2e(model_pt2e, quantizer) after_prepare_result_pt2e = model_pt2e(*example_inputs) model_pt2e = convert_pt2e(model_pt2e) quant_result_pt2e = model_pt2e(*example_inputs) exported_model = torch.export.export(model_pt2e, example_inputs) node_occurrence = { # conv2d: 1 for act, 1 for weight, 1 for output # 3 x linear: 1 for act, 1 for output ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default ): 8, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 9, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_channel.default ): 3, # There needs to be one for hardtanh } self.checkGraphModuleNodes( exported_model.graph_module, expected_node_occurrence=node_occurrence )