README.md
1# Structured Pruning
2
3## Intro / Motivation
4
5**Pruning** is the technique of removing parameters from a model to reduce the computational cost. The goal of pruning is to improve the performance of the model while maintaining it's accuracy.
6
7### Unstructured vs. Structured Pruning
8One way to do this is to consider each parameter individually. This gives us the greatest granularity when pruning and is called **unstructured pruning**.
9
10For example, consider a simple linear regression model that is parametrized by a weight tensor W.
11
12```
13W = [[1 2 3]
14 [4 5 6]
15 [7 1 9]]
16```
17
18We can prune the lowest absolute value elements in W in order to preserve as much information as possible.
19Below we've removed three parameters from W.
20
21```
22W_pruned = [[0 0 3]
23 [4 5 6]
24 [7 0 9]]
25```
26
27Unfortunately, zeroing out parameters does not offer a speed-up to the model out of the box. We need custom sparse kernels that are designed to take advantage of sparsity to speed up computation. For more information about unstructured pruning check out our tutorials [here]().
28
29However, if we zero out a row of parameters at a time instead of a single parameter, we can speed up computation by resizing the weight matrix. This is called **structured pruning** and is what this folder implements.
30
31```
32W_pruned = [[0 0 0] = [[4, 5, 6],
33 [4 5 6] [7, 1, 9]]
34 [7 1 9]]
35
36```
37### Weight Resizing
38
39However, since the pruned weight tensor has a different shape than the original weight tensor, subsequent operations will cause an error due to this shape mismatch. We need to remove both the weights of the original weight tensor and the columns of subsequent tensors that correspond to the pruned rows.
40
41You can see an example of this below for a model containing two linear layers, one parametrized by W and another by U
42
43
44
45By removing a row from U and a column from W, we can avoid a shape mismatch.
46
47
48
49
50One benefit of **structured pruning** is that it uses the same dense kernels that the original model uses, and does not rely on custom sparse kernel like **unstructured pruning**.
51However, structured pruning degrades accuracy more than unstructured pruning because of the lack of granularity, so it is not always the right choice.
52
53Generally the structured pruning process looks something like this:
541. Define what layers in the model you want to structured prune.
552. Evaluate the importance of each row in each layer in the model.
563. Remove rows by resizing the weight matrices of each layer
574. Stop if target sparsity level is met.
58
59The accuracy degradation of pruning can be quite large initially. Once we are satisfied with our pruned tensor, we usually retrain the model after pruning in order to restore some of this accuracy loss.
60
61## Quickstart Guide
62
63**Your model must be FX symbolically traceable**.
64
65You can test this with the following bit of code:
66
67```python
68from torch.fx import symbolic_trace
69model = MyModel()
70symbolic_trace(model)
71```
72
73Using `torch.fx` we can get a compute graph of our model. Each operation (add, multiply, ReLU) is a node in the graph, and the order of operations is defined by the edges of the graph.
74
75Structured pruning works by traversing this graph and looking for specific **patterns**, which are just a specific sequence of operations.
76
77Each pattern is tied to a pruning function, which is responsible for structured pruning the graph nodes that match the pattern.
78
79The above [example](#weight-resizing) of two linear layers would match against a `(nn.Linear, nn.Linear)` pattern. This is how we identify the rows to remove and the columns of the subsequent layer.
80
81Structured pruning also works on other patterns other than two adjacent Linear layers,
82
83- linear -> linear
84- linear -> activation -> linear
85- conv2d -> conv2d
86- conv2d -> activation -> conv2d
87- conv2d -> activation -> pool -> conv2d
88- conv2d -> pool -> activation -> conv2d
89- conv2d -> adaptive pool -> flatten -> linear
90
91A complete set of the patterns we support can be found [here](https://github.com/pytorch/pytorch/blob/master/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py#L85).
92
93If you are looking to prune a currently unsupported pattern, you can do this by modifying the pattern dict that we provide to the pruner, see [here](#writing-custom-patterns-and-pruning-functions-for-structured-pruning). Feel free to open a PR to add in new patterns.
94
95
96Here is an example script that will prune away 50% of the rows for all the linear layers in the model, based on the saliency of each row.
97```python
98from torch.ao.pruning._experimental.pruner import SaliencyPruner
99
100# Define model
101class Model(nn.Module):
102 def __init__(self) -> None:
103 super().__init__()
104 self.seq = nn.Sequential(
105 nn.Linear(700, 500, bias=True),
106 nn.ReLU(),
107 nn.Linear(500, 800, bias=False),
108 nn.ReLU(),
109 nn.Linear(800, 600, bias=True),
110 nn.ReLU(),
111 )
112 self.linear = nn.Linear(600, 4, bias=False)
113
114 def forward(self, x):
115 x = self.seq(x)
116 x = self.linear(x)
117 return x
118
119# Define pruning_config, which specifies which tensors you wish to prune.
120# The SaliencyPruner also needs a sparsity_level parameter to specify what % of rows to prune.
121pruning_config = [
122 {"tensor_fqn": "seq.0.weight", "sparsity_level": 0.5},
123 {"tensor_fqn": "seq.2.weight", "sparsity_level": 0.5},
124 {"tensor_fqn": "seq.4.weight", "sparsity_level": 0.5},
125 {"tensor_fqn": "linear.weight", "sparsity_level": 0.5},
126]
127
128original = Model()
129# define defaults
130# for structured pruning, we also prune biases by default.
131defaults = {"prune_bias": True}
132# any configs passed in here are defaults that are propagated
133# Your selection criteria is decided by which pruner you use
134pruner = SaliencyPruner(defaults, patterns=patterns)
135
136# Next we call `prepare`, which will attach `FakeStructuredSparsity` parameterizations
137# to the tensors specified in the config. These parameterizations will zero out
138# the appropriate weights in order to make the model behave as if it has been pruned.
139pruner.prepare(original, sparse_config)
140
141# take one pruning step. This will update the masks
142pruner.enable_mask_update = True
143pruner.step()
144
145# pruner.prune() will find patterns and apply that patterns pruning function to it's matching nodes.
146# The output of pruner.prune() is a model with resized weights and the masks / parametrizations removed.
147pruned_model = pruner.prune()
148```
149Afterwards, by printing the name and size of each parameter in our model, we can see that it has been pruned.
150
151```
152# original model
153Parameter name | Shape | # of elements
154--------------------|-----------------|---------------
155seq.0.weight | 500, 700 | 350000
156seq.0.bias | 500 | 500
157seq.2.weight | 800, 500 | 400000
158seq.4.weight | 600, 800 | 480000
159seq.4.bias | 600 | 600
160linear.weight | 4, 600 | 2400
161=== Total Number of Parameters: 1233500 ===
162```
163```
164# pruned model
165Parameter name | Shape | # of elements
166--------------------|-----------------|---------------
167seq.0.weight | 250, 700 | 175000
168seq.0.bias | 250 | 250
169seq.2.weight | 400, 250 | 100000
170seq.4.weight | 300, 400 | 120000
171seq.4.bias | 300 | 300
172linear.weight | 2, 300 | 600
173=== Total Number of Parameters: 396150 ===
174```
175
176Although we pruned 50% of the rows, the total number of parameters is 25% of the original model.
177
178Since we remove both the rows of a weight tensor and the columns of the subsequent tensor. The total number of parameters is roughly (1-0.5)* (1-0.5) = 0.25 of the original number of parameters.
179
180## Advanced Tutorial
181
182### Pruning Config
183
184To specify the layers to prune we just need the fully qualified name (FQN) of the tensor you are looking to prune in the module.
185You can get the FQN of a tensor by printing out `model.named_parameters()`.
186
187To prune multiple layers, we just append entries to the pruning config.
188**tensor_fqn** is the only required key in the pruning config. You can pass additional information in the config, for example the sparsity level you want to prune to by adding a key to the config. You can then access this additional information when you update the masks.
189
190### Implementing a Pruner
191
192If you want to prune weights using a different pruning criteria than saliency, you'll need to implement your own pruner.
193
194To do this, we need to extend a `BaseStructuredSparsifier` with a custom `update_mask` function.
195
196This `update_mask` function contains the user logic for picking what weights to prune.
197
198One common pruning criteria is to use the **saliency** of a row, which is defined as the sum of all the L1 norms of the weights in the row.
199The idea is to remove the weights that are small, since they wouldn't contribute much to the final prediction.
200
201Below we can see an implemented Saliency Pruner
202
203```python
204class SaliencyPruner(BaseStructuredSparsifier):
205 """
206 Prune filters based on the saliency
207 The saliency for a filter is given by the sum of the L1 norms of all of its weights
208 """
209
210 def update_mask(self, module, tensor_name, **kwargs):
211 # tensor_name will give you the FQN, all other keys in pruning config are present in kwargs
212 weights = getattr(module, tensor_name)
213 mask = getattr(module.parametrizations, tensor_name)[0].mask
214
215 # use negative weights so we can use topk (we prune out the smallest)
216 saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1)
217 num_to_pick = int(len(mask) * kwargs["sparsity_level"])
218 prune = saliency.topk(num_to_pick).indices
219
220 # Set the mask to be false for the rows we want to prune
221 mask.data[prune] = False
222
223```
224
225### Writing Custom Patterns and Pruning Functions for Structured Pruning
226If you're working with linear/conv2d layers, it's very probable that you just need to add an entry to the pattern dict mapping your pattern to an existing prune_function.
227
228This is because there are many modules, for example **pooling** that behave the same way and do not need to be modified by the pruning code.
229
230```python
231from torch.ao.pruning._experimental.pruner.prune_functions import prune_conv2d_activation_conv2d
232
233def prune_conv2d_pool_activation_conv2d(
234 c1: nn.Conv2d,
235 pool: nn.Module,
236 activation: Optional[Callable[[Tensor], Tensor]],
237 c2: nn.Conv2d,
238) -> None:
239 prune_conv2d_activation_conv2d(c1, activation, c2)
240
241# note how the pattern defined in the key will be passed to the pruning function as args
242my_patterns = {(nn.Conv2d, nn.MaxPool2d, nn.ReLU, nn.Conv2d): prune_conv2d_activation_conv2d}
243
244pruning_patterns = _get_default_structured_pruning_patterns()
245pruning_patterns.update(my_patterns)
246
247pruner = SaliencyPruner({}, patterns=pruning_patterns)
248```
249However, there are also modules like batch norm, which will not work properly without being pruned as well. In this instance, you would need to write a custom pruning function in order to handle that logic properly.
250
251You can see the implemented pruning functions [here](https://github.com/pytorch/pytorch/blob/master/torch/ao/pruning/_experimental/pruner/prune_functions.py) for examples. Please feel free to open a PR so we get a complete set of the patterns and pruning functions.
252