• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Parameterized unit tests for quantizing a Tensorflow graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.layers.python.layers import layers
22from tensorflow.contrib.quantize.python import fold_batch_norms
23from tensorflow.contrib.quantize.python import quantize
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import init_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn_ops
31from tensorflow.python.ops import variable_scope
32from tensorflow.python.platform import googletest
33
34batch_norm = layers.batch_norm
35conv2d = layers.conv2d
36fully_connected = layers.fully_connected
37separable_conv2d = layers.separable_conv2d
38
39
40class QuantizeTest(test_util.TensorFlowTestCase):
41
42  def _RunWithoutBatchNormTestOverParameters(self, test_fn):
43    # TODO(suharshs): Use parameterized test once OSS TF supports it.
44    parameters_list = [
45        # (activation, activation_op_name, with_bypass, delay)
46        (nn_ops.relu6, 'Relu6', False, None),
47        (nn_ops.relu, 'Relu', False, None),
48        (array_ops.identity, 'Identity', False, None),
49        (nn_ops.relu6, 'Relu6', False, 5000),
50        (nn_ops.relu, 'Relu', False, 5000),
51        (array_ops.identity, 'Identity', False, 5000),
52        (nn_ops.relu6, 'Relu6', True, None),
53        (nn_ops.relu, 'Relu', True, None),
54        (array_ops.identity, 'Identity', True, None),
55        (nn_ops.relu6, 'Relu6', True, 5000),
56        (nn_ops.relu, 'Relu', True, 5000),
57        (array_ops.identity, 'Identity', True, 5000),
58    ]
59    for params in parameters_list:
60      # Test everything with resource variables and normal variables.
61      test_fn(params[0], params[1], params[2], params[3], False, None)
62      test_fn(params[0], params[1], params[2], params[3], True, None)
63      # Test with both empty scope and an example scope
64      test_fn(params[0], params[1], params[2], params[3], False, 'test')
65      test_fn(params[0], params[1], params[2], params[3], True, 'test')
66
67  def _AssertCorrectQuantizedGraphWithoutBatchNorm(
68      self, graph, scope, layer, activation_op_name, with_bypass, delay,
69      use_resource):
70    quantization_node_name = 'FakeQuantWithMinMaxVars'
71    conv_scope = self._GetConvScope(scope, with_bypass)
72    delim = '/' if conv_scope else ''
73
74    if scope:
75      scope = scope + '/'
76    weights_quant = graph.get_operation_by_name(
77        conv_scope + delim + 'weights_quant/' + quantization_node_name)
78    self.assertEqual(weights_quant.type, quantization_node_name)
79
80    # Assemble the expected inputs.
81    if use_resource:
82      expected_inputs = [
83          conv_scope + delim +
84          'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
85          conv_scope + delim +
86          'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
87      ]
88      if layer == 'DepthwiseConv2dNative':
89        expected_inputs.append(conv_scope + delim + 'depthwise/ReadVariableOp')
90      else:
91        expected_inputs.append(conv_scope + delim + layer + '/ReadVariableOp')
92    else:
93      expected_inputs = [
94          conv_scope + delim + 'weights_quant/AssignMinLast',
95          conv_scope + delim + 'weights_quant/AssignMaxLast',
96      ]
97      if layer == 'DepthwiseConv2dNative':
98        expected_inputs.append(conv_scope + delim + 'depthwise_weights/read')
99      else:
100        expected_inputs.append(conv_scope + delim + 'weights/read')
101
102    self._AssertInputOpsAre(weights_quant, expected_inputs)
103    if delay and delay > 0:
104      output_op_name = (
105          conv_scope + delim + 'weights_quant/delayed_quant/Switch_1')
106    else:
107      if layer == 'DepthwiseConv2dNative':
108        output_op_name = conv_scope + delim + 'depthwise'
109      else:
110        output_op_name = conv_scope + delim + layer
111
112    self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
113
114    if with_bypass:
115      conv_quant = graph.get_operation_by_name(
116          conv_scope + delim + 'conv_quant/' + quantization_node_name)
117      self.assertEqual(conv_quant.type, quantization_node_name)
118      if use_resource:
119        expected_inputs = [
120            conv_scope + delim +
121            'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
122            conv_scope + delim +
123            'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
124            conv_scope + delim + 'BiasAdd',
125        ]
126      else:
127        expected_inputs = [
128            conv_scope + delim + 'conv_quant/AssignMinEma',
129            conv_scope + delim + 'conv_quant/AssignMaxEma',
130            conv_scope + delim + 'BiasAdd'
131        ]
132      self._AssertInputOpsAre(conv_quant, expected_inputs)
133
134      output_op_name = (
135          conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
136          if delay else scope + 'Add')
137      self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
138
139    act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
140                                            quantization_node_name)
141    self.assertEqual(act_quant.type, quantization_node_name)
142    if use_resource:
143      expected_inputs = [
144          scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
145          scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
146          scope + activation_op_name,
147      ]
148    else:
149      expected_inputs = [
150          scope + 'act_quant/AssignMinEma', scope + 'act_quant/AssignMaxEma',
151          scope + activation_op_name
152      ]
153    self._AssertInputOpsAre(act_quant, expected_inputs)
154    output_op_name = (
155        scope + 'act_quant/delayed_quant/Switch_1'
156        if delay else 'control_dependency')
157    self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
158    self._AssertIdempotent(graph)
159
160  def testQuantize_Conv2dWithoutBatchNorm(self):
161    self._RunWithoutBatchNormTestOverParameters(
162        self._TestQuantize_Conv2dWithoutBatchNorm)
163
164  def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
165                                           with_bypass, delay, use_resource,
166                                           scope):
167    """Tests quantization: inputs -> Conv2d no batch norm -> Activation.
168
169    Args:
170      activation: Callable that returns an Operation, a factory method for the
171        Activation.
172      activation_op_name: String, name of the Activation operation.
173      with_bypass: Bool, when true there is an extra connection added from
174        inputs to just before Activation.
175      delay: Int (optional), delay in number of steps until quantization starts.
176      use_resource: Bool, when true uses resource variables.
177      scope: String, specifies top level scope for the graph
178    """
179    graph = ops.Graph()
180    with graph.as_default():
181      variable_scope.get_variable_scope().set_use_resource(use_resource)
182      batch_size, height, width, depth = 5, 128, 128, 3
183      inputs = array_ops.zeros((batch_size, height, width, depth))
184      stride = 1 if with_bypass else 2
185      out_depth = 3 if with_bypass else 32
186      activation_fn = None if with_bypass else activation
187      conv_scope = self._GetConvScope(scope, with_bypass)
188      scope = '' if scope is None else scope
189      delim = '/' if scope else ''
190      node = conv2d(
191          inputs,
192          out_depth, [5, 5],
193          stride=stride,
194          padding='SAME',
195          weights_initializer=self._WeightInit(0.09),
196          activation_fn=activation_fn,
197          scope=conv_scope)
198      if with_bypass:
199        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
200        node = activation(node, name=scope + delim + activation_op_name)
201      update_barrier = control_flow_ops.no_op(name='update_barrier')
202      with ops.control_dependencies([update_barrier]):
203        array_ops.identity(node, name='control_dependency')
204
205      quantize.Quantize(graph, True, quant_delay=delay)
206
207    if conv_scope is None:
208      conv_scope = ''
209
210    self._AssertCorrectQuantizedGraphWithoutBatchNorm(
211        graph, scope, 'Conv2D', activation_op_name, with_bypass, delay,
212        use_resource)
213
214  def testQuantize_FCWithoutBatchNorm(self):
215    self._RunWithoutBatchNormTestOverParameters(
216        self._TestQuantize_FCWithoutBatchNorm)
217
218  def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
219                                       with_bypass, delay, use_resource, scope):
220    """Tests quantization: inputs -> FC no batch norm -> Activation.
221
222    Args:
223      activation: Callable that returns an Operation, a factory method for the
224        Activation.
225      activation_op_name: String, name of the Activation operation.
226      with_bypass: Bool, when true there is an extra connection added from
227        inputs to just before Activation.
228      delay: Int (optional), delay in number of steps until quantization starts.
229      use_resource: Bool, when true uses resource variables.
230      scope: String, specifies top level scope for the graph
231    """
232    graph = ops.Graph()
233    with graph.as_default():
234      variable_scope.get_variable_scope().set_use_resource(use_resource)
235      batch_size, depth = 5, 256
236      inputs = array_ops.zeros((batch_size, depth))
237      out_depth = 256 if with_bypass else 128
238      activation_fn = None if with_bypass else activation
239      fc_scope = self._GetConvScope(scope, with_bypass)
240      scope = '' if scope is None else scope
241      delim = '/' if scope else ''
242      node = fully_connected(
243          inputs,
244          out_depth,
245          weights_initializer=self._WeightInit(0.03),
246          activation_fn=activation_fn,
247          scope=fc_scope)
248      if with_bypass:
249        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
250        node = activation(node, name=scope + delim + activation_op_name)
251      update_barrier = control_flow_ops.no_op(name='update_barrier')
252      with ops.control_dependencies([update_barrier]):
253        array_ops.identity(node, name='control_dependency')
254      quantize.Quantize(graph, True, quant_delay=delay)
255
256    self._AssertCorrectQuantizedGraphWithoutBatchNorm(
257        graph, scope, 'MatMul', activation_op_name, with_bypass, delay,
258        use_resource)
259
260  def testQuantize_DepthwiseConv2dWithoutBatchNorm(self):
261    self._RunWithoutBatchNormTestOverParameters(
262        self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
263
264  def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
265      self, activation, activation_op_name, with_bypass, delay, use_resource,
266      scope):
267    """Tests quantization: inputs -> DWConv2d no batch norm -> Activation.
268
269    Args:
270      activation: Callable that returns an Operation, a factory method for the
271        Activation.
272      activation_op_name: String, name of the Activation operation.
273      with_bypass: Bool, when true there is an extra connection added from
274        inputs to just before Activation.
275      delay: Int (optional), delay in number of steps until quantization starts.
276      use_resource: Bool, when true uses resource variables.
277      scope: String, specifies top level scope for the graph
278    """
279    graph = ops.Graph()
280    with graph.as_default():
281      variable_scope.get_variable_scope().set_use_resource(use_resource)
282      batch_size, height, width, depth = 5, 128, 128, 3
283      inputs = array_ops.zeros((batch_size, height, width, depth))
284      stride = 1 if with_bypass else 2
285      activation_fn = None if with_bypass else activation
286      conv_scope = self._GetConvScope(scope, with_bypass)
287      scope = '' if scope is None else scope
288      delim = '/' if scope else ''
289
290      node = separable_conv2d(
291          inputs,
292          None, [5, 5],
293          stride=stride,
294          depth_multiplier=1.0,
295          padding='SAME',
296          weights_initializer=self._WeightInit(0.09),
297          activation_fn=activation_fn,
298          scope=conv_scope)
299      if with_bypass:
300        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
301        node = activation(node, name=scope + delim + activation_op_name)
302      update_barrier = control_flow_ops.no_op(name='update_barrier')
303      with ops.control_dependencies([update_barrier]):
304        array_ops.identity(node, name='control_dependency')
305      quantize.Quantize(graph, True, quant_delay=delay)
306
307    self._AssertCorrectQuantizedGraphWithoutBatchNorm(
308        graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass,
309        delay, use_resource)
310
311  def testQuantize_AtrousConvWithoutBatchNorm(self):
312    self._RunWithoutBatchNormTestOverParameters(
313        self._TestQuantize_AtrousConvWithoutBatchNorm)
314
315  def _TestQuantize_AtrousConvWithoutBatchNorm(self, activation,
316                                               activation_op_name, with_bypass,
317                                               delay, use_resource, scope):
318    """Tests quantization: inputs -> atrous conv no batch norm -> Activation.
319
320    Args:
321      activation: Callable that returns an Operation, a factory method for the
322        Activation.
323      activation_op_name: String, name of the Activation operation.
324      with_bypass: Bool, when true there is an extra connection added from
325        inputs to just before Activation.
326      delay: Int (optional), delay in number of steps until quantization starts.
327      use_resource: Bool, when true uses resource variables.
328      scope: String, specifies top level scope for the graph
329    """
330    graph = ops.Graph()
331    with graph.as_default():
332      variable_scope.get_variable_scope().set_use_resource(use_resource)
333      batch_size, height, width, depth = 5, 128, 128, 3
334      inputs = array_ops.zeros((batch_size, height, width, depth))
335      dilation_rate = 2
336      activation_fn = None if with_bypass else activation
337      conv_scope = self._GetConvScope(scope, with_bypass)
338      scope = '' if scope is None else scope
339      delim = '/' if scope else ''
340
341      node = separable_conv2d(
342          inputs,
343          None, [3, 3],
344          rate=dilation_rate,
345          depth_multiplier=1.0,
346          padding='SAME',
347          weights_initializer=self._WeightInit(0.09),
348          activation_fn=activation_fn,
349          scope=conv_scope)
350      if with_bypass:
351        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
352        node = activation(node, name=scope + delim + activation_op_name)
353      update_barrier = control_flow_ops.no_op(name='update_barrier')
354      with ops.control_dependencies([update_barrier]):
355        array_ops.identity(node, name='control_dependency')
356      quantize.Quantize(graph, True, quant_delay=delay)
357
358    self._AssertCorrectQuantizedGraphWithoutBatchNorm(
359        graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass,
360        delay, use_resource)
361
362  def _RunBatchNormTestOverParameters(self, test_fn):
363    # TODO(suharshs): Use parameterized test once OSS TF supports it.
364    parameters_list = [
365        # (activation, activation_op_name, with_bypass, delay, fused_batch_norm)
366        (nn_ops.relu6, 'Relu6', False, None, False),
367        (nn_ops.relu, 'Relu', False, None, False),
368        (array_ops.identity, 'Identity', False, None, False),
369        (nn_ops.relu6, 'Relu6', False, 5000, False),
370        (nn_ops.relu, 'Relu', False, 5000, False),
371        (array_ops.identity, 'Identity', False, 5000, False),
372        (nn_ops.relu6, 'Relu6', True, None, False),
373        (nn_ops.relu, 'Relu', True, None, False),
374        (array_ops.identity, 'Identity', True, None, False),
375        (nn_ops.relu6, 'Relu6', True, 5000, False),
376        (nn_ops.relu, 'Relu', True, 5000, False),
377        (array_ops.identity, 'Identity', True, 5000, False),
378        (nn_ops.relu6, 'Relu6', False, None, True),
379        (nn_ops.relu, 'Relu', False, None, True),
380        (array_ops.identity, 'Identity', False, None, True),
381        (nn_ops.relu6, 'Relu6', False, 5000, True),
382        (nn_ops.relu, 'Relu', False, 5000, True),
383        (array_ops.identity, 'Identity', False, 5000, True),
384        (nn_ops.relu6, 'Relu6', True, None, True),
385        (nn_ops.relu, 'Relu', True, None, True),
386        (array_ops.identity, 'Identity', True, None, True),
387        (nn_ops.relu6, 'Relu6', True, 5000, True),
388        (nn_ops.relu, 'Relu', True, 5000, True),
389        (array_ops.identity, 'Identity', True, 5000, True)
390    ]
391    for params in parameters_list:
392      # Test everything with resource variables and normal variables.
393      test_fn(params[0], params[1], params[2], params[3], params[4], False,
394              None)
395      test_fn(params[0], params[1], params[2], params[3], params[4], True, None)
396      test_fn(params[0], params[1], params[2], params[3], params[4], False,
397              'test')
398      test_fn(params[0], params[1], params[2], params[3], params[4], True,
399              'test')
400
401  def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer,
402                                                activation_op_name, with_bypass,
403                                                delay, use_resource):
404    quantization_node_name = 'FakeQuantWithMinMaxVars'
405    conv_scope = self._GetConvScope(scope, with_bypass)
406    delim = '/' if conv_scope else ''
407
408    if scope:
409      scope = scope + '/'
410
411    weights_quant = graph.get_operation_by_name(
412        conv_scope + delim + 'weights_quant/' + quantization_node_name)
413
414    self.assertEqual(weights_quant.type, quantization_node_name)
415    if use_resource:
416      expected_inputs = [
417          conv_scope + delim +
418          'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
419          conv_scope + delim +
420          'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
421      ]
422    else:
423      expected_inputs = [
424          conv_scope + delim + 'weights_quant/' + 'AssignMinLast',
425          conv_scope + delim + 'weights_quant/' + 'AssignMaxLast'
426      ]
427    expected_inputs.append(conv_scope + delim + 'mul_fold')
428
429    self._AssertInputOpsAre(weights_quant, expected_inputs)
430    if layer == 'DepthwiseConv2dNative':
431      output_op_name = conv_scope + delim + (
432          'weights_quant/delayed_quant/Switch_1' if delay else 'depthwise_Fold')
433    else:
434      output_op_name = conv_scope + delim + (
435          'weights_quant/delayed_quant/Switch_1' if delay else layer + '_Fold')
436    self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
437
438    if with_bypass:
439      conv_quant = graph.get_operation_by_name(
440          conv_scope + delim + 'conv_quant/' + quantization_node_name)
441      self.assertEqual(conv_quant.type, quantization_node_name)
442
443      if use_resource:
444        expected_inputs = [
445            conv_scope + delim +
446            'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
447            conv_scope + delim +
448            'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
449        ]
450      else:
451        expected_inputs = [
452            conv_scope + delim + 'conv_quant/AssignMinEma',
453            conv_scope + delim + 'conv_quant/AssignMaxEma',
454        ]
455      expected_inputs.append(conv_scope + delim + 'add_fold')
456
457      self._AssertInputOpsAre(conv_quant, expected_inputs)
458      output_op_name = (
459          conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
460          if delay else scope + 'Add')
461      self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
462
463    act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
464                                            quantization_node_name)
465    self.assertEqual(act_quant.type, quantization_node_name)
466
467    if use_resource:
468      expected_inputs = [
469          scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
470          scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
471      ]
472    else:
473      expected_inputs = [
474          scope + 'act_quant/AssignMinEma',
475          scope + 'act_quant/AssignMaxEma',
476      ]
477    expected_inputs.append(scope + activation_op_name)
478
479    self._AssertInputOpsAre(act_quant, expected_inputs)
480    output_op_name = (
481        scope + 'act_quant/delayed_quant/Switch_1'
482        if delay else 'control_dependency')
483    self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
484    self._AssertIdempotent(graph)
485
486  def testQuantize_Conv2dWithBatchNorm(self):
487    self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
488
489  def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
490                                        with_bypass, delay, fused_batch_norm,
491                                        use_resource, scope):
492    """Tests quantization: inputs -> Conv2d with batch norm -> Activation.
493
494    Args:
495      activation: Callable that returns an Operation, a factory method for the
496        Activation.
497      activation_op_name: String, name of the Activation operation.
498      with_bypass: Bool, when true there is an extra connection added from
499        inputs to just before Activation.
500      delay: Int (optional), delay in number of steps until quantization starts.
501      fused_batch_norm: Bool, when true use FusedBatchNorm.
502      use_resource: Bool, when true uses resource variables.
503      scope: String, specifies top level scope for the graph
504    """
505    graph = ops.Graph()
506    with graph.as_default():
507      variable_scope.get_variable_scope().set_use_resource(use_resource)
508      batch_size, height, width, depth = 5, 128, 128, 3
509      inputs = array_ops.zeros((batch_size, height, width, depth))
510      stride = 1 if with_bypass else 2
511      out_depth = 3 if with_bypass else 32
512      conv_scope = self._GetConvScope(scope, with_bypass)
513      scope = '' if scope is None else scope
514      delim = '/' if scope else ''
515      node = conv2d(
516          inputs,
517          out_depth, [5, 5],
518          stride=stride,
519          padding='SAME',
520          weights_initializer=self._WeightInit(0.09),
521          activation_fn=None,
522          normalizer_fn=batch_norm,
523          normalizer_params=self._BatchNormParams(fused_batch_norm),
524          scope=conv_scope)
525
526      # Manually add a bypass (optional) and an activation.
527      if with_bypass:
528        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
529
530      node = activation(node, name=scope + delim + activation_op_name)
531
532      update_barrier = control_flow_ops.no_op(name='update_barrier')
533      with ops.control_dependencies([update_barrier]):
534        array_ops.identity(node, name='control_dependency')
535
536      fold_batch_norms.FoldBatchNorms(graph, is_training=True)
537      quantize.Quantize(graph, True, quant_delay=delay)
538
539      self._AssertCorrectQuantizedGraphWithBatchNorm(
540          graph, scope, 'Conv2D', activation_op_name, with_bypass, delay,
541          use_resource)
542
543  def testQuantize_FCWithBatchNorm(self):
544    self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
545
546  def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
547                                    with_bypass, delay, fused_batch_norm,
548                                    use_resource, scope):
549    """Tests quantization: inputs -> FC with batch norm -> Activation.
550
551    Args:
552      activation: Callable that returns an Operation, a factory method for the
553        Activation.
554      activation_op_name: String, name of the Activation operation.
555      with_bypass: Bool, when true there is an extra connection added from
556        inputs to just before Activation.
557      delay: Int (optional), delay in number of steps until quantization starts.
558      fused_batch_norm: Bool, when true use FusedBatchNorm.
559      use_resource: Bool, when true uses resource variables.
560      scope: String, specifies top level scope for the graph
561    """
562    graph = ops.Graph()
563    with graph.as_default():
564      variable_scope.get_variable_scope().set_use_resource(use_resource)
565      batch_size, depth = 5, 256
566      inputs = array_ops.zeros((batch_size, depth))
567      out_depth = 256 if with_bypass else 128
568      conv_scope = self._GetConvScope(scope, with_bypass)
569      scope = '' if scope is None else scope
570      delim = '/' if scope else ''
571      node = fully_connected(
572          inputs,
573          out_depth,
574          weights_initializer=self._WeightInit(0.03),
575          activation_fn=None,
576          normalizer_fn=batch_norm,
577          normalizer_params=self._BatchNormParams(fused_batch_norm),
578          scope=conv_scope)
579
580      # Manually add a bypass (optional) and an activation.
581      if with_bypass:
582        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
583
584      node = activation(node, name=scope + delim + activation_op_name)
585
586      update_barrier = control_flow_ops.no_op(name='update_barrier')
587      with ops.control_dependencies([update_barrier]):
588        array_ops.identity(node, name='control_dependency')
589
590      fold_batch_norms.FoldBatchNorms(graph, is_training=True)
591
592      quantize.Quantize(graph, True, quant_delay=delay)
593
594    self._AssertCorrectQuantizedGraphWithBatchNorm(
595        graph, scope, 'MatMul', activation_op_name, with_bypass, delay,
596        use_resource)
597
598  def testQuantize_DepthwiseConv2dWithBatchNorm(self):
599    self._RunBatchNormTestOverParameters(
600        self._TestQuantize_DepthwiseConv2dWithBatchNorm)
601
602  def _TestQuantize_DepthwiseConv2dWithBatchNorm(
603      self, activation, activation_op_name, with_bypass, delay,
604      fused_batch_norm, use_resource, scope):
605    """Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
606
607    Args:
608      activation: Callable that returns an Operation, a factory method for the
609        Activation.
610      activation_op_name: String, name of the Activation operation.
611      with_bypass: Bool, when true there is an extra connection added from
612        inputs to just before Activation.
613      delay: Int (optional), delay in number of steps until quantization starts.
614      fused_batch_norm: Bool, when true use FusedBatchNorm.
615      use_resource: Bool, when true uses resource variables.
616      scope: String, specifies top level scope for the graph
617    """
618    graph = ops.Graph()
619    with graph.as_default():
620      variable_scope.get_variable_scope().set_use_resource(use_resource)
621      batch_size, height, width, depth = 5, 128, 128, 3
622      inputs = array_ops.zeros((batch_size, height, width, depth))
623      stride = 1 if with_bypass else 2
624      conv_scope = self._GetConvScope(scope, with_bypass)
625      scope = '' if scope is None else scope
626      delim = '/' if scope else ''
627      node = separable_conv2d(
628          inputs,
629          None, [5, 5],
630          stride=stride,
631          depth_multiplier=1.0,
632          padding='SAME',
633          weights_initializer=self._WeightInit(0.09),
634          activation_fn=None,
635          normalizer_fn=batch_norm,
636          normalizer_params=self._BatchNormParams(fused_batch_norm),
637          scope=conv_scope)
638
639      # Manually add a bypass (optional) and an activation.
640      if with_bypass:
641        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
642
643      node = activation(node, name=scope + delim + activation_op_name)
644
645      update_barrier = control_flow_ops.no_op(name='update_barrier')
646      with ops.control_dependencies([update_barrier]):
647        array_ops.identity(node, name='control_dependency')
648
649      fold_batch_norms.FoldBatchNorms(graph, is_training=True)
650      quantize.Quantize(graph, True, quant_delay=delay)
651
652      self._AssertCorrectQuantizedGraphWithBatchNorm(
653          graph, scope, 'DepthwiseConv2dNative', activation_op_name,
654          with_bypass, delay, use_resource)
655
656  def testQuantize_AtrousConvWithBatchNorm(self):
657    self._RunBatchNormTestOverParameters(
658        self._TestQuantize_AtrousConvWithBatchNorm)
659
660  def _TestQuantize_AtrousConvWithBatchNorm(
661      self, activation, activation_op_name, with_bypass, delay,
662      fused_batch_norm, use_resource, scope):
663    """Tests quantization: inputs -> atrous conv with batch norm -> Activation.
664
665    Args:
666      activation: Callable that returns an Operation, a factory method for the
667        Activation.
668      activation_op_name: String, name of the Activation operation.
669      with_bypass: Bool, when true there is an extra connection added from
670        inputs to just before Activation.
671      delay: Int (optional), delay in number of steps until quantization starts.
672      fused_batch_norm: Bool, when true use FusedBatchNorm.
673      use_resource: Bool, when true uses resource variables.
674      scope: String, specifies top level scope for the graph
675    """
676    graph = ops.Graph()
677    with graph.as_default():
678      variable_scope.get_variable_scope().set_use_resource(use_resource)
679      batch_size, height, width, depth = 5, 128, 128, 3
680      inputs = array_ops.zeros((batch_size, height, width, depth))
681      dilation_rate = 2
682      conv_scope = self._GetConvScope(scope, with_bypass)
683      scope = '' if scope is None else scope
684      delim = '/' if scope else ''
685
686      node = separable_conv2d(
687          inputs,
688          None, [3, 3],
689          rate=dilation_rate,
690          depth_multiplier=1.0,
691          padding='SAME',
692          weights_initializer=self._WeightInit(0.09),
693          activation_fn=None,
694          normalizer_fn=batch_norm,
695          normalizer_params=self._BatchNormParams(fused_batch_norm),
696          scope=conv_scope)
697
698      # Manually add a bypass (optional) and an activation.
699      if with_bypass:
700        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
701
702      node = activation(node, name=scope + delim + activation_op_name)
703
704      update_barrier = control_flow_ops.no_op(name='update_barrier')
705      with ops.control_dependencies([update_barrier]):
706        array_ops.identity(node, name='control_dependency')
707
708      fold_batch_norms.FoldBatchNorms(graph, is_training=True)
709      quantize.Quantize(graph, True, quant_delay=delay)
710
711      self._AssertCorrectQuantizedGraphWithBatchNorm(
712          graph, scope, 'DepthwiseConv2dNative', activation_op_name,
713          with_bypass, delay, use_resource)
714
715  def _AssertIdempotent(self, graph):
716    # Ensure that calling the rewrite again doesn't change the graph.
717    graph_def_before = str(graph.as_graph_def())
718    with graph.as_default():
719      # Ensuring that calling the rewrite again doesn't add more nodes.
720      fold_batch_norms.FoldBatchNorms(graph, is_training=True)
721      quantize.Quantize(graph, True)
722    graph_def_after = str(graph.as_graph_def())
723    self.assertEqual(graph_def_before, graph_def_after)
724
725  def testBatchNormForcedUpdates(self):
726    parameter_list = [
727        # (activation, activation_op_name, fused_batch_norm)
728        (nn_ops.relu6, 'Relu6', False),
729        (nn_ops.relu, 'Relu', False),
730        (array_ops.identity, 'Identity', False),
731        (nn_ops.relu6, 'Relu6', True),
732        (nn_ops.relu, 'Relu', True),
733        (array_ops.identity, 'Identity', True),
734    ]
735    for params in parameter_list:
736      self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False)
737      self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True)
738
739  def _TestBatchNormForcedUpdates(self, activation, activation_op_name,
740                                  fused_batch_norm, use_resource):
741    """post_activation bypass quantization should happen with forced updates."""
742    graph = ops.Graph()
743    with graph.as_default():
744      variable_scope.get_variable_scope().set_use_resource(use_resource)
745      batch_size, height, width, depth = 5, 128, 128, 3
746      input1 = array_ops.zeros((batch_size, height, width, depth))
747      input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
748      # Setting updates_collections to None forces updates adding an extra
749      # identity operation following batch norms.
750      bn_params = self._BatchNormParams(
751          fused=fused_batch_norm, force_updates=True)
752      conv = conv2d(
753          input1,
754          32, [5, 5],
755          stride=2,
756          padding='SAME',
757          weights_initializer=self._WeightInit(0.09),
758          activation_fn=activation,
759          normalizer_fn=batch_norm,
760          normalizer_params=bn_params,
761          scope='test/test')
762      bypass_tensor = math_ops.add(conv, input2, name='test/add')
763      # The output of the post_activation bypass will be another layer.
764      _ = conv2d(
765          bypass_tensor,
766          32, [5, 5],
767          stride=2,
768          padding='SAME',
769          weights_initializer=self._WeightInit(0.09),
770          normalizer_fn=batch_norm,
771          normalizer_params=bn_params,
772          activation_fn=activation,
773          scope='test/unused')
774
775      fold_batch_norms.FoldBatchNorms(graph, is_training=True)
776      quantize.Quantize(graph, is_training=True)
777
778      # Ensure that the bypass node is preceded by and followed by a
779      # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an
780      # activation.
781      self.assertTrue('FakeQuantWithMinMaxVars' in
782                      [c.type for c in bypass_tensor.consumers()])
783      self.assertTrue('FakeQuantWithMinMaxVars' in
784                      [i.op.type for i in bypass_tensor.op.inputs])
785
786    with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
787      f.write(str(graph.as_graph_def()))
788
789  def _GetConvScope(self, scope, with_bypass):
790    if scope is None:
791      scope = ''
792    delim = '/' if scope else ''
793
794    if with_bypass:
795      conv_scope = scope + delim + 'test2'
796    else:
797      conv_scope = scope
798
799    return conv_scope
800
801  def _BatchNormParams(self, fused=False, force_updates=False):
802    params = {
803        'center': True,
804        'scale': True,
805        'decay': 1.0 - 0.003,
806        'fused': fused
807    }
808    if force_updates:
809      params['updates_collections'] = None
810    return params
811
812  def _WeightInit(self, stddev):
813    """Returns truncated normal variable initializer.
814
815    Function is defined purely to shorten the name so that it stops wrapping.
816
817    Args:
818      stddev: Standard deviation of normal variable.
819
820    Returns:
821      An initialized that initializes with a truncated normal variable.
822    """
823    return init_ops.truncated_normal_initializer(stddev=stddev)
824
825  def _AssertInputOpsAre(self, op, in_op_names):
826    """Asserts that all inputs to op come from in_op_names (disregarding order).
827
828    Args:
829      op: Operation to check inputs for.
830      in_op_names: List of strings, operations where all op's inputs should
831        come from.
832    """
833    expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names]
834    self.assertItemsEqual([t.name for t in op.inputs], expected_inputs)
835
836  def _AssertOutputGoesToOps(self, op, graph, out_op_names):
837    """Asserts that outputs from op go to out_op_names (and perhaps others).
838
839    Args:
840      op: Operation to check outputs for.
841      graph: Graph where output operations are located.
842      out_op_names: List of strings, operations where op's outputs should go.
843    """
844    for out_op_name in out_op_names:
845      out_op = graph.get_operation_by_name(out_op_name)
846      self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs])
847
848
849if __name__ == '__main__':
850  googletest.main()
851