# Owner(s): ["oncall: quantization"] import copy import math import operator import unittest import torch import torch.nn as nn import torch.nn.functional as F from torch.ao.quantization import ( default_dynamic_qconfig, QConfigMapping, get_default_qconfig_mapping, ) import torch.ao.nn.quantized as nnq toq = torch.ops.quantized from torch.ao.quantization.quantize_fx import ( convert_fx, convert_to_reference_fx, prepare_fx, prepare_qat_fx, ) from torch.testing._internal.common_quantization import ( ConvBnModel, ConvBnReLUModel, ConvModel, QuantizationTestCase, skipIfNoFBGEMM, skipIfNoQNNPACK, withQNNPACKBackend, SingleLayerLinearDynamicModel, SingleLayerLinearModel, LSTMwithHiddenDynamicModel, SparseNNModel, skip_if_no_torchvision, TwoLayerLinearModel ) from torch.testing._internal.common_utils import skipIfTorchDynamo from torch.ao.quantization.quantization_mappings import ( get_default_static_quant_module_mappings, get_default_dynamic_quant_module_mappings, get_default_float_to_quantized_operator_mappings, ) from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_quantization import NodeSpec as ns from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns import torch.ao.quantization.fx.quantize_handler as qh from torch.ao.ns.fx.pattern_utils import ( get_type_a_related_to_b, ) from torch.ao.ns.fx.graph_matcher import ( get_matching_subgraph_pairs, GraphMatchingException, ) from torch.ao.ns.fx.utils import ( compute_sqnr, compute_normalized_l2_error, compute_cosine_similarity, ) from torch.ao.ns.fx.mappings import ( get_node_type_to_io_type_map, get_unmatchable_types_map, get_base_name_to_sets_of_related_ops, get_base_name_for_op, add_op_to_sets_of_related_ops, ) from torch.ao.ns.fx.weight_utils import ( get_op_to_type_to_weight_extraction_fn, ) from torch.ao.ns._numeric_suite_fx import ( extract_weights, _extract_weights_impl, add_loggers, _add_loggers_impl, OutputLogger, add_shadow_loggers, _add_shadow_loggers_impl, extract_logger_info, extract_shadow_logger_info, extend_logger_results_with_comparison, prepare_n_shadows_model, convert_n_shadows_model, extract_results_n_shadows_model, OutputComparisonLogger, print_comparisons_n_shadows_model, loggers_set_enabled, loggers_set_save_activations, _prepare_n_shadows_add_loggers_model, _n_shadows_compare_weights, ) from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping from torch.ao.quantization.backend_config import get_native_backend_config from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers # Note: these models are not for use outside of this file. While it's good # to reuse code, we also need to be able to iterate on tests # quickly when debugging. If a test model has a large number of callsites # across various different files, speed of debugging on individual test cases # decreases. class LinearReluFunctional(nn.Module): def __init__(self) -> None: super().__init__() self.w1 = nn.Parameter(torch.empty(4, 4)) self.b1 = nn.Parameter(torch.zeros(4)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) def forward(self, x): x = F.linear(x, self.w1, self.b1) x = F.relu(x) return x class LinearFunctional(nn.Module): def __init__(self) -> None: super().__init__() self.w1 = nn.Parameter(torch.empty(4, 4)) self.b1 = nn.Parameter(torch.zeros(4)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) def forward(self, x): x = F.linear(x, self.w1, self.b1) return x class LinearReluLinearFunctional(nn.Module): def __init__(self) -> None: super().__init__() self.w = nn.Parameter(torch.Tensor(4, 4)) self.b = nn.Parameter(torch.zeros(4)) torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5)) def forward(self, x): x = F.linear(x, self.w, self.b) x = F.relu(x) x = F.linear(x, self.w, self.b) return x class AddMulFunctional(nn.Module): def forward(self, x, y): x = x + 1.0 x = x * 1.0 x = 1.0 + x x = 1.0 * x x = x + y x = x * y return x class AllConvAndLinearFusionModules(torch.nn.Module): def __init__(self) -> None: super().__init__() # conv1d self.conv1d_0 = nn.Conv1d(1, 1, 1) # conv1d - relu self.conv1d_1 = nn.Conv1d(1, 1, 1) self.relu_0 = nn.ReLU() # conv1d - bn (qat only) self.conv1d_2 = nn.Conv1d(1, 1, 1) self.bn1d_0 = nn.BatchNorm1d(1) # conv1d - bn - relu (qat only) self.conv1d_3 = nn.Conv1d(1, 1, 1) self.bn1d_1 = nn.BatchNorm1d(1) self.relu_4 = nn.ReLU() # conv2d self.conv2d_0 = nn.Conv2d(1, 1, 1) # conv2d - relu self.conv2d_1 = nn.Conv2d(1, 1, 1) self.relu_1 = nn.ReLU() # conv2d - bn (qat only) self.conv2d_2 = nn.Conv2d(1, 1, 1) self.bn2d_0 = nn.BatchNorm2d(1) # conv2d - bn - relu (qat only) self.conv2d_3 = nn.Conv2d(1, 1, 1) self.bn2d_1 = nn.BatchNorm2d(1) self.relu_5 = nn.ReLU() # conv3d self.conv3d_0 = nn.Conv3d(1, 1, 1) # conv3d - relu self.conv3d_1 = nn.Conv3d(1, 1, 1) self.relu_2 = nn.ReLU() # conv3d - bn (qat only) self.conv3d_2 = nn.Conv3d(1, 1, 1) self.bn3d_0 = nn.BatchNorm3d(1) # conv3d - bn - relu (qat only) self.conv3d_3 = nn.Conv3d(1, 1, 1) self.bn3d_1 = nn.BatchNorm3d(1) self.relu_6 = nn.ReLU() # linear self.linear_0 = nn.Linear(1, 1) # linear - relu self.linear_1 = nn.Linear(1, 1) self.relu_3 = nn.ReLU() def forward(self, x): # conv1d x = self.conv1d_0(x) x = self.conv1d_1(x) x = self.relu_0(x) x = self.conv1d_2(x) x = self.bn1d_0(x) x = self.conv1d_3(x) x = self.bn1d_1(x) x = self.relu_4(x) # conv2d x = x.reshape(1, 1, 1, 1) x = self.conv2d_0(x) x = self.conv2d_1(x) x = self.relu_1(x) x = self.conv2d_2(x) x = self.bn2d_0(x) x = self.conv2d_3(x) x = self.bn2d_1(x) x = self.relu_5(x) # conv3d x = x.reshape(1, 1, 1, 1, 1) x = self.conv3d_0(x) x = self.conv3d_1(x) x = self.relu_2(x) x = self.conv3d_2(x) x = self.bn3d_0(x) x = self.conv3d_3(x) x = self.bn3d_1(x) x = self.relu_6(x) # linear x = x.reshape(1, 1) x = self.linear_0(x) x = self.linear_1(x) x = self.relu_3(x) return x class AllConvFunctional(torch.nn.Module): def __init__(self, weight1d, weight2d, weight3d, bias1d, bias2d, bias3d): super().__init__() self.weight1d = torch.nn.Parameter(weight1d) self.weight2d = torch.nn.Parameter(weight2d) self.weight3d = torch.nn.Parameter(weight3d) self.bias1d = torch.nn.Parameter(bias1d) self.bias2d = torch.nn.Parameter(bias2d) self.bias3d = torch.nn.Parameter(bias3d) self.stride1d = 1 self.padding1d = 0 self.dilation1d = 1 self.stride2d = (1, 1) self.padding2d = (0, 0) self.dilation2d = (1, 1) self.groups = 1 self.stride3d = (1, 1, 1) self.padding3d = (0, 0, 0) self.dilation3d = (1, 1, 1) def forward(self, x): x = F.conv1d( x, self.weight1d, self.bias1d, self.stride1d, self.padding1d, self.dilation1d, self.groups) x = F.conv1d( x, self.weight1d, self.bias1d, self.stride1d, self.padding1d, self.dilation1d, self.groups) x = F.relu(x) x = F.conv2d( x, self.weight2d, self.bias2d, self.stride2d, self.padding2d, self.dilation2d, self.groups) x = F.conv2d( x, self.weight2d, self.bias2d, self.stride2d, self.padding2d, self.dilation2d, self.groups) x = F.relu(x) x = F.conv3d( x, self.weight3d, self.bias3d, self.stride3d, self.padding3d, self.dilation3d, self.groups) x = F.conv3d( x, self.weight3d, self.bias3d, self.stride3d, self.padding3d, self.dilation3d, self.groups) x = F.relu(x) return x @torch.fx.wrap def _wrapped_hardswish(x): return F.hardswish(x) @torch.fx.wrap def _wrapped_hardswish_fp16(x): x = x.dequantize() x = F.hardswish(x) x = x.to(torch.float16) return x @torch.fx.wrap def _wrapped_sigmoid(x): return F.sigmoid(x) @torch.fx.wrap def _wrapped_linear(x, w, b): return F.linear(x, w, b) def get_all_quant_patterns(): """ we are in the process to migrate the frontend of fx graph mode quant to use backend_config_dict, so some of the patterns are moved to backend_config_dict this function will include these patterns so that we can still have all the patterns """ # TODO: we can remove this call, and get all patterns from backend_config_dict in # the future when the frontend refactor is done in fx graph mode quantization all_quant_patterns = get_default_quant_patterns() # some of the patterns are moved to (native) backend_config_dict so we need to # add them back here for pattern, quantize_handler in _get_pattern_to_quantize_handlers(get_native_backend_config()).items(): all_quant_patterns[pattern] = quantize_handler return all_quant_patterns class TestFXGraphMatcher(QuantizationTestCase): @skipIfNoFBGEMM def test_simple_mod(self): m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() conv_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, nn.Conv2d) + '_0' expected_types = { conv_name_0: ((nn.Conv2d, torch.ao.quantization.MinMaxObserver), (nnq.Conv2d, nnq.Conv2d)), } self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) @skipIfNoFBGEMM def test_simple_fun(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.w = nn.Parameter(torch.empty(1, 4)) self.b = nn.Parameter(torch.zeros(1)) torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5)) def forward(self, x): return F.linear(x, self.w, self.b) m = M().eval() mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() linear_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, F.linear) + '_0' expected_types = { linear_name_0: ((F.linear, torch.ao.quantization.MinMaxObserver), (toq.linear, toq.linear)) } self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) @skipIfNoFBGEMM def test_simple_fusion(self): m = LinearReluFunctional().eval() mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(4, 4),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() linear_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, F.linear) + '_0' expected_types = { linear_name_0: ((F.linear, torch.ao.quantization.MinMaxObserver), (toq.linear_relu, toq.linear_relu)), } self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) @skipIfNoFBGEMM def test_simple_mod_multi(self): m = nn.Sequential( nn.Sequential( nn.Conv2d(1, 1, 1), ), nn.Conv2d(1, 1, 1), ).eval() mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions results = get_matching_subgraph_pairs(mp, mq) @skipIfNoFBGEMM def test_simple_tensor_ops(self): class M(nn.Module): def forward(self, x, y): z = x + y return z m = M().eval() example_inputs = (torch.randn(1), torch.randn(1)) mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions results = get_matching_subgraph_pairs(mp, mq) @skipIfNoFBGEMM def test_matching_failure_node_count(self): # verify that matching graphs with matching node types but # different counts of matchable nodes fails m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() m2 = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() example_inputs = (torch.randn(1, 1, 1, 1),) mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) with self.assertRaises(GraphMatchingException) as ex: results = get_matching_subgraph_pairs(mp1, mp2) @skipIfNoFBGEMM def test_matching_failure_node_type(self): # verify that matching graphs with non-matching node types fails m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() m2 = nn.Sequential(nn.Linear(1, 1)).eval() example_inputs = (torch.randn(1, 1, 1, 1),) mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) example_inputs = (torch.randn(1, 1),) mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) with self.assertRaises(GraphMatchingException) as ex: results = get_matching_subgraph_pairs(mp1, mp2) @skipIfNoFBGEMM def test_nodes_before_cat(self): # verify that nodes before cat get matched class M(nn.Module): def forward(self, x0): x1 = torch.add(x0, 1.0) y1 = torch.add(x0, 1.0) x2 = torch.cat([x1, y1]) return x2 m = M().eval() example_inputs = (torch.randn(1),) mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() cat_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.cat) + '_0' add_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.add) + '_0' add_name_1 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.add) + '_1' expected_types = { cat_name_0: ((torch.cat, torch.cat), (torch.cat, torch.cat)), add_name_0: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), add_name_1: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), } self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) @skipIfNoFBGEMM def test_dict_return_type(self): # verify that we can traverse up nodes which return dictionaries class M(nn.Module): def forward(self, x0): x1 = torch.add(x0, 1.0) y1 = torch.add(x0, 1.0) z1 = torch.add(x0, 1.0) a1 = {'x1': x1, 'y1': (y1,), 'z1': [{'key': (z1,)}]} return a1 m = M().eval() example_inputs = (torch.randn(1),) mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() add_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.add) + '_0' add_name_1 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.add) + '_1' add_name_2 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.add) + '_2' expected_types = { add_name_0: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), add_name_1: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), add_name_2: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), } self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) @skipIfNoFBGEMM def test_nodes_with_equal_types_get_matched(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = torch.mul(x, x) x = torch.sigmoid(x) x = F.relu(x) return x m = M().eval() # prevent conv2 from getting quantized, so we can test # modules with equal types qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping().set_module_name("conv2", None) example_inputs = (torch.randn(1, 1, 1, 1),) mp = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() conv_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, nn.Conv2d) + '_0' conv_name_1 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, nn.Conv2d) + '_1' mul_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.mul) + '_0' relu_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.relu) + '_0' sigmoid_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.sigmoid) + '_0' # all of these should be matched expected_types = { conv_name_1: ((nn.Conv2d, torch.ao.quantization.HistogramObserver), (nnq.Conv2d, nnq.Conv2d)), conv_name_0: ((nn.Conv2d, torch.ao.quantization.HistogramObserver), (nn.Conv2d, nn.Conv2d)), mul_name_0: ((torch.mul, torch.ao.quantization.HistogramObserver), (toq.mul, toq.mul)), relu_name_0: ((F.relu, torch.ao.quantization.FixedQParamsObserver), (F.relu, F.relu)), sigmoid_name_0: ((torch.sigmoid, torch.ao.quantization.FixedQParamsObserver), (torch.sigmoid, torch.sigmoid)), } self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) def test_methods(self): """ Verify that graph matching works on methods """ class M(nn.Module): def forward(self, x): x = x.sigmoid() return x m1 = M().eval() m2 = M().eval() qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() example_inputs = (torch.randn(1),) m1p = prepare_fx(m1, qconfig_mapping, example_inputs=example_inputs) m2p = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs) results = get_matching_subgraph_pairs(m1p, m2p) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() sigmoid_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, torch.sigmoid) + '_0' expected_types = { sigmoid_name_0: (('sigmoid', torch.ao.quantization.FixedQParamsObserver), ('sigmoid', torch.ao.quantization.FixedQParamsObserver)), } self.assert_types_for_matched_subgraph_pairs( results, expected_types, m1p, m2p) def test_op_relationship_mapping(self): """ Tests that the mapping of op relationships is complete. """ base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() type_a_related_to_b = \ get_type_a_related_to_b(base_name_to_sets_of_related_ops) # 1. check static quant module mappings static_quant_mod_mappings = get_default_static_quant_module_mappings() for fp32_type, int8_type in static_quant_mod_mappings.items(): # skip quants and dequants, for the purposes of Numerical Suite types_to_skip = ( torch.ao.quantization.QuantStub, torch.ao.quantization.DeQuantStub, nnq.FloatFunctional, # the ConvTranspose3d swap is not implemented in FX Graph # mode quantization yet nn.ConvTranspose3d, # the GroupNorm swap is not implemented in FX Graph # mode quantization yet nn.GroupNorm, # nnq.ReLU6 is no longer swapped, because nn.ReLU6 can # take quantized inputs nn.ReLU6, ) if fp32_type in types_to_skip: continue # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 2. check static quant op mappings static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings() for fp32_type, int8_type in static_quant_fun_mappings.items(): # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 3. check dynamic quant mappings dynamic_quant_mappings = get_default_dynamic_quant_module_mappings() for fp32_type, int8_type in dynamic_quant_mappings.items(): # TODO(future PR): enable correct weight extraction for these # and remove from this list. types_to_skip = ( nn.GRUCell, nn.GRU, nn.LSTMCell, nn.RNNCell, ) if fp32_type in types_to_skip: continue # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 4. go through the ops mapped to each QuantizeHandler type, and verify # correctness. def _op_in_base_sets_of_related_ops(op): for ops in base_name_to_sets_of_related_ops.values(): if op in ops: return True return False unmatchable_types_map = get_unmatchable_types_map() FUNS_UNMATCHABLE = unmatchable_types_map['funs_unmatchable'] MODS_UNMATCHABLE = unmatchable_types_map['mods_unmatchable'] METHS_UNMATCHABLE = unmatchable_types_map['meths_unmatchable'] def _op_is_unmatchable(op): return ( op in FUNS_UNMATCHABLE or op in MODS_UNMATCHABLE or op in METHS_UNMATCHABLE ) default_quant_patterns = get_all_quant_patterns() for pattern, qhandler_cls in default_quant_patterns.items(): base_op = None if isinstance(pattern, tuple): base_op = pattern[-1] elif isinstance(pattern, str): base_op = pattern else: base_op = pattern qhandler_cls_all_ops_quantizeable = [ qh.CatQuantizeHandler, qh.ConvReluQuantizeHandler, qh.LinearReLUQuantizeHandler, qh.BatchNormQuantizeHandler, qh.EmbeddingQuantizeHandler, qh.RNNDynamicQuantizeHandler, ] qhandler_cls_quant_op_same_signature = [ qh.FixedQParamsOpQuantizeHandler, qh.CopyNodeQuantizeHandler, qh.GeneralTensorShapeOpQuantizeHandler, ] if qhandler_cls == qh.BinaryOpQuantizeHandler: # these ops do not have quantized equivalents ops_to_skip = [ torch.bmm, torch.div, torch.sub, operator.truediv, operator.sub ] if base_op in ops_to_skip: continue self.assertTrue( _op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") elif qhandler_cls == qh.RNNDynamicQuantizeHandler: # TODO(future PR): add support for all classes in # RNNDynamicQuantizeHandler pass elif qhandler_cls == qh.DefaultNodeQuantizeHandler: self.assertTrue( _op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") elif qhandler_cls in qhandler_cls_quant_op_same_signature: # these ops use the same op signature for fp32 and quantized # tensors self.assertTrue( _op_in_base_sets_of_related_ops(base_op) or _op_is_unmatchable(base_op), f"{base_op} not in sets of related ops or unmatchable") elif qhandler_cls in qhandler_cls_all_ops_quantizeable: self.assertTrue( _op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") else: # torch.sum does not have quantized equivalents if base_op in [ torch.sum, nn.GRUCell, nn.GRU, nn.LSTMCell, nn.RNNCell, ]: continue if isinstance(base_op, tuple): # skip fusion patterns continue # didn't match explicit quantize handler class, we can check if the # operator is in the related op set directly if not (_op_in_base_sets_of_related_ops(base_op) or _op_is_unmatchable(base_op)): raise AssertionError( f"handling for {qhandler_cls} for op {base_op} not implemented") @skipIfNoFBGEMM def test_user_defined_function(self): """ Verify that graph matching works on user defined functions """ class M1(nn.Module): def forward(self, x): x = F.hardswish(x) return x class M2(nn.Module): def forward(self, x): x = _wrapped_hardswish(x) return x qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() example_inputs = (torch.randn(1, 1, 1, 1),) m1 = prepare_fx(M1().eval(), qconfig_mapping, example_inputs=example_inputs) m2 = prepare_fx(M2().eval(), qconfig_mapping, example_inputs=example_inputs) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() add_op_to_sets_of_related_ops( base_name_to_sets_of_related_ops, _wrapped_hardswish, F.hardswish) results = get_matching_subgraph_pairs( m1, m2, base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops) hardswish_name_0 = 'base_op_' + get_base_name_for_op( base_name_to_sets_of_related_ops, F.hardswish) + '_0' expected_types = { hardswish_name_0: ((F.hardswish, torch.ao.quantization.HistogramObserver), (_wrapped_hardswish, _wrapped_hardswish)), } self.assert_types_for_matched_subgraph_pairs( results, expected_types, m1, m2) @skipIfNoFBGEMM def test_results_order(self): m = nn.Sequential( nn.Conv2d(1, 1, 1), nn.Linear(1, 1), ).eval() example_inputs = (torch.randn(1, 1, 1, 1),) mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) self.assertTrue(len(results) == 2) results_iter = iter(results.items()) _, (subgraph_a_0, subgraph_b_0) = next(results_iter) self.assertTrue(subgraph_a_0.start_node.name == '_0' and subgraph_b_0.start_node.name == '_0') _, (subgraph_a_1, subgraph_b_1) = next(results_iter) self.assertTrue(subgraph_a_1.start_node.name == '_1' and subgraph_b_1.start_node.name == '_1') class TestFXGraphMatcherModels(QuantizationTestCase): @skipIfTorchDynamo("too slow") @skipIfNoFBGEMM @skip_if_no_torchvision def test_mobilenet_v2(self): # verify that mobilenetv2 graph is able to be matched import torchvision m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).eval().float() example_inputs = (torch.randn(1, 3, 224, 224),) mp = prepare_fx(copy.deepcopy(m), {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) # assume success if no exceptions results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions results_mp_mq = get_matching_subgraph_pairs(mp, mq) @skipIfNoFBGEMM @skip_if_no_torchvision def test_mobilenet_v2_qat(self): # verify that mobilenetv2 graph is able to be matched import torchvision m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float() example_inputs = (torch.randn(1, 3, 224, 224),) mp = prepare_qat_fx( copy.deepcopy(m), {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, example_inputs=example_inputs) # assume success if no exceptions results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions results_mp_mq = get_matching_subgraph_pairs(mp, mq) class FXNumericSuiteQuantizationTestCase(QuantizationTestCase): def _test_extract_weights( self, m, example_inputs, results_len=0, qconfig_dict=None, prepare_fn=prepare_fx ): m = torch.fx.symbolic_trace(m) if qconfig_dict is None: qconfig_dict = {'': torch.ao.quantization.default_qconfig} mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # test both the public API as well as the internal GraphModule API for extract_weights_fun in (extract_weights, _extract_weights_impl): # test both m vs mp and mp vs mq for m1, m2 in ((m, mp), (mp, mq)): results = extract_weights_fun('a', m1, 'b', m2) self.assertTrue( len(results) == results_len, f"expected len {results_len}, got len {len(results)}") self.assert_ns_compare_dict_valid(results) extend_logger_results_with_comparison( results, 'a', 'b', compute_sqnr, 'sqnr') extend_logger_results_with_comparison( results, 'a', 'b', compute_normalized_l2_error, 'l2_error') extend_logger_results_with_comparison( results, 'a', 'b', compute_cosine_similarity, 'cosine_similarity') def _test_match_activations( self, m, data, prepared_expected_node_occurrence=None, results_len=0, should_log_inputs=False, qconfig_dict=None, skip_scripting=False, prepare_fn=prepare_fx, ): if qconfig_dict is None: qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() if prepare_fn == prepare_fx: m.eval() else: m.train() mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data) mp(*data) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) m_ns, mp_ns2 = add_loggers( 'a', m, 'b', copy.deepcopy(mp), OutputLogger, should_log_inputs=should_log_inputs) mp_ns, mq_ns = add_loggers( 'a', mp, 'b', mq, OutputLogger, should_log_inputs=should_log_inputs) if prepared_expected_node_occurrence: self.checkGraphModuleNodes( m_ns, expected_node_occurrence=prepared_expected_node_occurrence) self.checkGraphModuleNodes( mp_ns2, expected_node_occurrence=prepared_expected_node_occurrence) self.checkGraphModuleNodes( mp_ns, expected_node_occurrence=prepared_expected_node_occurrence) self.checkGraphModuleNodes( mq_ns, expected_node_occurrence=prepared_expected_node_occurrence) if not skip_scripting: m_ns = torch.jit.script(m_ns) mp_ns = torch.jit.script(mp_ns) mq_ns = torch.jit.script(mq_ns) # calibrate m_ns(*data) mp_ns2(*data) mp_ns(*data) mq_ns(*data) # check activation result correctness results = [] for m1, m2 in ((m_ns, mp_ns2), (mp_ns, mq_ns)): act_compare_dict = extract_logger_info( m1, m2, OutputLogger, 'b') self.assertTrue( len(act_compare_dict) == results_len, f"expected len {results_len}, got len {len(act_compare_dict)}") self.assert_ns_compare_dict_valid(act_compare_dict) extend_logger_results_with_comparison( act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr') extend_logger_results_with_comparison( act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error') extend_logger_results_with_comparison( act_compare_dict, 'a', 'b', compute_cosine_similarity, 'cosine_similarity') results.append(act_compare_dict) return results def _test_match_shadow_activations( self, m, data, prepared_expected_node_occurrence=None, results_len=None, should_log_inputs=False, qconfig_dict=None, skip_scripting=False, prepare_fn=prepare_fx, compare_fp32_vs_fp32_prepared=True, ): if qconfig_dict is None: qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() if prepare_fn == prepare_fx: m.eval() else: m.train() print("qconfig_dict:", qconfig_dict) mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data) print("prepared:", mp) mp(*data) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) print("quantized:", mq) if compare_fp32_vs_fp32_prepared: m_shadows_mp = add_shadow_loggers( 'a', copy.deepcopy(m), 'b', copy.deepcopy(mp), OutputLogger, should_log_inputs=should_log_inputs) mp_shadows_mq = add_shadow_loggers( 'a', mp, 'b', mq, OutputLogger, should_log_inputs=should_log_inputs) if prepared_expected_node_occurrence: if compare_fp32_vs_fp32_prepared: self.checkGraphModuleNodes( m_shadows_mp, expected_node_occurrence=prepared_expected_node_occurrence) self.checkGraphModuleNodes( mp_shadows_mq, expected_node_occurrence=prepared_expected_node_occurrence) if not skip_scripting: if compare_fp32_vs_fp32_prepared: m_shadows_mp = torch.jit.script(m_shadows_mp) mp_shadows_mq = torch.jit.script(mp_shadows_mq) # calibrate if compare_fp32_vs_fp32_prepared: m_shadows_mp(*data) mp_shadows_mq(*data) # check activation result correctness results = [] models = (m_shadows_mp, mp_shadows_mq) if \ compare_fp32_vs_fp32_prepared else (mp_shadows_mq,) for model in models: act_compare_dict = extract_shadow_logger_info( model, OutputLogger, 'b') if results_len is not None: self.assertTrue( len(act_compare_dict) == results_len, f"expected len {results_len}, got len {len(act_compare_dict)}") self.assert_ns_compare_dict_valid(act_compare_dict) extend_logger_results_with_comparison( act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr') extend_logger_results_with_comparison( act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error') extend_logger_results_with_comparison( act_compare_dict, 'a', 'b', compute_cosine_similarity, 'cosine_similarity') results.append(act_compare_dict) return results class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase): @skipIfNoFBGEMM def test_extract_weights_mod_ptq(self): m = AllConvAndLinearFusionModules().eval() example_inputs = (torch.randn(1, 1, 1, 1),) self._test_extract_weights(m, example_inputs, results_len=14) @skipIfNoFBGEMM def test_extract_weights_mod_qat(self): m = AllConvAndLinearFusionModules().train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} example_inputs = (torch.randn(1, 1, 1, 1),) self._test_extract_weights( m, example_inputs, results_len=14, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_linear_fun_ptq(self): m = LinearReluLinearFunctional().eval() example_inputs = (torch.randn(1, 4),) self._test_extract_weights(m, example_inputs, results_len=2) @skipIfNoFBGEMM def test_extract_weights_linear_fun_qat(self): m = LinearReluLinearFunctional().train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} example_inputs = (torch.randn(1, 4),) self._test_extract_weights( m, example_inputs, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_conv_fun_ptq(self): w1d = torch.randn(1, 1, 1) w2d = torch.randn(1, 1, 1, 1) w3d = torch.randn(1, 1, 1, 1, 1) b1d = torch.randn(1) b2d = torch.randn(1) b3d = torch.randn(1) m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).eval() example_inputs = (torch.randn(1, 1, 1, 1),) self._test_extract_weights(m, example_inputs, results_len=6) @skipIfNoFBGEMM def test_extract_weights_conv_fun_qat(self): w1d = torch.randn(1, 1, 1) w2d = torch.randn(1, 1, 1, 1) w3d = torch.randn(1, 1, 1, 1, 1) b1d = torch.randn(1) b2d = torch.randn(1) b3d = torch.randn(1) m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} example_inputs = (torch.randn(1, 1, 1, 1),) self._test_extract_weights( m, example_inputs, results_len=6, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_dynamic(self): # TODO(future PR): add Linear-ReLU, after #55393 is fixed. m = nn.Sequential(nn.Linear(1, 1)).eval() qconfig_dict = { 'object_type': [ (nn.Linear, default_dynamic_qconfig), ], } example_inputs = (torch.randn(1, 1),) self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_extract_weights_fqn(self): m = nn.Sequential( nn.Sequential(nn.Conv2d(1, 1, 1)), nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(1, 1, 1, 1),) mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) results = extract_weights('a', mp, 'b', mq) fqn_a_0 = results['_0_0']['weight']['a'][0]['fqn'] fqn_b_0 = results['_0_0']['weight']['b'][0]['fqn'] self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0) fqn_a_1 = results['_1']['weight']['a'][0]['fqn'] fqn_b_1 = results['_1']['weight']['b'][0]['fqn'] self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1) def _test_match_activations_mod_impl(self, prepare_fn=prepare_fx): m = nn.Sequential( torch.ao.quantization.QuantStub(), nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = None if prepare_fn == prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} expected_occurrence = { ns.call_module(OutputLogger): 2, } self._test_match_activations( m, (torch.randn(2, 1, 2, 2),), prepared_expected_node_occurrence=expected_occurrence, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_fn) @skipIfNoFBGEMM def test_match_activations_mod_ptq(self): self._test_match_activations_mod_impl(prepare_fn=prepare_fx) @skipIfNoFBGEMM def test_match_activations_mod_qat(self): self._test_match_activations_mod_impl(prepare_fn=prepare_qat_fx) def _test_match_activations_fun_impl(self, prepare_fn=prepare_fx): m = LinearReluLinearFunctional().eval() qconfig_dict = None if prepare_fn == prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} expected_occurrence = { ns.call_module(OutputLogger): 2, } self._test_match_activations( m, (torch.randn(4, 4),), prepared_expected_node_occurrence=expected_occurrence, results_len=2, prepare_fn=prepare_fn, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_match_activations_fun_ptq(self): self._test_match_activations_fun_impl(prepare_fn=prepare_fx) @skipIfNoFBGEMM def test_match_activations_fun_qat(self): self._test_match_activations_fun_impl(prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_match_activations_meth_ptq(self): """ Verify that add_loggers works on methods """ class M(nn.Module): def forward(self, x): x = x.sigmoid() return x m = M().eval() res = self._test_match_activations( m, (torch.randn(4, 4),), results_len=1) @skipIfNoFBGEMM def test_match_activations_fqn(self): m = nn.Sequential( nn.Sequential(nn.Conv2d(1, 1, 1)), nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(1, 1, 1, 1),) mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) mp_ns, mq_ns = add_loggers('a', mp, 'b', mq, OutputLogger) datum = torch.randn(1, 1, 1, 1) mp_ns(datum) mq_ns(datum) results = extract_logger_info(mp_ns, mq_ns, OutputLogger, 'b') fqn_a_0 = results['_0_0']['node_output']['a'][0]['fqn'] fqn_b_0 = results['_0_0']['node_output']['b'][0]['fqn'] self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0) fqn_a_1 = results['_1']['node_output']['a'][0]['fqn'] fqn_b_1 = results['_1']['node_output']['b'][0]['fqn'] self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1) def _test_add_shadow_loggers_mod_impl(self, prepare_fn=prepare_fx): m = nn.Sequential( nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = None if prepare_fn == prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} res = self._test_match_shadow_activations( m, (torch.randn(1, 1, 4, 4),), results_len=2, prepare_fn=prepare_fn, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_add_shadow_loggers_mod_ptq(self): self._test_add_shadow_loggers_mod_impl(prepare_fn=prepare_fx) @skipIfNoFBGEMM def test_add_shadow_loggers_mod_qat(self): self._test_add_shadow_loggers_mod_impl(prepare_fn=prepare_qat_fx) def _test_add_shadow_loggers_fun_impl(self, prepare_fn=prepare_fx): m = LinearReluLinearFunctional() qconfig_dict = None if prepare_fn == prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} res = self._test_match_shadow_activations( m, (torch.randn(4, 4),), results_len=2, prepare_fn=prepare_fn, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_add_shadow_loggers_fun_ptq(self): self._test_add_shadow_loggers_fun_impl(prepare_fn=prepare_fx) @skipIfNoFBGEMM def test_add_shadow_loggers_fun_qat(self): self._test_add_shadow_loggers_fun_impl(prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_add_shadow_loggers_meth_ptq(self): """ Verify that add_loggers works on methods """ class M(nn.Module): def forward(self, x): x = x.sigmoid() return x m = M().eval() res = self._test_match_shadow_activations( m, (torch.randn(4, 4),), # For now, sigmoid is not supported for shadowing because the dtype # inference for it is not implemented yet. So, this is just testing # that shadowing models with method calls does not crash. results_len=0) @skipIfNoFBGEMM def test_shadow_activations_fqn(self): m = nn.Sequential( nn.Sequential(nn.Conv2d(1, 1, 1)), nn.Conv2d(1, 1, 1), ).eval() qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() example_inputs = (torch.randn(1, 1, 1, 1),) mp = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers('a', mp, 'b', mq, OutputLogger) datum = torch.randn(1, 1, 1, 1) mp_shadows_mq(datum) results = extract_shadow_logger_info(mp_shadows_mq, OutputLogger, 'b') fqn_a_0 = results['_0_0']['node_output']['a'][0]['fqn'] fqn_b_0 = results['_0_0']['node_output']['b'][0]['fqn'] self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0) fqn_a_1 = results['_1']['node_output']['a'][0]['fqn'] fqn_b_1 = results['_1']['node_output']['b'][0]['fqn'] self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1) @skipIfNoFBGEMM def test_logging_inputs(self): """ Verifies that logging inputs works correctly """ class M(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv(x) x = torch.cat([x, x], dim=0) return x m = M().eval() self._test_match_shadow_activations( m, (torch.randn(1, 1, 4, 4),), results_len=1, should_log_inputs=True) @skipIfNoFBGEMM def test_ops_with_same_fp32_and_int8_signature(self): """ Verifies that we can match pairs of ops which have the same aten signature for fp32 and int8 tensors. """ class M(nn.Module): def __init__(self) -> None: super().__init__() self.max_pool_2d = nn.MaxPool2d(2) def forward(self, x): x = self.max_pool_2d(x) x = F.relu(x) return x m = M().eval() self._test_match_activations( m, (torch.randn(1, 1, 2, 2),), results_len=2) @skipIfNoFBGEMM def test_add_mul_inputs_activations(self): m = AddMulFunctional().eval() res = self._test_match_activations( m, (torch.randn(2, 2), torch.randn(2, 2)), results_len=6, should_log_inputs=True) @skipIfNoFBGEMM def test_linear_fp16_weights(self): qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} m = LinearReluFunctional().eval() example_inputs = (torch.randn(1, 4),) self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_linear_fp16_activations(self): for should_log_inputs in (True, False): qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} m = LinearReluFunctional().eval() num_loggers = 2 if should_log_inputs else 1 expected_occurrence = { ns.call_module(OutputLogger): num_loggers, } res = self._test_match_activations( m, (torch.randn(4, 4),), prepared_expected_node_occurrence=expected_occurrence, results_len=1, qconfig_dict=qconfig_dict, should_log_inputs=should_log_inputs) @skipIfNoFBGEMM def test_linear_fp16_shadow_activations(self): for should_log_inputs in (True, False): qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} m = LinearReluFunctional().eval() num_loggers = 4 if should_log_inputs else 2 expected_occurrence = { ns.call_module(OutputLogger): num_loggers, } res2 = self._test_match_shadow_activations( m, (torch.randn(4, 4),), prepared_expected_node_occurrence=expected_occurrence, results_len=1, qconfig_dict=qconfig_dict, should_log_inputs=should_log_inputs) @skipIfNoFBGEMM def test_linear_fp16_vs_linear_fp16_shadow_activations(self): m = LinearFunctional().eval() qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} example_inputs = (torch.randn(1, 4),) mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq1 = convert_fx(copy.deepcopy(mp)) mq2 = convert_fx(copy.deepcopy(mp)) mq1_shadows_mq2 = _add_shadow_loggers_impl( 'a', mq1, 'b', mq2, OutputLogger, should_log_inputs=False) mq1_shadows_mq2(torch.randn(4, 4)) act_compare_dict = extract_shadow_logger_info( mq1_shadows_mq2, OutputLogger, 'b') self.assertTrue(len(act_compare_dict) == 1) self.assert_ns_compare_dict_valid(act_compare_dict) @skipIfNoFBGEMM def test_op_with_either_fp32_or_int8_input(self): """ Verify that shadowing works with ops which accept either fp32 or int8 inputs. """ class M(nn.Module): def __init__(self) -> None: super().__init__() self.relu = nn.ReLU() def forward(self, x): x = self.relu(x) x = F.relu(x) return x m = M() res = self._test_match_shadow_activations( m, (torch.randn(4, 4),), # Note: shadowing relu by itself is currently not supported, # this test is just testing that it does not crash results_len=0) def _test_int8_shadows_int8_impl(self, m): """ Verify that shadowing works where both modules are int8 """ qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(4, 1, 4, 4),) mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(*example_inputs) mq1 = convert_fx(copy.deepcopy(mp)) mq2 = convert_fx(mp) mq1_shadows_mq2 = add_shadow_loggers('a', mq1, 'b', mq2, OutputLogger) mq1_shadows_mq2(torch.randn(4, 1, 4, 4)) act_compare_dict = extract_shadow_logger_info( mq1_shadows_mq2, OutputLogger, 'b') self.assertTrue(len(act_compare_dict) == 1) self.assert_ns_compare_dict_valid(act_compare_dict) @skipIfNoFBGEMM def test_int8_shadows_int8_mod(self): m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() self._test_int8_shadows_int8_impl(m) @skipIfNoFBGEMM def test_int8_shadows_int8_fun(self): m = LinearFunctional().eval() self._test_int8_shadows_int8_impl(m) @skipIfNoFBGEMM def test_user_module_scriptable(self): # Logging of the output of this class is not supported, because it is # neither a tensor or an RNN return type. class M1(nn.Module): def forward(self, x): x1 = x * 2 x2 = x * 4 return (x1, x2) class M2(nn.Module): def __init__(self) -> None: super().__init__() self.m1 = M1() def forward(self, x): x1, x2 = self.m1(x) return x1, x2 m = M2().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} prepare_custom_config_dict = { 'non_traceable_module_class': [M1], } example_inputs = (torch.randn(1),) mp1 = prepare_fx( m, qconfig_dict, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict) mp2 = copy.deepcopy(mp1) unmatchable_types_map = get_unmatchable_types_map() unmatchable_types_map['mods_unmatchable'].add(M1) mp1_ns, mp2_ns = _add_loggers_impl( 'a', mp1, 'b', mp2, OutputLogger, should_log_inputs=False, unmatchable_types_map=unmatchable_types_map) # Scripting a model with loggers should succeed. If it fails because of # incorrect dtypes, we can blocklist the associated types from being instrumented. mp1_ns_scripted = torch.jit.script(mp1_ns) mp2_ns_scripted = torch.jit.script(mp2_ns) @skipIfNoFBGEMM def test_user_module(self): """ For user defined modules, 1. weight extraction should not crash 2. unshadowed activations should only have loggers for known types 3. shadowed activations should only have loggers for known types with known dtypes """ class UserModule(nn.Module): def forward(self, x): return x class M(nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(1, 1) self.user_module = UserModule() def forward(self, x): x = self.linear(x) x = self.user_module(x) return x m = M().eval() # quantize without tracing through UserModule qconfig_dict = {'': torch.ao.quantization.default_qconfig} prepare_custom_config_dict = {'non_traceable_module_name': ['user_module']} example_inputs = (torch.randn(1, 1, 1),) mp = prepare_fx( m, qconfig_dict, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict) mp(*example_inputs) mq = convert_fx(copy.deepcopy(mp)) # weight extraction should not crash weights = _extract_weights_impl('fp32_prepared', mp, 'int8', mq) # unshadowed activations should have loggers # add loggers, without retracing # note: converting again because we cannot copy a quantized linear mp_ns, mq_ns = _add_loggers_impl( 'fp32_prepared', copy.deepcopy(mp), 'int8', convert_fx(copy.deepcopy(mp)), OutputLogger, should_log_inputs=True) # both fp32 and int8 models should have 2 loggers each, 2 for I/O # of linear, and 0 for I/O of user_module unshadowed_expected_occurrence = { ns.call_module(OutputLogger): 2, } self.checkGraphModuleNodes( mp_ns, expected_node_occurrence=unshadowed_expected_occurrence) self.checkGraphModuleNodes( mq_ns, expected_node_occurrence=unshadowed_expected_occurrence) # shadowed activations should only have loggers for nodes where # the types are known and we can do a dtype cast # add shadow loggers, without retracing mp_shadows_mq_ns = _add_shadow_loggers_impl( 'fp32_prepared', mp, 'int8', mq, OutputLogger, should_log_inputs=True) # 4 loggers for I/O of linear, 0 loggers for I/O of user_module shadowed_expected_occurrence = { ns.call_module(OutputLogger): 4, } self.checkGraphModuleNodes( mp_shadows_mq_ns, expected_node_occurrence=shadowed_expected_occurrence) def test_op_io_dtype_coverage(self): """ Tests that all the ops quantization cares about have input and output dtypes defined. """ base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() type_a_related_to_b = \ get_type_a_related_to_b(base_name_to_sets_of_related_ops) # TODO(future PR): clean this up node_type_to_io_type_map = get_node_type_to_io_type_map() FUNS_IO_TYPE_FP32 = node_type_to_io_type_map['funs_io_type_fp32'] FUNS_IO_TYPE_INT8 = node_type_to_io_type_map['funs_io_type_int8'] FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['funs_io_type_fp32_or_int8'] MODS_IO_TYPE_FP32 = node_type_to_io_type_map['mods_io_type_fp32'] MODS_IO_TYPE_INT8 = node_type_to_io_type_map['mods_io_type_int8'] MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['mods_io_type_fp32_or_int8'] METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['meths_io_type_fp32_or_int8'] unmatchable_types_map = get_unmatchable_types_map() FUNS_UNMATCHABLE = unmatchable_types_map['funs_unmatchable'] MODS_UNMATCHABLE = unmatchable_types_map['mods_unmatchable'] METHS_UNMATCHABLE = unmatchable_types_map['meths_unmatchable'] # 1. check static quant module mappings static_quant_mod_mappings = get_default_static_quant_module_mappings() for fp32_type, int8_type in static_quant_mod_mappings.items(): types_to_skip = ( torch.ao.quantization.QuantStub, torch.ao.quantization.DeQuantStub, nnq.FloatFunctional, # TODO(future PR): look into whether shadowing embeddings # makes sense nn.Embedding, nn.EmbeddingBag, # the ConvTranspose3d swap is not implemented in FX Graph # mode quantization yet nn.ConvTranspose3d, # the GroupNorm swap is not implemented in FX Graph # mode quantization yet nn.GroupNorm, # nnq.ReLU6 is no longer swapped, because nn.ReLU6 can # take quantized inputs nn.ReLU6, ) if fp32_type in types_to_skip: continue self.assertTrue( fp32_type in MODS_IO_TYPE_FP32, f"missing IO type handling for f{fp32_type}") self.assertTrue( int8_type in MODS_IO_TYPE_INT8, f"missing IO type handling for f{int8_type}") # 2. check static quant op mappings static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings() for fp32_type, int8_type in static_quant_fun_mappings.items(): self.assertTrue( fp32_type in FUNS_IO_TYPE_FP32, f"missing IO type handling for f{fp32_type}") self.assertTrue( int8_type in FUNS_IO_TYPE_INT8, f"missing IO type handling for f{int8_type}") # 3. check dynamic quant mappings dynamic_quant_mappings = get_default_dynamic_quant_module_mappings() for fp32_type1, fp32_type2 in dynamic_quant_mappings.items(): # TODO(future PR): verify correct I/O for these and remove from # this list. types_to_skip = ( nn.GRUCell, nn.GRU, nn.LSTMCell, nn.RNNCell, # TODO(future PR): look into whether shadowing embeddings # makes sense nn.Embedding, nn.EmbeddingBag, ) if fp32_type1 in types_to_skip: continue self.assertTrue( fp32_type1 in MODS_IO_TYPE_FP32, f"missing IO type handling for f{fp32_type1}") self.assertTrue( fp32_type2 in MODS_IO_TYPE_FP32, f"missing IO type handling for f{fp32_type2}") # 4. go through the ops mapped to each QuantizeHandler type, and verify # correctness. default_quant_patterns = get_all_quant_patterns() for pattern, qhandler_cls in default_quant_patterns.items(): base_op = None if isinstance(pattern, tuple): base_op = pattern[-1] elif isinstance(pattern, str): base_op = pattern else: base_op = pattern if ( qhandler_cls in ( qh.BinaryOpQuantizeHandler, qh.RNNDynamicQuantizeHandler, ) ): # TODO(future PR): implement shadowing for binary ops # TODO(future PR): implement shadowing for RNN ops continue elif qhandler_cls == qh.CatQuantizeHandler: self.assertTrue( base_op in FUNS_IO_TYPE_FP32_OR_INT8, f"missing IO type handling for {base_op}") elif ( qhandler_cls in ( qh.ConvReluQuantizeHandler, qh.LinearReLUQuantizeHandler, qh.BatchNormQuantizeHandler, qh.DefaultNodeQuantizeHandler, ) ): self.assertTrue( (base_op in FUNS_IO_TYPE_FP32) or (base_op in MODS_IO_TYPE_FP32), f"missing IO type handling for {base_op}") elif ( qhandler_cls in ( qh.FixedQParamsOpQuantizeHandler, qh.CopyNodeQuantizeHandler, qh.GeneralTensorShapeOpQuantizeHandler, ) ): if ( base_op in FUNS_UNMATCHABLE or base_op in MODS_UNMATCHABLE or base_op in METHS_UNMATCHABLE ): continue self.assertTrue( (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or (base_op in MODS_IO_TYPE_FP32_OR_INT8) or (base_op in METHS_IO_TYPE_FP32_OR_INT8) or # Softmax has a different signature for the quantized # version, so it does not fit into the cases above. (base_op is torch.nn.Softmax), f"missing IO type handling for {base_op}") elif qhandler_cls == qh.EmbeddingQuantizeHandler: # embedding shadowing is not implemented, for now continue else: if ( base_op in FUNS_UNMATCHABLE or base_op in MODS_UNMATCHABLE or base_op in METHS_UNMATCHABLE ): continue if qhandler_cls(None, {}).is_general_tensor_value_op(): self.assertTrue( (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or (base_op in MODS_IO_TYPE_FP32_OR_INT8) or (base_op in METHS_IO_TYPE_FP32_OR_INT8), f"missing IO type handling for {base_op} using {qhandler_cls}") else: self.assertTrue( (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or (base_op in MODS_IO_TYPE_FP32_OR_INT8) or (base_op in METHS_IO_TYPE_FP32_OR_INT8) or (base_op in FUNS_IO_TYPE_FP32) or (base_op in MODS_IO_TYPE_FP32) or f"missing IO type handling for {base_op} using {qhandler_cls}") @skipIfNoFBGEMM def test_user_defined_function(self): """ Verify that NS APIs work on user defined functions """ class M1(nn.Module): def __init__(self) -> None: super().__init__() self.w1 = nn.Parameter(torch.empty(1, 1)) self.b1 = nn.Parameter(torch.zeros(1)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) def forward(self, x): x = F.hardswish(x) x = x.sigmoid() x = F.linear(x, self.w1, self.b1) return x class M2(nn.Module): def __init__(self) -> None: super().__init__() self.w1 = nn.Parameter(torch.empty(1, 1)) self.b1 = nn.Parameter(torch.zeros(1)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) def forward(self, x): x = _wrapped_hardswish(x) x = _wrapped_sigmoid(x) x = _wrapped_linear(x, self.w1, self.b1) return x qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() example_inputs = (torch.randn(1, 1),) m1 = prepare_fx(M1().eval(), qconfig_mapping, example_inputs=example_inputs) m2 = prepare_fx(M2().eval(), qconfig_mapping, example_inputs=example_inputs) data = torch.randn(1, 1) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() add_op_to_sets_of_related_ops( base_name_to_sets_of_related_ops, _wrapped_hardswish, F.hardswish) add_op_to_sets_of_related_ops( base_name_to_sets_of_related_ops, _wrapped_sigmoid, F.sigmoid) add_op_to_sets_of_related_ops( base_name_to_sets_of_related_ops, _wrapped_linear, F.linear) op_to_type_to_weight_extraction_fn = \ get_op_to_type_to_weight_extraction_fn() op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \ torch.ao.ns.fx.weight_utils.get_linear_fun_weight # test compare weights results = extract_weights( 'a', m1, 'b', m2, base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, op_to_type_to_weight_extraction_fn=op_to_type_to_weight_extraction_fn) self.assertTrue(len(results) == 1) self.assertTrue(len(results['_wrapped_linear']['weight']) == 2) # test unshadowed activations m1_ns, m2_ns = _add_loggers_impl( 'a', copy.deepcopy(m1), 'b', copy.deepcopy(m2), OutputLogger, should_log_inputs=False, base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops) # calibrate m1_ns(data) m2_ns(data) # check activation result correctness act_compare_dict = extract_logger_info(m1_ns, m2_ns, OutputLogger, 'b') self.assertTrue(len(act_compare_dict) == 3) self.assert_ns_compare_dict_valid(act_compare_dict) # test shadowed activations node_type_to_io_type_map = get_node_type_to_io_type_map() node_type_to_io_type_map['funs_io_type_fp32'].add(_wrapped_hardswish) node_type_to_io_type_map['funs_io_type_fp32'].add(_wrapped_sigmoid) m2_shadows_m1_ns = _add_shadow_loggers_impl( 'a', m2, 'b', m1, OutputLogger, should_log_inputs=False, base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, node_type_to_io_type_map=node_type_to_io_type_map) # calibrate m2_shadows_m1_ns(data) # check activation result correctness act_compare_dict = extract_shadow_logger_info( m2_shadows_m1_ns, OutputLogger, 'b') self.assertTrue(len(act_compare_dict) == 2) self.assert_ns_compare_dict_valid(act_compare_dict) @skipIfNoFBGEMM def test_layer_names(self): m = nn.Sequential( nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), nn.Sigmoid(), ).eval() qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping("fbgemm") example_inputs = (torch.randn(1, 1, 1, 1),) mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) # extract weights results = extract_weights('fp32', mp, 'int8', mq) mq_node_names = [node.name for node in mq.graph.nodes] for layer_name in results.keys(): self.assertTrue(layer_name in mq_node_names) # match activations mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mp_ns, mq_ns = add_loggers( 'fp32', copy.deepcopy(mp), 'int8', mq, OutputLogger) data = torch.randn(1, 1, 1, 1) mp_ns(data) mq_ns(data) results = extract_logger_info(mp_ns, mq_ns, OutputLogger, 'int8') mq_node_names = [node.name for node in mq_ns.graph.nodes] for layer_name in results.keys(): self.assertTrue(layer_name in mq_node_names) # match shadow activations mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers( 'fp32', mp, 'int8', mq, OutputLogger) mp_shadows_mq(data) results = extract_shadow_logger_info( mp_shadows_mq, OutputLogger, 'int8') mq_node_names = [node.name for node in mp_shadows_mq.graph.nodes] for layer_name in results.keys(): self.assertTrue(layer_name in mq_node_names) @skipIfNoFBGEMM def test_extend_logger_results_with_comparison(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(1, 1, 1, 1),) mp = torch.ao.quantization.quantize_fx.prepare_fx( m, qconfig_dict, example_inputs=example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) # extract weights results = extract_weights('fp32', mp, 'int8', mq) extend_logger_results_with_comparison( results, 'fp32', 'int8', compute_sqnr, 'sqnr_int8_vs_fp32') extend_logger_results_with_comparison( results, 'fp32', 'int8', compute_normalized_l2_error, 'l2_error_int8_vs_fp32') extend_logger_results_with_comparison( results, 'fp32', 'int8', compute_cosine_similarity, 'cosine_similarity_int8_vs_fp32') for layer_results in results.values(): assert 'sqnr_int8_vs_fp32' in \ layer_results['weight']['int8'][0].keys() assert 'l2_error_int8_vs_fp32' in \ layer_results['weight']['int8'][0].keys() assert 'cosine_similarity_int8_vs_fp32' in \ layer_results['weight']['int8'][0].keys() @skipIfNoFBGEMM def test_int8_shadows_fp32_simple(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), nn.ReLU()).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(1, 1, 1, 1),) mp = torch.ao.quantization.quantize_fx.prepare_fx( m, qconfig_dict, example_inputs=example_inputs) mp(torch.randn(1, 1, 1, 1)) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers( 'int8', mq, 'fp32', mp, OutputLogger) # verify that scale and zp were extracted correctly # for the first op, the scale+zp live as attributes on the module scale_0 = mp_shadows_mq._0_input_scale_0 scale_0_ref = getattr(mq_ref, '0_input_scale_0') self.assertEqual(scale_0, scale_0_ref) zp_0 = mp_shadows_mq._0_input_zero_point_0 zp_0_ref = getattr(mq_ref, '0_input_zero_point_0') self.assertEqual(zp_0, zp_0_ref) # for the second op, the scale and zp of input to second op # must equal to scale and zp of output of first op scale_1 = mp_shadows_mq._1_input_scale_0 scale_1_ref = getattr(mq_ref, '0').scale self.assertEqual(scale_1, scale_1_ref) zp_1 = mp_shadows_mq._1_input_zero_point_0 zp_1_ref = getattr(mq_ref, '0').zero_point self.assertEqual(zp_1, zp_1_ref) # verify running data works mp_shadows_mq(torch.randn(1, 1, 1, 1)) act_compare_dict = extract_shadow_logger_info( mp_shadows_mq, OutputLogger, 'fp32') self.assertTrue(len(act_compare_dict) == 2) self.assert_ns_compare_dict_valid(act_compare_dict) @skipIfNoFBGEMM def test_int8_shadows_fp32_coverage(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.adaptive_avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.adaptive_avg_pool(x) # input qparams of conv will be input qparams of adaptive_avg_pool x = self.conv(x) x = torch.mul(x, x) x = self.conv(x) x = torch.add(x, x) x = F.relu(x) x = self.conv(x) return x m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(1, 1, 1, 1),) mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(*example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers( 'int8', mq, 'fp32', mp, OutputLogger) mp_shadows_mq(torch.randn(1, 1, 1, 1)) act_compare_dict = extract_shadow_logger_info( mp_shadows_mq, OutputLogger, 'fp32') self.assertTrue(len(act_compare_dict) == 3) self.assert_ns_compare_dict_valid(act_compare_dict) @skipIfNoFBGEMM def test_loggers_preserve_qat_numerics(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)) qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} example_inputs = (torch.randn(1, 1, 1, 1),) mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) mp(*example_inputs) mc = convert_fx(copy.deepcopy(mp)) mp.apply(torch.ao.quantization.disable_observer) ref_fp32 = mp(*example_inputs) ref_int8 = mc(*example_inputs) mp_ns, mc_ns = add_loggers('fp32', mp, 'int8', mc, OutputLogger) ref_fp32_ns = mp_ns(*example_inputs) ref_int8_ns = mc_ns(*example_inputs) self.assertEqual(ref_fp32, ref_fp32_ns) self.assertEqual(ref_int8, ref_int8_ns) @skipIfNoFBGEMM def test_shadow_loggers_preserve_qat_numerics(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)) qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} example_inputs = (torch.randn(1, 1, 1, 1),) mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) mp(*example_inputs) mc = convert_fx(copy.deepcopy(mp)) mp.apply(torch.ao.quantization.disable_observer) ref_fp32 = mp(*example_inputs) ref_int8 = mc(*example_inputs) mc_shadows_mp = add_shadow_loggers('int8', mc, 'fp32', mp, OutputLogger) ref_shadow = mc_shadows_mp(*example_inputs) self.assertEqual(ref_fp32, ref_shadow) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_extract_weights_cuda(self): # Note: this is not using quantization because quantized kernels do not # work on cuda yet. m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() results = extract_weights('a', m1, 'b', m2) extend_logger_results_with_comparison( results, 'a', 'b', compute_sqnr, 'sqnr') self.assert_ns_compare_dict_valid(results) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_add_loggers_cuda(self): # Note: this is not using quantization because quantized kernels do not # work on cuda yet. m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() m1_ns, m2_ns = add_loggers('a', m1, 'b', m2, OutputLogger) datum = torch.randn(1, 1, 1, 1) datum = datum.cuda() m1_ns(datum) m2_ns(datum) act_compare_dict = extract_logger_info(m1_ns, m2_ns, OutputLogger, 'b') extend_logger_results_with_comparison( act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr') @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_add_shadow_loggers_cuda(self): # Note: this is not using quantization because quantized kernels do not # work on cuda yet. m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() m1_shadows_m2 = add_shadow_loggers('a', m1, 'b', m2, OutputLogger) datum = torch.randn(1, 1, 1, 1) datum = datum.cuda() m1_shadows_m2(datum) act_compare_dict = extract_shadow_logger_info(m1_shadows_m2, OutputLogger, 'b') extend_logger_results_with_comparison( act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr') def test_fp16_shadows_fp32(self): m = LinearReluFunctional().eval() example_inputs = (torch.randn(1, 4),) qconfig_dict = {"": torch.ao.quantization.float16_static_qconfig} mp = prepare_fx(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) mq = convert_to_reference_fx(mp) mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger) def test_mul_add_cat_stack_skips_shadowing(self): class M(nn.Module): def forward(self, x): x = x * x x = torch.mul(x, x) x = x + x x = torch.add(x, x) x = torch.cat([x]) x = torch.stack([x]) return x m = M().eval() self._test_match_shadow_activations( m, (torch.randn(1, 1, 4, 4),), results_len=0) def test_op_with_only_kwargs_skips_shadowing(self): class M(nn.Module): def forward(self, x): x = torch.cat(tensors=[x]) x = torch.stack(tensors=[x]) return x m = M().eval() self._test_match_shadow_activations( m, (torch.randn(1, 1, 4, 4),), results_len=0) def test_unsupported_op_copy_skips_shadowing(self): """ Copying a `call_function` node is not implemented, test that this does not crash shadowing but instead skips the node. """ class M(nn.Module): def forward(self, x): # the second argument leads to attempting to copy a # call_function node x = F.layer_norm(x, x.shape[1:]) return x m = M().eval() self._test_match_shadow_activations( m, (torch.randn(1, 1, 4, 4),), results_len=0) def test_linear_kwargs_shadow(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.w1 = nn.Parameter(torch.empty(4, 4)) self.b1 = nn.Parameter(torch.zeros(4)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) def forward(self, x): x = F.linear(input=x, weight=self.w1, bias=self.b1) return x # note: FX graph mode quantization does not have good support # for kwargs-only right now, so we pass in two unquantized # models m = M().eval() mt = torch.fx.symbolic_trace(m) mt_copy = copy.deepcopy(mt) mt_shadows_mt_copy = add_shadow_loggers( 'a', mt, 'b', mt_copy, OutputLogger) mt_shadows_mt_copy(torch.randn(4, 4)) act_compare_dict = extract_shadow_logger_info( mt_shadows_mt_copy, OutputLogger, 'b') self.assertTrue(len(act_compare_dict) == 1) @skipIfNoQNNPACK class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase): """ Tests the "n shadows" workflow. """ def _test_impl(self, m, example_input, qconfig_mappings): backend_config = get_native_backend_config() # test that input is valid _ = m(*example_input) msp = prepare_n_shadows_model( m, example_input, qconfig_mappings, backend_config) # print('msp', msp) for _ in range(2): msp(*example_input) msq = convert_n_shadows_model(msp) loggers_set_enabled(msq, True) msq(*example_input) results = extract_results_n_shadows_model(msq) print_comparisons_n_shadows_model(results) return msq @withQNNPACKBackend def test_linear_mod(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(2, 2) def forward(self, x): x = self.fc1(x) return x m = M().eval() example_input = (torch.randn(2, 2),) qconfig_mappings = \ QConfigMultiMapping().set_global([torch.ao.quantization.default_qconfig]) self._test_impl(m, example_input, qconfig_mappings) @withQNNPACKBackend def test_linear_relu_mod(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(2, 2) self.fc2 = nn.Linear(2, 2) self.relu = nn.ReLU() def forward(self, x): x = self.fc1(x) x = self.fc2(x) x = self.relu(x) return x m = M().eval() example_input = (torch.randn(2, 2),) qconfig_mappings = ( QConfigMultiMapping().set_global([ torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig ]) ) self._test_impl(m, example_input, qconfig_mappings) @withQNNPACKBackend def test_conv_bn_relu_mod(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 1) self.bn = nn.BatchNorm2d(1) self.relu = nn.ReLU() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x m = M().eval() example_input = (torch.randn(32, 1, 16, 16),) qconfig_mappings = QConfigMultiMapping() \ .set_global([ torch.ao.quantization.default_qconfig, torch.ao.quantization.default_per_channel_qconfig ]) self._test_impl(m, example_input, qconfig_mappings) @withQNNPACKBackend def test_functions(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.w1 = nn.Parameter(torch.randn(2, 2)) self.b1 = nn.Parameter(torch.zeros(2)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) def forward(self, x): x = F.sigmoid(x) x = F.linear(x, self.w1, self.b1) x = F.linear(x, self.w1[:], self.b1) x = F.relu(x) x = x + x x = torch.cat([x]) x = torch.cat((x,)) x = torch.cat(tensors=[x]) # TODO(future PR): enable layernorm # blocked on FX graph mode quant not inserting observer for # second arg, if the second arg is a module input # x = F.layer_norm(x, x.shape) # x = F.layer_norm(x, x.shape[1:]) # x = x.reshape(1, -1) * 2 # x = F.layer_norm(x.reshape(1, -1), x.shape[1:]) x = torch.matmul(x, x.reshape(2, 2)) x = torch.matmul(x.reshape(2, 2), x.reshape(2, 2)) # TODO(future PR): enable below after FX graph mode quantization handles # it, currently this is not supported # x = F.linear(input=x, weight=self.w1, bias=self.b1) return x m = M().eval() example_input = (torch.randn(2, 2),) qconfig_mappings = QConfigMultiMapping() \ .set_global([torch.ao.quantization.default_qconfig]) self._test_impl(m, example_input, qconfig_mappings) @withQNNPACKBackend def test_partial_qconfig_mapping(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.fc = nn.Linear(2, 2) self.w1 = nn.Parameter(torch.randn(2, 2)) self.b1 = nn.Parameter(torch.randn(2)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) def forward(self, x): x = self.fc(x) x = F.linear(x, self.w1, self.b1) x = F.relu(x) x = x + x return x m = M().eval() example_input = (torch.randn(2, 2),) qconfig = torch.ao.quantization.default_qconfig qconfig_mappings = QConfigMultiMapping() \ .set_object_type(F.linear, [qconfig]) \ .set_object_type(F.relu, [qconfig]) self._test_impl(m, example_input, qconfig_mappings) @withQNNPACKBackend def test_logger_enabled_and_save_activations_flags(self): m = nn.Sequential(nn.Linear(1, 1)).eval() example_input = (torch.randn(1, 1),) qconfig_mappings = QConfigMultiMapping() \ .set_global([torch.ao.quantization.default_qconfig]) backend_config = get_native_backend_config() msp = prepare_n_shadows_model( m, example_input, qconfig_mappings, backend_config) for _ in range(2): msp(*example_input) def _check_logger_count(model, exp_count_stats, exp_count_comparisons): for name, mod in model.named_modules(): if isinstance(mod, OutputLogger): self.assertTrue( len(mod.stats) == exp_count_stats, f'stats: expected {len(mod.stats)} to equal {exp_count_stats}') if isinstance(mod, OutputComparisonLogger): self.assertTrue( len(mod.comparisons) == exp_count_comparisons, f'comparisons: expected {len(mod.comparisons)} to equal {exp_count_comparisons}') # check behavior with save_activations enabled msq = convert_n_shadows_model(copy.deepcopy(msp)) loggers_set_enabled(msq, True) loggers_set_save_activations(msq, True) # after prepare calibration but before convert calibration, loggers # should not have anything saved _check_logger_count(msq, 0, 0) msq(*example_input) # loggers should save each item after calibration _check_logger_count(msq, 1, 1) # check behavior with save_activations disabled msq = convert_n_shadows_model(copy.deepcopy(msp)) loggers_set_enabled(msq, True) loggers_set_save_activations(msq, False) # after prepare calibration but before convert calibration, loggers # should not have anything saved _check_logger_count(msq, 0, 0) msq(*example_input) # stats should be empty, but comparisons should be there _check_logger_count(msq, 0, 1) @skipIfTorchDynamo("too slow") @skip_if_no_torchvision @withQNNPACKBackend def test_mobilenet_v2(self): import torchvision m = torchvision.models.quantization.mobilenet_v2( pretrained=False, quantize=False).eval() example_input = (torch.randn(1, 3, 224, 224),) qconfig_mappings = QConfigMultiMapping() \ .set_global([torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig]) self._test_impl(m, example_input, qconfig_mappings) @withQNNPACKBackend def test_qconfig_multi_mapping_deduplication(self): # check that insertion deduplicates qconfigs qconfig_multi_mapping = QConfigMultiMapping().set_global( [torch.ao.quantization.default_qconfig, torch.ao.quantization.default_qconfig] ) self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 1) @withQNNPACKBackend def test_qconfig_multi_mapping_insert_padding(self): # test that inserting a higher priority qconfig style with fewer elements than a lower priority qconfig will # result in adding None to the extra QConfigMappings at that same style+key qconfig_multi_mapping = ( QConfigMultiMapping() .set_global( [ torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig, ] ) .set_object_type(torch.nn.Linear, [torch.ao.quantization.default_qconfig]) .set_module_name_regex("fc", [torch.ao.quantization.default_qconfig]) .set_module_name("fc2", [torch.ao.quantization.default_qconfig]) .set_module_name_object_type_order( "", nn.Linear, 0, [torch.ao.quantization.default_qconfig] ) ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[ torch.nn.Linear ], None, ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[ "fc" ], None, ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], None, ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[ 1 ].module_name_object_type_order_qconfigs[("", nn.Linear, 0)], None, ) @withQNNPACKBackend def test_qconfig_multi_mapping_retroactive_padding(self): # test that inserting a lower priority qconfig style with more elements thhan lower priority qconfig styles # will result in the new QConfigMapping having None at all previously existing styles+keys qconfig_multi_mapping = ( QConfigMultiMapping() .set_object_type(torch.nn.Linear, [torch.ao.quantization.default_qconfig]) .set_module_name_regex("fc", [torch.ao.quantization.default_qconfig]) .set_module_name("fc2", [torch.ao.quantization.default_qconfig]) .set_module_name_object_type_order( "", nn.Linear, 0, [torch.ao.quantization.default_qconfig] ) .set_global( [ torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig, ] ) ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[ torch.nn.Linear ], None, ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[ "fc" ], None, ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], None, ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[ 1 ].module_name_object_type_order_qconfigs[("", nn.Linear, 0)], None, ) @withQNNPACKBackend def test_qconfig_multi_mapping_end_to_end(self): # test that the prepare/convert_n_shadows_model works as expected # with qconfig_multi_mapping and avoids unwanted matches m = TwoLayerLinearModel().eval() example_input = m.get_example_inputs() qconfig_multi_mapping = ( QConfigMultiMapping() .set_global( [ torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig, ] ) .set_module_name("fc2", [None, torch.ao.quantization.default_qconfig]) ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], None, ) msq = self._test_impl(m, example_input, qconfig_multi_mapping) self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0) self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2) @withQNNPACKBackend def test_qconfig_multi_mapping_from_list(self): # test QConfigMultiMapping.from_list_qconfig_mapping works as expected m = TwoLayerLinearModel().eval() example_input = m.get_example_inputs() qconfig_mappings_list = [ QConfigMapping().set_global(torch.ao.quantization.default_qconfig), QConfigMapping() .set_global(torch.ao.quantization.default_dynamic_qconfig) .set_module_name("fc2", torch.ao.quantization.default_qconfig), ] qconfig_multi_mapping = QConfigMultiMapping().from_list_qconfig_mapping( qconfig_mappings_list ) self.assertEqual( qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], None, ) msq = self._test_impl(m, example_input, qconfig_multi_mapping) self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0) self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2) @withQNNPACKBackend def test_qconfig_multi_mapping_ordering(self): # test that the module ordering ignores None m = TwoLayerLinearModel().eval() example_input = m.get_example_inputs() qconfig_multi_mapping = ( QConfigMultiMapping() .set_global( [ torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig, ] ) .set_module_name( "fc2", [ None, torch.ao.quantization.default_dynamic_qconfig, torch.ao.quantization.default_qat_qconfig_v2, ], ) ) self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 2) msq = self._test_impl(m, example_input, qconfig_multi_mapping) self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) self.checkDynamicQuantizedLinear(msq.shadow_wrapper_1_1.mod_0, torch.qint8) self.checkQuantizedLinear(msq.shadow_wrapper_1_2.mod_0) @withQNNPACKBackend def test_qconfig_multi_mapping_repr(self): qconfig_multi_mapping = ( QConfigMultiMapping() .set_global( [ torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig, ] ) .set_module_name( "fc2", [ None, torch.ao.quantization.default_dynamic_qconfig, torch.ao.quantization.default_qat_qconfig_v2, ], ) ) self.assertTrue(isinstance(qconfig_multi_mapping.__repr__(), str)) @withQNNPACKBackend def test_custom_functions_and_tracer(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(2, 2) self.fc2 = nn.Linear(2, 2) def forward(self, x): x = self.fc1(x) x = self.fc2(x) return x m = M().eval() example_inputs = (torch.randn(2, 2),) qconfig_mappings = QConfigMultiMapping().set_global( [torch.ao.quantization.default_qat_qconfig] ) custom_tracer = torch.ao.quantization.quantize_fx.QuantizationTracer( ["fc2"], [] ) custom_prepare_fn = torch.ao.quantization.quantize_fx.prepare_qat_fx def custom_convert_fn(module, to_print): print(to_print) mod = torch.ao.quantization.quantize_fx.convert_fx(module) return mod backend_config = get_native_backend_config() # test that input is valid _ = m(*example_inputs) kwargs = {"to_print": "working"} msp = prepare_n_shadows_model( m, example_inputs, qconfig_mappings, backend_config, custom_prepare_fn=custom_prepare_fn, custom_prepare_kwargs=None, custom_tracer=custom_tracer, ) for _ in range(2): msp(*example_inputs) msq = convert_n_shadows_model( msp, custom_convert_fn=custom_convert_fn, custom_convert_kwargs=kwargs ) print(msq) loggers_set_enabled(msq, True) msq(*example_inputs) results = extract_results_n_shadows_model(msq) print_comparisons_n_shadows_model(results) def _test_extract_weights_impl(self, m, example_input, qconfig_mapping): backend_config = get_native_backend_config() results = _n_shadows_compare_weights( m, example_input, qconfig_mapping, backend_config) print_comparisons_n_shadows_model(results) @withQNNPACKBackend def test_extract_weights_linear(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.w1 = nn.Parameter(torch.randn(2, 2)) self.b1 = nn.Parameter(torch.randn(2)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) self.w2 = nn.Parameter(torch.randn(2, 2)) self.b2 = nn.Parameter(torch.randn(2)) torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) self.w3 = nn.Parameter(torch.randn(2, 2)) self.b3 = nn.Parameter(torch.randn(2)) torch.nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) self.w4 = nn.Parameter(torch.randn(2, 2)) self.b4 = nn.Parameter(torch.randn(2)) torch.nn.init.kaiming_uniform_(self.w4, a=math.sqrt(5)) def forward(self, x): x = F.linear(x, self.w1, self.b1) x = F.linear(x, self.w2, self.b2) x = F.relu(x) x = F.linear(x, self.w3, self.b3) x = F.linear(x, self.w4, self.b4) return x per_tensor_qconfig = torch.ao.quantization.default_qconfig m = M().eval() example_input = (torch.randn(2, 2),) qconfig_mapping = get_default_qconfig_mapping() # test unquantized qconfig_mapping.set_module_name_object_type_order( '', F.linear, 2, None) # test per-tensor qconfig_mapping.set_module_name_object_type_order( '', F.linear, 3, per_tensor_qconfig) self._test_extract_weights_impl(m, example_input, qconfig_mapping) def _test_add_loggers_impl(self, m, example_input, qconfig_mapping): backend_config = get_native_backend_config() m_copy = copy.deepcopy(m) # test that input is valid _ = m(*example_input) msp = _prepare_n_shadows_add_loggers_model( m, example_input, qconfig_mapping, backend_config) # print('msp', msp) msp(*example_input) msq = convert_n_shadows_model(msp) # print('msq', msq) loggers_set_enabled(msq, True) output_fp32 = msq(*example_input) results = extract_results_n_shadows_model(msq) # print(results) # print_comparisons_n_shadows_model(results) # get the last quantized output from results inner_results = results['model']['node_output'] last_subgraph = list(inner_results.keys())[-1] output_shadow = inner_results[last_subgraph][0]['values'][-1] # verify that both fp32 and quantized output matches reference output_fp32_ref = m_copy(*example_input) mp_ref = prepare_fx(m_copy, qconfig_mapping, example_input) for _ in range(2): mp_ref(*example_input) mq_ref = convert_fx(mp_ref) output_shadow_ref = mq_ref(*example_input) self.assertTrue( torch.allclose(output_fp32, output_fp32_ref), f"fp32 comparison: {output_fp32} not close to {output_fp32_ref}") # print('shadow', output_shadow.shape, output_shadow) # print('shadow_ref', output_shadow_ref.shape, output_shadow_ref) self.assertTrue( torch.allclose(output_shadow, output_shadow_ref), f"shadow comparison: {output_shadow} not close to {output_shadow_ref}") return msq @withQNNPACKBackend def test_add_loggers_linear_mod_quant_quant(self): m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) example_input = (torch.randn(2, 2),) qconfig_mapping = get_default_qconfig_mapping() self._test_add_loggers_impl(m, example_input, qconfig_mapping) @withQNNPACKBackend def test_add_loggers_linear_mod_fp32_quant(self): m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) example_input = (torch.randn(2, 2),) qconfig_mapping = get_default_qconfig_mapping() qconfig_mapping.set_module_name('0', None) self._test_add_loggers_impl(m, example_input, qconfig_mapping) @withQNNPACKBackend def test_add_loggers_linear_mod_quant_fp32(self): m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) example_input = (torch.randn(2, 2),) qconfig_mapping = get_default_qconfig_mapping() qconfig_mapping.set_module_name('1', None) self._test_add_loggers_impl(m, example_input, qconfig_mapping) @withQNNPACKBackend def test_add_loggers_linear_mod_fp32_fp32(self): m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) example_input = (torch.randn(2, 2),) qconfig_mapping = get_default_qconfig_mapping() qconfig_mapping.set_module_name('0', None) qconfig_mapping.set_module_name('1', None) self._test_add_loggers_impl(m, example_input, qconfig_mapping) @withQNNPACKBackend def test_add_loggers_conv_bn_relu_fusion_quant(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1), nn.ReLU()) m.eval() example_input = (torch.randn(16, 1, 4, 4),) qconfig_mapping = get_default_qconfig_mapping() self._test_add_loggers_impl(m, example_input, qconfig_mapping) @withQNNPACKBackend def test_add_loggers_conv_bn_relu_fusion_fp32(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1), nn.ReLU()) m.eval() example_input = (torch.randn(16, 1, 4, 4),) qconfig_mapping = get_default_qconfig_mapping() qconfig_mapping.set_module_name('0', None) qconfig_mapping.set_module_name('1', None) qconfig_mapping.set_module_name('2', None) self._test_add_loggers_impl(m, example_input, qconfig_mapping) @withQNNPACKBackend def test_add_loggers_functions(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.w1 = nn.Parameter(torch.randn(2, 2)) self.b1 = nn.Parameter(torch.randn(2)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) def forward(self, x): x = F.linear(x, self.w1, self.b1) x = F.relu(x) x = x + x x = x + 1 # TODO(future PR): support first arg being a scalar # x = 1 + x x = torch.cat([x, x]) x = torch.cat([x, x]) x = torch.cat(tensors=[x, x]) # function not matchable by quantization x = torch.nn.functional.rrelu(x) x = F.linear(x, self.w1, self.b1) return x m = M().eval() example_input = (torch.randn(16, 2),) for qconfig_mapping in ( get_default_qconfig_mapping(), QConfigMapping(), ): self._test_add_loggers_impl(m, example_input, qconfig_mapping) @skipIfTorchDynamo("too slow") @skip_if_no_torchvision @withQNNPACKBackend def test_add_loggers_mobilenet_v2(self): import torchvision m = torchvision.models.quantization.mobilenet_v2( pretrained=False, quantize=False).eval() example_input = (torch.randn(8, 3, 224, 224),) qconfig_mapping = get_default_qconfig_mapping() self._test_add_loggers_impl(m, example_input, qconfig_mapping) class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase): """ Tests numeric suite core APIs on non-toy models. """ @skipIfNoFBGEMM def test_compare_weights_conv(self): test_cases = ( (ConvModel(),), (ConvBnModel(),), (ConvBnReLUModel(),), ) for m, in test_cases: m.eval() example_inputs = (torch.randn(1, 3, 5, 5),) self._test_extract_weights(m, example_inputs, results_len=1) @skipIfNoFBGEMM def test_compare_weights_linear(self): test_cases = ( (SingleLayerLinearModel(), None), ( SingleLayerLinearDynamicModel(), {"object_type": [(nn.Linear, default_dynamic_qconfig)]}, ), ) for m, qconfig_dict in test_cases: m.eval() example_inputs = (torch.randn(1, 3, 5, 5),) res = self._test_extract_weights( m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_compare_weights_lstm_dynamic(self): qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) example_inputs = (lstm_input, lstm_hidden) m = LSTMwithHiddenDynamicModel().eval() res = self._test_extract_weights( m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_compare_activations_conv(self): test_cases = ( (ConvModel(),), (ConvBnModel(),), (ConvBnReLUModel(),), ) for m, in test_cases: m.eval() res = self._test_match_activations( m, (torch.randn(1, 3, 4, 4),), results_len=1) @skipIfNoFBGEMM def test_compare_activations_linear(self): test_cases = ( (SingleLayerLinearModel(), None), ( SingleLayerLinearDynamicModel(), {"object_type": [(nn.Linear, default_dynamic_qconfig)]}, ), ) for m, qconfig_dict in test_cases: m.eval() res = self._test_match_activations( m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_compare_activations_lstm_dynamic(self): qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} m = LSTMwithHiddenDynamicModel().eval() lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) # TODO(future PR): enable scripting (quant prepared LSTM not scriptable) res = self._test_match_activations( m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict, skip_scripting=True) @skipIfNoFBGEMM def test_compare_shadow_activations_conv(self): test_cases = ( (ConvModel(),), (ConvBnModel(),), (ConvBnReLUModel(),), ) for m, in test_cases: m.eval() res = self._test_match_shadow_activations( m, (torch.randn(1, 3, 4, 4),), results_len=1) @skipIfNoFBGEMM def test_compare_shadow_activations_linear(self): test_cases = ( (SingleLayerLinearModel(), None), ( SingleLayerLinearDynamicModel(), {"object_type": [(nn.Linear, default_dynamic_qconfig)]}, ), ) for m, qconfig_dict in test_cases: m.eval() res = self._test_match_shadow_activations( m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_compare_shadow_activations_lstm_dynamic(self): qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} m = LSTMwithHiddenDynamicModel().eval() lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) # TODO(future PR): enable scripting (quant prepared LSTM not scriptable) res = self._test_match_shadow_activations( m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict, skip_scripting=True) @skipIfNoFBGEMM def test_sparsenn_compare_activations(self): for should_log_inputs in (True, False): sparse_nn = SparseNNModel().eval() idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) offsets = torch.LongTensor([0, 4]) x = torch.randn(2, 4) self._test_match_activations( sparse_nn, (idx, offsets, x), results_len=5, should_log_inputs=should_log_inputs) @skipIfNoFBGEMM def test_sparsenn_shadow(self): for should_log_inputs in (True, False): sparse_nn = SparseNNModel().eval() idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) offsets = torch.LongTensor([0, 4]) x = torch.randn(2, 4) self._test_match_shadow_activations( sparse_nn, (idx, offsets, x), results_len=3, should_log_inputs=should_log_inputs) @skipIfTorchDynamo("too slow") @skip_if_no_torchvision @skipIfNoFBGEMM def test_resnet18(self): import torchvision m = torchvision.models.quantization.resnet18(pretrained=False, quantize=False).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} self._test_match_shadow_activations( m, (torch.randn(1, 3, 224, 224),), qconfig_dict=qconfig_dict, should_log_inputs=False) @skipIfTorchDynamo("too slow") @skip_if_no_torchvision @skipIfNoFBGEMM def test_mobilenet_v2(self): import torchvision m = torchvision.models.quantization.mobilenet_v2(pretrained=False, quantize=False).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} self._test_match_shadow_activations( m, (torch.randn(1, 3, 224, 224),), qconfig_dict=qconfig_dict, should_log_inputs=False)