• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for folding batch norm layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.layers.python.layers import layers
22from tensorflow.contrib.quantize.python import fold_batch_norms
23from tensorflow.python.client import session
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import random_seed
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gradients
30from tensorflow.python.ops import init_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import nn_ops
33from tensorflow.python.ops import random_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import googletest
37from tensorflow.python.training import saver as saver_lib
38
39batch_norm = layers.batch_norm
40conv2d = layers.conv2d
41fully_connected = layers.fully_connected
42separable_conv2d = layers.separable_conv2d
43
44
45# TODO(suharshs): Use parameterized test once OSS TF supports it.
46class FoldBatchNormsTest(test_util.TensorFlowTestCase):
47
48  def _RunTestOverParameters(self, test_fn):
49    parameters_list = [
50        # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm,
51        # freeze_batch_norm_delay, insert identity node)
52        (nn_ops.relu6, 'Relu6', False, False, False, 100, False),
53        (nn_ops.relu, 'Relu', False, False, False, None, False),
54        (nn_ops.relu6, 'Relu6', True, False, False, 100, False),
55        (nn_ops.relu, 'Relu', True, False, False, None, False),
56        (nn_ops.relu6, 'Relu6', False, True, False, 100, False),
57        (nn_ops.relu, 'Relu', False, True, False, None, False),
58        (nn_ops.relu6, 'Relu6', True, True, False, 100, False),
59        (nn_ops.relu, 'Relu', True, True, False, None, False),
60        # Fused batch norm always has scaling enabled.
61        (nn_ops.relu6, 'Relu6', False, True, True, None, False),
62        (nn_ops.relu, 'Relu', False, True, True, 100, False),
63        (nn_ops.relu6, 'Relu6', True, True, True, None, False),
64        (nn_ops.relu, 'Relu', True, True, True, 100, False),
65        (nn_ops.relu6, 'Relu6', False, True, True, None, True),
66        (nn_ops.relu, 'Relu', False, True, True, 100, True),
67        (nn_ops.relu6, 'Relu6', True, True, True, None, True),
68        (nn_ops.relu, 'Relu', True, True, True, 100, True),
69    ]
70    for params in parameters_list:
71      test_fn(params[0], params[1], params[2], params[3], params[4], params[5],
72              params[6])
73
74  def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling,
75                      fused_batch_norm, freeze_batch_norm_delay,
76                      insert_identity_node):
77    """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
78
79    Args:
80      relu: Callable that returns an Operation, a factory method for the Relu*.
81      relu_op_name: String, name of the Relu* operation.
82      with_bypass: Bool, when true there is an extra connection added from
83        inputs to just before Relu*.
84      has_scaling: Bool, when true the batch norm has scaling.
85      fused_batch_norm: Bool, when true the batch norm is fused.
86      freeze_batch_norm_delay: None or the number of steps after which training
87      switches to using frozen mean and variance
88      insert_identity_node: Bool, insert identity node between conv and batch
89      norm
90    """
91    g = ops.Graph()
92    with g.as_default():
93      batch_size, height, width = 5, 128, 128
94      inputs = array_ops.zeros((batch_size, height, width, 3))
95      out_depth = 3 if with_bypass else 32
96      stride = 1 if with_bypass else 2
97      activation_fn = None if with_bypass else relu
98      name = 'test/test2' if with_bypass else 'test'
99      if insert_identity_node:
100        with g.name_scope(name):
101          node = conv2d(
102              inputs,
103              out_depth, [5, 5],
104              stride=stride,
105              padding='SAME',
106              weights_initializer=self._WeightInit(0.09),
107              activation_fn=None,
108              normalizer_fn=None,
109              biases_initializer=None)
110          conv_out = array_ops.identity(node, name='conv_out')
111
112          node = batch_norm(
113              conv_out,
114              center=True,
115              scale=has_scaling,
116              decay=1.0 - 0.003,
117              fused=fused_batch_norm)
118          if activation_fn is not None:
119            node = activation_fn(node)
120          conv_name = name + '/Conv'
121      else:
122        node = conv2d(
123            inputs,
124            out_depth, [5, 5],
125            stride=stride,
126            padding='SAME',
127            weights_initializer=self._WeightInit(0.09),
128            activation_fn=activation_fn,
129            normalizer_fn=batch_norm,
130            normalizer_params=self._BatchNormParams(
131                scale=has_scaling, fused=fused_batch_norm),
132            scope=name)
133        conv_name = name
134      if with_bypass:
135        node = math_ops.add(inputs, node, name='test/Add')
136        relu(node, name='test/' + relu_op_name)
137
138      fold_batch_norms.FoldBatchNorms(
139          g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
140
141    folded_mul = g.get_operation_by_name(conv_name + '/mul_fold')
142    self.assertEqual(folded_mul.type, 'Mul')
143    self._AssertInputOpsAre(folded_mul, [
144        conv_name + '/correction_mult',
145        self._BatchNormMultiplierName(conv_name, has_scaling, fused_batch_norm)
146    ])
147    self._AssertOutputGoesToOps(folded_mul, g, [conv_name + '/Conv2D_Fold'])
148
149    folded_conv = g.get_operation_by_name(conv_name + '/Conv2D_Fold')
150    self.assertEqual(folded_conv.type, 'Conv2D')
151    self._AssertInputOpsAre(folded_conv,
152                            [conv_name + '/mul_fold', inputs.op.name])
153    self._AssertOutputGoesToOps(folded_conv, g, [conv_name + '/post_conv_mul'])
154
155    folded_add = g.get_operation_by_name(conv_name + '/add_fold')
156    self.assertEqual(folded_add.type, 'Add')
157    self._AssertInputOpsAre(folded_add, [
158        conv_name + '/correction_add',
159        self._BathNormBiasName(conv_name, fused_batch_norm)
160    ])
161    output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
162    self._AssertOutputGoesToOps(folded_add, g, output_op_names)
163    if freeze_batch_norm_delay is not None:
164      self._AssertMovingAveragesAreFrozen(g, name)
165
166    for op in g.get_operations():
167      self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
168
169  def testFoldConv2d(self):
170    self._RunTestOverParameters(self._TestFoldConv2d)
171
172  def testMultipleLayerConv2d(self,
173                              relu=nn_ops.relu,
174                              relu_op_name='Relu',
175                              has_scaling=True,
176                              fused_batch_norm=False,
177                              freeze_batch_norm_delay=None,
178                              insert_identity_node=False):
179    """Tests folding cases for a network with multiple layers.
180
181    Args:
182      relu: Callable that returns an Operation, a factory method for the Relu*.
183      relu_op_name: String, name of the Relu* operation.
184      has_scaling: Bool, when true the batch norm has scaling.
185      fused_batch_norm: Bool, when true the batch norm is fused.
186      freeze_batch_norm_delay: None or the number of steps after which training
187      switches to using frozen mean and variance
188      insert_identity_node: Bool, insert identity node between conv and batch
189      norm
190    """
191    g = ops.Graph()
192    with g.as_default():
193      batch_size, height, width = 5, 128, 128
194      inputs = array_ops.zeros((batch_size, height, width, 3))
195      out_depth = 3
196      stride = 1
197      activation_fn = relu
198      scope = 'topnet/testnet'
199      with variable_scope.variable_scope(scope, [inputs]):
200        layer1 = conv2d(
201            inputs,
202            out_depth, [5, 5],
203            stride=stride,
204            padding='SAME',
205            weights_initializer=self._WeightInit(0.09),
206            activation_fn=None,
207            normalizer_fn=None,
208            scope='testnet/layer1')
209        # Add bn and relu with different scope
210        layer1 = batch_norm(
211            layer1, scale=has_scaling, fused=fused_batch_norm, scope='layer1')
212        layer1 = activation_fn(layer1)
213        layer2 = conv2d(
214            layer1,
215            2 * out_depth, [5, 5],
216            stride=stride,
217            padding='SAME',
218            weights_initializer=self._WeightInit(0.09),
219            activation_fn=activation_fn,
220            normalizer_fn=batch_norm,
221            normalizer_params=self._BatchNormParams(
222                scale=has_scaling, fused=fused_batch_norm),
223            scope='testnet/layer2')
224        # Add bn and relu with different scope
225        layer2 = batch_norm(
226            layer2, scale=has_scaling, fused=fused_batch_norm, scope='layer2')
227        _ = activation_fn(layer2)
228
229      scope = 'topnet/testnet/testnet/layer2'
230
231      fold_batch_norms.FoldBatchNorms(
232          g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
233    folded_mul = g.get_operation_by_name(scope + '/mul_fold')
234    self.assertEqual(folded_mul.type, 'Mul')
235    self._AssertInputOpsAre(folded_mul, [
236        scope + '/correction_mult',
237        self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
238    ])
239    self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold'])
240
241    folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold')
242    self.assertEqual(folded_conv.type, 'Conv2D')
243    # Remove :0 at end of name for tensor prior to comparison
244    self._AssertInputOpsAre(folded_conv,
245                            [scope + '/mul_fold', layer1.name[:-2]])
246    self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul'])
247
248    folded_add = g.get_operation_by_name(scope + '/add_fold')
249    self.assertEqual(folded_add.type, 'Add')
250    self._AssertInputOpsAre(folded_add, [
251        scope + '/correction_add',
252        self._BathNormBiasName(scope, fused_batch_norm)
253    ])
254    output_op_names = [scope + '/' + relu_op_name]
255    self._AssertOutputGoesToOps(folded_add, g, output_op_names)
256    if freeze_batch_norm_delay is not None:
257      self._AssertMovingAveragesAreFrozen(g, scope)
258
259    for op in g.get_operations():
260      self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
261
262  def _TestFoldConv2dUnknownShape(self,
263                                  relu,
264                                  relu_op_name,
265                                  with_bypass,
266                                  has_scaling,
267                                  fused_batch_norm,
268                                  freeze_batch_norm_delay,
269                                  insert_identity_node=False):
270    """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
271
272    Tests that folding works even with an input shape where some dimensions are
273    not known (i.e. None).
274
275    Args:
276      relu: Callable that returns an Operation, a factory method for the Relu*.
277      relu_op_name: String, name of the Relu* operation.
278      with_bypass: Bool, when true there is an extra connection added from
279        inputs to just before Relu*.
280      has_scaling: Bool, when true the batch norm has scaling.
281      fused_batch_norm: Bool, when true the batch norm is fused.
282      freeze_batch_norm_delay: None or the number of steps after which training
283      switches to using frozen mean and variance
284      insert_identity_node: Bool, insert identity node between conv and batch
285      norm
286    """
287    g = ops.Graph()
288    with g.as_default():
289      inputs = array_ops.placeholder(dtypes.float32, shape=(5, None, None, 3))
290      out_depth = 3 if with_bypass else 32
291      stride = 1 if with_bypass else 2
292      activation_fn = None if with_bypass else relu
293      scope = 'test/test2' if with_bypass else 'test'
294      node = conv2d(
295          inputs,
296          out_depth, [5, 5],
297          stride=stride,
298          padding='SAME',
299          weights_initializer=self._WeightInit(0.09),
300          activation_fn=activation_fn,
301          normalizer_fn=batch_norm,
302          normalizer_params=self._BatchNormParams(
303              scale=has_scaling, fused=fused_batch_norm),
304          scope=scope)
305      if with_bypass:
306        node = math_ops.add(inputs, node, name='test/Add')
307        relu(node, name='test/' + relu_op_name)
308
309      fold_batch_norms.FoldBatchNorms(
310          g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
311
312    folded_mul = g.get_operation_by_name(scope + '/mul_fold')
313    self.assertEqual(folded_mul.type, 'Mul')
314    self._AssertInputOpsAre(folded_mul, [
315        scope + '/correction_mult',
316        self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
317    ])
318    self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold'])
319
320    folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold')
321    self.assertEqual(folded_conv.type, 'Conv2D')
322    self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name])
323    self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul'])
324
325    folded_add = g.get_operation_by_name(scope + '/add_fold')
326    self.assertEqual(folded_add.type, 'Add')
327    self._AssertInputOpsAre(folded_add, [
328        scope + '/correction_add',
329        self._BathNormBiasName(scope, fused_batch_norm)
330    ])
331    output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
332    self._AssertOutputGoesToOps(folded_add, g, output_op_names)
333    if freeze_batch_norm_delay is not None:
334      self._AssertMovingAveragesAreFrozen(g, scope)
335
336    for op in g.get_operations():
337      self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
338
339  def testFoldConv2dUnknownShape(self):
340    self._RunTestOverParameters(self._TestFoldConv2dUnknownShape)
341
342  def _TestFoldFullyConnectedLayer(
343      self, relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm,
344      freeze_batch_norm_delay, insert_identity_node):
345    """Tests folding cases: inputs -> FC with batch norm -> Relu*.
346
347    Args:
348      relu: Callable that returns an Operation, a factory method for the Relu*.
349      relu_op_name: String, name of the Relu* operation.
350      with_bypass: Bool, when true there is an extra connection added from
351        inputs to just before Relu*.
352      has_scaling: Bool, when true the batch norm has scaling.
353      fused_batch_norm: Bool, when true the batch norm is fused.
354      freeze_batch_norm_delay: None or the number of steps after which training
355      switches to using frozen mean and variance
356      insert_identity_node: Bool, insert identity node between conv and batch
357      norm
358    """
359    g = ops.Graph()
360    with g.as_default():
361      batch_size, depth = 5, 256
362      inputs = array_ops.zeros((batch_size, depth))
363      out_depth = 256 if with_bypass else 128
364      activation_fn = None if with_bypass else relu
365      name = 'test/test2' if with_bypass else 'test'
366      insert_identity_node = fused_batch_norm
367      if insert_identity_node:
368        with g.name_scope(name):
369          node = fully_connected(
370              inputs,
371              out_depth,
372              weights_initializer=self._WeightInit(0.03),
373              activation_fn=None,
374              normalizer_fn=None,
375              biases_initializer=None)
376          node = array_ops.identity(node, name='fc_out')
377
378          node = batch_norm(
379              node,
380              center=True,
381              scale=has_scaling,
382              decay=1.0 - 0.003,
383              fused=fused_batch_norm)
384          if activation_fn is not None:
385            node = activation_fn(node)
386          fc_name = name + '/fully_connected'
387      else:
388
389        node = fully_connected(
390            inputs,
391            out_depth,
392            weights_initializer=self._WeightInit(0.03),
393            activation_fn=activation_fn,
394            normalizer_fn=batch_norm,
395            normalizer_params=self._BatchNormParams(
396                scale=has_scaling, fused=fused_batch_norm),
397            scope=name)
398        fc_name = name
399      if with_bypass:
400        node = math_ops.add(inputs, node, name='test/Add')
401        relu(node, name='test/' + relu_op_name)
402
403      fold_batch_norms.FoldBatchNorms(
404          g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
405
406    folded_mul = g.get_operation_by_name(fc_name + '/mul_fold')
407    self.assertEqual(folded_mul.type, 'Mul')
408    self._AssertInputOpsAre(folded_mul, [
409        fc_name + '/correction_mult',
410        self._BatchNormMultiplierName(fc_name, has_scaling, fused_batch_norm)
411    ])
412    self._AssertOutputGoesToOps(folded_mul, g, [fc_name + '/MatMul_Fold'])
413
414    folded_conv = g.get_operation_by_name(fc_name + '/MatMul_Fold')
415    self.assertEqual(folded_conv.type, 'MatMul')
416    self._AssertInputOpsAre(folded_conv,
417                            [fc_name + '/mul_fold', inputs.op.name])
418    self._AssertOutputGoesToOps(folded_conv, g, [fc_name + '/post_conv_mul'])
419
420    folded_add = g.get_operation_by_name(fc_name + '/add_fold')
421    self.assertEqual(folded_add.type, 'Add')
422    self._AssertInputOpsAre(folded_add, [
423        fc_name + '/correction_add',
424        self._BathNormBiasName(fc_name, fused_batch_norm)
425    ])
426    output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
427    self._AssertOutputGoesToOps(folded_add, g, output_op_names)
428    if freeze_batch_norm_delay is not None:
429      self._AssertMovingAveragesAreFrozen(g, name)
430
431    for op in g.get_operations():
432      self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
433
434  def testFoldFullyConnectedLayer(self):
435    self._RunTestOverParameters(self._TestFoldFullyConnectedLayer)
436
437  def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass,
438                               has_scaling, fused_batch_norm,
439                               freeze_batch_norm_delay, insert_identity_node):
440    """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*.
441
442    Args:
443      relu: Callable that returns an Operation, a factory method for the Relu*.
444      relu_op_name: String, name of the Relu* operation.
445      with_bypass: Bool, when true there is an extra connection added from
446        inputs to just before Relu*.
447      has_scaling: Bool, when true the batch norm has scaling.
448      fused_batch_norm: Bool, when true the batch norm is fused.
449      freeze_batch_norm_delay: None or the number of steps after which training
450      insert_identity_node: Bool, insert identity node between conv and batch
451        norm switches to using frozen mean and variance
452    """
453    g = ops.Graph()
454    with g.as_default():
455      batch_size, height, width = 5, 128, 128
456      inputs = array_ops.zeros((batch_size, height, width, 3))
457      stride = 1 if with_bypass else 2
458      activation_fn = None if with_bypass else relu
459      name = 'test/test2' if with_bypass else 'test'
460      if insert_identity_node:
461        with g.name_scope(name):
462          node = separable_conv2d(
463              inputs,
464              None, [5, 5],
465              stride=stride,
466              depth_multiplier=1.0,
467              padding='SAME',
468              weights_initializer=self._WeightInit(0.09),
469              activation_fn=None,
470              normalizer_fn=None,
471              biases_initializer=None)
472          node = array_ops.identity(node, name='sep_conv_out')
473
474          node = batch_norm(
475              node,
476              center=True,
477              scale=has_scaling,
478              decay=1.0 - 0.003,
479              fused=fused_batch_norm)
480          if activation_fn is not None:
481            node = activation_fn(node)
482          sep_conv_name = name + '/SeparableConv2d'
483      else:
484        node = separable_conv2d(
485            inputs,
486            None, [5, 5],
487            stride=stride,
488            depth_multiplier=1.0,
489            padding='SAME',
490            weights_initializer=self._WeightInit(0.09),
491            activation_fn=activation_fn,
492            normalizer_fn=batch_norm,
493            normalizer_params=self._BatchNormParams(
494                scale=has_scaling, fused=fused_batch_norm),
495            scope=name)
496        sep_conv_name = name
497      if with_bypass:
498        node = math_ops.add(inputs, node, name='test/Add')
499        relu(node, name='test/' + relu_op_name)
500
501      fold_batch_norms.FoldBatchNorms(
502          g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
503
504    folded_mul = g.get_operation_by_name(sep_conv_name + '/mul_fold')
505    self.assertEqual(folded_mul.type, 'Mul')
506    if fused_batch_norm:
507      scale_reshape_op_name = sep_conv_name + '/BatchNorm_Fold/scale_reshape'
508    else:
509      scale_reshape_op_name = sep_conv_name + '/scale_reshape'
510    self._AssertInputOpsAre(
511        folded_mul, [sep_conv_name + '/correction_mult', scale_reshape_op_name])
512    self._AssertOutputGoesToOps(folded_mul, g,
513                                [sep_conv_name + '/depthwise_Fold'])
514
515    scale_reshape = g.get_operation_by_name(scale_reshape_op_name)
516    self.assertEqual(scale_reshape.type, 'Reshape')
517    self._AssertInputOpsAre(scale_reshape, [
518        self._BatchNormMultiplierName(sep_conv_name, has_scaling,
519                                      fused_batch_norm),
520        scale_reshape_op_name + '/shape'
521    ])
522    self._AssertOutputGoesToOps(scale_reshape, g, [sep_conv_name + '/mul_fold'])
523
524    folded_conv = g.get_operation_by_name(sep_conv_name + '/depthwise_Fold')
525    self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative')
526    self._AssertInputOpsAre(folded_conv,
527                            [sep_conv_name + '/mul_fold', inputs.op.name])
528    self._AssertOutputGoesToOps(folded_conv, g,
529                                [sep_conv_name + '/post_conv_mul'])
530
531    folded_add = g.get_operation_by_name(sep_conv_name + '/add_fold')
532    self.assertEqual(folded_add.type, 'Add')
533    self._AssertInputOpsAre(folded_add, [
534        sep_conv_name + '/correction_add',
535        self._BathNormBiasName(sep_conv_name, fused_batch_norm)
536    ])
537    output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
538    self._AssertOutputGoesToOps(folded_add, g, output_op_names)
539    if freeze_batch_norm_delay is not None:
540      self._AssertMovingAveragesAreFrozen(g, name)
541
542    for op in g.get_operations():
543      self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
544
545  def testFoldDepthwiseConv2d(self):
546    self._RunTestOverParameters(self._TestFoldDepthwiseConv2d)
547
548  def _TestFoldAtrousConv2d(self, relu, relu_op_name, with_bypass, has_scaling,
549                            fused_batch_norm, freeze_batch_norm_delay,
550                            insert_identity_node):
551    """Tests folding: inputs -> AtrousConv2d with batch norm -> Relu*.
552
553    Args:
554      relu: Callable that returns an Operation, a factory method for the Relu*.
555      relu_op_name: String, name of the Relu* operation.
556      with_bypass: Bool, when true there is an extra connection added from
557        inputs to just before Relu*.
558      has_scaling: Bool, when true the batch norm has scaling.
559      fused_batch_norm: Bool, when true the batch norm is fused.
560      freeze_batch_norm_delay: None or the number of steps after which training
561        switches to using frozen mean and variance
562      insert_identity_node: Bool, insert identity node between conv and batch
563        norm
564    """
565    g = ops.Graph()
566    with g.as_default():
567      batch_size, height, width = 5, 128, 128
568      inputs = array_ops.zeros((batch_size, height, width, 3))
569      dilation_rate = 2
570      activation_fn = None if with_bypass else relu
571      name = 'test/test2' if with_bypass else 'test'
572      if insert_identity_node:
573        with g.name_scope(name):
574          node = separable_conv2d(
575              inputs,
576              None, [3, 3],
577              rate=dilation_rate,
578              depth_multiplier=1.0,
579              padding='SAME',
580              weights_initializer=self._WeightInit(0.09),
581              activation_fn=None,
582              normalizer_fn=None,
583              biases_initializer=None)
584          node = array_ops.identity(node, name='sep_conv_out')
585
586          node = batch_norm(
587              node,
588              center=True,
589              scale=has_scaling,
590              decay=1.0 - 0.003,
591              fused=fused_batch_norm)
592          if activation_fn is not None:
593            node = activation_fn(node)
594          sep_conv_name = name + '/SeparableConv2d'
595      else:
596        node = separable_conv2d(
597            inputs,
598            None, [3, 3],
599            rate=dilation_rate,
600            depth_multiplier=1.0,
601            padding='SAME',
602            weights_initializer=self._WeightInit(0.09),
603            activation_fn=activation_fn,
604            normalizer_fn=batch_norm,
605            normalizer_params=self._BatchNormParams(
606                scale=has_scaling, fused=fused_batch_norm),
607            scope=name)
608        sep_conv_name = name
609      if with_bypass:
610        node = math_ops.add(inputs, node, name='test/Add')
611        relu(node, name='test/' + relu_op_name)
612
613      fold_batch_norms.FoldBatchNorms(
614          g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
615
616    folded_mul = g.get_operation_by_name(sep_conv_name + '/mul_fold')
617    self.assertEqual(folded_mul.type, 'Mul')
618    if fused_batch_norm:
619      scale_reshape_op_name = sep_conv_name + '/BatchNorm_Fold/scale_reshape'
620    else:
621      scale_reshape_op_name = sep_conv_name + '/scale_reshape'
622    self._AssertInputOpsAre(
623        folded_mul, [sep_conv_name + '/correction_mult', scale_reshape_op_name])
624    self._AssertOutputGoesToOps(folded_mul, g,
625                                [sep_conv_name + '/depthwise_Fold'])
626
627    scale_reshape = g.get_operation_by_name(scale_reshape_op_name)
628    self.assertEqual(scale_reshape.type, 'Reshape')
629    self._AssertInputOpsAre(scale_reshape, [
630        self._BatchNormMultiplierName(sep_conv_name, has_scaling,
631                                      fused_batch_norm),
632        scale_reshape_op_name + '/shape'
633    ])
634    self._AssertOutputGoesToOps(scale_reshape, g, [sep_conv_name + '/mul_fold'])
635
636    folded_conv = g.get_operation_by_name(sep_conv_name + '/depthwise_Fold')
637    self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative')
638    self._AssertInputOpsAre(folded_conv, [
639        sep_conv_name + '/mul_fold', sep_conv_name + '/depthwise/SpaceToBatchND'
640    ])
641    if fused_batch_norm:
642      self._AssertOutputGoesToOps(folded_conv, g,
643                                  [sep_conv_name + '/BatchToSpaceND_Fold'])
644    else:
645      self._AssertOutputGoesToOps(
646          folded_conv, g, [sep_conv_name + '/depthwise/BatchToSpaceND_Fold'])
647
648    folded_add = g.get_operation_by_name(sep_conv_name + '/add_fold')
649    self.assertEqual(folded_add.type, 'Add')
650    self._AssertInputOpsAre(folded_add, [
651        sep_conv_name + '/correction_add',
652        self._BathNormBiasName(sep_conv_name, fused_batch_norm)
653    ])
654    output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
655    self._AssertOutputGoesToOps(folded_add, g, output_op_names)
656    if freeze_batch_norm_delay is not None:
657      self._AssertMovingAveragesAreFrozen(g, name)
658
659    for op in g.get_operations():
660      self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
661
662  def testFoldAtrousConv2d(self):
663    self._RunTestOverParameters(self._TestFoldAtrousConv2d)
664
665  def _TestCompareFoldAndUnfolded(self,
666                                  relu,
667                                  relu_op_name,
668                                  with_bypass,
669                                  has_scaling,
670                                  fused_batch_norm,
671                                  freeze_batch_norm_delay,
672                                  insert_identity_node=False):
673    """Tests that running folded and unfolded BN returns the same results.
674
675    Args:
676      relu: Callable that returns an Operation, a factory method for the Relu*.
677      relu_op_name: String, name of the Relu* operation.
678      with_bypass: Bool, when true there is an extra connection added from
679        inputs to just before Relu*.
680      has_scaling: Bool, when true the batch norm has scaling.
681      fused_batch_norm: Bool, when true the batch norm is fused.
682      freeze_batch_norm_delay: None or the number of steps after which training
683      switches to using frozen mean and variance
684      insert_identity_node: Bool, insert identity node between conv and batch
685      norm
686    """
687    random_seed.set_random_seed(1234)
688    unfolded_g = ops.Graph()
689    with unfolded_g.as_default():
690      batch_size, height, width = 5, 128, 128
691      inputs = random_ops.random_uniform(
692          (batch_size, height, width, 3), dtype=dtypes.float32, seed=1234)
693      out_depth = 3 if with_bypass else 32
694      stride = 1 if with_bypass else 2
695      activation_fn = None if with_bypass else relu
696      scope = 'test/test2' if with_bypass else 'test'
697      node = conv2d(
698          inputs,
699          out_depth, [5, 5],
700          stride=stride,
701          padding='SAME',
702          weights_initializer=self._WeightInit(0.09),
703          activation_fn=activation_fn,
704          normalizer_fn=batch_norm,
705          normalizer_params=self._BatchNormParams(
706              scale=has_scaling, fused=fused_batch_norm),
707          scope=scope)
708      if with_bypass:
709        node = math_ops.add(inputs, node, name='test/Add')
710      relu_node = relu(node, name='test/' + relu_op_name)
711    folded_g = self._CopyGraph(unfolded_g)
712    with folded_g.as_default():
713      fold_batch_norms.FoldBatchNorms(
714          folded_g,
715          is_training=True,
716          freeze_batch_norm_delay=freeze_batch_norm_delay)
717    with session.Session(graph=unfolded_g) as sess:
718      sess.run(variables.global_variables_initializer())
719      grad_node = gradients.gradients(relu_node, inputs)
720      results = sess.run([relu_node, grad_node])
721      unfolded_forward, unfolded_backward = results[0], results[1]
722
723    with session.Session(graph=folded_g) as sess:
724      sess.run(variables.global_variables_initializer())
725      relu_node = folded_g.get_tensor_by_name(relu_node.name)
726      inputs = folded_g.get_tensor_by_name(inputs.name)
727      grad_node = gradients.gradients(relu_node, inputs)
728      results = sess.run([relu_node, grad_node])
729      folded_forward, folded_backward = results[0], results[1]
730
731    # Check that the folded and unfolded results match.
732    self.assertAllClose(unfolded_forward, folded_forward, atol=1e-3)
733    self.assertAllClose(unfolded_backward, folded_backward, atol=1e-3)
734
735  def testCompareFoldAndUnfolded(self):
736    self._RunTestOverParameters(self._TestCompareFoldAndUnfolded)
737
738  def _BatchNormParams(self, scale=True, fused=False):
739    return {
740        'center': True,
741        'scale': scale,
742        'decay': 1.0 - 0.003,
743        'fused': fused
744    }
745
746  def _BatchNormMultiplierName(self, scope, has_scaling, fused):
747    if has_scaling:
748      if fused:
749        return scope + '/BatchNorm_Fold/mul'
750      return scope + '/BatchNorm/batchnorm_1/mul'
751    return scope + '/BatchNorm/batchnorm_1/Rsqrt'
752
753  def _BathNormBiasName(self, scope, fused):
754    if fused:
755      return scope + '/BatchNorm_Fold/bias'
756    return scope + '/BatchNorm/batchnorm_1/sub'
757
758  def _WeightInit(self, stddev):
759    """Returns a truncated normal variable initializer.
760
761    Function is defined purely to shorten the name so that it stops wrapping.
762
763    Args:
764      stddev: Standard deviation of normal variable.
765
766    Returns:
767      An initializer that initializes with a truncated normal variable.
768    """
769    return init_ops.truncated_normal_initializer(stddev=stddev, seed=1234)
770
771  def _AssertInputOpsAre(self, op, in_op_names):
772    """Asserts that all inputs to op come from in_op_names (disregarding order).
773
774    Args:
775      op: Operation to check inputs for.
776      in_op_names: List of strings, operations where all op's inputs should
777        come from.
778    """
779    expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names]
780    self.assertItemsEqual([t.name for t in op.inputs], expected_inputs)
781
782  def _AssertOutputGoesToOps(self, op, graph, out_op_names):
783    """Asserts that outputs from op go to out_op_names (and perhaps others).
784
785    Args:
786      op: Operation to check outputs for.
787      graph: Graph where output operations are located.
788      out_op_names: List of strings, operations where op's outputs should go.
789    """
790    for out_op_name in out_op_names:
791      out_op = graph.get_operation_by_name(out_op_name)
792      self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs])
793
794  def _AssertMovingAveragesAreFrozen(self, graph, scope):
795    """Asserts to check if moving mean and variance are frozen.
796
797    Args:
798      graph: Graph where the operations are located.
799      scope: Scope of batch norm op
800    """
801    moving_average_mult = graph.get_operation_by_name(
802        scope + '/BatchNorm/AssignMovingAvg/mul')
803    self.assertTrue(
804        moving_average_mult.inputs[1].name.find('freeze_moving_mean/Merge') > 0)
805    moving_var_mult = graph.get_operation_by_name(
806        scope + '/BatchNorm/AssignMovingAvg_1/mul')
807    self.assertTrue(
808        moving_var_mult.inputs[1].name.find('freeze_moving_var/Merge') > 0)
809
810  def _CopyGraph(self, graph):
811    """Return a copy of graph."""
812    meta_graph = saver_lib.export_meta_graph(
813        graph=graph, collection_list=graph.get_all_collection_keys())
814    graph_copy = ops.Graph()
815    with graph_copy.as_default():
816      _ = saver_lib.import_meta_graph(meta_graph)
817    return graph_copy
818
819
820if __name__ == '__main__':
821  googletest.main()
822