1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Parameterized unit tests for quantizing a Tensorflow graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.layers.python.layers import layers 22from tensorflow.contrib.quantize.python import fold_batch_norms 23from tensorflow.contrib.quantize.python import quantize 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import init_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import nn_ops 31from tensorflow.python.ops import variable_scope 32from tensorflow.python.platform import googletest 33 34batch_norm = layers.batch_norm 35conv2d = layers.conv2d 36fully_connected = layers.fully_connected 37separable_conv2d = layers.separable_conv2d 38 39 40class QuantizeTest(test_util.TensorFlowTestCase): 41 42 def _RunWithoutBatchNormTestOverParameters(self, test_fn): 43 # TODO(suharshs): Use parameterized test once OSS TF supports it. 44 parameters_list = [ 45 # (activation, activation_op_name, with_bypass, delay) 46 (nn_ops.relu6, 'Relu6', False, None), 47 (nn_ops.relu, 'Relu', False, None), 48 (array_ops.identity, 'Identity', False, None), 49 (nn_ops.relu6, 'Relu6', False, 5000), 50 (nn_ops.relu, 'Relu', False, 5000), 51 (array_ops.identity, 'Identity', False, 5000), 52 (nn_ops.relu6, 'Relu6', True, None), 53 (nn_ops.relu, 'Relu', True, None), 54 (array_ops.identity, 'Identity', True, None), 55 (nn_ops.relu6, 'Relu6', True, 5000), 56 (nn_ops.relu, 'Relu', True, 5000), 57 (array_ops.identity, 'Identity', True, 5000), 58 ] 59 for params in parameters_list: 60 # Test everything with resource variables and normal variables. 61 test_fn(params[0], params[1], params[2], params[3], False, None) 62 test_fn(params[0], params[1], params[2], params[3], True, None) 63 # Test with both empty scope and an example scope 64 test_fn(params[0], params[1], params[2], params[3], False, 'test') 65 test_fn(params[0], params[1], params[2], params[3], True, 'test') 66 67 def _AssertCorrectQuantizedGraphWithoutBatchNorm( 68 self, graph, scope, layer, activation_op_name, with_bypass, delay, 69 use_resource): 70 quantization_node_name = 'FakeQuantWithMinMaxVars' 71 conv_scope = self._GetConvScope(scope, with_bypass) 72 delim = '/' if conv_scope else '' 73 74 if scope: 75 scope = scope + '/' 76 weights_quant = graph.get_operation_by_name( 77 conv_scope + delim + 'weights_quant/' + quantization_node_name) 78 self.assertEqual(weights_quant.type, quantization_node_name) 79 80 # Assemble the expected inputs. 81 if use_resource: 82 expected_inputs = [ 83 conv_scope + delim + 84 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp', 85 conv_scope + delim + 86 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', 87 ] 88 if layer == 'DepthwiseConv2dNative': 89 expected_inputs.append(conv_scope + delim + 'depthwise/ReadVariableOp') 90 else: 91 expected_inputs.append(conv_scope + delim + layer + '/ReadVariableOp') 92 else: 93 expected_inputs = [ 94 conv_scope + delim + 'weights_quant/AssignMinLast', 95 conv_scope + delim + 'weights_quant/AssignMaxLast', 96 ] 97 if layer == 'DepthwiseConv2dNative': 98 expected_inputs.append(conv_scope + delim + 'depthwise_weights/read') 99 else: 100 expected_inputs.append(conv_scope + delim + 'weights/read') 101 102 self._AssertInputOpsAre(weights_quant, expected_inputs) 103 if delay and delay > 0: 104 output_op_name = ( 105 conv_scope + delim + 'weights_quant/delayed_quant/Switch_1') 106 else: 107 if layer == 'DepthwiseConv2dNative': 108 output_op_name = conv_scope + delim + 'depthwise' 109 else: 110 output_op_name = conv_scope + delim + layer 111 112 self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) 113 114 if with_bypass: 115 conv_quant = graph.get_operation_by_name( 116 conv_scope + delim + 'conv_quant/' + quantization_node_name) 117 self.assertEqual(conv_quant.type, quantization_node_name) 118 if use_resource: 119 expected_inputs = [ 120 conv_scope + delim + 121 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp', 122 conv_scope + delim + 123 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', 124 conv_scope + delim + 'BiasAdd', 125 ] 126 else: 127 expected_inputs = [ 128 conv_scope + delim + 'conv_quant/AssignMinEma', 129 conv_scope + delim + 'conv_quant/AssignMaxEma', 130 conv_scope + delim + 'BiasAdd' 131 ] 132 self._AssertInputOpsAre(conv_quant, expected_inputs) 133 134 output_op_name = ( 135 conv_scope + delim + 'conv_quant/delayed_quant/Switch_1' 136 if delay else scope + 'Add') 137 self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) 138 139 act_quant = graph.get_operation_by_name(scope + 'act_quant/' + 140 quantization_node_name) 141 self.assertEqual(act_quant.type, quantization_node_name) 142 if use_resource: 143 expected_inputs = [ 144 scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp', 145 scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', 146 scope + activation_op_name, 147 ] 148 else: 149 expected_inputs = [ 150 scope + 'act_quant/AssignMinEma', scope + 'act_quant/AssignMaxEma', 151 scope + activation_op_name 152 ] 153 self._AssertInputOpsAre(act_quant, expected_inputs) 154 output_op_name = ( 155 scope + 'act_quant/delayed_quant/Switch_1' 156 if delay else 'control_dependency') 157 self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) 158 self._AssertIdempotent(graph) 159 160 def testQuantize_Conv2dWithoutBatchNorm(self): 161 self._RunWithoutBatchNormTestOverParameters( 162 self._TestQuantize_Conv2dWithoutBatchNorm) 163 164 def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, 165 with_bypass, delay, use_resource, 166 scope): 167 """Tests quantization: inputs -> Conv2d no batch norm -> Activation. 168 169 Args: 170 activation: Callable that returns an Operation, a factory method for the 171 Activation. 172 activation_op_name: String, name of the Activation operation. 173 with_bypass: Bool, when true there is an extra connection added from 174 inputs to just before Activation. 175 delay: Int (optional), delay in number of steps until quantization starts. 176 use_resource: Bool, when true uses resource variables. 177 scope: String, specifies top level scope for the graph 178 """ 179 graph = ops.Graph() 180 with graph.as_default(): 181 variable_scope.get_variable_scope().set_use_resource(use_resource) 182 batch_size, height, width, depth = 5, 128, 128, 3 183 inputs = array_ops.zeros((batch_size, height, width, depth)) 184 stride = 1 if with_bypass else 2 185 out_depth = 3 if with_bypass else 32 186 activation_fn = None if with_bypass else activation 187 conv_scope = self._GetConvScope(scope, with_bypass) 188 scope = '' if scope is None else scope 189 delim = '/' if scope else '' 190 node = conv2d( 191 inputs, 192 out_depth, [5, 5], 193 stride=stride, 194 padding='SAME', 195 weights_initializer=self._WeightInit(0.09), 196 activation_fn=activation_fn, 197 scope=conv_scope) 198 if with_bypass: 199 node = math_ops.add(inputs, node, name=scope + delim + 'Add') 200 node = activation(node, name=scope + delim + activation_op_name) 201 update_barrier = control_flow_ops.no_op(name='update_barrier') 202 with ops.control_dependencies([update_barrier]): 203 array_ops.identity(node, name='control_dependency') 204 205 quantize.Quantize(graph, True, quant_delay=delay) 206 207 if conv_scope is None: 208 conv_scope = '' 209 210 self._AssertCorrectQuantizedGraphWithoutBatchNorm( 211 graph, scope, 'Conv2D', activation_op_name, with_bypass, delay, 212 use_resource) 213 214 def testQuantize_FCWithoutBatchNorm(self): 215 self._RunWithoutBatchNormTestOverParameters( 216 self._TestQuantize_FCWithoutBatchNorm) 217 218 def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, 219 with_bypass, delay, use_resource, scope): 220 """Tests quantization: inputs -> FC no batch norm -> Activation. 221 222 Args: 223 activation: Callable that returns an Operation, a factory method for the 224 Activation. 225 activation_op_name: String, name of the Activation operation. 226 with_bypass: Bool, when true there is an extra connection added from 227 inputs to just before Activation. 228 delay: Int (optional), delay in number of steps until quantization starts. 229 use_resource: Bool, when true uses resource variables. 230 scope: String, specifies top level scope for the graph 231 """ 232 graph = ops.Graph() 233 with graph.as_default(): 234 variable_scope.get_variable_scope().set_use_resource(use_resource) 235 batch_size, depth = 5, 256 236 inputs = array_ops.zeros((batch_size, depth)) 237 out_depth = 256 if with_bypass else 128 238 activation_fn = None if with_bypass else activation 239 fc_scope = self._GetConvScope(scope, with_bypass) 240 scope = '' if scope is None else scope 241 delim = '/' if scope else '' 242 node = fully_connected( 243 inputs, 244 out_depth, 245 weights_initializer=self._WeightInit(0.03), 246 activation_fn=activation_fn, 247 scope=fc_scope) 248 if with_bypass: 249 node = math_ops.add(inputs, node, name=scope + delim + 'Add') 250 node = activation(node, name=scope + delim + activation_op_name) 251 update_barrier = control_flow_ops.no_op(name='update_barrier') 252 with ops.control_dependencies([update_barrier]): 253 array_ops.identity(node, name='control_dependency') 254 quantize.Quantize(graph, True, quant_delay=delay) 255 256 self._AssertCorrectQuantizedGraphWithoutBatchNorm( 257 graph, scope, 'MatMul', activation_op_name, with_bypass, delay, 258 use_resource) 259 260 def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): 261 self._RunWithoutBatchNormTestOverParameters( 262 self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) 263 264 def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( 265 self, activation, activation_op_name, with_bypass, delay, use_resource, 266 scope): 267 """Tests quantization: inputs -> DWConv2d no batch norm -> Activation. 268 269 Args: 270 activation: Callable that returns an Operation, a factory method for the 271 Activation. 272 activation_op_name: String, name of the Activation operation. 273 with_bypass: Bool, when true there is an extra connection added from 274 inputs to just before Activation. 275 delay: Int (optional), delay in number of steps until quantization starts. 276 use_resource: Bool, when true uses resource variables. 277 scope: String, specifies top level scope for the graph 278 """ 279 graph = ops.Graph() 280 with graph.as_default(): 281 variable_scope.get_variable_scope().set_use_resource(use_resource) 282 batch_size, height, width, depth = 5, 128, 128, 3 283 inputs = array_ops.zeros((batch_size, height, width, depth)) 284 stride = 1 if with_bypass else 2 285 activation_fn = None if with_bypass else activation 286 conv_scope = self._GetConvScope(scope, with_bypass) 287 scope = '' if scope is None else scope 288 delim = '/' if scope else '' 289 290 node = separable_conv2d( 291 inputs, 292 None, [5, 5], 293 stride=stride, 294 depth_multiplier=1.0, 295 padding='SAME', 296 weights_initializer=self._WeightInit(0.09), 297 activation_fn=activation_fn, 298 scope=conv_scope) 299 if with_bypass: 300 node = math_ops.add(inputs, node, name=scope + delim + 'Add') 301 node = activation(node, name=scope + delim + activation_op_name) 302 update_barrier = control_flow_ops.no_op(name='update_barrier') 303 with ops.control_dependencies([update_barrier]): 304 array_ops.identity(node, name='control_dependency') 305 quantize.Quantize(graph, True, quant_delay=delay) 306 307 self._AssertCorrectQuantizedGraphWithoutBatchNorm( 308 graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, 309 delay, use_resource) 310 311 def testQuantize_AtrousConvWithoutBatchNorm(self): 312 self._RunWithoutBatchNormTestOverParameters( 313 self._TestQuantize_AtrousConvWithoutBatchNorm) 314 315 def _TestQuantize_AtrousConvWithoutBatchNorm(self, activation, 316 activation_op_name, with_bypass, 317 delay, use_resource, scope): 318 """Tests quantization: inputs -> atrous conv no batch norm -> Activation. 319 320 Args: 321 activation: Callable that returns an Operation, a factory method for the 322 Activation. 323 activation_op_name: String, name of the Activation operation. 324 with_bypass: Bool, when true there is an extra connection added from 325 inputs to just before Activation. 326 delay: Int (optional), delay in number of steps until quantization starts. 327 use_resource: Bool, when true uses resource variables. 328 scope: String, specifies top level scope for the graph 329 """ 330 graph = ops.Graph() 331 with graph.as_default(): 332 variable_scope.get_variable_scope().set_use_resource(use_resource) 333 batch_size, height, width, depth = 5, 128, 128, 3 334 inputs = array_ops.zeros((batch_size, height, width, depth)) 335 dilation_rate = 2 336 activation_fn = None if with_bypass else activation 337 conv_scope = self._GetConvScope(scope, with_bypass) 338 scope = '' if scope is None else scope 339 delim = '/' if scope else '' 340 341 node = separable_conv2d( 342 inputs, 343 None, [3, 3], 344 rate=dilation_rate, 345 depth_multiplier=1.0, 346 padding='SAME', 347 weights_initializer=self._WeightInit(0.09), 348 activation_fn=activation_fn, 349 scope=conv_scope) 350 if with_bypass: 351 node = math_ops.add(inputs, node, name=scope + delim + 'Add') 352 node = activation(node, name=scope + delim + activation_op_name) 353 update_barrier = control_flow_ops.no_op(name='update_barrier') 354 with ops.control_dependencies([update_barrier]): 355 array_ops.identity(node, name='control_dependency') 356 quantize.Quantize(graph, True, quant_delay=delay) 357 358 self._AssertCorrectQuantizedGraphWithoutBatchNorm( 359 graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, 360 delay, use_resource) 361 362 def _RunBatchNormTestOverParameters(self, test_fn): 363 # TODO(suharshs): Use parameterized test once OSS TF supports it. 364 parameters_list = [ 365 # (activation, activation_op_name, with_bypass, delay, fused_batch_norm) 366 (nn_ops.relu6, 'Relu6', False, None, False), 367 (nn_ops.relu, 'Relu', False, None, False), 368 (array_ops.identity, 'Identity', False, None, False), 369 (nn_ops.relu6, 'Relu6', False, 5000, False), 370 (nn_ops.relu, 'Relu', False, 5000, False), 371 (array_ops.identity, 'Identity', False, 5000, False), 372 (nn_ops.relu6, 'Relu6', True, None, False), 373 (nn_ops.relu, 'Relu', True, None, False), 374 (array_ops.identity, 'Identity', True, None, False), 375 (nn_ops.relu6, 'Relu6', True, 5000, False), 376 (nn_ops.relu, 'Relu', True, 5000, False), 377 (array_ops.identity, 'Identity', True, 5000, False), 378 (nn_ops.relu6, 'Relu6', False, None, True), 379 (nn_ops.relu, 'Relu', False, None, True), 380 (array_ops.identity, 'Identity', False, None, True), 381 (nn_ops.relu6, 'Relu6', False, 5000, True), 382 (nn_ops.relu, 'Relu', False, 5000, True), 383 (array_ops.identity, 'Identity', False, 5000, True), 384 (nn_ops.relu6, 'Relu6', True, None, True), 385 (nn_ops.relu, 'Relu', True, None, True), 386 (array_ops.identity, 'Identity', True, None, True), 387 (nn_ops.relu6, 'Relu6', True, 5000, True), 388 (nn_ops.relu, 'Relu', True, 5000, True), 389 (array_ops.identity, 'Identity', True, 5000, True) 390 ] 391 for params in parameters_list: 392 # Test everything with resource variables and normal variables. 393 test_fn(params[0], params[1], params[2], params[3], params[4], False, 394 None) 395 test_fn(params[0], params[1], params[2], params[3], params[4], True, None) 396 test_fn(params[0], params[1], params[2], params[3], params[4], False, 397 'test') 398 test_fn(params[0], params[1], params[2], params[3], params[4], True, 399 'test') 400 401 def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer, 402 activation_op_name, with_bypass, 403 delay, use_resource): 404 quantization_node_name = 'FakeQuantWithMinMaxVars' 405 conv_scope = self._GetConvScope(scope, with_bypass) 406 delim = '/' if conv_scope else '' 407 408 if scope: 409 scope = scope + '/' 410 411 weights_quant = graph.get_operation_by_name( 412 conv_scope + delim + 'weights_quant/' + quantization_node_name) 413 414 self.assertEqual(weights_quant.type, quantization_node_name) 415 if use_resource: 416 expected_inputs = [ 417 conv_scope + delim + 418 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp', 419 conv_scope + delim + 420 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', 421 ] 422 else: 423 expected_inputs = [ 424 conv_scope + delim + 'weights_quant/' + 'AssignMinLast', 425 conv_scope + delim + 'weights_quant/' + 'AssignMaxLast' 426 ] 427 expected_inputs.append(conv_scope + delim + 'mul_fold') 428 429 self._AssertInputOpsAre(weights_quant, expected_inputs) 430 if layer == 'DepthwiseConv2dNative': 431 output_op_name = conv_scope + delim + ( 432 'weights_quant/delayed_quant/Switch_1' if delay else 'depthwise_Fold') 433 else: 434 output_op_name = conv_scope + delim + ( 435 'weights_quant/delayed_quant/Switch_1' if delay else layer + '_Fold') 436 self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) 437 438 if with_bypass: 439 conv_quant = graph.get_operation_by_name( 440 conv_scope + delim + 'conv_quant/' + quantization_node_name) 441 self.assertEqual(conv_quant.type, quantization_node_name) 442 443 if use_resource: 444 expected_inputs = [ 445 conv_scope + delim + 446 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp', 447 conv_scope + delim + 448 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', 449 ] 450 else: 451 expected_inputs = [ 452 conv_scope + delim + 'conv_quant/AssignMinEma', 453 conv_scope + delim + 'conv_quant/AssignMaxEma', 454 ] 455 expected_inputs.append(conv_scope + delim + 'add_fold') 456 457 self._AssertInputOpsAre(conv_quant, expected_inputs) 458 output_op_name = ( 459 conv_scope + delim + 'conv_quant/delayed_quant/Switch_1' 460 if delay else scope + 'Add') 461 self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) 462 463 act_quant = graph.get_operation_by_name(scope + 'act_quant/' + 464 quantization_node_name) 465 self.assertEqual(act_quant.type, quantization_node_name) 466 467 if use_resource: 468 expected_inputs = [ 469 scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp', 470 scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', 471 ] 472 else: 473 expected_inputs = [ 474 scope + 'act_quant/AssignMinEma', 475 scope + 'act_quant/AssignMaxEma', 476 ] 477 expected_inputs.append(scope + activation_op_name) 478 479 self._AssertInputOpsAre(act_quant, expected_inputs) 480 output_op_name = ( 481 scope + 'act_quant/delayed_quant/Switch_1' 482 if delay else 'control_dependency') 483 self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) 484 self._AssertIdempotent(graph) 485 486 def testQuantize_Conv2dWithBatchNorm(self): 487 self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) 488 489 def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, 490 with_bypass, delay, fused_batch_norm, 491 use_resource, scope): 492 """Tests quantization: inputs -> Conv2d with batch norm -> Activation. 493 494 Args: 495 activation: Callable that returns an Operation, a factory method for the 496 Activation. 497 activation_op_name: String, name of the Activation operation. 498 with_bypass: Bool, when true there is an extra connection added from 499 inputs to just before Activation. 500 delay: Int (optional), delay in number of steps until quantization starts. 501 fused_batch_norm: Bool, when true use FusedBatchNorm. 502 use_resource: Bool, when true uses resource variables. 503 scope: String, specifies top level scope for the graph 504 """ 505 graph = ops.Graph() 506 with graph.as_default(): 507 variable_scope.get_variable_scope().set_use_resource(use_resource) 508 batch_size, height, width, depth = 5, 128, 128, 3 509 inputs = array_ops.zeros((batch_size, height, width, depth)) 510 stride = 1 if with_bypass else 2 511 out_depth = 3 if with_bypass else 32 512 conv_scope = self._GetConvScope(scope, with_bypass) 513 scope = '' if scope is None else scope 514 delim = '/' if scope else '' 515 node = conv2d( 516 inputs, 517 out_depth, [5, 5], 518 stride=stride, 519 padding='SAME', 520 weights_initializer=self._WeightInit(0.09), 521 activation_fn=None, 522 normalizer_fn=batch_norm, 523 normalizer_params=self._BatchNormParams(fused_batch_norm), 524 scope=conv_scope) 525 526 # Manually add a bypass (optional) and an activation. 527 if with_bypass: 528 node = math_ops.add(inputs, node, name=scope + delim + 'Add') 529 530 node = activation(node, name=scope + delim + activation_op_name) 531 532 update_barrier = control_flow_ops.no_op(name='update_barrier') 533 with ops.control_dependencies([update_barrier]): 534 array_ops.identity(node, name='control_dependency') 535 536 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 537 quantize.Quantize(graph, True, quant_delay=delay) 538 539 self._AssertCorrectQuantizedGraphWithBatchNorm( 540 graph, scope, 'Conv2D', activation_op_name, with_bypass, delay, 541 use_resource) 542 543 def testQuantize_FCWithBatchNorm(self): 544 self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) 545 546 def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, 547 with_bypass, delay, fused_batch_norm, 548 use_resource, scope): 549 """Tests quantization: inputs -> FC with batch norm -> Activation. 550 551 Args: 552 activation: Callable that returns an Operation, a factory method for the 553 Activation. 554 activation_op_name: String, name of the Activation operation. 555 with_bypass: Bool, when true there is an extra connection added from 556 inputs to just before Activation. 557 delay: Int (optional), delay in number of steps until quantization starts. 558 fused_batch_norm: Bool, when true use FusedBatchNorm. 559 use_resource: Bool, when true uses resource variables. 560 scope: String, specifies top level scope for the graph 561 """ 562 graph = ops.Graph() 563 with graph.as_default(): 564 variable_scope.get_variable_scope().set_use_resource(use_resource) 565 batch_size, depth = 5, 256 566 inputs = array_ops.zeros((batch_size, depth)) 567 out_depth = 256 if with_bypass else 128 568 conv_scope = self._GetConvScope(scope, with_bypass) 569 scope = '' if scope is None else scope 570 delim = '/' if scope else '' 571 node = fully_connected( 572 inputs, 573 out_depth, 574 weights_initializer=self._WeightInit(0.03), 575 activation_fn=None, 576 normalizer_fn=batch_norm, 577 normalizer_params=self._BatchNormParams(fused_batch_norm), 578 scope=conv_scope) 579 580 # Manually add a bypass (optional) and an activation. 581 if with_bypass: 582 node = math_ops.add(inputs, node, name=scope + delim + 'Add') 583 584 node = activation(node, name=scope + delim + activation_op_name) 585 586 update_barrier = control_flow_ops.no_op(name='update_barrier') 587 with ops.control_dependencies([update_barrier]): 588 array_ops.identity(node, name='control_dependency') 589 590 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 591 592 quantize.Quantize(graph, True, quant_delay=delay) 593 594 self._AssertCorrectQuantizedGraphWithBatchNorm( 595 graph, scope, 'MatMul', activation_op_name, with_bypass, delay, 596 use_resource) 597 598 def testQuantize_DepthwiseConv2dWithBatchNorm(self): 599 self._RunBatchNormTestOverParameters( 600 self._TestQuantize_DepthwiseConv2dWithBatchNorm) 601 602 def _TestQuantize_DepthwiseConv2dWithBatchNorm( 603 self, activation, activation_op_name, with_bypass, delay, 604 fused_batch_norm, use_resource, scope): 605 """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. 606 607 Args: 608 activation: Callable that returns an Operation, a factory method for the 609 Activation. 610 activation_op_name: String, name of the Activation operation. 611 with_bypass: Bool, when true there is an extra connection added from 612 inputs to just before Activation. 613 delay: Int (optional), delay in number of steps until quantization starts. 614 fused_batch_norm: Bool, when true use FusedBatchNorm. 615 use_resource: Bool, when true uses resource variables. 616 scope: String, specifies top level scope for the graph 617 """ 618 graph = ops.Graph() 619 with graph.as_default(): 620 variable_scope.get_variable_scope().set_use_resource(use_resource) 621 batch_size, height, width, depth = 5, 128, 128, 3 622 inputs = array_ops.zeros((batch_size, height, width, depth)) 623 stride = 1 if with_bypass else 2 624 conv_scope = self._GetConvScope(scope, with_bypass) 625 scope = '' if scope is None else scope 626 delim = '/' if scope else '' 627 node = separable_conv2d( 628 inputs, 629 None, [5, 5], 630 stride=stride, 631 depth_multiplier=1.0, 632 padding='SAME', 633 weights_initializer=self._WeightInit(0.09), 634 activation_fn=None, 635 normalizer_fn=batch_norm, 636 normalizer_params=self._BatchNormParams(fused_batch_norm), 637 scope=conv_scope) 638 639 # Manually add a bypass (optional) and an activation. 640 if with_bypass: 641 node = math_ops.add(inputs, node, name=scope + delim + 'Add') 642 643 node = activation(node, name=scope + delim + activation_op_name) 644 645 update_barrier = control_flow_ops.no_op(name='update_barrier') 646 with ops.control_dependencies([update_barrier]): 647 array_ops.identity(node, name='control_dependency') 648 649 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 650 quantize.Quantize(graph, True, quant_delay=delay) 651 652 self._AssertCorrectQuantizedGraphWithBatchNorm( 653 graph, scope, 'DepthwiseConv2dNative', activation_op_name, 654 with_bypass, delay, use_resource) 655 656 def testQuantize_AtrousConvWithBatchNorm(self): 657 self._RunBatchNormTestOverParameters( 658 self._TestQuantize_AtrousConvWithBatchNorm) 659 660 def _TestQuantize_AtrousConvWithBatchNorm( 661 self, activation, activation_op_name, with_bypass, delay, 662 fused_batch_norm, use_resource, scope): 663 """Tests quantization: inputs -> atrous conv with batch norm -> Activation. 664 665 Args: 666 activation: Callable that returns an Operation, a factory method for the 667 Activation. 668 activation_op_name: String, name of the Activation operation. 669 with_bypass: Bool, when true there is an extra connection added from 670 inputs to just before Activation. 671 delay: Int (optional), delay in number of steps until quantization starts. 672 fused_batch_norm: Bool, when true use FusedBatchNorm. 673 use_resource: Bool, when true uses resource variables. 674 scope: String, specifies top level scope for the graph 675 """ 676 graph = ops.Graph() 677 with graph.as_default(): 678 variable_scope.get_variable_scope().set_use_resource(use_resource) 679 batch_size, height, width, depth = 5, 128, 128, 3 680 inputs = array_ops.zeros((batch_size, height, width, depth)) 681 dilation_rate = 2 682 conv_scope = self._GetConvScope(scope, with_bypass) 683 scope = '' if scope is None else scope 684 delim = '/' if scope else '' 685 686 node = separable_conv2d( 687 inputs, 688 None, [3, 3], 689 rate=dilation_rate, 690 depth_multiplier=1.0, 691 padding='SAME', 692 weights_initializer=self._WeightInit(0.09), 693 activation_fn=None, 694 normalizer_fn=batch_norm, 695 normalizer_params=self._BatchNormParams(fused_batch_norm), 696 scope=conv_scope) 697 698 # Manually add a bypass (optional) and an activation. 699 if with_bypass: 700 node = math_ops.add(inputs, node, name=scope + delim + 'Add') 701 702 node = activation(node, name=scope + delim + activation_op_name) 703 704 update_barrier = control_flow_ops.no_op(name='update_barrier') 705 with ops.control_dependencies([update_barrier]): 706 array_ops.identity(node, name='control_dependency') 707 708 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 709 quantize.Quantize(graph, True, quant_delay=delay) 710 711 self._AssertCorrectQuantizedGraphWithBatchNorm( 712 graph, scope, 'DepthwiseConv2dNative', activation_op_name, 713 with_bypass, delay, use_resource) 714 715 def _AssertIdempotent(self, graph): 716 # Ensure that calling the rewrite again doesn't change the graph. 717 graph_def_before = str(graph.as_graph_def()) 718 with graph.as_default(): 719 # Ensuring that calling the rewrite again doesn't add more nodes. 720 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 721 quantize.Quantize(graph, True) 722 graph_def_after = str(graph.as_graph_def()) 723 self.assertEqual(graph_def_before, graph_def_after) 724 725 def testBatchNormForcedUpdates(self): 726 parameter_list = [ 727 # (activation, activation_op_name, fused_batch_norm) 728 (nn_ops.relu6, 'Relu6', False), 729 (nn_ops.relu, 'Relu', False), 730 (array_ops.identity, 'Identity', False), 731 (nn_ops.relu6, 'Relu6', True), 732 (nn_ops.relu, 'Relu', True), 733 (array_ops.identity, 'Identity', True), 734 ] 735 for params in parameter_list: 736 self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False) 737 self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True) 738 739 def _TestBatchNormForcedUpdates(self, activation, activation_op_name, 740 fused_batch_norm, use_resource): 741 """post_activation bypass quantization should happen with forced updates.""" 742 graph = ops.Graph() 743 with graph.as_default(): 744 variable_scope.get_variable_scope().set_use_resource(use_resource) 745 batch_size, height, width, depth = 5, 128, 128, 3 746 input1 = array_ops.zeros((batch_size, height, width, depth)) 747 input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) 748 # Setting updates_collections to None forces updates adding an extra 749 # identity operation following batch norms. 750 bn_params = self._BatchNormParams( 751 fused=fused_batch_norm, force_updates=True) 752 conv = conv2d( 753 input1, 754 32, [5, 5], 755 stride=2, 756 padding='SAME', 757 weights_initializer=self._WeightInit(0.09), 758 activation_fn=activation, 759 normalizer_fn=batch_norm, 760 normalizer_params=bn_params, 761 scope='test/test') 762 bypass_tensor = math_ops.add(conv, input2, name='test/add') 763 # The output of the post_activation bypass will be another layer. 764 _ = conv2d( 765 bypass_tensor, 766 32, [5, 5], 767 stride=2, 768 padding='SAME', 769 weights_initializer=self._WeightInit(0.09), 770 normalizer_fn=batch_norm, 771 normalizer_params=bn_params, 772 activation_fn=activation, 773 scope='test/unused') 774 775 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 776 quantize.Quantize(graph, is_training=True) 777 778 # Ensure that the bypass node is preceded by and followed by a 779 # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an 780 # activation. 781 self.assertTrue('FakeQuantWithMinMaxVars' in 782 [c.type for c in bypass_tensor.consumers()]) 783 self.assertTrue('FakeQuantWithMinMaxVars' in 784 [i.op.type for i in bypass_tensor.op.inputs]) 785 786 with open('/tmp/bn_quant_test.pbtxt', 'w') as f: 787 f.write(str(graph.as_graph_def())) 788 789 def _GetConvScope(self, scope, with_bypass): 790 if scope is None: 791 scope = '' 792 delim = '/' if scope else '' 793 794 if with_bypass: 795 conv_scope = scope + delim + 'test2' 796 else: 797 conv_scope = scope 798 799 return conv_scope 800 801 def _BatchNormParams(self, fused=False, force_updates=False): 802 params = { 803 'center': True, 804 'scale': True, 805 'decay': 1.0 - 0.003, 806 'fused': fused 807 } 808 if force_updates: 809 params['updates_collections'] = None 810 return params 811 812 def _WeightInit(self, stddev): 813 """Returns truncated normal variable initializer. 814 815 Function is defined purely to shorten the name so that it stops wrapping. 816 817 Args: 818 stddev: Standard deviation of normal variable. 819 820 Returns: 821 An initialized that initializes with a truncated normal variable. 822 """ 823 return init_ops.truncated_normal_initializer(stddev=stddev) 824 825 def _AssertInputOpsAre(self, op, in_op_names): 826 """Asserts that all inputs to op come from in_op_names (disregarding order). 827 828 Args: 829 op: Operation to check inputs for. 830 in_op_names: List of strings, operations where all op's inputs should 831 come from. 832 """ 833 expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] 834 self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) 835 836 def _AssertOutputGoesToOps(self, op, graph, out_op_names): 837 """Asserts that outputs from op go to out_op_names (and perhaps others). 838 839 Args: 840 op: Operation to check outputs for. 841 graph: Graph where output operations are located. 842 out_op_names: List of strings, operations where op's outputs should go. 843 """ 844 for out_op_name in out_op_names: 845 out_op = graph.get_operation_by_name(out_op_name) 846 self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) 847 848 849if __name__ == '__main__': 850 googletest.main() 851