1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper functions to add support for magnitude-based model pruning. 16 17 # Adds variables and ops to the graph to enable 18 # elementwise masking of weights 19 apply_mask(weights) 20 21 # Returns a list containing the sparsity of each of the weight tensors 22 get_weight_sparsity() 23 24 # Returns a list of all the masked weight tensorflow variables 25 get_masked_weights() 26 27 # Returns a list of all the mask tensorflow variables 28 get_masks() 29 30 # Returns a list of all the thresholds 31 get_thresholds() 32 33 # Returns a list of all the weight tensors that have been masked 34 get_weights() 35 36 The Pruning class uses a tf.hparams object to set up the 37 parameters for a model pruning. Here's a typical usage: 38 39 # Parse pruning hyperparameters 40 pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) 41 42 # Create a pruning object using the pruning_hparams 43 p = pruning.Pruning(pruning_hparams) 44 45 # Add mask update ops to the graph 46 mask_update_op = p.conditional_mask_update_op() 47 48 # Add the summaries 49 p.add_pruning_summaries() 50 51 # Run the op 52 session.run(mask_update_op) 53 54 # An object of the pruning also accepts externally defined sparsity: 55 sparsity = tf.Variable(0.5, name = "ConstantSparsity") 56 p = pruning.Pruning(pruning_hparams, sparsity=sparsity) 57""" 58# pylint: disable=missing-docstring 59from __future__ import absolute_import 60from __future__ import division 61from __future__ import print_function 62 63from tensorflow.contrib.model_pruning.python import pruning_utils 64from tensorflow.contrib.model_pruning.python.layers import core_layers as core 65from tensorflow.contrib.training.python.training import hparam 66from tensorflow.python.framework import dtypes 67from tensorflow.python.framework import ops 68from tensorflow.python.ops import array_ops 69from tensorflow.python.ops import control_flow_ops 70from tensorflow.python.ops import init_ops 71from tensorflow.python.ops import math_ops 72from tensorflow.python.ops import nn_impl 73from tensorflow.python.ops import nn_ops 74from tensorflow.python.ops import state_ops 75from tensorflow.python.ops import variable_scope 76from tensorflow.python.ops import variables 77from tensorflow.python.platform import tf_logging as logging 78from tensorflow.python.summary import summary 79from tensorflow.python.training import training_util 80 81_MASK_COLLECTION = core.MASK_COLLECTION 82_THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION 83_MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION 84_WEIGHT_COLLECTION = core.WEIGHT_COLLECTION 85_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME 86 87 88def apply_mask(x, scope=''): 89 """Apply mask to a given weight tensor. 90 91 Args: 92 x: Input weight tensor 93 scope: The current variable scope. Defaults to "". 94 Returns: 95 Tensor representing masked_weights 96 """ 97 98 mask = pruning_utils.weight_mask_variable(x, scope) 99 threshold = pruning_utils.weight_threshold_variable(x, scope) 100 # Add masked_weights in the weights namescope so as to make it easier 101 # for the quantization library to add quant ops. 102 masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME) 103 104 # Make sure the mask for a given variable are not added multiple times to the 105 # collection. This is particularly important when applying mask to RNN's 106 # weight variables 107 if mask not in ops.get_collection_ref(_MASK_COLLECTION): 108 ops.add_to_collection(_THRESHOLD_COLLECTION, threshold) 109 ops.add_to_collection(_MASK_COLLECTION, mask) 110 ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights) 111 ops.add_to_collection(_WEIGHT_COLLECTION, x) 112 return masked_weights 113 114 115def get_masked_weights(): 116 return ops.get_collection(_MASKED_WEIGHT_COLLECTION) 117 118 119def get_masks(): 120 return ops.get_collection(_MASK_COLLECTION) 121 122 123def get_thresholds(): 124 return ops.get_collection(_THRESHOLD_COLLECTION) 125 126 127def get_weights(): 128 return ops.get_collection(_WEIGHT_COLLECTION) 129 130 131def get_weight_sparsity(): 132 """Get sparsity of the weights. 133 134 Args: 135 None 136 137 Returns: 138 A list containing the sparsity of each of the weight tensors 139 """ 140 masks = get_masks() 141 return [nn_impl.zero_fraction(mask) for mask in masks] 142 143 144def get_pruning_hparams(): 145 """Get a tf.HParams object with the default values for the hyperparameters. 146 147 name: string 148 name of the pruning specification. Used for adding summaries and ops under 149 a common tensorflow name_scope 150 begin_pruning_step: integer 151 the global step at which to begin pruning 152 end_pruning_step: integer 153 the global step at which to terminate pruning. Defaults to -1 implying 154 that pruning continues till the training stops 155 weight_sparsity_map: list of strings 156 comma separed list of weight variable name:target sparsity pairs. 157 For layers/weights not in this list, sparsity as specified by the 158 target_sparsity hyperparameter is used. 159 Eg. [conv1:0.9,conv2/kernel:0.8] 160 threshold_decay: float 161 the decay factor to use for exponential decay of the thresholds 162 pruning_frequency: integer 163 How often should the masks be updated? (in # of global_steps) 164 nbins: integer 165 number of bins to use for histogram computation 166 block_height: integer 167 number of rows in a block (defaults to 1) 168 block_width: integer 169 number of cols in a block (defaults to 1) 170 block_pooling_function: string 171 Whether to perform average (AVG) or max (MAX) pooling in the block 172 (default: AVG) 173 initial_sparsity: float 174 initial sparsity value 175 target_sparsity: float 176 target sparsity value 177 sparsity_function_begin_step: integer 178 the global step at this which the gradual sparsity function begins to 179 take effect 180 sparsity_function_end_step: integer 181 the global step used as the end point for the gradual sparsity function 182 sparsity_function_exponent: float 183 exponent = 1 is linearly varying sparsity between initial and final. 184 exponent > 1 varies more slowly towards the end than the beginning 185 use_tpu: False 186 Indicates whether to use TPU 187 188 We use the following sparsity function: 189 190 num_steps = (sparsity_function_end_step - 191 sparsity_function_begin_step)/pruning_frequency 192 sparsity(step) = (initial_sparsity - target_sparsity)* 193 [1-step/(num_steps -1)]**exponent + target_sparsity 194 195 Args: 196 None 197 198 Returns: 199 tf.HParams object initialized to default values 200 201 """ 202 return hparam.HParams( 203 name='model_pruning', 204 begin_pruning_step=0, 205 end_pruning_step=-1, 206 weight_sparsity_map=[''], 207 threshold_decay=0.0, 208 pruning_frequency=10, 209 nbins=256, 210 block_height=1, 211 block_width=1, 212 block_pooling_function='AVG', 213 initial_sparsity=0.0, 214 target_sparsity=0.5, 215 sparsity_function_begin_step=0, 216 sparsity_function_end_step=100, 217 sparsity_function_exponent=3.0, 218 use_tpu=False) 219 220 221class Pruning(object): 222 223 def __init__(self, spec=None, global_step=None, sparsity=None): 224 """Set up the specification for model pruning. 225 226 If a spec is provided, the sparsity is set up based on the sparsity_function 227 in the spec. The effect of sparsity_function is overridden if the sparsity 228 variable is passed to the constructor. This enables setting up arbitrary 229 sparsity profiles externally and passing it to this pruning functions. 230 231 Args: 232 spec: Pruning spec as defined in pruning.proto 233 global_step: A tensorflow variable that is used while setting up the 234 sparsity function 235 sparsity: A tensorflow scalar variable storing the sparsity 236 """ 237 # Pruning specification 238 self._spec = spec if spec else get_pruning_hparams() 239 240 # Sanity check for pruning hparams 241 self._validate_spec() 242 243 # A tensorflow variable that tracks the sparsity function. 244 # If not provided as input, the graph must already contain the global_step 245 # variable before calling this constructor. 246 self._global_step = self._setup_global_step(global_step) 247 248 # Stores the tensorflow sparsity variable. 249 # Built using self._setup_sparsity() or provided externally 250 self._sparsity = (sparsity 251 if sparsity is not None else self._setup_sparsity()) 252 253 # List of tensorflow assignments ops for new masks and thresholds 254 self._assign_ops = [] 255 256 # Tensorflow variable keeping track of the last global step when the masks 257 # were updated 258 self._last_update_step = self._setup_last_update_step() 259 260 # Block dimensions 261 self._block_dim = [self._spec.block_height, self._spec.block_width] 262 263 # Block pooling function 264 self._block_pooling_function = self._spec.block_pooling_function 265 266 # Mapping of weight names and target sparsity 267 self._weight_sparsity_map = self._get_weight_sparsity_map() 268 269 def _validate_spec(self): 270 spec = self._spec 271 if spec.begin_pruning_step < 0: 272 raise ValueError('Illegal value for begin_pruning_step') 273 274 if spec.begin_pruning_step >= spec.end_pruning_step: 275 if spec.end_pruning_step != -1: 276 raise ValueError( 277 'Pruning must begin before it can end. begin_step=%d, end_step=%d.' 278 'Set end_pruning_step to -1 if pruning is required till training' 279 'stops' % (spec.begin_pruning_step, spec.end_pruning_step)) 280 281 if spec.sparsity_function_begin_step < 0: 282 raise ValueError('Illegal value for sparsity_function_begin_step') 283 284 if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step: 285 raise ValueError( 286 'Sparsity function requires begin_step < end_step') 287 288 if not 0.0 <= spec.threshold_decay < 1.0: 289 raise ValueError('threshold_decay must be in range [0,1)') 290 291 if not 0.0 <= spec.initial_sparsity < 1.0: 292 raise ValueError('initial_sparsity must be in range [0,1)') 293 294 if not 0.0 <= spec.target_sparsity < 1.0: 295 raise ValueError('target_sparsity must be in range [0,1)') 296 297 def _setup_global_step(self, global_step): 298 graph_global_step = global_step 299 if graph_global_step is None: 300 graph_global_step = training_util.get_global_step() 301 302 return math_ops.cast(graph_global_step, dtypes.int32) 303 304 def _setup_sparsity(self): 305 begin_step = self._spec.sparsity_function_begin_step 306 end_step = self._spec.sparsity_function_end_step 307 initial_sparsity = self._spec.initial_sparsity 308 target_sparsity = self._spec.target_sparsity 309 exponent = self._spec.sparsity_function_exponent 310 311 with ops.name_scope(self._spec.name): 312 p = math_ops.minimum( 313 1.0, 314 math_ops.maximum( 315 0.0, 316 math_ops.div( 317 math_ops.cast(self._global_step - begin_step, dtypes.float32), 318 end_step - begin_step))) 319 sparsity = math_ops.add( 320 math_ops.multiply(initial_sparsity - target_sparsity, 321 math_ops.pow(1 - p, exponent)), 322 target_sparsity, 323 name='sparsity') 324 325 return sparsity 326 327 def _setup_last_update_step(self): 328 with variable_scope.variable_scope( 329 self._spec.name, use_resource=self._spec.use_tpu) as scope: 330 try: 331 last_update_step = variable_scope.get_variable( 332 'last_mask_update_step', [], 333 initializer=init_ops.zeros_initializer(), 334 trainable=False, 335 dtype=dtypes.int32) 336 except ValueError: 337 scope.reuse_variables() 338 last_update_step = variable_scope.get_variable( 339 'last_mask_update_step', dtype=dtypes.int32) 340 return last_update_step 341 342 def _get_weight_sparsity_map(self): 343 """Return the map of weight_name:sparsity parsed from the hparams.""" 344 weight_sparsity_map = {} 345 val_list = self._spec.weight_sparsity_map 346 filtered_val_list = [l for l in val_list if l] 347 for val in filtered_val_list: 348 weight_name, sparsity = val.split(':') 349 if float(sparsity) >= 1.0: 350 raise ValueError('Weight sparsity can not exceed 1.0') 351 weight_sparsity_map[weight_name] = float(sparsity) 352 353 return weight_sparsity_map 354 355 def _get_sparsity(self, weight_name): 356 """Return target sparsity for the given layer/weight name.""" 357 target_sparsity = [ 358 sparsity for name, sparsity in self._weight_sparsity_map.items() 359 if weight_name.find(name) != -1 360 ] 361 if not target_sparsity: 362 return self._sparsity 363 364 if len(target_sparsity) > 1: 365 raise ValueError( 366 'Multiple matches in weight_sparsity_map for weight %s' % weight_name) 367 # TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize 368 # to handle other cases as well. 369 return math_ops.mul( 370 self._sparsity, 371 math_ops.div(target_sparsity[0], self._spec.target_sparsity)) 372 373 def _update_mask(self, weights, threshold): 374 """Updates the mask for a given weight tensor. 375 376 This functions first computes the cdf of the weight tensor, and estimates 377 the threshold value such that 'desired_sparsity' fraction of weights 378 have magnitude less than the threshold. 379 380 Args: 381 weights: The weight tensor that needs to be masked. 382 threshold: The current threshold value. The function will compute a new 383 threshold and return the exponential moving average using the current 384 value of threshold 385 386 Returns: 387 new_threshold: The new value of the threshold based on weights, and 388 sparsity at the current global_step 389 new_mask: A numpy array of the same size and shape as weights containing 390 0 or 1 to indicate which of the values in weights falls below 391 the threshold 392 393 Raises: 394 ValueError: if sparsity is not defined 395 """ 396 if self._sparsity is None: 397 raise ValueError('Sparsity variable undefined') 398 399 sparsity = self._get_sparsity(weights.op.name) 400 with ops.name_scope(weights.op.name + '_pruning_ops'): 401 abs_weights = math_ops.abs(weights) 402 k = math_ops.cast( 403 math_ops.round( 404 math_ops.cast(array_ops.size(abs_weights), dtypes.float32) * 405 (1 - sparsity)), dtypes.int32) 406 # Sort the entire array 407 values, _ = nn_ops.top_k( 408 array_ops.reshape(abs_weights, [-1]), k=array_ops.size(abs_weights)) 409 # Grab the (k-1) th value 410 current_threshold = array_ops.gather(values, k - 1) 411 smoothed_threshold = math_ops.add_n([ 412 math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay), 413 math_ops.multiply(threshold, self._spec.threshold_decay) 414 ]) 415 416 new_mask = math_ops.cast( 417 math_ops.greater_equal(abs_weights, smoothed_threshold), 418 dtypes.float32) 419 420 return smoothed_threshold, new_mask 421 422 def _maybe_update_block_mask(self, weights, threshold): 423 """Performs block-granular masking of the weights. 424 425 Block pruning occurs only if the block_height or block_width is > 1 and 426 if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise 427 pruning occurs. 428 Args: 429 weights: The weight tensor that needs to be masked. 430 threshold: The current threshold value. The function will compute a new 431 threshold and return the exponential moving average using the current 432 value of threshold 433 434 Returns: 435 new_threshold: The new value of the threshold based on weights, and 436 sparsity at the current global_step 437 new_mask: A numpy array of the same size and shape as weights containing 438 0 or 1 to indicate which of the values in weights falls below 439 the threshold 440 441 Raises: 442 ValueError: if block pooling function is not AVG or MAX 443 """ 444 squeezed_weights = array_ops.squeeze(weights) 445 if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]: 446 return self._update_mask(weights, threshold) 447 448 if self._block_pooling_function not in ['AVG', 'MAX']: 449 raise ValueError('Unknown pooling function for block sparsity: %s' % 450 self._block_pooling_function) 451 452 with ops.name_scope(weights.op.name + '_pruning_ops'): 453 abs_weights = math_ops.abs(squeezed_weights) 454 455 pool_window = [self._block_dim[0], self._block_dim[1]] 456 pool_fn = pruning_utils.factorized_pool 457 squeeze_axis = None 458 if not self._spec.use_tpu: 459 pool_fn = nn_ops.pool 460 abs_weights = array_ops.reshape( 461 abs_weights, 462 [1, abs_weights.get_shape()[0], 463 abs_weights.get_shape()[1], 1]) 464 squeeze_axis = [0, 3] 465 466 pooled_weights = pool_fn( 467 abs_weights, 468 window_shape=pool_window, 469 pooling_type=self._block_pooling_function, 470 strides=pool_window, 471 padding='SAME', 472 name=weights.op.name + '_pooled') 473 474 if pooled_weights.get_shape().ndims != 2: 475 pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis) 476 477 smoothed_threshold, new_mask = self._update_mask(pooled_weights, 478 threshold) 479 480 updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim) 481 sliced_mask = array_ops.slice( 482 updated_mask, [0, 0], 483 [squeezed_weights.get_shape()[0], 484 squeezed_weights.get_shape()[1]]) 485 486 return smoothed_threshold, array_ops.reshape(sliced_mask, 487 array_ops.shape(weights)) 488 489 def _get_mask_assign_ops(self): 490 # Make sure the assignment ops have not already been added to the list 491 if self._assign_ops: 492 raise ValueError( 493 'Assign op list not empty. _get_mask_assign_ops() called twice?') 494 495 masks = get_masks() 496 weights = get_weights() 497 thresholds = get_thresholds() 498 499 if len(masks) != len(thresholds): 500 raise ValueError( 501 'Number of masks %s and number of thresholds %s mismatch' % 502 (len(masks), len(thresholds))) 503 504 for index, mask in enumerate(masks): 505 threshold = thresholds[index] 506 weight = weights[index] 507 is_partitioned = isinstance(weight, variables.PartitionedVariable) 508 if is_partitioned: 509 weight = weight.as_tensor() 510 511 new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold) 512 self._assign_ops.append( 513 pruning_utils.variable_assign(threshold, new_threshold)) 514 515 self._assign_ops.append( 516 pruning_utils.partitioned_variable_assign(mask, new_mask) 517 if is_partitioned else pruning_utils.variable_assign(mask, new_mask)) 518 519 def mask_update_op(self): 520 with ops.name_scope(self._spec.name): 521 if not self._assign_ops: 522 self._get_mask_assign_ops() 523 with ops.control_dependencies([ 524 state_ops.assign( 525 self._last_update_step, 526 self._global_step, 527 name='last_mask_update_step_assign') 528 ]): 529 with ops.control_dependencies(self._assign_ops): 530 logging.info('Updating masks.') 531 return control_flow_ops.no_op('mask_update') 532 533 def conditional_mask_update_op(self): 534 535 def maybe_update_masks(): 536 with ops.name_scope(self._spec.name): 537 is_step_within_pruning_range = math_ops.logical_and( 538 math_ops.greater_equal(self._global_step, 539 self._spec.begin_pruning_step), 540 # If end_pruning_step is negative, keep pruning forever! 541 math_ops.logical_or( 542 math_ops.less_equal(self._global_step, 543 self._spec.end_pruning_step), 544 math_ops.less(self._spec.end_pruning_step, 0))) 545 is_pruning_step = math_ops.less_equal( 546 math_ops.add(self._last_update_step, self._spec.pruning_frequency), 547 self._global_step) 548 return math_ops.logical_and(is_step_within_pruning_range, 549 is_pruning_step) 550 551 def mask_update_op(): 552 return self.mask_update_op() 553 554 def no_update_op(): 555 return control_flow_ops.no_op() 556 557 return control_flow_ops.cond(maybe_update_masks(), mask_update_op, 558 no_update_op) 559 560 def add_pruning_summaries(self): 561 """Adds summaries of weight sparsities and thresholds.""" 562 with ops.name_scope(self._spec.name + '_summaries'): 563 summary.scalar('sparsity', self._sparsity) 564 summary.scalar('last_mask_update_step', self._last_update_step) 565 masks = get_masks() 566 thresholds = get_thresholds() 567 for mask, threshold in zip(masks, thresholds): 568 summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask)) 569 summary.scalar(threshold.op.name + '/threshold', threshold) 570 571 def print_hparams(self): 572 logging.info(self._spec.to_json()) 573