1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for the quantize_graph graph rewriting API.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23from tensorflow.contrib.layers.python.layers import layers 24from tensorflow.contrib.quantize.python import quantize_graph 25from tensorflow.python import training 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import init_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn_ops 32from tensorflow.python.ops import template 33from tensorflow.python.platform import googletest 34 35 36class QuantizeGraphTest(test_util.TensorFlowTestCase): 37 # We have a lot of other tests that test the details of the rewrite, here we 38 # just the specific features of the quantize_graph API. 39 40 def _RunTestOverAllRewrites(self, test_fn): 41 rewrite_fns = [ 42 quantize_graph.create_training_graph, 43 quantize_graph.create_eval_graph, 44 quantize_graph.experimental_create_training_graph, 45 quantize_graph.experimental_create_eval_graph, 46 ] 47 for fn in rewrite_fns: 48 test_fn(fn) 49 50 def _RunTestOverTrainingRewrites(self, test_fn): 51 rewrite_fns = [ 52 quantize_graph.create_training_graph, 53 quantize_graph.experimental_create_training_graph, 54 functools.partial( 55 quantize_graph.experimental_create_training_graph, symmetric=True), 56 ] 57 for fn in rewrite_fns: 58 test_fn(fn) 59 60 def _RunTestOverEvalRewrites(self, test_fn): 61 rewrite_fns = [ 62 quantize_graph.create_eval_graph, 63 quantize_graph.experimental_create_eval_graph, 64 functools.partial( 65 quantize_graph.experimental_create_eval_graph, symmetric=True), 66 ] 67 for fn in rewrite_fns: 68 test_fn(fn) 69 70 def _RunTestOverExperimentalRewrites(self, test_fn): 71 rewrite_fns = [ 72 quantize_graph.experimental_create_training_graph, 73 quantize_graph.experimental_create_eval_graph, 74 ] 75 for fn in rewrite_fns: 76 test_fn(fn) 77 78 def _RunTestOverExperimentalRewritesWithScope(self, test_fn, scope): 79 def with_absent_scope(fn): 80 def fn_with_absent_scope(*args): 81 fn(*args, scope=scope) 82 return fn_with_absent_scope 83 rewrite_fns = [ 84 with_absent_scope( 85 quantize_graph.experimental_create_training_graph), 86 with_absent_scope( 87 quantize_graph.experimental_create_eval_graph), 88 ] 89 for fn in rewrite_fns: 90 test_fn(fn) 91 92 def testRewrite(self): 93 self._RunTestOverAllRewrites(self._TestRewrite) 94 95 def _TestRewrite(self, rewrite_fn): 96 graph = ops.Graph() 97 with graph.as_default(): 98 self._ConvLayer() 99 100 orig_variable_names = set( 101 [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) 102 103 rewrite_fn(graph) 104 105 q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 106 # Ensure that variables were added. 107 self.assertTrue(len(orig_variable_names) < len(q_variables)) 108 109 def testDefaultGraph(self): 110 self._RunTestOverAllRewrites(self._TestRewrite) 111 112 def _TestDefaultGraph(self, rewrite_fn): 113 # Tests that the default graph is correctly used when no args are provided 114 # to rewrite_fn. 115 with ops.Graph().as_default() as g: 116 self._ConvLayer() 117 orig_variable_names = set( 118 [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) 119 rewrite_fn() 120 121 q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 122 # Ensure that variables were added. 123 self.assertTrue(len(orig_variable_names) < len(q_variables)) 124 125 def testWithPostActivationBypass(self): 126 self._RunTestOverAllRewrites(self._TestWithPostActivationBypass) 127 128 def _TestWithPostActivationBypass(self, rewrite_fn): 129 # Tests that the default graph is correctly used when no args are provided 130 # to rewrite_fn. 131 with ops.Graph().as_default() as g: 132 self._ConvLayer(post_activation_bypass=True, scope='scope1') 133 rewrite_fn() 134 135 op_names = [op.name for op in g.get_operations()] 136 self.assertTrue(any( 137 'scope1/post_activation_bypass_quant/' in name for name in op_names)) 138 139 def testQuantDelay(self): 140 self._RunTestOverTrainingRewrites(self._TestQuantDelay) 141 142 def _TestQuantDelay(self, rewrite_fn): 143 with ops.Graph().as_default() as g: 144 self._ConvLayer() 145 quant_delay = 100 146 rewrite_fn(quant_delay=quant_delay) 147 148 quant_delay_found = False 149 for op in g.get_operations(): 150 # Check to see if the quant_delay is correctly set. 151 if 'activate_quant' in op.name and op.type == 'Const': 152 quant_delay_found = True 153 const_value = str(op.get_attr('value')) 154 self.assertTrue(('int64_val: %i' % quant_delay) in const_value) 155 self.assertTrue(quant_delay_found) 156 157 def testTrainingOpsCheck(self): 158 self._RunTestOverTrainingRewrites(self._TestTrainingOpsCheck) 159 160 def _TestTrainingOpsCheck(self, rewrite_fn): 161 with ops.Graph().as_default(): 162 output = self._ConvLayer() 163 output_scalar = math_ops.reduce_sum(output) 164 loss = math_ops.square(output_scalar - 1) 165 opt = training.gradient_descent.GradientDescentOptimizer(0.0001) 166 opt.minimize(loss) 167 with self.assertRaisesRegexp(ValueError, 'Training op found in graph'): 168 rewrite_fn() 169 170 def testWeightBits(self): 171 self._RunTestOverExperimentalRewrites(self._TestWeightBits) 172 173 def _TestWeightBits(self, rewrite_fn): 174 with ops.Graph().as_default() as g: 175 self._ConvLayer() 176 weight_bits = 4 177 rewrite_fn(weight_bits=weight_bits) 178 179 weights_quant_found = False 180 for op in g.get_operations(): 181 # Check to see if FakeQuant operations for weights have the right bits 182 # set. 183 if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars': 184 weights_quant_found = True 185 self.assertEqual(op.get_attr('num_bits'), weight_bits) 186 self.assertTrue(weights_quant_found) 187 188 def testActivationBits(self): 189 self._RunTestOverExperimentalRewrites(self._TestActivationBits) 190 191 def _TestActivationBits(self, rewrite_fn): 192 with ops.Graph().as_default() as g: 193 self._ConvLayer() 194 activation_bits = 4 195 rewrite_fn(activation_bits=activation_bits) 196 197 act_quant_found = False 198 for op in g.get_operations(): 199 # Check to see if FakeQuant operations for activations have the right bits 200 # set. 201 act_quant_names = ['act_quant', 'conv_quant', 'add_quant'] 202 if any(s in op.name 203 for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars': 204 act_quant_found = True 205 self.assertEqual(op.get_attr('num_bits'), activation_bits) 206 self.assertTrue(act_quant_found) 207 208 def testTrainingQuantization(self): 209 self._RunTestOverTrainingRewrites(self._TestTrainingQuantization) 210 211 def _TestTrainingQuantization(self, rewrite_fn): 212 with ops.Graph().as_default() as g: 213 self._ConvLayer() 214 rewrite_fn() 215 216 # Ensure that FakeQuant and variable update nodes were found. 217 quant_found = False 218 assign_min_last_found = False 219 assign_min_ema_found = False 220 assign_max_last_found = False 221 assign_max_ema_found = False 222 for op in g.get_operations(): 223 # Check that FakeQuant operations were added. 224 if op.type == 'FakeQuantWithMinMaxVars': 225 quant_found = True 226 # Check that update operations for the added min max variables exist in 227 # the graph. 228 if 'AssignMinLast' in op.name: 229 assign_min_last_found = True 230 elif 'AssignMinEma' in op.name: 231 assign_min_ema_found = True 232 elif 'AssignMaxLast' in op.name: 233 assign_max_last_found = True 234 elif 'AssignMaxEma' in op.name: 235 assign_max_ema_found = True 236 self.assertTrue(assign_min_last_found) 237 self.assertTrue(assign_min_ema_found) 238 self.assertTrue(assign_max_last_found) 239 self.assertTrue(assign_max_ema_found) 240 self.assertTrue(quant_found) 241 242 def testEvalQuantization(self): 243 self._RunTestOverEvalRewrites(self._TestEvalQuantization) 244 245 def _TestEvalQuantization(self, rewrite_fn): 246 with ops.Graph().as_default() as g: 247 self._ConvLayer() 248 rewrite_fn() 249 250 # Ensure that FakeQuant and variable update nodes were found. 251 quant_found = False 252 for op in g.get_operations(): 253 # Check that FakeQuant operations were added. 254 if op.type == 'FakeQuantWithMinMaxVars': 255 quant_found = True 256 # Check that update operations for the added min max variables don't 257 # exist in the graph. 258 update_names = [ 259 'AssignMinLast', 'AssignMinEma', 'AssignMaxLast', 'AssignMaxEma' 260 ] 261 self.assertFalse(any(s in op.name for s in update_names)) 262 self.assertTrue(quant_found) 263 264 def testIdempotent(self): 265 self._RunTestOverAllRewrites(self._TestIdempotent) 266 267 def _TestIdempotent(self, rewrite_fn): 268 with ops.Graph().as_default() as g: 269 self._ConvLayer() 270 rewrite_fn() 271 graph_def_before = str(g.as_graph_def()) 272 # Ensuring that calling the rewrite again doesn't add more nodes. 273 rewrite_fn() 274 graph_def_after = str(g.as_graph_def()) 275 self.assertEqual(graph_def_before, graph_def_after) 276 277 def testIdentityNode(self): 278 self._RunTestOverAllRewrites(self._TestIdentityNode) 279 280 def _TestIdentityNode(self, rewrite_fn): 281 graph = ops.Graph() 282 with graph.as_default(): 283 self._LayerWithIdentity() 284 285 rewrite_fn(graph) 286 op_names = [op.name for op in graph.get_operations()] 287 self.assertTrue(any('test/Conv/weights_quant' in name for name in op_names)) 288 self.assertTrue(any('test/Conv/act_quant' in name for name in op_names)) 289 bn_out_identity = graph.get_operation_by_name('test/bn_out') 290 self._AssertInputOpsAre(bn_out_identity, [ 291 'test/Conv/add_fold', 292 ]) 293 294 conv_out_identity = graph.get_operation_by_name('test/conv_out') 295 self._AssertOutputGoesToOps(conv_out_identity, graph, 296 ['test/BatchNorm/FusedBatchNorm']) 297 298 def testActivationQuantization(self): 299 self._RunTestOverAllRewrites(self._TestActivationQuantization) 300 301 def _TestActivationQuantization(self, rewrite_fn): 302 graph = ops.Graph() 303 with graph.as_default(): 304 _ = self._LayerWithActivationProcessing() 305 306 rewrite_fn(graph) 307 # Check if outputs of multipliers and adds are quantized. 308 309 mul_op = graph.get_operation_by_name('test/Mul') 310 self._AssertOutputGoesToOps( 311 mul_op, graph, 312 ['test/Mul/activation_Mul_quant/FakeQuantWithMinMaxVars']) 313 mul_op = graph.get_operation_by_name('test/Mul_1') 314 self._AssertOutputGoesToOps( 315 mul_op, graph, 316 ['test/Mul_1/activation_Mul_quant/FakeQuantWithMinMaxVars']) 317 add_op = graph.get_operation_by_name('test/add') 318 self._AssertOutputGoesToOps( 319 add_op, graph, 320 ['test/add/activation_Add_quant/FakeQuantWithMinMaxVars']) 321 322 def testRewriteWithScope(self): 323 self._RunTestOverExperimentalRewritesWithScope( 324 self._TestRewriteWithScope, 'scope1') 325 326 def _TestRewriteWithScope(self, rewrite_fn): 327 graph = ops.Graph() 328 with graph.as_default(): 329 scope1_output = self._ConvLayer(scope='scope1') 330 self._ConvLayer(input_tensor=scope1_output, scope='scope2') 331 332 rewrite_fn(graph) 333 334 op_names = [op.name for op in graph.get_operations()] 335 # The weights and activation of scope1 is quantized, but not scope2. 336 self.assertTrue( 337 any('scope1/Conv/act_quant' in name for name in op_names)) 338 self.assertTrue( 339 any('scope1/Conv/weights_quant' in name for name in op_names)) 340 self.assertFalse( 341 any('scope2/Conv/act_quant' in name for name in op_names)) 342 self.assertFalse( 343 any('scope2/Conv/weights_quant' in name for name in op_names)) 344 345 def testRewriteWithNonMatchingScope(self): 346 self._RunTestOverExperimentalRewritesWithScope( 347 self._TestRewriteWithNonMatchingScope, 'NonExistingScope') 348 349 def _TestRewriteWithNonMatchingScope(self, rewrite_fn): 350 graph = ops.Graph() 351 with graph.as_default(): 352 self._ConvLayer() 353 354 op_names_before_rewrite = set([op.name for op in graph.get_operations()]) 355 rewrite_fn(graph) 356 op_names_after_rewrite = set([op.name for op in graph.get_operations()]) 357 358 # No ops should be inserted or removed. 359 self.assertEqual(op_names_before_rewrite, op_names_after_rewrite) 360 361 def testActivationRewriteWithScope(self): 362 self._RunTestOverExperimentalRewritesWithScope( 363 self._TestActivationRewriteWithScope, 'scope1') 364 365 def _TestActivationRewriteWithScope(self, rewrite_fn): 366 graph = ops.Graph() 367 with graph.as_default(): 368 output = self._LayerWithIdentity(scope='scope1') 369 with ops.name_scope('scope2'): 370 output = nn_ops.relu6(output) 371 scaled_output1 = math_ops.mul(2.0, output) 372 scaled_output2 = math_ops.mul(3.0, output) 373 output = scaled_output1 + scaled_output2 374 rewrite_fn(graph) 375 376 op_names = [op.name for op in graph.get_operations()] 377 # The weights and activation of scope1 is quantized, but not scope2. 378 self.assertTrue(any('scope1/Conv/act_quant' in name for name in op_names)) 379 self.assertTrue( 380 any('scope1/Conv/weights_quant' in name for name in op_names)) 381 382 for op_name in op_names: 383 if op_name.startswith('scope2'): 384 self.assertTrue('FakeQuant' not in op_name) 385 386 def testActivationRewriteWithNonMatchingScope(self): 387 self._RunTestOverExperimentalRewritesWithScope( 388 self._TestActivationRewriteWithNonMatchingScope, 'NonExistingScope') 389 390 def _TestActivationRewriteWithNonMatchingScope(self, rewrite_fn): 391 graph = ops.Graph() 392 with graph.as_default(): 393 self._LayerWithActivationProcessing() 394 395 rewrite_fn(graph) 396 op_types_after_rewrite = set([op.type for op in graph.get_operations()]) 397 self.assertFalse( 398 op_types_after_rewrite.intersection('FakeQuantWithMinMaxVars')) 399 # No fake quant ops should be inserted. 400 401 def testWithSharedWeights(self): 402 403 self._RunTestOverAllRewrites(self._TestWithSharedWeights) 404 self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights) 405 406 def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1): 407 self._TestWithSharedWeights(rewrite_fn, quant_delay) 408 409 def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None): 410 with ops.Graph().as_default() as g: 411 conv = template.make_template('shared_weights_conv', self._ConvLayer) 412 conv() 413 conv() 414 if quant_delay is None: 415 rewrite_fn() 416 else: 417 rewrite_fn(quant_delay=quant_delay) 418 419 conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D'] 420 weights_quants = [ 421 op for op in g.get_operations() 422 if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars' 423 ] 424 # Check that the shared weights variable is not quantized multiple times 425 self.assertTrue(len(weights_quants) == 1) 426 weights_quant_tensor = weights_quants[0].outputs[0] 427 if quant_delay: 428 delayed_weights_quants = [ 429 op for op in g.get_operations() 430 if 'weights_quant' in op.name and op.type == 'Merge' 431 ] 432 self.assertTrue(len(delayed_weights_quants) == 1) 433 weights_quant_tensor = delayed_weights_quants[0].outputs[0] 434 # Check that the Conv2D operations get the quantized weights 435 self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops)) 436 437 def _ConvLayer( 438 self, input_tensor=None, scope='test', pre_activation_bypass=False, 439 post_activation_bypass=False): 440 """Add a basic convolution layer to the default graph.""" 441 batch_size, height, width, depth = 5, 128, 128, 3 442 if input_tensor is None: 443 input_tensor = array_ops.zeros((batch_size, height, width, depth)) 444 weight_init = init_ops.truncated_normal_initializer 445 with ops.name_scope(scope): 446 output = layers.conv2d( 447 input_tensor, 448 depth, [5, 5], 449 padding='SAME', 450 weights_initializer=weight_init(0.09), 451 activation_fn=None) 452 if pre_activation_bypass: 453 output += input_tensor 454 output = nn_ops.relu6(output) 455 if post_activation_bypass: 456 output += input_tensor 457 return output 458 459 def _LayerWithIdentity(self, 460 input_tensor=None, 461 scope='test', 462 post_activation_bypass=False): 463 """Add a basic conv, identity, batch norm with skip to the default graph.""" 464 batch_size, height, width, depth = 5, 128, 128, 3 465 if input_tensor is None: 466 input_tensor = array_ops.zeros((batch_size, height, width, depth)) 467 weight_init = init_ops.truncated_normal_initializer 468 with ops.name_scope(scope): 469 output = layers.conv2d( 470 input_tensor, 471 depth, [5, 5], 472 padding='SAME', 473 weights_initializer=weight_init(0.09), 474 activation_fn=None, 475 normalizer_fn=None, 476 biases_initializer=None) 477 output = array_ops.identity(output, name='conv_out') 478 479 output = layers.batch_norm( 480 output, center=True, scale=True, decay=1.0 - 0.003, fused=True) 481 482 output = array_ops.identity(output, name='bn_out') 483 if post_activation_bypass: 484 output += input_tensor 485 return output 486 487 def _LayerWithActivationProcessing(self, 488 input_tensor=None, 489 scope='test', 490 post_activation_bypass=False): 491 492 batch_size, height, width, depth = 5, 128, 128, 3 493 if input_tensor is None: 494 input_tensor = array_ops.zeros((batch_size, height, width, depth)) 495 weight_init = init_ops.truncated_normal_initializer 496 with ops.name_scope(scope): 497 output = layers.conv2d( 498 input_tensor, 499 depth, [5, 5], 500 padding='SAME', 501 weights_initializer=weight_init(0.09), 502 activation_fn=None, 503 normalizer_fn=None, 504 biases_initializer=None) 505 506 output = layers.batch_norm( 507 output, center=True, scale=True, decay=1.0 - 0.003, fused=True) 508 509 output = nn_ops.relu6(output) 510 scaled_output1 = math_ops.mul(2.0, output) 511 scaled_output2 = math_ops.mul(3.0, output) 512 output = scaled_output1 + scaled_output2 513 return output 514 515 def _AssertInputOpsAre(self, op, in_op_names): 516 """Asserts that all inputs to op come from in_op_names (disregarding order). 517 518 Args: 519 op: Operation to check inputs for. 520 in_op_names: List of strings, operations where all op's inputs should come 521 from. 522 """ 523 expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] 524 self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) 525 526 def _AssertOutputGoesToOps(self, op, graph, out_op_names): 527 """Asserts that outputs from op go to out_op_names (and perhaps others). 528 529 Args: 530 op: Operation to check outputs for. 531 graph: Graph where output operations are located. 532 out_op_names: List of strings, operations where op's outputs should go. 533 """ 534 for out_op_name in out_op_names: 535 out_op = graph.get_operation_by_name(out_op_name) 536 self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) 537 538 539if __name__ == '__main__': 540 googletest.main() 541