1# mypy: allow-untyped-defs 2import copy 3import warnings 4from collections import defaultdict 5from typing import Any, Dict, List, Optional 6 7import torch 8from torch import nn 9from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn 10 11 12__all__ = ["ActivationSparsifier"] 13 14 15class ActivationSparsifier: 16 r""" 17 The Activation sparsifier class aims to sparsify/prune activations in a neural 18 network. The idea is to attach the sparsifier to a layer (or layers) and it 19 zeroes out the activations based on the mask_fn (or sparsification function) 20 input by the user. 21 The mask_fn is applied once all the inputs are aggregated and reduced i.e. 22 mask = mask_fn(reduce_fn(aggregate_fn(activations))) 23 24 Note:: 25 The sparsification mask is computed on the input **before it goes through the attached layer**. 26 27 Args: 28 model (nn.Module): 29 The model whose layers will be sparsified. The layers that needs to be 30 sparsified should be added separately using the register_layer() function 31 aggregate_fn (Optional, Callable): 32 default aggregate_fn that is used if not specified while registering the layer. 33 specifies how inputs should be aggregated over time. 34 The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor. 35 Example 36 def add_agg_fn(tensor1, tensor2): return tensor1 + tensor2 37 reduce_fn (Optional, Callable): 38 default reduce_fn that is used if not specified while registering the layer. 39 reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after 40 calling agg_fn() on all inputs. 41 Example 42 def mean_reduce_fn(agg_tensor): return agg_tensor.mean(dim=0) 43 mask_fn (Optional, Callable): 44 default mask_fn that is used to create the sparsification mask using the tensor obtained after 45 calling the reduce_fn(). This is used by default if a custom one is passed in the 46 register_layer(). 47 Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config 48 arguments. 49 features (Optional, list): 50 default selected features to sparsify. 51 If this is non-empty, then the mask_fn will be applied for each feature of the input. 52 For example, 53 mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features] 54 feature_dim (Optional, int): 55 default dimension of input features. Again, features along this dim will be chosen 56 for sparsification. 57 sparse_config (Dict): 58 Default configuration for the mask_fn. This config will be passed 59 with the mask_fn() 60 61 Example: 62 >>> # xdoctest: +SKIP 63 >>> model = SomeModel() 64 >>> act_sparsifier = ActivationSparsifier(...) # init activation sparsifier 65 >>> # Initialize aggregate_fn 66 >>> def agg_fn(x, y): 67 >>> return x + y 68 >>> 69 >>> # Initialize reduce_fn 70 >>> def reduce_fn(x): 71 >>> return torch.mean(x, dim=0) 72 >>> 73 >>> # Initialize mask_fn 74 >>> def mask_fn(data): 75 >>> return torch.eye(data.shape).to(data.device) 76 >>> 77 >>> 78 >>> act_sparsifier.register_layer(model.some_layer, aggregate_fn=agg_fn, reduce_fn=reduce_fn, mask_fn=mask_fn) 79 >>> 80 >>> # start training process 81 >>> for _ in [...]: 82 >>> # epoch starts 83 >>> # model.forward(), compute_loss() and model.backwards() 84 >>> # epoch ends 85 >>> act_sparsifier.step() 86 >>> # end training process 87 >>> sparsifier.squash_mask() 88 """ 89 90 def __init__( 91 self, 92 model: nn.Module, 93 aggregate_fn=None, 94 reduce_fn=None, 95 mask_fn=None, 96 features=None, 97 feature_dim=None, 98 **sparse_config, 99 ): 100 self.model = model 101 self.defaults: Dict[str, Any] = defaultdict() 102 self.defaults["sparse_config"] = sparse_config 103 104 # functions 105 self.defaults["aggregate_fn"] = aggregate_fn 106 self.defaults["reduce_fn"] = reduce_fn 107 self.defaults["mask_fn"] = mask_fn 108 109 # default feature and feature_dim 110 self.defaults["features"] = features 111 self.defaults["feature_dim"] = feature_dim 112 113 self.data_groups: Dict[str, Dict] = defaultdict( 114 dict 115 ) # contains all relevant info w.r.t each registered layer 116 117 self.state: Dict[str, Any] = defaultdict(dict) # layer name -> mask 118 119 @staticmethod 120 def _safe_rail_checks(args): 121 """Makes sure that some of the functions and attributes are not passed incorrectly""" 122 123 # if features are not None, then feature_dim must not be None 124 features, feature_dim = args["features"], args["feature_dim"] 125 if features is not None: 126 assert feature_dim is not None, "need feature dim to select features" 127 128 # all the *_fns should be callable 129 fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"] 130 for key in fn_keys: 131 fn = args[key] 132 assert callable(fn), "function should be callable" 133 134 def _aggregate_hook(self, name): 135 """Returns hook that computes aggregate of activations passing through.""" 136 137 # gather some data 138 feature_dim = self.data_groups[name]["feature_dim"] 139 features = self.data_groups[name]["features"] 140 agg_fn = self.data_groups[name]["aggregate_fn"] 141 142 def hook(module, input) -> None: 143 input_data = input[0] 144 145 data = self.data_groups[name].get("data") # aggregated data 146 if features is None: 147 # no features associated, data should not be a list 148 if data is None: 149 data = torch.zeros_like(input_data) 150 self.state[name]["mask"] = torch.ones_like(input_data) 151 out_data = agg_fn(data, input_data) 152 else: 153 # data should be a list [aggregated over each feature only] 154 if data is None: 155 out_data = [ 156 0 for _ in range(0, len(features)) 157 ] # create one incase of 1st forward 158 self.state[name]["mask"] = [0 for _ in range(0, len(features))] 159 else: 160 out_data = data # a list 161 162 # compute aggregate over each feature 163 for feature_idx in range(len(features)): 164 # each feature is either a list or scalar, convert it to torch tensor 165 feature_tensor = ( 166 torch.Tensor([features[feature_idx]]) 167 .long() 168 .to(input_data.device) 169 ) 170 data_feature = torch.index_select( 171 input_data, feature_dim, feature_tensor 172 ) 173 if data is None: 174 curr_data = torch.zeros_like(data_feature) 175 self.state[name]["mask"][feature_idx] = torch.ones_like( 176 data_feature 177 ) 178 else: 179 curr_data = data[feature_idx] 180 out_data[feature_idx] = agg_fn(curr_data, data_feature) 181 self.data_groups[name]["data"] = out_data 182 183 return hook 184 185 def register_layer( 186 self, 187 layer: nn.Module, 188 aggregate_fn=None, 189 reduce_fn=None, 190 mask_fn=None, 191 features=None, 192 feature_dim=None, 193 **sparse_config, 194 ): 195 r""" 196 Registers a layer for sparsification. The layer should be part of self.model. 197 Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn 198 and store the aggregated activations that is input over each step. 199 200 Note:: 201 - There is no need to pass in the name of the layer as it is automatically computed as per 202 the fqn convention. 203 204 - All the functions (fn) passed as argument will be called at a dim, feature level. 205 """ 206 name = module_to_fqn(self.model, layer) 207 assert name is not None, "layer not found in the model" # satisfy mypy 208 209 if name in self.data_groups: # unregister layer if already present 210 warnings.warn( 211 "layer already attached to the sparsifier, deregistering the layer and registering with new config" 212 ) 213 self.unregister_layer(name=name) 214 215 local_args = copy.deepcopy(self.defaults) 216 update_dict = { 217 "aggregate_fn": aggregate_fn, 218 "reduce_fn": reduce_fn, 219 "mask_fn": mask_fn, 220 "features": features, 221 "feature_dim": feature_dim, 222 "layer": layer, 223 } 224 local_args.update( 225 (arg, val) for arg, val in update_dict.items() if val is not None 226 ) 227 local_args["sparse_config"].update(sparse_config) 228 229 self._safe_rail_checks(local_args) 230 231 self.data_groups[name] = local_args 232 agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name)) 233 234 self.state[name][ 235 "mask" 236 ] = None # mask will be created when model forward is called. 237 238 # attach agg hook 239 self.data_groups[name]["hook"] = agg_hook 240 241 # for serialization purposes, we know whether aggregate_hook is attached 242 # or sparsify_hook() 243 self.data_groups[name]["hook_state"] = "aggregate" # aggregate hook is attached 244 245 def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None): 246 """ 247 Returns mask associated to the layer. 248 249 The mask is 250 - a torch tensor is features for that layer is None. 251 - a list of torch tensors for each feature, otherwise 252 253 Note:: 254 The shape of the mask is unknown until model.forward() is applied. 255 Hence, if get_mask() is called before model.forward(), an 256 error will be raised. 257 """ 258 assert ( 259 name is not None or layer is not None 260 ), "Need at least name or layer obj to retrieve mask" 261 262 if name is None: 263 assert layer is not None 264 name = module_to_fqn(self.model, layer) 265 assert name is not None, "layer not found in the specified model" 266 267 if name not in self.state: 268 raise ValueError("Error: layer with the given name not found") 269 270 mask = self.state[name].get("mask", None) 271 272 if mask is None: 273 raise ValueError( 274 "Error: shape unknown, call layer() routine at least once to infer mask" 275 ) 276 return mask 277 278 def unregister_layer(self, name): 279 """Detaches the sparsifier from the layer""" 280 281 # detach any hooks attached 282 self.data_groups[name]["hook"].remove() 283 284 # pop from the state dict 285 self.state.pop(name) 286 287 # pop from the data groups 288 self.data_groups.pop(name) 289 290 def step(self): 291 """Internally calls the update_mask() function for each layer""" 292 with torch.no_grad(): 293 for name, configs in self.data_groups.items(): 294 data = configs["data"] 295 self.update_mask(name, data, configs) 296 297 self.data_groups[name].pop("data") # reset the accumulated data 298 299 def update_mask(self, name, data, configs): 300 """ 301 Called for each registered layer and does the following- 302 1. apply reduce_fn on the aggregated activations 303 2. use mask_fn to compute the sparsification mask 304 305 Note: 306 the reduce_fn and mask_fn is called for each feature, dim over the data 307 """ 308 mask = self.get_mask(name) 309 sparse_config = configs["sparse_config"] 310 features = configs["features"] 311 reduce_fn = configs["reduce_fn"] 312 mask_fn = configs["mask_fn"] 313 if features is None: 314 data = reduce_fn(data) 315 mask.data = mask_fn(data, **sparse_config) 316 else: 317 for feature_idx in range(len(features)): 318 data_feature = reduce_fn(data[feature_idx]) 319 mask[feature_idx].data = mask_fn(data_feature, **sparse_config) 320 321 def _sparsify_hook(self, name): 322 """Returns hook that applies sparsification mask to input entering the attached layer""" 323 mask = self.get_mask(name) 324 features = self.data_groups[name]["features"] 325 feature_dim = self.data_groups[name]["feature_dim"] 326 327 def hook(module, input): 328 input_data = input[0] 329 if features is None: 330 # apply to all the features 331 return input_data * mask 332 else: 333 # apply per feature, feature_dim 334 for feature_idx in range(0, len(features)): 335 feature = ( 336 torch.Tensor([features[feature_idx]]) 337 .long() 338 .to(input_data.device) 339 ) 340 sparsified = ( 341 torch.index_select(input_data, feature_dim, feature) 342 * mask[feature_idx] 343 ) 344 input_data.index_copy_(feature_dim, feature, sparsified) 345 return input_data 346 347 return hook 348 349 def squash_mask(self, attach_sparsify_hook=True, **kwargs): 350 """ 351 Unregisters aggregate hook that was applied earlier and registers sparsification hooks if 352 attach_sparsify_hook = True. 353 """ 354 for name, configs in self.data_groups.items(): 355 # unhook agg hook 356 configs["hook"].remove() 357 configs.pop("hook") 358 self.data_groups[name]["hook_state"] = "None" 359 if attach_sparsify_hook: 360 configs["hook"] = configs["layer"].register_forward_pre_hook( 361 self._sparsify_hook(name) 362 ) 363 configs[ 364 "hook_state" 365 ] = "sparsify" # signals that sparsify hook is now attached 366 367 def _get_serializable_data_groups(self): 368 """Exclude hook and layer from the config keys before serializing 369 370 TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing. 371 For time-being, functions are treated the same way as other attributes 372 """ 373 data_groups: Dict[str, Any] = defaultdict() 374 for name, config in self.data_groups.items(): 375 new_config = { 376 key: value 377 for key, value in config.items() 378 if key not in ["hook", "layer"] 379 } 380 data_groups[name] = new_config 381 return data_groups 382 383 def _convert_mask(self, states_dict, sparse_coo=True): 384 r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument. 385 If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor 386 """ 387 states = copy.deepcopy(states_dict) 388 for state in states.values(): 389 if state["mask"] is not None: 390 if isinstance(state["mask"], List): 391 for idx in range(len(state["mask"])): 392 if sparse_coo: 393 state["mask"][idx] = state["mask"][idx].to_sparse_coo() 394 else: 395 state["mask"][idx] = state["mask"][idx].to_dense() 396 else: 397 if sparse_coo: 398 state["mask"] = state["mask"].to_sparse_coo() 399 else: 400 state["mask"] = state["mask"].to_dense() 401 return states 402 403 def state_dict(self) -> Dict[str, Any]: 404 r"""Returns the state of the sparsifier as a :class:`dict`. 405 406 It contains: 407 * state - contains name -> mask mapping. 408 * data_groups - a dictionary containing all config information for each 409 layer 410 * defaults - the default config while creating the constructor 411 """ 412 data_groups = self._get_serializable_data_groups() 413 state = self._convert_mask(self.state) 414 return {"state": state, "data_groups": data_groups, "defaults": self.defaults} 415 416 def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 417 r"""The load_state_dict() restores the state of the sparsifier based on the state_dict 418 419 Args: 420 * state_dict - the dictionary that to which the current sparsifier needs to be restored to 421 """ 422 state = state_dict["state"] 423 data_groups, defaults = state_dict["data_groups"], state_dict["defaults"] 424 425 self.__set_state__( 426 {"state": state, "data_groups": data_groups, "defaults": defaults} 427 ) 428 429 def __get_state__(self) -> Dict[str, Any]: 430 data_groups = self._get_serializable_data_groups() 431 state = self._convert_mask(self.state) 432 return { 433 "defaults": self.defaults, 434 "state": state, 435 "data_groups": data_groups, 436 } 437 438 def __set_state__(self, state: Dict[str, Any]) -> None: 439 state["state"] = self._convert_mask( 440 state["state"], sparse_coo=False 441 ) # convert mask to dense tensor 442 self.__dict__.update(state) 443 444 # need to attach layer and hook info into the data_groups 445 for name, config in self.data_groups.items(): 446 # fetch layer 447 layer = fqn_to_module(self.model, name) 448 assert layer is not None # satisfy mypy 449 450 # if agg_mode is True, then layer in aggregate mode 451 if "hook_state" in config and config["hook_state"] == "aggregate": 452 hook = layer.register_forward_pre_hook(self._aggregate_hook(name)) 453 454 elif "hook_state" in config and config["hook_state"] == "sparsify": 455 hook = layer.register_forward_pre_hook(self._sparsify_hook(name)) 456 457 config["layer"] = layer 458 config["hook"] = hook # type: ignore[possibly-undefined] 459 460 def __repr__(self): 461 format_string = self.__class__.__name__ + " (" 462 for name, config in self.data_groups.items(): 463 format_string += "\n" 464 format_string += "\tData Group\n" 465 format_string += f"\t name: {name}\n" 466 for key in sorted(config.keys()): 467 if key in ["data", "hook", "reduce_fn", "mask_fn", "aggregate_fn"]: 468 continue 469 format_string += f"\t {key}: {config[key]}\n" 470 format_string += ")" 471 return format_string 472