# Owner(s): ["module: unknown"] import copy import logging from typing import List import torch import torch.nn as nn import torch.nn.functional as F from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import ( ActivationSparsifier, ) from torch.ao.pruning.sparsifier.utils import module_to_fqn from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3) self.conv2 = nn.Conv2d(32, 32, kernel_size=3) self.identity1 = nn.Identity() self.max_pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.linear1 = nn.Linear(4608, 128) self.identity2 = nn.Identity() self.linear2 = nn.Linear(128, 10) def forward(self, x): out = self.conv1(x) out = self.conv2(out) out = self.identity1(out) out = self.max_pool1(out) batch_size = x.shape[0] out = out.reshape(batch_size, -1) out = F.relu(self.identity2(self.linear1(out))) out = self.linear2(out) return out class TestActivationSparsifier(TestCase): def _check_constructor(self, activation_sparsifier, model, defaults, sparse_config): """Helper function to check if the model, defaults and sparse_config are loaded correctly in the activation sparsifier """ sparsifier_defaults = activation_sparsifier.defaults combined_defaults = {**defaults, "sparse_config": sparse_config} # more keys are populated in activation sparsifier (eventhough they may be None) assert len(combined_defaults) <= len(activation_sparsifier.defaults) for key, config in sparsifier_defaults.items(): # all the keys in combined_defaults should be present in sparsifier defaults assert config == combined_defaults.get(key, None) def _check_register_layer( self, activation_sparsifier, defaults, sparse_config, layer_args_list ): """Checks if layers in the model are correctly mapped to it's arguments. Args: activation_sparsifier (sparsifier object) activation sparsifier object that is being tested. defaults (Dict) all default config (except sparse_config) sparse_config (Dict) default sparse config passed to the sparsifier layer_args_list (list of tuples) Each entry in the list corresponds to the layer arguments. First entry in the tuple corresponds to all the arguments other than sparse_config Second entry in the tuple corresponds to sparse_config """ # check args data_groups = activation_sparsifier.data_groups assert len(data_groups) == len(layer_args_list) for layer_args in layer_args_list: layer_arg, sparse_config_layer = layer_args # check sparse config sparse_config_actual = copy.deepcopy(sparse_config) sparse_config_actual.update(sparse_config_layer) name = module_to_fqn(activation_sparsifier.model, layer_arg["layer"]) assert data_groups[name]["sparse_config"] == sparse_config_actual # assert the rest other_config_actual = copy.deepcopy(defaults) other_config_actual.update(layer_arg) other_config_actual.pop("layer") for key, value in other_config_actual.items(): assert key in data_groups[name] assert value == data_groups[name][key] # get_mask should raise error with self.assertRaises(ValueError): activation_sparsifier.get_mask(name=name) def _check_pre_forward_hook(self, activation_sparsifier, data_list): """Registering a layer attaches a pre-forward hook to that layer. This function checks if the pre-forward hook works as expected. Specifically, checks if the input is aggregated correctly. Basically, asserts that the aggregate of input activations is the same as what was computed in the sparsifier. Args: activation_sparsifier (sparsifier object) activation sparsifier object that is being tested. data_list (list of torch tensors) data input to the model attached to the sparsifier """ # can only check for the first layer data_agg_actual = data_list[0] model = activation_sparsifier.model layer_name = module_to_fqn(model, model.conv1) agg_fn = activation_sparsifier.data_groups[layer_name]["aggregate_fn"] for i in range(1, len(data_list)): data_agg_actual = agg_fn(data_agg_actual, data_list[i]) assert "data" in activation_sparsifier.data_groups[layer_name] assert torch.all( activation_sparsifier.data_groups[layer_name]["data"] == data_agg_actual ) return data_agg_actual def _check_step(self, activation_sparsifier, data_agg_actual): """Checks if .step() works as expected. Specifically, checks if the mask is computed correctly. Args: activation_sparsifier (sparsifier object) activation sparsifier object that is being tested. data_agg_actual (torch tensor) aggregated torch tensor """ model = activation_sparsifier.model layer_name = module_to_fqn(model, model.conv1) assert layer_name is not None reduce_fn = activation_sparsifier.data_groups[layer_name]["reduce_fn"] data_reduce_actual = reduce_fn(data_agg_actual) mask_fn = activation_sparsifier.data_groups[layer_name]["mask_fn"] sparse_config = activation_sparsifier.data_groups[layer_name]["sparse_config"] mask_actual = mask_fn(data_reduce_actual, **sparse_config) mask_model = activation_sparsifier.get_mask(layer_name) assert torch.all(mask_model == mask_actual) for config in activation_sparsifier.data_groups.values(): assert "data" not in config def _check_squash_mask(self, activation_sparsifier, data): """Makes sure that squash_mask() works as usual. Specifically, checks if the sparsifier hook is attached correctly. This is achieved by only looking at the identity layers and making sure that the output == layer(input * mask). Args: activation_sparsifier (sparsifier object) activation sparsifier object that is being tested. data (torch tensor) dummy batched data """ # create a forward hook for checking output == layer(input * mask) def check_output(name): mask = activation_sparsifier.get_mask(name) features = activation_sparsifier.data_groups[name].get("features") feature_dim = activation_sparsifier.data_groups[name].get("feature_dim") def hook(module, input, output): input_data = input[0] if features is None: assert torch.all(mask * input_data == output) else: for feature_idx in range(0, len(features)): feature = torch.Tensor( [features[feature_idx]], device=input_data.device ).long() inp_data_feature = torch.index_select( input_data, feature_dim, feature ) out_data_feature = torch.index_select( output, feature_dim, feature ) assert torch.all( mask[feature_idx] * inp_data_feature == out_data_feature ) return hook for name, config in activation_sparsifier.data_groups.items(): if "identity" in name: config["layer"].register_forward_hook(check_output(name)) activation_sparsifier.model(data) def _check_state_dict(self, sparsifier1): """Checks if loading and restoring of state_dict() works as expected. Basically, dumps the state of the sparsifier and loads it in the other sparsifier and checks if all the configuration are in line. This function is called at various times in the workflow to makes sure that the sparsifier can be dumped and restored at any point in time. """ state_dict = sparsifier1.state_dict() new_model = Model() # create an empty new sparsifier sparsifier2 = ActivationSparsifier(new_model) assert sparsifier2.defaults != sparsifier1.defaults assert len(sparsifier2.data_groups) != len(sparsifier1.data_groups) sparsifier2.load_state_dict(state_dict) assert sparsifier2.defaults == sparsifier1.defaults for name, state in sparsifier2.state.items(): assert name in sparsifier1.state mask1 = sparsifier1.state[name]["mask"] mask2 = state["mask"] if mask1 is None: assert mask2 is None else: assert type(mask1) == type(mask2) if isinstance(mask1, List): assert len(mask1) == len(mask2) for idx in range(len(mask1)): assert torch.all(mask1[idx] == mask2[idx]) else: assert torch.all(mask1 == mask2) # make sure that the state dict is stored as torch sparse for state in state_dict["state"].values(): mask = state["mask"] if mask is not None: if isinstance(mask, List): for idx in range(len(mask)): assert mask[idx].is_sparse else: assert mask.is_sparse dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups for layer_name, config in dg1.items(): assert layer_name in dg2 # exclude hook and layer config1 = { key: value for key, value in config.items() if key not in ["hook", "layer"] } config2 = { key: value for key, value in dg2[layer_name].items() if key not in ["hook", "layer"] } assert config1 == config2 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_activation_sparsifier(self): """Simulates the workflow of the activation sparsifier, starting from object creation till squash_mask(). The idea is to check that everything works as expected while in the workflow. """ # defining aggregate, reduce and mask functions def agg_fn(x, y): return x + y def reduce_fn(x): return torch.mean(x, dim=0) def _vanilla_norm_sparsifier(data, sparsity_level): r"""Similar to data norm sparsifier but block_shape = (1,1). Simply, flatten the data, sort it and mask out the values less than threshold """ data_norm = torch.abs(data).flatten() _, sorted_idx = torch.sort(data_norm) threshold_idx = round(sparsity_level * len(sorted_idx)) sorted_idx = sorted_idx[:threshold_idx] mask = torch.ones_like(data_norm) mask.scatter_(dim=0, index=sorted_idx, value=0) mask = mask.reshape(data.shape) return mask # Creating default function and sparse configs # default sparse_config sparse_config = {"sparsity_level": 0.5} defaults = {"aggregate_fn": agg_fn, "reduce_fn": reduce_fn} # simulate the workflow # STEP 1: make data and activation sparsifier object model = Model() # create model activation_sparsifier = ActivationSparsifier(model, **defaults, **sparse_config) # Test Constructor self._check_constructor(activation_sparsifier, model, defaults, sparse_config) # STEP 2: Register some layers register_layer1_args = { "layer": model.conv1, "mask_fn": _vanilla_norm_sparsifier, } sparse_config_layer1 = {"sparsity_level": 0.3} register_layer2_args = { "layer": model.linear1, "features": [0, 10, 234], "feature_dim": 1, "mask_fn": _vanilla_norm_sparsifier, } sparse_config_layer2 = {"sparsity_level": 0.1} register_layer3_args = { "layer": model.identity1, "mask_fn": _vanilla_norm_sparsifier, } sparse_config_layer3 = {"sparsity_level": 0.3} register_layer4_args = { "layer": model.identity2, "features": [0, 10, 20], "feature_dim": 1, "mask_fn": _vanilla_norm_sparsifier, } sparse_config_layer4 = {"sparsity_level": 0.1} layer_args_list = [ (register_layer1_args, sparse_config_layer1), (register_layer2_args, sparse_config_layer2), ] layer_args_list += [ (register_layer3_args, sparse_config_layer3), (register_layer4_args, sparse_config_layer4), ] # Registering.. for layer_args in layer_args_list: layer_arg, sparse_config_layer = layer_args activation_sparsifier.register_layer(**layer_arg, **sparse_config_layer) # check if things are registered correctly self._check_register_layer( activation_sparsifier, defaults, sparse_config, layer_args_list ) # check state_dict after registering and before model forward self._check_state_dict(activation_sparsifier) # check if forward pre hooks actually work # some dummy data data_list = [] num_data_points = 5 for _ in range(0, num_data_points): rand_data = torch.randn(16, 1, 28, 28) activation_sparsifier.model(rand_data) data_list.append(rand_data) data_agg_actual = self._check_pre_forward_hook(activation_sparsifier, data_list) # check state_dict() before step() self._check_state_dict(activation_sparsifier) # STEP 3: sparsifier step activation_sparsifier.step() # check state_dict() after step() and before squash_mask() self._check_state_dict(activation_sparsifier) # self.check_step() self._check_step(activation_sparsifier, data_agg_actual) # STEP 4: squash mask activation_sparsifier.squash_mask() self._check_squash_mask(activation_sparsifier, data_list[0]) # check state_dict() after squash_mask() self._check_state_dict(activation_sparsifier)