1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Logic to fold batch norm into preceding convolution or FC layers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import re 22from tensorflow.contrib.quantize.python import common 23from tensorflow.contrib.quantize.python import graph_matcher 24from tensorflow.contrib.quantize.python import input_to_ops 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.layers import utils 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn 32from tensorflow.python.ops import nn_ops 33from tensorflow.python.ops import variable_scope 34from tensorflow.python.util import compat 35 36 37def FoldBatchNorms(graph, is_training, freeze_batch_norm_delay=None): 38 """Finds batch norm layers and folds them into preceding layers. 39 40 Folding only affects the following layers: Conv2D, fully connected, depthwise 41 convolution. 42 43 Args: 44 graph: Graph to walk and modify. 45 is_training: Bool, true if training. 46 freeze_batch_norm_delay: How many steps to wait before freezing moving mean 47 and variance and using them for batch normalization. This value is used 48 only when is_training is True. 49 Raises: 50 ValueError: When batch norm folding fails. 51 """ 52 _FoldFusedBatchNorms( 53 graph, is_training, freeze_batch_norm_delay=freeze_batch_norm_delay) 54 _FoldUnfusedBatchNorms( 55 graph, 56 is_training=is_training, 57 freeze_batch_norm_delay=freeze_batch_norm_delay) 58 59 60def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): 61 """Finds fused batch norm layers and folds them into preceding layers. 62 63 Folding only affects the following layers: Conv2D, fully connected, depthwise 64 convolution. 65 66 Args: 67 graph: Graph to walk and modify. 68 is_training: Bool, true if training. 69 freeze_batch_norm_delay: How many steps to wait before freezing moving mean 70 and variance and using them for batch normalization. 71 72 Raises: 73 ValueError: When batch norm folding fails. 74 """ 75 for match in _FindFusedBatchNorms(graph): 76 scope, sep, _ = match.layer_op.name.rpartition('/') 77 # Make sure new ops are added to `graph` and put on the same device as 78 # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope 79 # named `scope`. Otherwise, TF creates a unique scope whose name starts with 80 # `scope`. 81 with graph.as_default(), graph.name_scope(scope + sep): 82 with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep): 83 # new weights = old weights * gamma / sqrt(variance + epsilon) 84 # new biases = -mean * gamma / sqrt(variance + epsilon) + beta 85 multiplier_tensor = match.gamma_tensor * math_ops.rsqrt( 86 match.variance_tensor + match.bn_op.get_attr('epsilon')) 87 bias_tensor = math_ops.subtract( 88 match.beta_tensor, 89 match.mean_tensor * multiplier_tensor, 90 name='bias') 91 92 correction_scale, correction_recip, correction_offset = None, None, None 93 if is_training: 94 correction_scale, correction_recip, correction_offset = ( 95 _ComputeBatchNormCorrections( 96 context='', 97 match=match, 98 freeze_batch_norm_delay=freeze_batch_norm_delay)) 99 # The shape of depthwise weights is different, so we need to reshape the 100 # multiplier_tensor to ensure that the scaled_weight_tensor has the 101 # expected shape. 102 weights = match.weight_tensor 103 if match.layer_op.type == 'DepthwiseConv2dNative': 104 new_shape = [ 105 match.weight_tensor.get_shape().as_list()[2], 106 match.weight_tensor.get_shape().as_list()[3] 107 ] 108 multiplier_tensor = array_ops.reshape( 109 multiplier_tensor, new_shape, name='scale_reshape') 110 111 if correction_scale is not None: 112 correction_scale = array_ops.reshape( 113 correction_scale, new_shape, name='correction_reshape') 114 115 if correction_scale is not None: 116 weights = math_ops.multiply( 117 correction_scale, weights, name='correction_mult') 118 119 scaled_weight_tensor = math_ops.multiply( 120 weights, multiplier_tensor, name='mul_fold') 121 122 new_layer_tensor = _CloneWithNewOperands( 123 match.layer_op, match.input_tensor, scaled_weight_tensor, 124 match.batch_to_space_op) 125 126 if correction_recip is not None: 127 new_layer_tensor = math_ops.multiply( 128 correction_recip, new_layer_tensor, name='post_conv_mul') 129 new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset), 130 'correction_add') 131 132 bias_add_tensor = math_ops.add( 133 new_layer_tensor, bias_tensor, name='add_fold') 134 135 nodes_modified_count = common.RerouteTensor(bias_add_tensor, 136 match.output_tensor) 137 if nodes_modified_count == 0: 138 raise ValueError('Folding batch norms failed, %s had no outputs.' % 139 match.output_tensor.name) 140 141 142def _FindFusedBatchNorms(graph): 143 """Finds all ops and tensors related to found FusedBatchNorms. 144 145 Args: 146 graph: Graph to inspect. 147 148 Returns: 149 _FusedBatchNormMatches. 150 """ 151 input_pattern = graph_matcher.OpTypePattern('*') 152 # In practice, the weight pattern can match a Variable or a SpaceToBatchND 153 # operation that follows a variable for atrous convolutions. 154 weight_pattern = graph_matcher.OpTypePattern('*') 155 gamma_pattern = graph_matcher.OpTypePattern('*') 156 beta_pattern = graph_matcher.OpTypePattern('*') 157 mean_pattern = graph_matcher.OpTypePattern('*') 158 variance_pattern = graph_matcher.OpTypePattern('*') 159 160 moving_average_pattern = graph_matcher.OpTypePattern('*') 161 bn_decay_pattern = graph_matcher.OpTypePattern('*') 162 layer_pattern = graph_matcher.OpTypePattern( 163 'Conv2D|DepthwiseConv2dNative|MatMul', 164 inputs=[input_pattern, weight_pattern]) 165 batch_to_space_pattern = graph_matcher.OpTypePattern( 166 'BatchToSpaceND', 167 inputs=[ 168 layer_pattern, 169 graph_matcher.OpTypePattern('*'), 170 graph_matcher.OpTypePattern('*') 171 ]) 172 # Identity between conv/matmul and bn 173 layer_pattern_with_identity = graph_matcher.OpTypePattern( 174 'Identity', 175 inputs=[ 176 graph_matcher.OneofPattern([batch_to_space_pattern, layer_pattern]) 177 ]) 178 layer_output_pattern = graph_matcher.OneofPattern( 179 [layer_pattern_with_identity, layer_pattern, batch_to_space_pattern]) 180 181 # MatMul has a Reshape between it and FusedBatchNorm. 182 matmul_reshape_pattern = graph_matcher.OpTypePattern( 183 'Reshape', 184 inputs=[layer_output_pattern, 185 graph_matcher.OpTypePattern('*')]) 186 187 batch_norm_pattern = graph_matcher.OpTypePattern( 188 'FusedBatchNorm', 189 inputs=[ 190 graph_matcher.OneofPattern( 191 [matmul_reshape_pattern, layer_output_pattern]), gamma_pattern, 192 beta_pattern, mean_pattern, variance_pattern 193 ]) 194 matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( 195 'Reshape', inputs=[batch_norm_pattern, 196 graph_matcher.OpTypePattern('*')]) 197 198 batch_norm_identity_pattern = graph_matcher.OpTypePattern( 199 'Identity', inputs=[batch_norm_pattern, matmul_bn_output_reshape_pattern]) 200 201 bn_identity_matcher = graph_matcher.GraphMatcher(batch_norm_identity_pattern) 202 203 bn_matcher = graph_matcher.GraphMatcher( 204 graph_matcher.OneofPattern( 205 [matmul_bn_output_reshape_pattern, batch_norm_pattern])) 206 207 moving_average_sub_pattern = graph_matcher.OpTypePattern( 208 'Sub', inputs=[moving_average_pattern, batch_norm_pattern]) 209 moving_average_mul_pattern = graph_matcher.OpTypePattern( 210 'Mul', inputs=[moving_average_sub_pattern, bn_decay_pattern]) 211 212 moving_avg_mul_matcher = graph_matcher.GraphMatcher( 213 moving_average_mul_pattern) 214 215 def _GetLayerMatch(match_result): 216 """Populates a layer match object containing ops/tensors for folding BNs. 217 218 Args: 219 match_result: Matched result from graph matcher 220 221 Returns: 222 layer_op: Matching conv/fc op prior to batch norm 223 BatchNormMatch: _BatchNormMatch containing all required batch norm 224 parameters. 225 """ 226 moving_mean_tensor = None 227 moving_variance_tensor = None 228 bn_decay_mean_tensor = None 229 bn_decay_var_tensor = None 230 batch_to_space_op = None 231 layer_op = match_result.get_op(layer_pattern) 232 layer_tensor = match_result.get_tensor(layer_pattern) 233 bn_id_op = match_result.get_op(batch_norm_identity_pattern) 234 bn_op = match_result.get_op(batch_norm_pattern) 235 if bn_id_op is None: 236 bn_id_op = bn_op 237 238 batch_epsilon = bn_op.get_attr('epsilon') 239 240 # In the MatMul case, the output of batch norm is reshaped back into a 241 # 2D tensor, so the output_tensor is the output of the Reshape op. 242 output_tensor = bn_op.outputs[0] 243 if layer_op.type == 'MatMul': 244 output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) 245 # If the matcher didn't match matmul_bn_output_reshape, there will be 246 # another match for this 'MatMul' later, so we can skip this one. 247 if output_reshape_op is None: 248 return None, None 249 output_tensor = output_reshape_op.outputs[0] 250 251 # Ensure that the output tensor has consumers, otherwise this is a dangling 252 # node and not a match. 253 if not output_tensor.consumers(): 254 return None, None 255 256 batch_to_space_op = match_result.get_op(batch_to_space_pattern) 257 input_tensor = match_result.get_tensor(input_pattern) 258 weight_tensor = match_result.get_tensor(weight_pattern) 259 gamma_tensor = match_result.get_tensor(gamma_pattern) 260 beta_tensor = match_result.get_tensor(beta_pattern) 261 # FusedBatchNorm in training is different from that in inference. It takes 262 # empty 'mean' and empty 'variance', and produces the mean and the variance 263 # of the batch. Therefore, when is_training is true, mean_tensor and 264 # variance_tensor point to 1st and 2nd (0-based) output of bn_op, 265 # respectively; when is_training is false, they point to bn_op's inputs. 266 is_training = bn_op.get_attr('is_training') 267 if is_training: 268 # FusedBatchNormGrad doesn't compute gradients of the batch_mean and 269 # batch_variance outputs, so we need to substitute our own custom 270 # gradient. 271 # TODO(suharshs, raghuramank): Find a way to avoid needing this hack. 272 # pylint: disable=protected-access 273 bn_op._set_attr( 274 '_gradient_op_type', 275 attr_value_pb2.AttrValue(s=compat.as_bytes('FoldFusedBatchNormGrad'))) 276 # pylint: enable=protected-access 277 mean_tensor = bn_op.outputs[1] 278 # The batch variance used during forward and backward prop is biased, 279 # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average 280 # calculation, the variance is corrected by the term N/N-1 (Bessel's 281 # correction). The variance tensor read from FuseBatchNorm has Bessel's 282 # correction applied, so we undo it here. 283 scope, sep, _ = bn_op.name.rpartition('/') 284 g = ops.get_default_graph() 285 with g.as_default(), g.name_scope(scope + sep): 286 n = math_ops.cast( 287 array_ops.size(layer_tensor) / array_ops.size(mean_tensor), 288 dtypes.float32) 289 variance_tensor = math_ops.multiply( 290 bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction') 291 # TODO(suharshs): Find a way to get rid of this inner match. 292 for mul_match_result in moving_avg_mul_matcher.match_graph(graph): 293 sub_op = mul_match_result.get_op(moving_average_sub_pattern) 294 if sub_op.inputs[1].name == bn_op.outputs[1].name: 295 # During training: Batch Mean is bn_op.outputs[1] 296 moving_mean_tensor = sub_op.inputs[0] 297 bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern) 298 if sub_op.inputs[1].name == bn_op.outputs[2].name: 299 # During training: Batch Var is bn_op.outputs[2] 300 moving_variance_tensor = sub_op.inputs[0] 301 bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern) 302 else: 303 mean_tensor = match_result.get_tensor(mean_pattern) 304 variance_tensor = match_result.get_tensor(variance_pattern) 305 306 return layer_op, _BatchNormMatch( 307 layer_op=layer_op, 308 bn_op=bn_op, 309 output_tensor=output_tensor, 310 input_tensor=input_tensor, 311 weight_tensor=weight_tensor, 312 gamma_tensor=gamma_tensor, 313 beta_tensor=beta_tensor, 314 mean_tensor=mean_tensor, 315 variance_tensor=variance_tensor, 316 moving_mean_tensor=moving_mean_tensor, 317 moving_variance_tensor=moving_variance_tensor, 318 bn_decay_mean_tensor=bn_decay_mean_tensor, 319 bn_decay_var_tensor=bn_decay_var_tensor, 320 batch_epsilon=batch_epsilon, 321 batch_to_space_op=batch_to_space_op) 322 323 layer_matches = [] 324 # We use matched_layer_set to ensure that layers aren't matched multiple 325 # times. 326 matched_layer_set = set() 327 for match_result in bn_identity_matcher.match_graph(graph): 328 layer_op, layer_match = _GetLayerMatch(match_result) 329 if layer_op is not None: 330 if layer_op not in matched_layer_set: 331 matched_layer_set.add(layer_op) 332 layer_matches.append(layer_match) 333 334 for match_result in bn_matcher.match_graph(graph): 335 layer_op, layer_match = _GetLayerMatch(match_result) 336 if layer_op is not None: 337 if layer_op not in matched_layer_set: 338 matched_layer_set.add(layer_op) 339 layer_matches.append(layer_match) 340 341 return layer_matches 342 343 344def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay): 345 """Computes batch norm correction params. 346 347 Before batch normalization is frozen: 348 We use batch statistics for batch norm. 349 correction_scale = sigma_b/sigma_mv 350 correction_recip = 1/correction_scale 351 correction_offset = 0 352 353 After batch normalization is frozen: 354 correction_scale = sigma_b/sigma_mv 355 correction_recip = 1 356 correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). 357 358 Batch norm is frozen if global_step > bn_freeze_delay. 359 The corrections ensure that: 360 a) The weights are quantized after scaling by gamma/sigma_mv. This enables 361 smoother training as the scaling on the weights changes slowly, rather than 362 jump across mini-batches 363 b) Changing the values of the corrections allows for one to switch between 364 using batch statistics to using moving mean and average, without requiring 365 changes to batch_norm 366 367 368 Args: 369 context: The scope under which we look for batch norm params 370 match: Object containing required batch norm tensors for correction 371 computation. 372 freeze_batch_norm_delay: Delay in steps at which computation switches 373 from regular batch norm to frozen mean and variance. 374 375 376 Returns: 377 A tuple of correction_scale, correction_recip, correction_offset 378 """ 379 380 g = ops.get_default_graph() 381 prefix = '' if not context else context 382 with g.name_scope(prefix + 'batch_norm_correction'): 383 recip_sigma_mv = math_ops.rsqrt( 384 match.moving_variance_tensor + match.batch_epsilon) 385 recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) 386 correction_scale = math_ops.divide( 387 recip_sigma_mv, recip_sigma, name='scale_compute') 388 correction_scale = array_ops.identity( 389 correction_scale, name='correction_scale') 390 correction_recip = math_ops.reciprocal( 391 correction_scale, name='reciprocal_compute') 392 correction_offset = math_ops.multiply( 393 match.gamma_tensor, 394 match.mean_tensor * recip_sigma - 395 match.moving_mean_tensor * recip_sigma_mv, 396 name='offset_compute') 397 398 if freeze_batch_norm_delay is not None: 399 use_mv_avg = math_ops.greater_equal( 400 common.CreateOrGetQuantizationStep(), 401 freeze_batch_norm_delay, 402 name='use_moving_average') 403 else: 404 use_mv_avg = False 405 406 bn_decay_zero = 0.0 407 bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) 408 bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) 409 410 bn_decay_mean_out = utils.smart_cond( 411 use_mv_avg, 412 lambda: bn_decay_zero, 413 lambda: match.bn_decay_mean_tensor, 414 name='freeze_moving_mean') 415 416 common.RerouteTensor( 417 bn_decay_mean_out, 418 match.bn_decay_mean_tensor, 419 can_modify=bn_decay_mean_consumers) 420 421 bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) 422 bn_decay_var_out = utils.smart_cond( 423 use_mv_avg, 424 lambda: bn_decay_zero, 425 lambda: match.bn_decay_var_tensor, 426 name='freeze_moving_var') 427 common.RerouteTensor( 428 bn_decay_var_out, 429 match.bn_decay_var_tensor, 430 can_modify=bn_decay_var_consumers) 431 432 correction_recip = utils.smart_cond( 433 use_mv_avg, 434 lambda: array_ops.ones(correction_scale.shape), 435 lambda: correction_recip, 436 name='correction_recip') 437 438 correction_offset = utils.smart_cond( 439 use_mv_avg, 440 lambda: correction_offset, 441 lambda: array_ops.zeros(correction_offset.shape), 442 name='correction_offset') 443 return correction_scale, correction_recip, correction_offset 444 445 446def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor, 447 batch_to_space_op): 448 """Clones layer_op with input_tensor and weight_tensor as new inputs.""" 449 new_layer_name = layer_op.name.split('/')[-1] + '_Fold' 450 if layer_op.type == 'Conv2D': 451 return nn_ops.conv2d( 452 input_tensor, 453 weight_tensor, 454 strides=layer_op.get_attr('strides'), 455 padding=layer_op.get_attr('padding'), 456 use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), 457 data_format=layer_op.get_attr('data_format').decode(), 458 name=new_layer_name) 459 elif layer_op.type == 'MatMul': 460 return math_ops.matmul( 461 input_tensor, 462 weight_tensor, 463 transpose_a=layer_op.get_attr('transpose_a'), 464 transpose_b=layer_op.get_attr('transpose_b'), 465 name=new_layer_name) 466 elif layer_op.type == 'DepthwiseConv2dNative': 467 # We don't copy dilation rate because we reuse the input SpaceToBatch 468 # and create our own BatchToSpace operation below. 469 conv = nn.depthwise_conv2d( 470 input_tensor, 471 weight_tensor, 472 strides=layer_op.get_attr('strides'), 473 padding=layer_op.get_attr('padding'), 474 name=new_layer_name) 475 # Copy the batch to space operation if we have a atrous convolution. 476 if batch_to_space_op: 477 batch_to_space_op = layer_op.outputs[0].consumers()[0] 478 # TODO(suharshs): It's hard to make this name match with the unfused name. 479 # Restructure this code to not rely on scope at all. 480 new_batch_to_space_name = batch_to_space_op.name.split('/')[-1] + '_Fold' 481 conv = array_ops.batch_to_space_nd( 482 conv, 483 batch_to_space_op.inputs[1], 484 batch_to_space_op.inputs[2], 485 name=new_batch_to_space_name) 486 return conv 487 else: 488 raise ValueError('Cannot handle operation of type: %s' % layer_op.type) 489 490 491@ops.RegisterGradient('FoldFusedBatchNormGrad') 492def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, 493 unused_2): 494 x = op.inputs[0] 495 n = math_ops.cast( 496 array_ops.size(x) / array_ops.size(grad_mean), dtypes.float32) 497 dmean_dx = grad_mean / n 498 dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1) 499 return (dmean_dx + dvar_dx), None, None, None, None 500 501 502def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): 503 """Finds unfused batch norm layers and folds them into preceding layers. 504 505 Folding only affects the following layers: Conv2D, fully connected, depthwise 506 convolution. 507 508 Args: 509 graph: Graph to walk and modify. 510 is_training: Bool, True if training. 511 freeze_batch_norm_delay: How many steps to wait before freezing moving mean 512 and variance and using them for batch normalization. 513 514 Raises: 515 ValueError: When batch norm folding fails. 516 """ 517 input_to_ops_map = input_to_ops.InputToOps(graph) 518 519 for bn in common.BatchNormGroups(graph): 520 has_scaling = _HasScaling(graph, input_to_ops_map, bn) 521 522 if not _IsValidUnfusedBatchNorm(graph, bn): 523 continue 524 525 # The mangling code intimately depends on BatchNorm node's internals. 526 original_op, folded_op = _CreateFoldedOp( 527 graph, 528 bn, 529 has_scaling=has_scaling, 530 freeze_batch_norm_delay=freeze_batch_norm_delay, 531 is_training=is_training) 532 533 activation = common.GetEndpointActivationOp(graph, bn) 534 if activation: 535 nodes_modified_count = common.RerouteTensor( 536 folded_op.outputs[0], original_op.outputs[0], can_modify=[activation]) 537 if nodes_modified_count != 1: 538 raise ValueError('Unexpected inputs to op: %s' % activation.name) 539 continue 540 541 # Treat consumer ops in bypass modules differently since they have Add 542 # operations instead of Relu* above. 543 # Changes to make sure that the correct scope is selected for the bypass add 544 # The rule here is that if the scope is of the form: str1/str2 for the 545 # batch norm, 546 # the bypass add is at scope str1. If bn is of scope just str1, then the 547 # bypass add is at scope ''. 548 # If there is no batch norm, then there is no bypass add. 549 add_bypass_ctx = '' 550 if bn: 551 try: 552 add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) 553 except AttributeError: 554 add_bypass_ctx = '' 555 556 if add_bypass_ctx: 557 add_bypass_ctx = add_bypass_ctx + '/' 558 559 add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add') 560 nodes_modified_count = common.RerouteTensor( 561 folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass]) 562 if nodes_modified_count != 1: 563 raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) 564 565 566def _IsValidUnfusedBatchNorm(graph, context): 567 """Checks that the output of the unfused batch norm has consumers.""" 568 add_shift = graph.get_operation_by_name(context + 569 'BatchNorm/batchnorm_1/add_1') 570 # Ensure that the output tensor of batch norm has consumers, otherwise this 571 # is a dangling node and not a match. 572 return bool(add_shift.outputs[0].consumers()) 573 574 575def _FindMatchingTensor(graph, match_pattern, scope): 576 """Finds best match of ops matching match_pattern with scope. 577 578 Example: _FindMatchingTensor(graph,'/BatchNorm/moments/Squeeze', 579 'MobilenetV1/MobilenetV1/Conv2d_0/') returns: 580 Tensor('MobilenetV1/Conv2d_0/BatchNorm/moments/Squeeze') 581 582 Args: 583 graph: Graph to inspect. 584 match_pattern: Part of the name of the op that we need to match, should 585 be present in the op's name 586 scope: The scope of the op. All the elements of the scope need not be 587 present in the op's name. 588 589 Returns: 590 Tensor from graph that provides the best match to the match_pattern and 591 scope 592 """ 593 594 oplist = graph.get_operations() 595 split_context = set(scope.split('/')) 596 match_dict = {} 597 for op in oplist: 598 if op.name.endswith(match_pattern): 599 split_name = op.name.split('/') 600 num_matches = len(set(split_name) & split_context) 601 602 if num_matches > 0 or not scope: 603 match_dict[op.name] = num_matches 604 # match_dict contains matching op names from graph with values being 605 # number of matches to scope. We pick the key with the most matches 606 if match_dict: 607 max_key = max(match_dict, key=match_dict.get) 608 return graph.get_tensor_by_name(max_key + ':0') 609 else: 610 return None 611 612 613def _GetBatchNormParams(graph, context, has_scaling): 614 """Extracts relevant tensors for folding batch norms. 615 616 Args: 617 graph: Graph to inspect. 618 context: The scope under which we look for batch norm params 619 has_scaling: Bool that specifies if scaling is done as part of batch norm. 620 621 Returns: 622 _BatchNormMatch containing all required batch norm parameters. 623 """ 624 gamma_tensor = None 625 batch_mean_tensor = None 626 batch_variance_tensor = None 627 moving_mean_tensor = None 628 moving_variance_tensor = None 629 batch_epsilon = None 630 bn_decay_mean_tensor = None 631 bn_decay_var_tensor = None 632 633 # TODO(raghuramank) This code relies on string matching and needs to be 634 # updated if unfused batch norm continues to be widely used 635 # Matching variable names is brittle and relies on scoping 636 # conventions. Fused batch norm folding is more robust. Support for unfused 637 # batch norms will be deprecated as we move forward. Fused batch norms allow 638 # for faster training and should be used whenever possible. 639 # context contains part of the names of the tensors we are interested in: 640 # For MobilenetV1, the context has repetitions: 641 # MobilenetV1/MobilenetV1/Conv2d_3_depthwise 642 # when the moving_mean tensor has the name: 643 # MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/read 644 # To pick the correct variable name, it is necessary to ignore the repeating 645 # header. 646 647 # For MobilenetV2, this problem does not exist: 648 # The context is: MobilenetV2/expanded_conv_3/depthwise 649 # and the names of the tensors start with a single MobilenetV2 650 # The moving mean for example, has the name: 651 # MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read 652 # We identify the best match for an op by checking for 653 # 1. The suffix of the op is exactly matched 654 # 2. Maximum number of matches with the context.The matching 655 # score is given by the number of parts of context (split by /) that 656 # are present in the parts of the tensor name (again split by /). 657 # For example: scope= MobilenetV2/MobilenetV2/expanded_conv_3 and 658 # op.name = MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read 659 # will have 2 matches,scope with a different conv layer will have one match. 660 661 op_suffix_mean = 'BatchNorm/moments/Squeeze' 662 op_suffix_variance = 'BatchNorm/moments/Squeeze_1' 663 op_suffix_epsilon = 'BatchNorm/batchnorm_1/add/y' 664 op_suffix_bn_decay_mean = 'BatchNorm/AssignMovingAvg/decay' 665 op_suffix_bn_decay_var = 'BatchNorm/AssignMovingAvg_1/decay' 666 667 if variable_scope.get_variable_scope().use_resource: 668 op_suffix_gamma = 'BatchNorm/gamma/Read/ReadVariableOp' 669 op_suffix_moving_variance = ( 670 'BatchNorm/moving_variance/Read/ReadVariableOp') 671 op_suffix_moving_mean = ('BatchNorm/moving_mean/Read/ReadVariableOp') 672 else: 673 op_suffix_gamma = 'BatchNorm/gamma' 674 op_suffix_moving_variance = 'BatchNorm/moving_variance/read' 675 op_suffix_moving_mean = 'BatchNorm/moving_mean/read' 676 # Parse through list of ops to find relevant ops 677 678 batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context) 679 batch_variance_tensor = _FindMatchingTensor(graph, op_suffix_variance, 680 context) 681 moving_mean_tensor = _FindMatchingTensor(graph, op_suffix_moving_mean, 682 context) 683 moving_variance_tensor = _FindMatchingTensor(graph, op_suffix_moving_variance, 684 context) 685 batch_epsilon = _FindMatchingTensor(graph, op_suffix_epsilon, context) 686 bn_decay_mean_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_mean, 687 context) 688 bn_decay_var_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_var, 689 context) 690 if batch_mean_tensor is None and moving_mean_tensor is None: 691 ValueError('Error folding unfused batch norms') 692 if has_scaling: 693 gamma_tensor = _FindMatchingTensor(graph, op_suffix_gamma, context) 694 695 if not has_scaling: 696 gamma_tensor = array_ops.ones(moving_mean_tensor.shape) 697 698 return _BatchNormMatch( 699 layer_op=None, 700 bn_op=None, 701 output_tensor=None, 702 input_tensor=None, 703 weight_tensor=None, 704 gamma_tensor=gamma_tensor, 705 beta_tensor=None, 706 mean_tensor=batch_mean_tensor, 707 variance_tensor=batch_variance_tensor, 708 moving_mean_tensor=moving_mean_tensor, 709 moving_variance_tensor=moving_variance_tensor, 710 bn_decay_mean_tensor=bn_decay_mean_tensor, 711 bn_decay_var_tensor=bn_decay_var_tensor, 712 batch_epsilon=batch_epsilon, 713 batch_to_space_op=None) 714 715 716def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, 717 is_training): 718 """Folds in batch norm layer into preceding convolution or FC layer. 719 720 Creates 3 new nodes, connects their inputs and adds them to the graph: 721 mul is cloned into mul_fold, Conv2D or MatMul, or DepthwiseConv2d is cloned 722 into respective *_Fold, add is cloned into add_fold. 723 724 Args: 725 graph: Graph to modify. 726 context: String, batch norm context, i.e. node into which BatchNorm is 727 nested. 728 has_scaling: Whether the batch norm has scaling enabled. 729 freeze_batch_norm_delay: How many steps to wait before freezing moving mean 730 and variance and using them for batch normalization. 731 is_training: Bool, true if training. 732 733 Raises: 734 ValueError: When operation type is not supported, or input and output tensor 735 shapes mismatch for created operations: mul_fold, add_fold. 736 737 Returns: 738 A pair of Operations, the first is the original consumer node of the batch 739 norm (../BatchNorm/batchnorm_1/add_1), the second is the consumer node of 740 the folded graph (add_fold). 741 """ 742 mul_scale_name = 'mul_1' if has_scaling else 'mul' 743 mul_scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' + 744 mul_scale_name) 745 op_below = mul_scale.inputs[0].op 746 # Skip over the BatchToSpace operation in the case of atrous convolutions. 747 batch_to_space_op = None 748 if op_below.type == 'BatchToSpaceND': 749 batch_to_space_op = op_below 750 op_below = op_below.inputs[0].op 751 weights = op_below.inputs[1] 752 match = _GetBatchNormParams( 753 graph=graph, context=context, has_scaling=has_scaling) 754 correction_scale, correction_recip, correction_offset = None, None, None 755 if is_training: 756 correction_scale, correction_recip, correction_offset = ( 757 _ComputeBatchNormCorrections( 758 context=context, 759 match=match, 760 freeze_batch_norm_delay=freeze_batch_norm_delay)) 761 # Special handling for weights of depthwise convolution. 762 if op_below.type == 'DepthwiseConv2dNative': 763 new_shape = [ 764 weights.get_shape().as_list()[2], 765 weights.get_shape().as_list()[3] 766 ] 767 scale_name = 'mul' if has_scaling else 'Rsqrt' 768 scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' + 769 scale_name) 770 scale = array_ops.reshape(scale.outputs[0], new_shape, 771 context + 'scale_reshape') 772 773 if correction_scale is not None: 774 correction_scale = array_ops.reshape(correction_scale, new_shape, 775 context + 'correction_reshape') 776 with ops.device(mul_scale.device): 777 weights = math_ops.multiply(correction_scale, weights, 778 context + 'correction_mult') 779 780 mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights), 781 (1, scale)]) 782 elif op_below.type in ['Conv2D', 'MatMul']: 783 784 if correction_scale is not None: 785 with ops.device(mul_scale.device): 786 weights = math_ops.multiply(correction_scale, weights, 787 context + 'correction_mult') 788 mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights)]) 789 else: 790 raise ValueError('Cannot handle operation of type: %s' % op_below.type) 791 _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0]) 792 793 conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', 794 [(1, mul_fold.outputs[0])]) 795 796 add_shift = graph.get_operation_by_name(context + 797 'BatchNorm/batchnorm_1/add_1') 798 799 corrected_output = conv_or_fc_folded.outputs[0] 800 # Copy the batch to space operation if we have a atrous convolution. 801 if batch_to_space_op: 802 corrected_output = array_ops.batch_to_space_nd( 803 corrected_output, 804 batch_to_space_op.inputs[1], 805 batch_to_space_op.inputs[2], 806 name=batch_to_space_op.name + '_Fold') 807 if correction_offset is not None: 808 with ops.device(conv_or_fc_folded.device): 809 corrected_output = math_ops.multiply(correction_recip, corrected_output, 810 context + 'post_conv_mul') 811 corrected_output = math_ops.add(corrected_output, (correction_offset), 812 context + 'correction_add') 813 add_fold = _CloneOp(add_shift, context + 'add_fold', [(0, corrected_output)]) 814 _AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0]) 815 return add_shift, add_fold 816 817 818def _CloneOp(op, new_name, new_inputs): 819 """Clones a given op, replaces its name and some of its inputs. 820 821 Args: 822 op: Operation to modify. 823 new_name: String, a new name to set on cloned op. 824 new_inputs: A list of tuples (idx, tensor), each input with corresponding 825 index will be replaced by the given Tensor in the cloned op. 826 827 Returns: 828 Operation, the cloned op. 829 830 Raises: 831 TypeError: When Operation type is not supported. 832 ValueError: When input shapes are incompatible. 833 """ 834 inputs = list(op.inputs) 835 for new_input in new_inputs: 836 inputs[new_input[0]] = new_input[1] 837 return _OP_CLONER.Clone(op, inputs, new_name) 838 839 840class _OpCloner(object): 841 """Helper class that clones tf.Operations based on their type.""" 842 843 def __init__(self): 844 self.op_type_to_action = { 845 'Mul': self._CloneMul, 846 'Add': self._CloneAdd, 847 'Conv2D': self._CloneConv2d, 848 'DepthwiseConv2dNative': self._CloneDepthwiseConv2d, 849 'MatMul': self._CloneMatMul, 850 } 851 852 def _CloneMul(self, op, inputs, new_name): 853 del op # Unused. 854 return math_ops.multiply(inputs[0], inputs[1], name=new_name).op 855 856 def _CloneAdd(self, op, inputs, new_name): 857 del op # Unused. 858 return math_ops.add(inputs[0], inputs[1], name=new_name).op 859 860 def _CloneConv2d(self, op, inputs, new_name): 861 input_tensor = inputs[0] 862 weights = inputs[1] 863 self._AssertConvShapes(op.name, input_tensor, weights) 864 return nn_ops.conv2d( 865 input_tensor, 866 weights, 867 strides=op.get_attr('strides'), 868 padding=op.get_attr('padding'), 869 use_cudnn_on_gpu=op.get_attr('use_cudnn_on_gpu'), 870 data_format=op.get_attr('data_format').decode(), 871 name=new_name).op 872 873 def _CloneDepthwiseConv2d(self, op, inputs, new_name): 874 input_tensor = inputs[0] 875 weights = inputs[1] 876 self._AssertConvShapes(op.name, input_tensor, weights) 877 return nn.depthwise_conv2d( 878 input_tensor, 879 weights, 880 strides=op.get_attr('strides'), 881 padding=op.get_attr('padding'), 882 name=new_name).op 883 884 def _CloneMatMul(self, op, inputs, new_name): 885 weights = inputs[0] 886 input_tensor = inputs[1] 887 self._AssertFCShapes(op.name, weights, input_tensor) 888 return math_ops.matmul( 889 weights, 890 input_tensor, 891 transpose_a=op.get_attr('transpose_a'), 892 transpose_b=op.get_attr('transpose_b'), 893 name=new_name).op 894 895 def Clone(self, op, inputs, new_name): 896 try: 897 return self.op_type_to_action[op.type](op, inputs, new_name) 898 except KeyError: 899 raise TypeError('Unsupported operation type: %s' % op.type) 900 901 def _AssertConvShapes(self, op_name, input_tensor, weights): 902 """Makes sure that convolution inputs have compatible shapes. 903 904 Args: 905 op_name: Operation name, only used in error message. 906 input_tensor: Input that is convolved. 907 weights: Weights of the convolution filter. 908 909 Raises: 910 ValueError: When input shapes are incompatible. 911 """ 912 input_shape = input_tensor.get_shape() 913 weights_shape = weights.get_shape() 914 if (len(input_shape) != 4 or len(weights_shape) != 4 or 915 input_shape[3] != weights_shape[2]): 916 raise ValueError('Incompatible shapes for op %s inputs: %s and %s' % 917 (op_name, input_shape, weights_shape)) 918 919 def _AssertFCShapes(self, op_name, weights, input_tensor): 920 """Makes sure that FC layer inputs have compatible shapes. 921 922 Args: 923 op_name: Operation name, only used in error message. 924 weights: Weights used in FC layer. 925 input_tensor: Input into FC layer. 926 927 Raises: 928 ValueError: When input shapes are incompatible. 929 """ 930 weights_shape = weights.get_shape() 931 input_shape = input_tensor.get_shape() 932 if (len(weights_shape) != 2 or len(input_shape) != 2 or 933 weights_shape[1] != input_shape[0]): 934 raise ValueError('Incompatible shapes for op %s inputs: %s and %s' % 935 (op_name, weights_shape, input_shape)) 936 937_OP_CLONER = _OpCloner() 938 939 940def _AssertShapesMatch(op_name, in_tensor, out_tensor): 941 """Makes sure that shapes of input and output tensors are compatible. 942 943 Args: 944 op_name: String, operation name, only used in error message. 945 in_tensor: Tensor, input tensor. 946 out_tensor: Tensor, output tensor. 947 948 Raises: 949 ValueError: When input and output tensors have different shapes. 950 """ 951 in_shape = in_tensor.get_shape() 952 out_shape = out_tensor.get_shape() 953 954 if not in_shape.is_compatible_with(out_shape): 955 raise ValueError('%s should not change tensor shape: input %s, ' 956 'output %s' % (op_name, in_shape, out_shape)) 957 958 959def _HasScaling(graph, input_to_ops_map, bn): 960 r"""Checks if batch norm has scaling enabled. 961 962 Difference between batch norm with scaling and without is that with scaling: 963 964 Rsqrt -> mul -> mul_1 965 \-> mul_2 966 967 where 968 mul multiplies gamma by inverse square root of EMA of batch variance, 969 mul_1 multiplies output of mul with output from the base operation 970 (convolution, FC or depthwise convolution), 971 mul_2 multiplies output of mul with EMA of batch mean, 972 and without scaling: 973 974 Rsqrt -> mul 975 \-> mul_1 976 977 where 978 mul multiplies the inverse square root of EMA of batch variance with output 979 from the base operation, 980 mul_1 multiplies inverse square root of EMA of batch variance with EMA 981 of batch mean. 982 983 Args: 984 graph: Graph to inspect. 985 input_to_ops_map: InputToOps object containing mapping from tensor's name 986 to ops that take it as input. 987 bn: Batch norm layer prefix string. 988 989 Returns: 990 A boolean indicating whether this batch norm layer has scaling enabled. 991 """ 992 rsqrt_op = graph.get_operation_by_name(bn + 'BatchNorm/batchnorm_1/Rsqrt') 993 rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) 994 995 return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 996 997 998class _BatchNormMatch(object): 999 """Contains all information related to a found Fused/UnfusedBatchNorm.""" 1000 1001 def __init__(self, layer_op, bn_op, output_tensor, input_tensor, 1002 weight_tensor, gamma_tensor, beta_tensor, mean_tensor, 1003 variance_tensor, moving_mean_tensor, moving_variance_tensor, 1004 bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon, 1005 batch_to_space_op): 1006 self._layer_op = layer_op 1007 self._bn_op = bn_op 1008 self._output_tensor = output_tensor 1009 self._input_tensor = input_tensor 1010 self._weight_tensor = weight_tensor 1011 self._gamma_tensor = gamma_tensor 1012 self._beta_tensor = beta_tensor 1013 self._mean_tensor = mean_tensor 1014 self._variance_tensor = variance_tensor 1015 self._moving_mean_tensor = moving_mean_tensor 1016 self._moving_variance_tensor = moving_variance_tensor 1017 self._bn_decay_mean_tensor = bn_decay_mean_tensor 1018 self._bn_decay_var_tensor = bn_decay_var_tensor 1019 self._batch_epsilon = batch_epsilon 1020 self._batch_to_space_op = batch_to_space_op 1021 1022 @property 1023 def layer_op(self): 1024 return self._layer_op 1025 1026 @property 1027 def bn_op(self): 1028 return self._bn_op 1029 1030 @property 1031 def output_tensor(self): 1032 return self._output_tensor 1033 1034 @property 1035 def input_tensor(self): 1036 return self._input_tensor 1037 1038 @property 1039 def weight_tensor(self): 1040 return self._weight_tensor 1041 1042 @property 1043 def gamma_tensor(self): 1044 return self._gamma_tensor 1045 1046 @property 1047 def beta_tensor(self): 1048 return self._beta_tensor 1049 1050 @property 1051 def mean_tensor(self): 1052 return self._mean_tensor 1053 1054 @property 1055 def variance_tensor(self): 1056 return self._variance_tensor 1057 1058 @property 1059 def moving_mean_tensor(self): 1060 return self._moving_mean_tensor 1061 1062 @property 1063 def moving_variance_tensor(self): 1064 return self._moving_variance_tensor 1065 1066 @property 1067 def batch_epsilon(self): 1068 return self._batch_epsilon 1069 1070 @property 1071 def bn_decay_mean_tensor(self): 1072 return self._bn_decay_mean_tensor 1073 1074 @property 1075 def bn_decay_var_tensor(self): 1076 return self._bn_decay_var_tensor 1077 1078 @property 1079 def batch_to_space_op(self): 1080 return self._batch_to_space_op 1081