# Owner(s): ["module: unknown"] import copy import logging import random import torch from torch import nn from torch.ao.pruning._experimental.pruner import ( BaseStructuredSparsifier, FakeStructuredSparsity, FPGMPruner, LSTMSaliencyPruner, SaliencyPruner, ) from torch.nn.utils import parametrize from torch.testing._internal.common_pruning import ( Conv2dActivation, Conv2dBias, Conv2dPadBias, Conv2dPool, Conv2dPoolFlatten, Conv2dPoolFlattenFunctional, LinearActivation, LinearActivationFunctional, LinearBias, LSTMLayerNormLinearModel, LSTMLinearModel, rows_are_subset, SimpleConv2d, SimpleLinear, ) from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) DEVICES = { torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), } class SimplePruner(BaseStructuredSparsifier): def update_mask(self, module, tensor_name, **kwargs): getattr(module.parametrizations, tensor_name)[0].mask[1] = False class ImplementedPruner(BaseStructuredSparsifier): def update_mask(self, module, tensor_name, **kwargs): """Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning""" num_rows = len(module.parametrizations[tensor_name][0].mask) prune = random.sample(list(range(num_rows)), num_rows // 3) module.parametrizations[tensor_name][0].mask[prune] = False class BottomHalfLSTMPruner(BaseStructuredSparsifier): """ Pruner that will remove the bottom half of the rows. This is primarily meant for testing purposes """ def update_mask(self, module, tensor_name, **kwargs): for p in getattr(module.parametrizations, tensor_name): if isinstance(p, FakeStructuredSparsity): mask = p.mask masks = torch.split(mask, len(mask) // 4) for small in masks: num = len(small) small[num // 2 :] = False new_mask = torch.cat(masks) mask.data = new_mask.data class TestSaliencyPruner(TestCase): def test_saliency_pruner_update_mask(self): """Test that we prune out the row with the lowest saliency (first row)""" model = SimpleLinear() with torch.no_grad(): model.linear1.weight = nn.Parameter( torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]) ) pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}] pruner = SaliencyPruner({}) pruner.prepare(model, pruning_config) pruner.enable_mask_update = True pruner.step() pruned_model = pruner.prune() expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]]) pruned = pruned_model.linear1.weight assert expected.shape == pruned.shape assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() def test_lstm_saliency_pruner_update_mask(self): model = LSTMLinearModel( input_dim=2, hidden_dim=2, output_dim=2, num_layers=1, ) manual_weights = torch.Tensor( [[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]] ) with torch.no_grad(): model.lstm.weight_ih_l0 = nn.Parameter(manual_weights) model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights)) model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0]) model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0]) config = [ {"tensor_fqn": "lstm.weight_ih_l0"}, {"tensor_fqn": "lstm.weight_hh_l0"}, ] lstm_input = torch.ones((1, 2)) fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5}) fx_pruner.prepare(model, config) fx_pruner.enable_mask_update = True fx_pruner.step() model.eval() pruned_model = fx_pruner.prune() pruned_model.eval() # make sure both models run model(lstm_input) pruned_model(lstm_input) # make sure lowest saliency rows are pruned expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]]) pruned = model.lstm.weight_ih_l0 assert expected.shape == pruned.shape assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() expected = torch.Tensor([[2], [2], [-2], [-2]]) pruned = model.lstm.weight_hh_l0 assert expected.shape == pruned.shape assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() expected = torch.Tensor([2, 2, -2, -2]) for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]: assert expected.shape == pruned.shape assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() class TestBaseStructuredSparsifier(TestCase): def _check_pruner_prepared(self, model, pruner, device): for config in pruner.groups: module = config["module"] assert module.weight.device.type == device.type # Check mask exists assert config["tensor_fqn"] in pruner.state # Check parametrization exists and is correct assert parametrize.is_parametrized(module) assert hasattr(module, "parametrizations") # Assume that this is the 1st/only parametrization assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity def _check_pruner_valid_before_step(self, model, pruner, device): for config in pruner.groups: modules = [] if type(config["module"]) is tuple: modules.extend(config["module"]) else: module = config["module"] modules.append(module) for module in modules: assert module.weight.device.type == device.type assert module.parametrizations.weight[0].mask.dtype == torch.bool def _check_pruner_valid_after_step(self, model, pruner, mask, device): for config in pruner.groups: modules = [] if type(config["module"]) is tuple: modules.extend(config["module"]) else: module = config["module"] modules.append(module) for module in modules: assert module.weight.device.type == device.type total = module.parametrizations.weight[0].mask.numel() assert ( module.parametrizations.weight[0].mask.count_nonzero() == total - mask ) def _test_constructor_on_device(self, model, device): self.assertRaisesRegex( TypeError, "BaseStructuredSparsifier.*update_mask", BaseStructuredSparsifier, ) model1 = copy.deepcopy(model).to(device) pruner = SimplePruner(None) pruner.prepare(model1, None) pruner.enable_mask_update = True for g in pruner.groups: module = g["module"] assert module.weight.device.type == device.type assert len(pruner.groups) == 5 pruner.step() # Can instantiate the model with configs model2 = copy.deepcopy(model).to(device) pruner = SimplePruner({"test": 3}) pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}]) assert len(pruner.groups) == 1 assert pruner.groups[0]["module_fqn"] == "seq.0" assert "test" in pruner.groups[0] assert pruner.groups[0]["test"] == 3 def test_constructor(self): model = SimpleLinear() for device in DEVICES: self._test_constructor_on_device(model, torch.device(device)) def _test_prepare_linear_on_device(self, model, device): model = copy.deepcopy(model).to(device) x = torch.ones(128, 7, device=device) pruner = SimplePruner(None) pruner.prepare(model, None) self._check_pruner_prepared(model, pruner, device) assert model(x).shape == (128, 10) def test_prepare_linear(self): models = [ SimpleLinear(), LinearBias(), LinearActivation(), LinearActivationFunctional(), ] # without and with bias for device in DEVICES: for model in models: self._test_prepare_linear_on_device(model, torch.device(device)) def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device): x = torch.ones((1, 1, 28, 28), device=device) pruner = SimplePruner(None) pruner.prepare(model, config) self._check_pruner_prepared(model, pruner, device) assert model(x).shape == expected_shape def test_prepare_conv2d(self): models = [ SimpleConv2d(), Conv2dBias(), Conv2dActivation(), Conv2dPadBias(), Conv2dPool(), ] shapes = [ (1, 52, 20, 20), (1, 52, 18, 18), (1, 52, 18, 18), (1, 52, 24, 24), (1, 52, 3, 3), ] configs = [None, None, None, None, None] for device in DEVICES: for model, shape, config in zip(models, shapes, configs): model = model.to(device) self._test_prepare_conv2d_on_device( model, shape, config, torch.device(device) ) def _test_step_linear_on_device(self, model, device): model = model.to(device) x = torch.ones(7, 7, device=device) pruner = SimplePruner(None) pruner.prepare(model, None) pruner.enable_mask_update = True self._check_pruner_valid_before_step(model, pruner, device) pruner.step() self._check_pruner_valid_after_step(model, pruner, 1, device) def test_step_linear(self): models = [ SimpleLinear(), LinearBias(), LinearActivation(), LinearActivationFunctional(), ] for device in DEVICES: for model in models: self._test_step_linear_on_device(model, torch.device(device)) def _test_step_conv2d_on_device(self, model, expected_shape, config, device): model = model.to(device) x = torch.ones((1, 1, 28, 28), device=device) pruner = SimplePruner(None) pruner.prepare(model, config) pruner.enable_mask_update = True self._check_pruner_valid_before_step(model, pruner, device) pruner.step() self._check_pruner_valid_after_step(model, pruner, 1, device) assert model(x).shape == expected_shape @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_step_conv2d(self): models = [ SimpleConv2d(), Conv2dBias(), Conv2dActivation(), Conv2dPadBias(), Conv2dPool(), ] shapes = [ (1, 52, 20, 20), (1, 52, 18, 18), (1, 52, 18, 18), (1, 52, 24, 24), (1, 52, 3, 3), ] configs = [None, None, None, None, None] for device in DEVICES: for model, shape, config in zip(models, shapes, configs): self._test_step_conv2d_on_device( model, shape, config, torch.device(device) ) def _check_pruner_pruned(self, model, pruner, device): for config in pruner.groups: module = config["module"] assert not hasattr(module, "parametrizations") assert not hasattr(module, "mask") def _test_linear_on_device( self, model, config, expected_shape, device, also_prune_bias ): model = model.to(device) model.eval() num_original_params = sum(p.numel() for p in model.parameters()) x = torch.ones(128, 7, device=device) pruner = ImplementedPruner({"prune_bias": also_prune_bias}) pruner.prepare(model, config) pruner.enable_mask_update = True pruner.step() y_expected = model(x) assert y_expected.shape == (128, 10) self._check_pruner_prepared(model, pruner, device) # Pruning step pruned = pruner.prune() y_pruned = pruned(x) num_pruned_params = sum(p.numel() for p in pruned.parameters()) assert y_pruned.shape == expected_shape self._check_pruner_pruned(model, pruner, device) if y_pruned.shape == y_expected.shape: assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all() assert num_pruned_params < num_original_params def test_prune_linear_linear(self): r"""test pruning linear-> linear modules""" configs, shapes = [], [] configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.1.weight"}, {"tensor_fqn": "seq.2.weight"}, ] ) shapes.append((128, 10)) configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.1.weight"}, {"tensor_fqn": "seq.2.weight"}, {"tensor_fqn": "linear1.weight"}, ] ) shapes.append((128, 10)) configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.2.weight"}, ] ) shapes.append((128, 10)) for device in DEVICES: for also_prune_bias in [True, False]: for config, shape in zip(configs, shapes): self._test_linear_on_device( SimpleLinear(), config, shape, torch.device(device), also_prune_bias, ) def test_prune_linear_bias_linear(self): # linear(bias) -> linear(no bias) configs, shapes = [], [] configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.1.weight"}, ] ) shapes.append((128, 10)) # linear(bias) -> linear(bias) configs.append( [ {"tensor_fqn": "seq.2.weight"}, {"tensor_fqn": "seq.3.weight"}, ] ) shapes.append((128, 10)) # linear(no bias) -> linear(bias) configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.1.weight"}, {"tensor_fqn": "seq.2.weight"}, ] ) shapes.append((128, 10)) for device in DEVICES: for also_prune_bias in [True, False]: for config, shape in zip(configs, shapes): self._test_linear_on_device( LinearBias(), config, shape, torch.device(device), also_prune_bias, ) def test_prune_linear_activation_linear(self): config = [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.2.weight"}, {"tensor_fqn": "seq.4.weight"}, {"tensor_fqn": "linear1.weight"}, ] shape = (128, 10) for device in DEVICES: for also_prune_bias in [True, False]: # test version with nn.Modules self._test_linear_on_device( LinearActivation(), config, shape, torch.device(device), also_prune_bias, ) # test functional version self._test_linear_on_device( LinearActivationFunctional(), config, shape, torch.device(device), also_prune_bias, ) def _test_conv2d_on_device( self, model, config, x, expected_shape, device, also_prune_bias ): model = model.to(device) num_original_params = sum(p.numel() for p in model.parameters()) model.eval() pruner = ImplementedPruner({"prune_bias": also_prune_bias}) pruner.prepare(model, config) pruner.enable_mask_update = True pruner.step() y_expected = model(x) assert y_expected.shape == expected_shape self._check_pruner_prepared(model, pruner, device) # Fusion step pruned = pruner.prune() y_pruned = pruned(x) num_pruned_params = sum(p.numel() for p in pruned.parameters()) assert y_pruned.shape == expected_shape self._check_pruner_pruned(model, pruner, device) if y_pruned.shape == y_expected.shape: # TODO This rtol is a little high, need to double check if something specific is causing this to fail assert torch.isclose( y_expected, y_pruned, rtol=1e-3, atol=1e-3, ).all(), f"fail for {type(model)}" # only time this should be equal is when all layers have padding and we can't prune assert num_pruned_params <= num_original_params def test_prune_conv2d_conv2d(self): configs, shapes = [], [] # all within sequential blocks configs.append( [ {"tensor_fqn": "seq.0.weight"}, ] ) shapes.append((1, 52, 20, 20)) # prune across sequential blocks configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.1.weight"}, {"tensor_fqn": "conv2d1.weight"}, ] ) shapes.append((1, 52, 20, 20)) for device in DEVICES: x = torch.ones((1, 1, 28, 28), device=device) for also_prune_bias in [True, False]: for config, shape in zip(configs, shapes): self._test_conv2d_on_device( SimpleConv2d(), config, x, shape, torch.device(device), also_prune_bias, ) def test_prune_conv2d_bias_conv2d(self): # Conv2d with Bias and no Activation configs, shapes = [], [] # conv2d(bias) -> conv2d(bias) configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.1.weight"}, ] ) shapes.append((1, 52, 18, 18)) # conv2d(no bias) -> conv2d(bias) configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.1.weight"}, {"tensor_fqn": "conv2d1.weight"}, ] ) shapes.append((1, 52, 18, 18)) # conv2d(bias) -> conv2d(no bias) configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.1.weight"}, {"tensor_fqn": "seq.2.weight"}, ] ) shapes.append((1, 52, 18, 18)) for device in DEVICES: x = torch.ones((1, 1, 28, 28), device=device) for also_prune_bias in [True, False]: for config, shape in zip(configs, shapes): self._test_conv2d_on_device( Conv2dBias(), config, x, shape, torch.device(device), also_prune_bias, ) def test_prune_conv2d_activation_conv2d(self): # Conv2d with Activation and no Bias configs, shapes = [], [] # conv2d(no bias) -> activation -> conv2d(no bias) configs.append( [ {"tensor_fqn": "seq.4.weight"}, ] ) shapes.append((1, 52, 18, 18)) # conv2d(bias) -> activation -> conv2d(bias) configs.append( [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.2.weight"}, ] ) shapes.append((1, 52, 18, 18)) # conv2d(bias) -> activation -> conv2d(no bias) configs.append( [ {"tensor_fqn": "seq.2.weight"}, {"tensor_fqn": "seq.4.weight"}, ] ) shapes.append((1, 52, 18, 18)) # conv2d(no bias) -> activation -> conv2d(bias) configs.append( [ {"tensor_fqn": "conv2d1.weight"}, ] ) shapes.append((1, 52, 18, 18)) for device in DEVICES: x = torch.ones((1, 1, 28, 28), device=device) for also_prune_bias in [True, False]: for config, shape in zip(configs, shapes): self._test_conv2d_on_device( Conv2dActivation(), config, x, shape, torch.device(device), also_prune_bias, ) def test_prune_conv2d_padding_conv2d(self): # Conv2d with Padded layers after Bias layers configs, shapes = [], [] # conv(padded, bias) -> conv(padded, bias) configs.append( [ {"tensor_fqn": "seq.4.weight"}, ] ) shapes.append((1, 52, 24, 24)) # conv(no bias, no pad) -> conv(padded, bias) configs.append( [ {"tensor_fqn": "seq.2.weight"}, ] ) shapes.append((1, 52, 24, 24)) # conv(padded, bias) -> conv ( no bias ,no pad) configs.append( [ {"tensor_fqn": "seq.0.weight"}, ] ) shapes.append((1, 52, 24, 24)) # conv(pad, bias) -> conv(no pad, bias) configs.append( [ {"tensor_fqn": "seq.6.weight"}, ] ) shapes.append((1, 52, 24, 24)) # conv(no pad, bias) -> conv(pad, bias) configs.append( [ {"tensor_fqn": "seq.8.weight"}, ] ) shapes.append((1, 52, 24, 24)) for device in DEVICES: x = torch.ones((1, 1, 28, 28), device=device) for also_prune_bias in [True, False]: for config, shape in zip(configs, shapes): self._test_conv2d_on_device( Conv2dPadBias(), config, x, shape, torch.device(device), also_prune_bias, ) def test_prune_conv2d_pool_conv2d(self): # Conv2d with Pooling layers config = [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.3.weight"}, {"tensor_fqn": "conv2d1.weight"}, {"tensor_fqn": "conv2d2.weight"}, ] shape = (1, 52, 3, 3) for device in DEVICES: x = torch.ones((1, 1, 28, 28), device=device) for also_prune_bias in [True, False]: self._test_conv2d_on_device( Conv2dPool(), config, x, shape, torch.device(device), also_prune_bias, ) @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_complex_conv2d(self): """Test fusion for models that contain Conv2d & Linear modules. Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add""" config = [ {"tensor_fqn": "seq.0.weight"}, {"tensor_fqn": "seq.3.weight"}, {"tensor_fqn": "conv2d1.weight"}, {"tensor_fqn": "conv2d2.weight"}, ] shape = (1, 13) for device in DEVICES: x = torch.ones((1, 1, 28, 28), device=device) for also_prune_bias in [True, False]: self._test_conv2d_on_device( Conv2dPoolFlattenFunctional(), config, x, shape, torch.device(device), also_prune_bias, ) self._test_conv2d_on_device( Conv2dPoolFlatten(), config, x, shape, torch.device(device), also_prune_bias, ) def test_prune_lstm_linear_multiple_layer(self): """ Test fusion support for LSTM(multi-layer) -> Linear """ model = LSTMLinearModel( input_dim=8, hidden_dim=8, output_dim=8, num_layers=2, ) config = [ {"tensor_fqn": "lstm.weight_ih_l0"}, {"tensor_fqn": "lstm.weight_hh_l0"}, {"tensor_fqn": "lstm.weight_ih_l1"}, {"tensor_fqn": "lstm.weight_hh_l1"}, ] lstm_input = torch.ones((1, 8)) fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) fx_pruner.prepare(model, config) fx_pruner.enable_mask_update = True fx_pruner.step() model.eval() _, _ = model(lstm_input) pruned_model = fx_pruner.prune() pruned_model.eval() _, _ = pruned_model(lstm_input) expected_params = dict(model.named_parameters()) for name, param in model.named_parameters(): assert name in expected_params # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics # Instead we check that the weights of the new LSTM are a subset of the weights of # the old LSTM assert rows_are_subset(param, expected_params[name]) del expected_params[name] # assert we haven't deleted any keys assert len(expected_params) == 0 def test_prune_lstm_linear_single_layer(self): """ Test fusion support for LSTM (single-layer) -> Linear """ model = LSTMLinearModel( input_dim=8, hidden_dim=8, output_dim=8, num_layers=1, ) config = [ {"tensor_fqn": "lstm.weight_ih_l0"}, {"tensor_fqn": "lstm.weight_hh_l0"}, ] lstm_input = torch.ones((1, 8)) fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) fx_pruner.prepare(model, config) fx_pruner.enable_mask_update = True fx_pruner.step() model.eval() out_expected, lstm_out_expected = model(lstm_input) pruned_model = fx_pruner.prune() pruned_model.eval() out_pruned, lstm_out_pruned = pruned_model(lstm_input) r, c = lstm_out_expected.size() # We cannot check that y_expected == y_pruned as usual because # zeros vs. missing elements yield different numerical results. # Instead that we check that the pruned elements are the first half of the results # since we are using a BottomHalfLSTMPruner assert torch.isclose( lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07 ).all() # also check that output of linear is the same shape, this means we've resized # linear columns correctly. assert out_expected.shape == out_pruned.shape def test_prune_lstm_layernorm_linear_multiple_layer(self): """ Test fusion support for LSTM(multi-layer) -> Linear """ model = LSTMLayerNormLinearModel( input_dim=8, output_dim=8, hidden_dim=8, num_layers=2, ) config = [ {"tensor_fqn": "lstm.weight_ih_l0"}, {"tensor_fqn": "lstm.weight_hh_l0"}, {"tensor_fqn": "lstm.weight_ih_l1"}, {"tensor_fqn": "lstm.weight_hh_l1"}, ] lstm_input = torch.ones((1, 8)) fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) fx_pruner.prepare(model, config) fx_pruner.enable_mask_update = True fx_pruner.step() model.eval() _, _ = model(lstm_input) pruned_model = fx_pruner.prune() pruned_model.eval() _, _ = pruned_model(lstm_input) expected_params = dict(model.named_parameters()) for name, param in model.named_parameters(): assert name in expected_params # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics # Instead we check that the weights of the new LSTM are a subset of the weights of # the old LSTM assert rows_are_subset(param, expected_params[name]) del expected_params[name] # assert we haven't deleted any keys assert len(expected_params) == 0 def test_prune_lstm_layernorm_linear_single_layer(self): """ Test fusion support for LSTM (single-layer) -> Linear """ model = LSTMLinearModel( input_dim=8, hidden_dim=8, output_dim=8, num_layers=1, ) config = [ {"tensor_fqn": "lstm.weight_ih_l0"}, {"tensor_fqn": "lstm.weight_hh_l0"}, ] lstm_input = torch.ones((1, 8)) fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) fx_pruner.prepare(model, config) fx_pruner.enable_mask_update = True fx_pruner.step() model.eval() out_expected, lstm_out_expected = model(lstm_input) pruned_model = fx_pruner.prune() pruned_model.eval() out_pruned, lstm_out_pruned = pruned_model(lstm_input) r, c = lstm_out_expected.size() # We cannot check that y_expected == y_pruned as usual because # zeros vs. missing elements yield different numerical results. # Instead that we check that the pruned elements are the first half of the results # since we are using a BottomHalfLSTMPruner assert torch.isclose( lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07 ).all() # also check that output of linear is the same shape, this means we've resized # linear columns correctly. assert out_expected.shape == out_pruned.shape class TestFPGMPruner(TestCase): """ Test case for the implementation of paper: `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. """ class SimpleConvFPGM(nn.Module): def __init__(self) -> None: super().__init__() self.conv2d1 = nn.Conv2d( in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False ) # Manually set the filter weights for demonstration purposes """ Three filters' weight are manually set to values 3.0, 2.0, and 0.1. Different from the norm-based decision that prunes filter with value 0.1, FPGM will prune the one with value 2.0. """ weights = torch.tensor([3.0, 2.0, 0.1]) # Weight weights for each filter weights = weights[:, None, None, None] # broadcasting self.conv2d1.weight.data.copy_( torch.ones(self.conv2d1.weight.shape) * weights ) # Second Convolutional Layer self.conv2d2 = nn.Conv2d( in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False ) weights = torch.tensor([6.0, 7.0, 0.4, 0.5]) weights = weights[:, None, None, None] self.conv2d2.weight.data.copy_( torch.ones(self.conv2d2.weight.shape) * weights ) def forward(self, x): x = self.conv2d1(x) x = self.conv2d2(x) return x def test_compute_distance(self, device="cpu"): """Test the distance computation function""" model = TestFPGMPruner.SimpleConvFPGM().to(device) pruner = FPGMPruner(0.3) dist_conv1 = pruner._compute_distance(model.conv2d1.weight) # compute the distance matrix using torch.cdist flattened_filters = torch.Tensor( [ [ 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, ], [ 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, ], [ 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, ], ] ) """ Expected distance matrix should have the following values: [0.0000, 3.0000, 8.7000], [3.0000, 0.0000, 5.7000], [8.7000, 5.7000, 0.0000], the distance should therefore be: [11.7000, 8.7000, 14.4000] """ expected_dist_matrix_conv1 = torch.cdist( flattened_filters, flattened_filters, p=2 ) expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1) assert torch.isclose( dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07 ).all() def _test_update_mask_on_single_layer(self, expected_conv1, device): """Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value""" # test pruning with one layer of conv2d model = TestFPGMPruner.SimpleConvFPGM().to(device) x = torch.ones((1, 1, 32, 32), device=device) pruner = FPGMPruner(0.3) config = [{"tensor_fqn": "conv2d1.weight"}] pruner.prepare(model, config) pruner.enable_mask_update = True pruner.step() assert ( pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False ), "do not prune the least-norm filter" # fusion step pruned_model = pruner.prune() pruned_y = pruned_model(x) # assert shapes expected_conv1 = expected_conv1.to(device) assert pruned_y.shape == (1, 4, 32, 32) assert pruned_model.conv2d1.weight.shape == expected_conv1.shape assert pruned_model.conv2d2.weight.shape == ( 4, 2, 3, 3, ), "conv2d2 should have input channel pruned" # assert value assert torch.isclose( pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07 ).all() def _test_update_mask_on_multiple_layer( self, expected_conv1, expected_conv2, device ): # the second setting model = TestFPGMPruner.SimpleConvFPGM().to(device) x = torch.ones((1, 1, 32, 32), device=device) pruner = FPGMPruner(0.3) config = [ {"tensor_fqn": "conv2d1.weight"}, {"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5}, ] pruner.prepare(model, config) pruner.enable_mask_update = True pruner.step() # Get the masks for the two least-norm filters mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1] mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2] # Check if either of the least-norm filters is not pruned assert ( mask1.item() is not False or mask2.item() is not False ), "Do not prune all least-norm filters" # fusion step pruned_model = pruner.prune() pruned_y = pruned_model(x) # assert shapes expected_conv1 = expected_conv1.to(device) expected_conv2 = expected_conv2.to(device) assert pruned_y.shape == (1, 2, 32, 32) assert pruned_model.conv2d1.weight.shape == expected_conv1.shape assert pruned_model.conv2d2.weight.shape == expected_conv2.shape # assert values assert torch.isclose( pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07 ).all() assert torch.isclose( pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07 ).all() def test_update_mask(self): weights = torch.tensor([3.0, 0.1]) expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None] weights = torch.tensor([7.0, 0.4]) expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None] for device in DEVICES: self._test_update_mask_on_single_layer(expected_conv1, device) self._test_update_mask_on_multiple_layer( expected_conv1, expected_conv2, device )