README.md
1# Activation Sparsifier
2
3## Introduction
4Activation sparsifier attaches itself to a layer(s) in the model and prunes the activations passing through them. **Note that the layer weights are not pruned here.**
5
6## How does it work?
7The idea is to compute a mask to prune the activations. To compute the mask, we need a representative tensor that generalizes activations coming from all the batches in the dataset.
8
9There are 3 main steps involved:
101. **Aggregation**: The activations coming from inputs across all the batches are aggregated using a user-defined `aggregate_fn`.
11A simple example is the add function.
122. **Reduce**: The aggregated activations are then reduced using a user-defined `reduce_fn`. A simple example is average.
133. **Masking**: The reduced activations are then passed into a user-defined `mask_fn` to compute the mask.
14
15Essentially, the high level idea of computing the mask is
16
17```
18>>> aggregated_tensor = aggregate_fn([activation for activation in all_activations])
19>>> reduced_tensor = reduce_fn(aggregated_tensor)
20>>> mask = mask_fn(reduced_tensor)
21```
22
23*The activation sparsifier also supports per-feature/channel sparsity. This means that a desired set of features in an activation can be also pruned. The mask will be stored per feature.*
24
25```
26>>> # when features = None, mask is a tensor computed on the entire activation tensor
27>>> # otherwise, mask is a list of tensors of length = len(features), computed on each feature of activations
28>>>
29>>> # On a high level, this is how the mask is computed if features is not None
30>>> for i in range(len(features)):
31>>> aggregated_tensor_feature = aggregate_fn([activation[features[i]] for activation in all_activations])
32>>> mask[i] = mask_fn(reduce_fn(aggregated_tensor_feature))
33```
34
35## Implementation Details
36The activation sparsifier attaches itself to a set of layers in a model and then attempts to sparsify the activations flowing through them. *Attach* means registering a [`forward_pre_hook()`](https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#register_forward_pre_hook) to the layer.
37
38Let's go over the 3 steps again -
391. **Aggregation**: The activation of aggregation happens by attaching a hook to the layer that specifically applies and stores the aggregated data. The aggregation happens per feature, if the features are specified, otherwise it happens on the entire tensor.
40The `aggregate_fn` should accept two input tensors and return an aggregated tensor. Example:
41```
42def aggregate_fn(tensor1, tensor2):
43 return tensor1 + tensor2
44```
45
462. **Reduce**: This is initiated once the `step()` is called. The `reduce_fn()` is called on the aggregated tensor. The goal is to squash the aggregated tensor.
47The `reduce_fn` should accept one tensor as argument and return a reduced tensor. Example:
48```
49def reduce_fn(agg_tensor):
50 return agg_tensor.mean(dim=0)
51```
52
533. **Masking**: The computation of the mask happens immediately after the reduce operation. The `mask_fn()` is applied on the reduced tensor. Again, this happens per-feature, if the features are specified.
54The `mask_fn` should accept a tensor (reduced) and sparse config as arguments and return a mask (computed using tensor according to the config). Example:
55```
56def mask_fn(tensor, threshold): # threshold is the sparse config here
57 mask = torch.ones_like(tensor)
58 mask[torch.abs(tensor) < threshold] = 0.0
59 return mask
60```
61
62## API Design
63`ActivationSparsifier`: Attaches itself to a model layer and sparsifies the activation flowing through that layer. The user can pass in the default `aggregate_fn`, `reduce_fn` and `mask_fn`. Additionally, `features` and `feature_dim` are also accepted.
64
65`register_layer`: Registers a layer for sparsification. Specifically, registers `forward_pre_hook()` that performs aggregation.
66
67`step`: For each registered layer, applies the `reduce_fn` on aggregated activations and then applies `mask_fn` after reduce operation.
68
69`squash_mask`: Unregisters aggregate hook that was applied earlier and registers sparsification hooks if `attach_sparsify_hook=True`. Sparsification hooks applies the computed mask to the activations before it flows into the registered layer.
70
71## Example
72
73```
74# Fetch model
75model = SomeModel()
76
77# define some aggregate, reduce and mask functions
78def aggregate_fn(tensor1, tensor2):
79 return tensor1 + tensor2
80
81def reduce_fn(tensor):
82 return tensor.mean(dim=0)
83
84def mask_fn(data, threshold):
85 mask = torch.ones_like(tensor)
86 mask[torch.abs(tensor) < threshold] = 0.0
87 return mask)
88
89# sparse config
90default_sparse_config = {"threshold": 0.5}
91
92# define activation sparsifier
93act_sparsifier = ActivationSparsifier(model=model, aggregate_fn=aggregate_fn, reduce_fn=reduce_fn, mask_fn=mask_fn, **threshold)
94
95# register some layer to sparsify their activations
96act_sparsifier.register_layer(model.some_layer, threshold=0.8) # custom sparse config
97
98for epoch in range(EPOCHS):
99 for input, target in dataset:
100 ...
101 out = model(input)
102 ...
103 act_sparsifier.step() # mask is computed
104
105act_sparsifier.squash_mask(attach_sparsify_hook=True) # activations are multiplied with the computed mask before flowing through the layer
106```
107