• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import copy
3import warnings
4from collections import defaultdict
5from typing import Any, Dict, List, Optional
6
7import torch
8from torch import nn
9from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn
10
11
12__all__ = ["ActivationSparsifier"]
13
14
15class ActivationSparsifier:
16    r"""
17    The Activation sparsifier class aims to sparsify/prune activations in a neural
18    network. The idea is to attach the sparsifier to a layer (or layers) and it
19    zeroes out the activations based on the mask_fn (or sparsification function)
20    input by the user.
21    The mask_fn is applied once all the inputs are aggregated and reduced i.e.
22    mask = mask_fn(reduce_fn(aggregate_fn(activations)))
23
24    Note::
25        The sparsification mask is computed on the input **before it goes through the attached layer**.
26
27    Args:
28        model (nn.Module):
29            The model whose layers will be sparsified. The layers that needs to be
30            sparsified should be added separately using the register_layer() function
31        aggregate_fn (Optional, Callable):
32            default aggregate_fn that is used if not specified while registering the layer.
33            specifies how inputs should be aggregated over time.
34            The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor.
35            Example
36                def add_agg_fn(tensor1, tensor2):  return tensor1 + tensor2
37                reduce_fn (Optional, Callable):
38                    default reduce_fn that is used if not specified while registering the layer.
39                    reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after
40                    calling agg_fn() on all inputs.
41                    Example
42                def mean_reduce_fn(agg_tensor):    return agg_tensor.mean(dim=0)
43                mask_fn (Optional, Callable):
44                    default mask_fn that is used to create the sparsification mask using the tensor obtained after
45                    calling the reduce_fn(). This is used by default if a custom one is passed in the
46                    register_layer().
47                    Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config
48                    arguments.
49                features (Optional, list):
50                    default selected features to sparsify.
51                    If this is non-empty, then the mask_fn will be applied for each feature of the input.
52                    For example,
53                mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features]
54                feature_dim (Optional, int):
55                    default dimension of input features. Again, features along this dim will be chosen
56                    for sparsification.
57                sparse_config (Dict):
58                    Default configuration for the mask_fn. This config will be passed
59                    with the mask_fn()
60
61    Example:
62        >>> # xdoctest: +SKIP
63        >>> model = SomeModel()
64        >>> act_sparsifier = ActivationSparsifier(...)  # init activation sparsifier
65        >>> # Initialize aggregate_fn
66        >>> def agg_fn(x, y):
67        >>>     return x + y
68        >>>
69        >>> # Initialize reduce_fn
70        >>> def reduce_fn(x):
71        >>>     return torch.mean(x, dim=0)
72        >>>
73        >>> # Initialize mask_fn
74        >>> def mask_fn(data):
75        >>>     return torch.eye(data.shape).to(data.device)
76        >>>
77        >>>
78        >>> act_sparsifier.register_layer(model.some_layer, aggregate_fn=agg_fn, reduce_fn=reduce_fn, mask_fn=mask_fn)
79        >>>
80        >>> # start training process
81        >>> for _ in [...]:
82        >>>     # epoch starts
83        >>>         # model.forward(), compute_loss() and model.backwards()
84        >>>     # epoch ends
85        >>>     act_sparsifier.step()
86        >>> # end training process
87        >>> sparsifier.squash_mask()
88    """
89
90    def __init__(
91        self,
92        model: nn.Module,
93        aggregate_fn=None,
94        reduce_fn=None,
95        mask_fn=None,
96        features=None,
97        feature_dim=None,
98        **sparse_config,
99    ):
100        self.model = model
101        self.defaults: Dict[str, Any] = defaultdict()
102        self.defaults["sparse_config"] = sparse_config
103
104        # functions
105        self.defaults["aggregate_fn"] = aggregate_fn
106        self.defaults["reduce_fn"] = reduce_fn
107        self.defaults["mask_fn"] = mask_fn
108
109        # default feature and feature_dim
110        self.defaults["features"] = features
111        self.defaults["feature_dim"] = feature_dim
112
113        self.data_groups: Dict[str, Dict] = defaultdict(
114            dict
115        )  # contains all relevant info w.r.t each registered layer
116
117        self.state: Dict[str, Any] = defaultdict(dict)  # layer name -> mask
118
119    @staticmethod
120    def _safe_rail_checks(args):
121        """Makes sure that some of the functions and attributes are not passed incorrectly"""
122
123        # if features are not None, then feature_dim must not be None
124        features, feature_dim = args["features"], args["feature_dim"]
125        if features is not None:
126            assert feature_dim is not None, "need feature dim to select features"
127
128        # all the *_fns should be callable
129        fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"]
130        for key in fn_keys:
131            fn = args[key]
132            assert callable(fn), "function should be callable"
133
134    def _aggregate_hook(self, name):
135        """Returns hook that computes aggregate of activations passing through."""
136
137        # gather some data
138        feature_dim = self.data_groups[name]["feature_dim"]
139        features = self.data_groups[name]["features"]
140        agg_fn = self.data_groups[name]["aggregate_fn"]
141
142        def hook(module, input) -> None:
143            input_data = input[0]
144
145            data = self.data_groups[name].get("data")  # aggregated data
146            if features is None:
147                # no features associated, data should not be a list
148                if data is None:
149                    data = torch.zeros_like(input_data)
150                    self.state[name]["mask"] = torch.ones_like(input_data)
151                out_data = agg_fn(data, input_data)
152            else:
153                # data should be a list [aggregated over each feature only]
154                if data is None:
155                    out_data = [
156                        0 for _ in range(0, len(features))
157                    ]  # create one incase of 1st forward
158                    self.state[name]["mask"] = [0 for _ in range(0, len(features))]
159                else:
160                    out_data = data  # a list
161
162                # compute aggregate over each feature
163                for feature_idx in range(len(features)):
164                    # each feature is either a list or scalar, convert it to torch tensor
165                    feature_tensor = (
166                        torch.Tensor([features[feature_idx]])
167                        .long()
168                        .to(input_data.device)
169                    )
170                    data_feature = torch.index_select(
171                        input_data, feature_dim, feature_tensor
172                    )
173                    if data is None:
174                        curr_data = torch.zeros_like(data_feature)
175                        self.state[name]["mask"][feature_idx] = torch.ones_like(
176                            data_feature
177                        )
178                    else:
179                        curr_data = data[feature_idx]
180                    out_data[feature_idx] = agg_fn(curr_data, data_feature)
181            self.data_groups[name]["data"] = out_data
182
183        return hook
184
185    def register_layer(
186        self,
187        layer: nn.Module,
188        aggregate_fn=None,
189        reduce_fn=None,
190        mask_fn=None,
191        features=None,
192        feature_dim=None,
193        **sparse_config,
194    ):
195        r"""
196        Registers a layer for sparsification. The layer should be part of self.model.
197        Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn
198        and store the aggregated activations that is input over each step.
199
200        Note::
201            - There is no need to pass in the name of the layer as it is automatically computed as per
202              the fqn convention.
203
204            - All the functions (fn) passed as argument will be called at a dim, feature level.
205        """
206        name = module_to_fqn(self.model, layer)
207        assert name is not None, "layer not found in the model"  # satisfy mypy
208
209        if name in self.data_groups:  # unregister layer if already present
210            warnings.warn(
211                "layer already attached to the sparsifier, deregistering the layer and registering with new config"
212            )
213            self.unregister_layer(name=name)
214
215        local_args = copy.deepcopy(self.defaults)
216        update_dict = {
217            "aggregate_fn": aggregate_fn,
218            "reduce_fn": reduce_fn,
219            "mask_fn": mask_fn,
220            "features": features,
221            "feature_dim": feature_dim,
222            "layer": layer,
223        }
224        local_args.update(
225            (arg, val) for arg, val in update_dict.items() if val is not None
226        )
227        local_args["sparse_config"].update(sparse_config)
228
229        self._safe_rail_checks(local_args)
230
231        self.data_groups[name] = local_args
232        agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name))
233
234        self.state[name][
235            "mask"
236        ] = None  # mask will be created when model forward is called.
237
238        # attach agg hook
239        self.data_groups[name]["hook"] = agg_hook
240
241        # for serialization purposes, we know whether aggregate_hook is attached
242        # or sparsify_hook()
243        self.data_groups[name]["hook_state"] = "aggregate"  # aggregate hook is attached
244
245    def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None):
246        """
247        Returns mask associated to the layer.
248
249        The mask is
250            - a torch tensor is features for that layer is None.
251            - a list of torch tensors for each feature, otherwise
252
253        Note::
254            The shape of the mask is unknown until model.forward() is applied.
255            Hence, if get_mask() is called before model.forward(), an
256            error will be raised.
257        """
258        assert (
259            name is not None or layer is not None
260        ), "Need at least name or layer obj to retrieve mask"
261
262        if name is None:
263            assert layer is not None
264            name = module_to_fqn(self.model, layer)
265            assert name is not None, "layer not found in the specified model"
266
267        if name not in self.state:
268            raise ValueError("Error: layer with the given name not found")
269
270        mask = self.state[name].get("mask", None)
271
272        if mask is None:
273            raise ValueError(
274                "Error: shape unknown, call layer() routine at least once to infer mask"
275            )
276        return mask
277
278    def unregister_layer(self, name):
279        """Detaches the sparsifier from the layer"""
280
281        # detach any hooks attached
282        self.data_groups[name]["hook"].remove()
283
284        # pop from the state dict
285        self.state.pop(name)
286
287        # pop from the data groups
288        self.data_groups.pop(name)
289
290    def step(self):
291        """Internally calls the update_mask() function for each layer"""
292        with torch.no_grad():
293            for name, configs in self.data_groups.items():
294                data = configs["data"]
295                self.update_mask(name, data, configs)
296
297                self.data_groups[name].pop("data")  # reset the accumulated data
298
299    def update_mask(self, name, data, configs):
300        """
301        Called for each registered layer and does the following-
302            1. apply reduce_fn on the aggregated activations
303            2. use mask_fn to compute the sparsification mask
304
305        Note:
306            the reduce_fn and mask_fn is called for each feature, dim over the data
307        """
308        mask = self.get_mask(name)
309        sparse_config = configs["sparse_config"]
310        features = configs["features"]
311        reduce_fn = configs["reduce_fn"]
312        mask_fn = configs["mask_fn"]
313        if features is None:
314            data = reduce_fn(data)
315            mask.data = mask_fn(data, **sparse_config)
316        else:
317            for feature_idx in range(len(features)):
318                data_feature = reduce_fn(data[feature_idx])
319                mask[feature_idx].data = mask_fn(data_feature, **sparse_config)
320
321    def _sparsify_hook(self, name):
322        """Returns hook that applies sparsification mask to input entering the attached layer"""
323        mask = self.get_mask(name)
324        features = self.data_groups[name]["features"]
325        feature_dim = self.data_groups[name]["feature_dim"]
326
327        def hook(module, input):
328            input_data = input[0]
329            if features is None:
330                # apply to all the features
331                return input_data * mask
332            else:
333                # apply per feature, feature_dim
334                for feature_idx in range(0, len(features)):
335                    feature = (
336                        torch.Tensor([features[feature_idx]])
337                        .long()
338                        .to(input_data.device)
339                    )
340                    sparsified = (
341                        torch.index_select(input_data, feature_dim, feature)
342                        * mask[feature_idx]
343                    )
344                    input_data.index_copy_(feature_dim, feature, sparsified)
345                return input_data
346
347        return hook
348
349    def squash_mask(self, attach_sparsify_hook=True, **kwargs):
350        """
351        Unregisters aggregate hook that was applied earlier and registers sparsification hooks if
352        attach_sparsify_hook = True.
353        """
354        for name, configs in self.data_groups.items():
355            # unhook agg hook
356            configs["hook"].remove()
357            configs.pop("hook")
358            self.data_groups[name]["hook_state"] = "None"
359            if attach_sparsify_hook:
360                configs["hook"] = configs["layer"].register_forward_pre_hook(
361                    self._sparsify_hook(name)
362                )
363            configs[
364                "hook_state"
365            ] = "sparsify"  # signals that sparsify hook is now attached
366
367    def _get_serializable_data_groups(self):
368        """Exclude hook and layer from the config keys before serializing
369
370        TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing.
371              For time-being, functions are treated the same way as other attributes
372        """
373        data_groups: Dict[str, Any] = defaultdict()
374        for name, config in self.data_groups.items():
375            new_config = {
376                key: value
377                for key, value in config.items()
378                if key not in ["hook", "layer"]
379            }
380            data_groups[name] = new_config
381        return data_groups
382
383    def _convert_mask(self, states_dict, sparse_coo=True):
384        r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument.
385        If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor
386        """
387        states = copy.deepcopy(states_dict)
388        for state in states.values():
389            if state["mask"] is not None:
390                if isinstance(state["mask"], List):
391                    for idx in range(len(state["mask"])):
392                        if sparse_coo:
393                            state["mask"][idx] = state["mask"][idx].to_sparse_coo()
394                        else:
395                            state["mask"][idx] = state["mask"][idx].to_dense()
396                else:
397                    if sparse_coo:
398                        state["mask"] = state["mask"].to_sparse_coo()
399                    else:
400                        state["mask"] = state["mask"].to_dense()
401        return states
402
403    def state_dict(self) -> Dict[str, Any]:
404        r"""Returns the state of the sparsifier as a :class:`dict`.
405
406        It contains:
407        * state - contains name -> mask mapping.
408        * data_groups - a dictionary containing all config information for each
409            layer
410        * defaults - the default config while creating the constructor
411        """
412        data_groups = self._get_serializable_data_groups()
413        state = self._convert_mask(self.state)
414        return {"state": state, "data_groups": data_groups, "defaults": self.defaults}
415
416    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
417        r"""The load_state_dict() restores the state of the sparsifier based on the state_dict
418
419        Args:
420        * state_dict - the dictionary that to which the current sparsifier needs to be restored to
421        """
422        state = state_dict["state"]
423        data_groups, defaults = state_dict["data_groups"], state_dict["defaults"]
424
425        self.__set_state__(
426            {"state": state, "data_groups": data_groups, "defaults": defaults}
427        )
428
429    def __get_state__(self) -> Dict[str, Any]:
430        data_groups = self._get_serializable_data_groups()
431        state = self._convert_mask(self.state)
432        return {
433            "defaults": self.defaults,
434            "state": state,
435            "data_groups": data_groups,
436        }
437
438    def __set_state__(self, state: Dict[str, Any]) -> None:
439        state["state"] = self._convert_mask(
440            state["state"], sparse_coo=False
441        )  # convert mask to dense tensor
442        self.__dict__.update(state)
443
444        # need to attach layer and hook info into the data_groups
445        for name, config in self.data_groups.items():
446            # fetch layer
447            layer = fqn_to_module(self.model, name)
448            assert layer is not None  # satisfy mypy
449
450            # if agg_mode is True, then layer in aggregate mode
451            if "hook_state" in config and config["hook_state"] == "aggregate":
452                hook = layer.register_forward_pre_hook(self._aggregate_hook(name))
453
454            elif "hook_state" in config and config["hook_state"] == "sparsify":
455                hook = layer.register_forward_pre_hook(self._sparsify_hook(name))
456
457            config["layer"] = layer
458            config["hook"] = hook  # type: ignore[possibly-undefined]
459
460    def __repr__(self):
461        format_string = self.__class__.__name__ + " ("
462        for name, config in self.data_groups.items():
463            format_string += "\n"
464            format_string += "\tData Group\n"
465            format_string += f"\t    name: {name}\n"
466            for key in sorted(config.keys()):
467                if key in ["data", "hook", "reduce_fn", "mask_fn", "aggregate_fn"]:
468                    continue
469                format_string += f"\t    {key}: {config[key]}\n"
470        format_string += ")"
471        return format_string
472