1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for folding batch norm layers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.layers.python.layers import layers 22from tensorflow.contrib.quantize.python import fold_batch_norms 23from tensorflow.python.client import session 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import random_seed 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gradients 30from tensorflow.python.ops import init_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import nn_ops 33from tensorflow.python.ops import random_ops 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import googletest 37from tensorflow.python.training import saver as saver_lib 38 39batch_norm = layers.batch_norm 40conv2d = layers.conv2d 41fully_connected = layers.fully_connected 42separable_conv2d = layers.separable_conv2d 43 44 45# TODO(suharshs): Use parameterized test once OSS TF supports it. 46class FoldBatchNormsTest(test_util.TensorFlowTestCase): 47 48 def _RunTestOverParameters(self, test_fn): 49 parameters_list = [ 50 # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, 51 # freeze_batch_norm_delay, insert identity node) 52 (nn_ops.relu6, 'Relu6', False, False, False, 100, False), 53 (nn_ops.relu, 'Relu', False, False, False, None, False), 54 (nn_ops.relu6, 'Relu6', True, False, False, 100, False), 55 (nn_ops.relu, 'Relu', True, False, False, None, False), 56 (nn_ops.relu6, 'Relu6', False, True, False, 100, False), 57 (nn_ops.relu, 'Relu', False, True, False, None, False), 58 (nn_ops.relu6, 'Relu6', True, True, False, 100, False), 59 (nn_ops.relu, 'Relu', True, True, False, None, False), 60 # Fused batch norm always has scaling enabled. 61 (nn_ops.relu6, 'Relu6', False, True, True, None, False), 62 (nn_ops.relu, 'Relu', False, True, True, 100, False), 63 (nn_ops.relu6, 'Relu6', True, True, True, None, False), 64 (nn_ops.relu, 'Relu', True, True, True, 100, False), 65 (nn_ops.relu6, 'Relu6', False, True, True, None, True), 66 (nn_ops.relu, 'Relu', False, True, True, 100, True), 67 (nn_ops.relu6, 'Relu6', True, True, True, None, True), 68 (nn_ops.relu, 'Relu', True, True, True, 100, True), 69 ] 70 for params in parameters_list: 71 test_fn(params[0], params[1], params[2], params[3], params[4], params[5], 72 params[6]) 73 74 def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling, 75 fused_batch_norm, freeze_batch_norm_delay, 76 insert_identity_node): 77 """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. 78 79 Args: 80 relu: Callable that returns an Operation, a factory method for the Relu*. 81 relu_op_name: String, name of the Relu* operation. 82 with_bypass: Bool, when true there is an extra connection added from 83 inputs to just before Relu*. 84 has_scaling: Bool, when true the batch norm has scaling. 85 fused_batch_norm: Bool, when true the batch norm is fused. 86 freeze_batch_norm_delay: None or the number of steps after which training 87 switches to using frozen mean and variance 88 insert_identity_node: Bool, insert identity node between conv and batch 89 norm 90 """ 91 g = ops.Graph() 92 with g.as_default(): 93 batch_size, height, width = 5, 128, 128 94 inputs = array_ops.zeros((batch_size, height, width, 3)) 95 out_depth = 3 if with_bypass else 32 96 stride = 1 if with_bypass else 2 97 activation_fn = None if with_bypass else relu 98 name = 'test/test2' if with_bypass else 'test' 99 if insert_identity_node: 100 with g.name_scope(name): 101 node = conv2d( 102 inputs, 103 out_depth, [5, 5], 104 stride=stride, 105 padding='SAME', 106 weights_initializer=self._WeightInit(0.09), 107 activation_fn=None, 108 normalizer_fn=None, 109 biases_initializer=None) 110 conv_out = array_ops.identity(node, name='conv_out') 111 112 node = batch_norm( 113 conv_out, 114 center=True, 115 scale=has_scaling, 116 decay=1.0 - 0.003, 117 fused=fused_batch_norm) 118 if activation_fn is not None: 119 node = activation_fn(node) 120 conv_name = name + '/Conv' 121 else: 122 node = conv2d( 123 inputs, 124 out_depth, [5, 5], 125 stride=stride, 126 padding='SAME', 127 weights_initializer=self._WeightInit(0.09), 128 activation_fn=activation_fn, 129 normalizer_fn=batch_norm, 130 normalizer_params=self._BatchNormParams( 131 scale=has_scaling, fused=fused_batch_norm), 132 scope=name) 133 conv_name = name 134 if with_bypass: 135 node = math_ops.add(inputs, node, name='test/Add') 136 relu(node, name='test/' + relu_op_name) 137 138 fold_batch_norms.FoldBatchNorms( 139 g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) 140 141 folded_mul = g.get_operation_by_name(conv_name + '/mul_fold') 142 self.assertEqual(folded_mul.type, 'Mul') 143 self._AssertInputOpsAre(folded_mul, [ 144 conv_name + '/correction_mult', 145 self._BatchNormMultiplierName(conv_name, has_scaling, fused_batch_norm) 146 ]) 147 self._AssertOutputGoesToOps(folded_mul, g, [conv_name + '/Conv2D_Fold']) 148 149 folded_conv = g.get_operation_by_name(conv_name + '/Conv2D_Fold') 150 self.assertEqual(folded_conv.type, 'Conv2D') 151 self._AssertInputOpsAre(folded_conv, 152 [conv_name + '/mul_fold', inputs.op.name]) 153 self._AssertOutputGoesToOps(folded_conv, g, [conv_name + '/post_conv_mul']) 154 155 folded_add = g.get_operation_by_name(conv_name + '/add_fold') 156 self.assertEqual(folded_add.type, 'Add') 157 self._AssertInputOpsAre(folded_add, [ 158 conv_name + '/correction_add', 159 self._BathNormBiasName(conv_name, fused_batch_norm) 160 ]) 161 output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] 162 self._AssertOutputGoesToOps(folded_add, g, output_op_names) 163 if freeze_batch_norm_delay is not None: 164 self._AssertMovingAveragesAreFrozen(g, name) 165 166 for op in g.get_operations(): 167 self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) 168 169 def testFoldConv2d(self): 170 self._RunTestOverParameters(self._TestFoldConv2d) 171 172 def testMultipleLayerConv2d(self, 173 relu=nn_ops.relu, 174 relu_op_name='Relu', 175 has_scaling=True, 176 fused_batch_norm=False, 177 freeze_batch_norm_delay=None, 178 insert_identity_node=False): 179 """Tests folding cases for a network with multiple layers. 180 181 Args: 182 relu: Callable that returns an Operation, a factory method for the Relu*. 183 relu_op_name: String, name of the Relu* operation. 184 has_scaling: Bool, when true the batch norm has scaling. 185 fused_batch_norm: Bool, when true the batch norm is fused. 186 freeze_batch_norm_delay: None or the number of steps after which training 187 switches to using frozen mean and variance 188 insert_identity_node: Bool, insert identity node between conv and batch 189 norm 190 """ 191 g = ops.Graph() 192 with g.as_default(): 193 batch_size, height, width = 5, 128, 128 194 inputs = array_ops.zeros((batch_size, height, width, 3)) 195 out_depth = 3 196 stride = 1 197 activation_fn = relu 198 scope = 'topnet/testnet' 199 with variable_scope.variable_scope(scope, [inputs]): 200 layer1 = conv2d( 201 inputs, 202 out_depth, [5, 5], 203 stride=stride, 204 padding='SAME', 205 weights_initializer=self._WeightInit(0.09), 206 activation_fn=None, 207 normalizer_fn=None, 208 scope='testnet/layer1') 209 # Add bn and relu with different scope 210 layer1 = batch_norm( 211 layer1, scale=has_scaling, fused=fused_batch_norm, scope='layer1') 212 layer1 = activation_fn(layer1) 213 layer2 = conv2d( 214 layer1, 215 2 * out_depth, [5, 5], 216 stride=stride, 217 padding='SAME', 218 weights_initializer=self._WeightInit(0.09), 219 activation_fn=activation_fn, 220 normalizer_fn=batch_norm, 221 normalizer_params=self._BatchNormParams( 222 scale=has_scaling, fused=fused_batch_norm), 223 scope='testnet/layer2') 224 # Add bn and relu with different scope 225 layer2 = batch_norm( 226 layer2, scale=has_scaling, fused=fused_batch_norm, scope='layer2') 227 _ = activation_fn(layer2) 228 229 scope = 'topnet/testnet/testnet/layer2' 230 231 fold_batch_norms.FoldBatchNorms( 232 g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) 233 folded_mul = g.get_operation_by_name(scope + '/mul_fold') 234 self.assertEqual(folded_mul.type, 'Mul') 235 self._AssertInputOpsAre(folded_mul, [ 236 scope + '/correction_mult', 237 self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) 238 ]) 239 self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) 240 241 folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold') 242 self.assertEqual(folded_conv.type, 'Conv2D') 243 # Remove :0 at end of name for tensor prior to comparison 244 self._AssertInputOpsAre(folded_conv, 245 [scope + '/mul_fold', layer1.name[:-2]]) 246 self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) 247 248 folded_add = g.get_operation_by_name(scope + '/add_fold') 249 self.assertEqual(folded_add.type, 'Add') 250 self._AssertInputOpsAre(folded_add, [ 251 scope + '/correction_add', 252 self._BathNormBiasName(scope, fused_batch_norm) 253 ]) 254 output_op_names = [scope + '/' + relu_op_name] 255 self._AssertOutputGoesToOps(folded_add, g, output_op_names) 256 if freeze_batch_norm_delay is not None: 257 self._AssertMovingAveragesAreFrozen(g, scope) 258 259 for op in g.get_operations(): 260 self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) 261 262 def _TestFoldConv2dUnknownShape(self, 263 relu, 264 relu_op_name, 265 with_bypass, 266 has_scaling, 267 fused_batch_norm, 268 freeze_batch_norm_delay, 269 insert_identity_node=False): 270 """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. 271 272 Tests that folding works even with an input shape where some dimensions are 273 not known (i.e. None). 274 275 Args: 276 relu: Callable that returns an Operation, a factory method for the Relu*. 277 relu_op_name: String, name of the Relu* operation. 278 with_bypass: Bool, when true there is an extra connection added from 279 inputs to just before Relu*. 280 has_scaling: Bool, when true the batch norm has scaling. 281 fused_batch_norm: Bool, when true the batch norm is fused. 282 freeze_batch_norm_delay: None or the number of steps after which training 283 switches to using frozen mean and variance 284 insert_identity_node: Bool, insert identity node between conv and batch 285 norm 286 """ 287 g = ops.Graph() 288 with g.as_default(): 289 inputs = array_ops.placeholder(dtypes.float32, shape=(5, None, None, 3)) 290 out_depth = 3 if with_bypass else 32 291 stride = 1 if with_bypass else 2 292 activation_fn = None if with_bypass else relu 293 scope = 'test/test2' if with_bypass else 'test' 294 node = conv2d( 295 inputs, 296 out_depth, [5, 5], 297 stride=stride, 298 padding='SAME', 299 weights_initializer=self._WeightInit(0.09), 300 activation_fn=activation_fn, 301 normalizer_fn=batch_norm, 302 normalizer_params=self._BatchNormParams( 303 scale=has_scaling, fused=fused_batch_norm), 304 scope=scope) 305 if with_bypass: 306 node = math_ops.add(inputs, node, name='test/Add') 307 relu(node, name='test/' + relu_op_name) 308 309 fold_batch_norms.FoldBatchNorms( 310 g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) 311 312 folded_mul = g.get_operation_by_name(scope + '/mul_fold') 313 self.assertEqual(folded_mul.type, 'Mul') 314 self._AssertInputOpsAre(folded_mul, [ 315 scope + '/correction_mult', 316 self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) 317 ]) 318 self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) 319 320 folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold') 321 self.assertEqual(folded_conv.type, 'Conv2D') 322 self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) 323 self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) 324 325 folded_add = g.get_operation_by_name(scope + '/add_fold') 326 self.assertEqual(folded_add.type, 'Add') 327 self._AssertInputOpsAre(folded_add, [ 328 scope + '/correction_add', 329 self._BathNormBiasName(scope, fused_batch_norm) 330 ]) 331 output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] 332 self._AssertOutputGoesToOps(folded_add, g, output_op_names) 333 if freeze_batch_norm_delay is not None: 334 self._AssertMovingAveragesAreFrozen(g, scope) 335 336 for op in g.get_operations(): 337 self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) 338 339 def testFoldConv2dUnknownShape(self): 340 self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) 341 342 def _TestFoldFullyConnectedLayer( 343 self, relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, 344 freeze_batch_norm_delay, insert_identity_node): 345 """Tests folding cases: inputs -> FC with batch norm -> Relu*. 346 347 Args: 348 relu: Callable that returns an Operation, a factory method for the Relu*. 349 relu_op_name: String, name of the Relu* operation. 350 with_bypass: Bool, when true there is an extra connection added from 351 inputs to just before Relu*. 352 has_scaling: Bool, when true the batch norm has scaling. 353 fused_batch_norm: Bool, when true the batch norm is fused. 354 freeze_batch_norm_delay: None or the number of steps after which training 355 switches to using frozen mean and variance 356 insert_identity_node: Bool, insert identity node between conv and batch 357 norm 358 """ 359 g = ops.Graph() 360 with g.as_default(): 361 batch_size, depth = 5, 256 362 inputs = array_ops.zeros((batch_size, depth)) 363 out_depth = 256 if with_bypass else 128 364 activation_fn = None if with_bypass else relu 365 name = 'test/test2' if with_bypass else 'test' 366 insert_identity_node = fused_batch_norm 367 if insert_identity_node: 368 with g.name_scope(name): 369 node = fully_connected( 370 inputs, 371 out_depth, 372 weights_initializer=self._WeightInit(0.03), 373 activation_fn=None, 374 normalizer_fn=None, 375 biases_initializer=None) 376 node = array_ops.identity(node, name='fc_out') 377 378 node = batch_norm( 379 node, 380 center=True, 381 scale=has_scaling, 382 decay=1.0 - 0.003, 383 fused=fused_batch_norm) 384 if activation_fn is not None: 385 node = activation_fn(node) 386 fc_name = name + '/fully_connected' 387 else: 388 389 node = fully_connected( 390 inputs, 391 out_depth, 392 weights_initializer=self._WeightInit(0.03), 393 activation_fn=activation_fn, 394 normalizer_fn=batch_norm, 395 normalizer_params=self._BatchNormParams( 396 scale=has_scaling, fused=fused_batch_norm), 397 scope=name) 398 fc_name = name 399 if with_bypass: 400 node = math_ops.add(inputs, node, name='test/Add') 401 relu(node, name='test/' + relu_op_name) 402 403 fold_batch_norms.FoldBatchNorms( 404 g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) 405 406 folded_mul = g.get_operation_by_name(fc_name + '/mul_fold') 407 self.assertEqual(folded_mul.type, 'Mul') 408 self._AssertInputOpsAre(folded_mul, [ 409 fc_name + '/correction_mult', 410 self._BatchNormMultiplierName(fc_name, has_scaling, fused_batch_norm) 411 ]) 412 self._AssertOutputGoesToOps(folded_mul, g, [fc_name + '/MatMul_Fold']) 413 414 folded_conv = g.get_operation_by_name(fc_name + '/MatMul_Fold') 415 self.assertEqual(folded_conv.type, 'MatMul') 416 self._AssertInputOpsAre(folded_conv, 417 [fc_name + '/mul_fold', inputs.op.name]) 418 self._AssertOutputGoesToOps(folded_conv, g, [fc_name + '/post_conv_mul']) 419 420 folded_add = g.get_operation_by_name(fc_name + '/add_fold') 421 self.assertEqual(folded_add.type, 'Add') 422 self._AssertInputOpsAre(folded_add, [ 423 fc_name + '/correction_add', 424 self._BathNormBiasName(fc_name, fused_batch_norm) 425 ]) 426 output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] 427 self._AssertOutputGoesToOps(folded_add, g, output_op_names) 428 if freeze_batch_norm_delay is not None: 429 self._AssertMovingAveragesAreFrozen(g, name) 430 431 for op in g.get_operations(): 432 self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) 433 434 def testFoldFullyConnectedLayer(self): 435 self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) 436 437 def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass, 438 has_scaling, fused_batch_norm, 439 freeze_batch_norm_delay, insert_identity_node): 440 """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. 441 442 Args: 443 relu: Callable that returns an Operation, a factory method for the Relu*. 444 relu_op_name: String, name of the Relu* operation. 445 with_bypass: Bool, when true there is an extra connection added from 446 inputs to just before Relu*. 447 has_scaling: Bool, when true the batch norm has scaling. 448 fused_batch_norm: Bool, when true the batch norm is fused. 449 freeze_batch_norm_delay: None or the number of steps after which training 450 insert_identity_node: Bool, insert identity node between conv and batch 451 norm switches to using frozen mean and variance 452 """ 453 g = ops.Graph() 454 with g.as_default(): 455 batch_size, height, width = 5, 128, 128 456 inputs = array_ops.zeros((batch_size, height, width, 3)) 457 stride = 1 if with_bypass else 2 458 activation_fn = None if with_bypass else relu 459 name = 'test/test2' if with_bypass else 'test' 460 if insert_identity_node: 461 with g.name_scope(name): 462 node = separable_conv2d( 463 inputs, 464 None, [5, 5], 465 stride=stride, 466 depth_multiplier=1.0, 467 padding='SAME', 468 weights_initializer=self._WeightInit(0.09), 469 activation_fn=None, 470 normalizer_fn=None, 471 biases_initializer=None) 472 node = array_ops.identity(node, name='sep_conv_out') 473 474 node = batch_norm( 475 node, 476 center=True, 477 scale=has_scaling, 478 decay=1.0 - 0.003, 479 fused=fused_batch_norm) 480 if activation_fn is not None: 481 node = activation_fn(node) 482 sep_conv_name = name + '/SeparableConv2d' 483 else: 484 node = separable_conv2d( 485 inputs, 486 None, [5, 5], 487 stride=stride, 488 depth_multiplier=1.0, 489 padding='SAME', 490 weights_initializer=self._WeightInit(0.09), 491 activation_fn=activation_fn, 492 normalizer_fn=batch_norm, 493 normalizer_params=self._BatchNormParams( 494 scale=has_scaling, fused=fused_batch_norm), 495 scope=name) 496 sep_conv_name = name 497 if with_bypass: 498 node = math_ops.add(inputs, node, name='test/Add') 499 relu(node, name='test/' + relu_op_name) 500 501 fold_batch_norms.FoldBatchNorms( 502 g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) 503 504 folded_mul = g.get_operation_by_name(sep_conv_name + '/mul_fold') 505 self.assertEqual(folded_mul.type, 'Mul') 506 if fused_batch_norm: 507 scale_reshape_op_name = sep_conv_name + '/BatchNorm_Fold/scale_reshape' 508 else: 509 scale_reshape_op_name = sep_conv_name + '/scale_reshape' 510 self._AssertInputOpsAre( 511 folded_mul, [sep_conv_name + '/correction_mult', scale_reshape_op_name]) 512 self._AssertOutputGoesToOps(folded_mul, g, 513 [sep_conv_name + '/depthwise_Fold']) 514 515 scale_reshape = g.get_operation_by_name(scale_reshape_op_name) 516 self.assertEqual(scale_reshape.type, 'Reshape') 517 self._AssertInputOpsAre(scale_reshape, [ 518 self._BatchNormMultiplierName(sep_conv_name, has_scaling, 519 fused_batch_norm), 520 scale_reshape_op_name + '/shape' 521 ]) 522 self._AssertOutputGoesToOps(scale_reshape, g, [sep_conv_name + '/mul_fold']) 523 524 folded_conv = g.get_operation_by_name(sep_conv_name + '/depthwise_Fold') 525 self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') 526 self._AssertInputOpsAre(folded_conv, 527 [sep_conv_name + '/mul_fold', inputs.op.name]) 528 self._AssertOutputGoesToOps(folded_conv, g, 529 [sep_conv_name + '/post_conv_mul']) 530 531 folded_add = g.get_operation_by_name(sep_conv_name + '/add_fold') 532 self.assertEqual(folded_add.type, 'Add') 533 self._AssertInputOpsAre(folded_add, [ 534 sep_conv_name + '/correction_add', 535 self._BathNormBiasName(sep_conv_name, fused_batch_norm) 536 ]) 537 output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] 538 self._AssertOutputGoesToOps(folded_add, g, output_op_names) 539 if freeze_batch_norm_delay is not None: 540 self._AssertMovingAveragesAreFrozen(g, name) 541 542 for op in g.get_operations(): 543 self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) 544 545 def testFoldDepthwiseConv2d(self): 546 self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) 547 548 def _TestFoldAtrousConv2d(self, relu, relu_op_name, with_bypass, has_scaling, 549 fused_batch_norm, freeze_batch_norm_delay, 550 insert_identity_node): 551 """Tests folding: inputs -> AtrousConv2d with batch norm -> Relu*. 552 553 Args: 554 relu: Callable that returns an Operation, a factory method for the Relu*. 555 relu_op_name: String, name of the Relu* operation. 556 with_bypass: Bool, when true there is an extra connection added from 557 inputs to just before Relu*. 558 has_scaling: Bool, when true the batch norm has scaling. 559 fused_batch_norm: Bool, when true the batch norm is fused. 560 freeze_batch_norm_delay: None or the number of steps after which training 561 switches to using frozen mean and variance 562 insert_identity_node: Bool, insert identity node between conv and batch 563 norm 564 """ 565 g = ops.Graph() 566 with g.as_default(): 567 batch_size, height, width = 5, 128, 128 568 inputs = array_ops.zeros((batch_size, height, width, 3)) 569 dilation_rate = 2 570 activation_fn = None if with_bypass else relu 571 name = 'test/test2' if with_bypass else 'test' 572 if insert_identity_node: 573 with g.name_scope(name): 574 node = separable_conv2d( 575 inputs, 576 None, [3, 3], 577 rate=dilation_rate, 578 depth_multiplier=1.0, 579 padding='SAME', 580 weights_initializer=self._WeightInit(0.09), 581 activation_fn=None, 582 normalizer_fn=None, 583 biases_initializer=None) 584 node = array_ops.identity(node, name='sep_conv_out') 585 586 node = batch_norm( 587 node, 588 center=True, 589 scale=has_scaling, 590 decay=1.0 - 0.003, 591 fused=fused_batch_norm) 592 if activation_fn is not None: 593 node = activation_fn(node) 594 sep_conv_name = name + '/SeparableConv2d' 595 else: 596 node = separable_conv2d( 597 inputs, 598 None, [3, 3], 599 rate=dilation_rate, 600 depth_multiplier=1.0, 601 padding='SAME', 602 weights_initializer=self._WeightInit(0.09), 603 activation_fn=activation_fn, 604 normalizer_fn=batch_norm, 605 normalizer_params=self._BatchNormParams( 606 scale=has_scaling, fused=fused_batch_norm), 607 scope=name) 608 sep_conv_name = name 609 if with_bypass: 610 node = math_ops.add(inputs, node, name='test/Add') 611 relu(node, name='test/' + relu_op_name) 612 613 fold_batch_norms.FoldBatchNorms( 614 g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) 615 616 folded_mul = g.get_operation_by_name(sep_conv_name + '/mul_fold') 617 self.assertEqual(folded_mul.type, 'Mul') 618 if fused_batch_norm: 619 scale_reshape_op_name = sep_conv_name + '/BatchNorm_Fold/scale_reshape' 620 else: 621 scale_reshape_op_name = sep_conv_name + '/scale_reshape' 622 self._AssertInputOpsAre( 623 folded_mul, [sep_conv_name + '/correction_mult', scale_reshape_op_name]) 624 self._AssertOutputGoesToOps(folded_mul, g, 625 [sep_conv_name + '/depthwise_Fold']) 626 627 scale_reshape = g.get_operation_by_name(scale_reshape_op_name) 628 self.assertEqual(scale_reshape.type, 'Reshape') 629 self._AssertInputOpsAre(scale_reshape, [ 630 self._BatchNormMultiplierName(sep_conv_name, has_scaling, 631 fused_batch_norm), 632 scale_reshape_op_name + '/shape' 633 ]) 634 self._AssertOutputGoesToOps(scale_reshape, g, [sep_conv_name + '/mul_fold']) 635 636 folded_conv = g.get_operation_by_name(sep_conv_name + '/depthwise_Fold') 637 self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') 638 self._AssertInputOpsAre(folded_conv, [ 639 sep_conv_name + '/mul_fold', sep_conv_name + '/depthwise/SpaceToBatchND' 640 ]) 641 if fused_batch_norm: 642 self._AssertOutputGoesToOps(folded_conv, g, 643 [sep_conv_name + '/BatchToSpaceND_Fold']) 644 else: 645 self._AssertOutputGoesToOps( 646 folded_conv, g, [sep_conv_name + '/depthwise/BatchToSpaceND_Fold']) 647 648 folded_add = g.get_operation_by_name(sep_conv_name + '/add_fold') 649 self.assertEqual(folded_add.type, 'Add') 650 self._AssertInputOpsAre(folded_add, [ 651 sep_conv_name + '/correction_add', 652 self._BathNormBiasName(sep_conv_name, fused_batch_norm) 653 ]) 654 output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] 655 self._AssertOutputGoesToOps(folded_add, g, output_op_names) 656 if freeze_batch_norm_delay is not None: 657 self._AssertMovingAveragesAreFrozen(g, name) 658 659 for op in g.get_operations(): 660 self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) 661 662 def testFoldAtrousConv2d(self): 663 self._RunTestOverParameters(self._TestFoldAtrousConv2d) 664 665 def _TestCompareFoldAndUnfolded(self, 666 relu, 667 relu_op_name, 668 with_bypass, 669 has_scaling, 670 fused_batch_norm, 671 freeze_batch_norm_delay, 672 insert_identity_node=False): 673 """Tests that running folded and unfolded BN returns the same results. 674 675 Args: 676 relu: Callable that returns an Operation, a factory method for the Relu*. 677 relu_op_name: String, name of the Relu* operation. 678 with_bypass: Bool, when true there is an extra connection added from 679 inputs to just before Relu*. 680 has_scaling: Bool, when true the batch norm has scaling. 681 fused_batch_norm: Bool, when true the batch norm is fused. 682 freeze_batch_norm_delay: None or the number of steps after which training 683 switches to using frozen mean and variance 684 insert_identity_node: Bool, insert identity node between conv and batch 685 norm 686 """ 687 random_seed.set_random_seed(1234) 688 unfolded_g = ops.Graph() 689 with unfolded_g.as_default(): 690 batch_size, height, width = 5, 128, 128 691 inputs = random_ops.random_uniform( 692 (batch_size, height, width, 3), dtype=dtypes.float32, seed=1234) 693 out_depth = 3 if with_bypass else 32 694 stride = 1 if with_bypass else 2 695 activation_fn = None if with_bypass else relu 696 scope = 'test/test2' if with_bypass else 'test' 697 node = conv2d( 698 inputs, 699 out_depth, [5, 5], 700 stride=stride, 701 padding='SAME', 702 weights_initializer=self._WeightInit(0.09), 703 activation_fn=activation_fn, 704 normalizer_fn=batch_norm, 705 normalizer_params=self._BatchNormParams( 706 scale=has_scaling, fused=fused_batch_norm), 707 scope=scope) 708 if with_bypass: 709 node = math_ops.add(inputs, node, name='test/Add') 710 relu_node = relu(node, name='test/' + relu_op_name) 711 folded_g = self._CopyGraph(unfolded_g) 712 with folded_g.as_default(): 713 fold_batch_norms.FoldBatchNorms( 714 folded_g, 715 is_training=True, 716 freeze_batch_norm_delay=freeze_batch_norm_delay) 717 with session.Session(graph=unfolded_g) as sess: 718 sess.run(variables.global_variables_initializer()) 719 grad_node = gradients.gradients(relu_node, inputs) 720 results = sess.run([relu_node, grad_node]) 721 unfolded_forward, unfolded_backward = results[0], results[1] 722 723 with session.Session(graph=folded_g) as sess: 724 sess.run(variables.global_variables_initializer()) 725 relu_node = folded_g.get_tensor_by_name(relu_node.name) 726 inputs = folded_g.get_tensor_by_name(inputs.name) 727 grad_node = gradients.gradients(relu_node, inputs) 728 results = sess.run([relu_node, grad_node]) 729 folded_forward, folded_backward = results[0], results[1] 730 731 # Check that the folded and unfolded results match. 732 self.assertAllClose(unfolded_forward, folded_forward, atol=1e-3) 733 self.assertAllClose(unfolded_backward, folded_backward, atol=1e-3) 734 735 def testCompareFoldAndUnfolded(self): 736 self._RunTestOverParameters(self._TestCompareFoldAndUnfolded) 737 738 def _BatchNormParams(self, scale=True, fused=False): 739 return { 740 'center': True, 741 'scale': scale, 742 'decay': 1.0 - 0.003, 743 'fused': fused 744 } 745 746 def _BatchNormMultiplierName(self, scope, has_scaling, fused): 747 if has_scaling: 748 if fused: 749 return scope + '/BatchNorm_Fold/mul' 750 return scope + '/BatchNorm/batchnorm_1/mul' 751 return scope + '/BatchNorm/batchnorm_1/Rsqrt' 752 753 def _BathNormBiasName(self, scope, fused): 754 if fused: 755 return scope + '/BatchNorm_Fold/bias' 756 return scope + '/BatchNorm/batchnorm_1/sub' 757 758 def _WeightInit(self, stddev): 759 """Returns a truncated normal variable initializer. 760 761 Function is defined purely to shorten the name so that it stops wrapping. 762 763 Args: 764 stddev: Standard deviation of normal variable. 765 766 Returns: 767 An initializer that initializes with a truncated normal variable. 768 """ 769 return init_ops.truncated_normal_initializer(stddev=stddev, seed=1234) 770 771 def _AssertInputOpsAre(self, op, in_op_names): 772 """Asserts that all inputs to op come from in_op_names (disregarding order). 773 774 Args: 775 op: Operation to check inputs for. 776 in_op_names: List of strings, operations where all op's inputs should 777 come from. 778 """ 779 expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] 780 self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) 781 782 def _AssertOutputGoesToOps(self, op, graph, out_op_names): 783 """Asserts that outputs from op go to out_op_names (and perhaps others). 784 785 Args: 786 op: Operation to check outputs for. 787 graph: Graph where output operations are located. 788 out_op_names: List of strings, operations where op's outputs should go. 789 """ 790 for out_op_name in out_op_names: 791 out_op = graph.get_operation_by_name(out_op_name) 792 self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) 793 794 def _AssertMovingAveragesAreFrozen(self, graph, scope): 795 """Asserts to check if moving mean and variance are frozen. 796 797 Args: 798 graph: Graph where the operations are located. 799 scope: Scope of batch norm op 800 """ 801 moving_average_mult = graph.get_operation_by_name( 802 scope + '/BatchNorm/AssignMovingAvg/mul') 803 self.assertTrue( 804 moving_average_mult.inputs[1].name.find('freeze_moving_mean/Merge') > 0) 805 moving_var_mult = graph.get_operation_by_name( 806 scope + '/BatchNorm/AssignMovingAvg_1/mul') 807 self.assertTrue( 808 moving_var_mult.inputs[1].name.find('freeze_moving_var/Merge') > 0) 809 810 def _CopyGraph(self, graph): 811 """Return a copy of graph.""" 812 meta_graph = saver_lib.export_meta_graph( 813 graph=graph, collection_list=graph.get_all_collection_keys()) 814 graph_copy = ops.Graph() 815 with graph_copy.as_default(): 816 _ = saver_lib.import_meta_graph(meta_graph) 817 return graph_copy 818 819 820if __name__ == '__main__': 821 googletest.main() 822