# Owner(s): ["oncall: quantization"] from .common import AOMigrationTestCase class TestAOMigrationQuantizationFx(AOMigrationTestCase): def test_function_import_quantize_fx(self): function_list = [ "_check_is_graph_module", "_swap_ff_with_fxff", "_fuse_fx", "QuantizationTracer", "_prepare_fx", "_prepare_standalone_module_fx", "fuse_fx", "Scope", "ScopeContextManager", "prepare_fx", "prepare_qat_fx", "_convert_fx", "convert_fx", "_convert_standalone_module_fx", ] self._test_function_import("quantize_fx", function_list) def test_function_import_fx(self): function_list = [ "prepare", "convert", "fuse", ] self._test_function_import("fx", function_list) def test_function_import_fx_graph_module(self): function_list = [ "FusedGraphModule", "ObservedGraphModule", "_is_observed_module", "ObservedStandaloneGraphModule", "_is_observed_standalone_module", "QuantizedGraphModule", ] self._test_function_import("fx.graph_module", function_list) def test_function_import_fx_pattern_utils(self): function_list = [ "QuantizeHandler", "_register_fusion_pattern", "get_default_fusion_patterns", "_register_quant_pattern", "get_default_quant_patterns", "get_default_output_activation_post_process_map", ] self._test_function_import("fx.pattern_utils", function_list) def test_function_import_fx_equalize(self): function_list = [ "reshape_scale", "_InputEqualizationObserver", "_WeightEqualizationObserver", "calculate_equalization_scale", "EqualizationQConfig", "input_equalization_observer", "weight_equalization_observer", "default_equalization_qconfig", "fused_module_supports_equalization", "nn_module_supports_equalization", "node_supports_equalization", "is_equalization_observer", "get_op_node_and_weight_eq_obs", "maybe_get_weight_eq_obs_node", "maybe_get_next_input_eq_obs", "maybe_get_next_equalization_scale", "scale_input_observer", "scale_weight_node", "scale_weight_functional", "clear_weight_quant_obs_node", "remove_node", "update_obs_for_equalization", "convert_eq_obs", "_convert_equalization_ref", "get_layer_sqnr_dict", "get_equalization_qconfig_dict", ] self._test_function_import("fx._equalize", function_list) def test_function_import_fx_quantization_patterns(self): function_list = [ "QuantizeHandler", "BinaryOpQuantizeHandler", "CatQuantizeHandler", "ConvReluQuantizeHandler", "LinearReLUQuantizeHandler", "BatchNormQuantizeHandler", "EmbeddingQuantizeHandler", "RNNDynamicQuantizeHandler", "DefaultNodeQuantizeHandler", "FixedQParamsOpQuantizeHandler", "CopyNodeQuantizeHandler", "CustomModuleQuantizeHandler", "GeneralTensorShapeOpQuantizeHandler", "StandaloneModuleQuantizeHandler", ] self._test_function_import( "fx.quantization_patterns", function_list, new_package_name="fx.quantize_handler", ) def test_function_import_fx_match_utils(self): function_list = ["_MatchResult", "MatchAllNode", "_is_match", "_find_matches"] self._test_function_import("fx.match_utils", function_list) def test_function_import_fx_prepare(self): function_list = ["prepare"] self._test_function_import("fx.prepare", function_list) def test_function_import_fx_convert(self): function_list = ["convert"] self._test_function_import("fx.convert", function_list) def test_function_import_fx_fuse(self): function_list = ["fuse"] self._test_function_import("fx.fuse", function_list) def test_function_import_fx_fusion_patterns(self): function_list = ["FuseHandler", "DefaultFuseHandler"] self._test_function_import( "fx.fusion_patterns", function_list, new_package_name="fx.fuse_handler", ) # we removed matching test for torch.quantization.fx.quantization_types # old: torch.quantization.fx.quantization_types # new: torch.ao.quantization.utils # both are valid, but we'll deprecate the old path in the future def test_function_import_fx_utils(self): function_list = [ "get_custom_module_class_keys", "get_linear_prepack_op_for_dtype", "get_qconv_prepack_op", "get_new_attr_name_with_prefix", "graph_module_from_producer_nodes", "assert_and_get_unique_device", "create_getattr_from_value", "all_node_args_have_no_tensors", "get_non_observable_arg_indexes_and_types", "maybe_get_next_module", ] self._test_function_import("fx.utils", function_list)