• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""End-to-end benchmark for batch normalization."""
16
17import argparse
18import sys
19import time
20
21from tensorflow.python.client import session as session_lib
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gen_nn_ops
27from tensorflow.python.ops import gradients_impl
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn_impl
30from tensorflow.python.ops import random_ops
31from tensorflow.python.ops import variables
32import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
33from tensorflow.python.platform import test
34
35
36def batch_norm_op(tensor, mean, variance, beta, gamma, scale):
37  """Fused kernel for batch normalization."""
38  # _batch_norm_with_global_normalization is deprecated in v9
39  test_util.set_producer_version(ops.get_default_graph(), 8)
40  # pylint: disable=protected-access
41  return gen_nn_ops._batch_norm_with_global_normalization(
42      tensor, mean, variance, beta, gamma, 0.001, scale)
43  # pylint: enable=protected-access
44
45
46# Note that the naive implementation is much slower:
47# batch_norm = (tensor - mean) * tf.math.rsqrt(variance + 0.001)
48# if scale:
49#   batch_norm *= gamma
50# return batch_norm + beta
51def batch_norm_py(tensor, mean, variance, beta, gamma, scale):
52  """Python implementation of batch normalization."""
53  return nn_impl.batch_normalization(tensor, mean, variance, beta, gamma if
54                                     scale else None, 0.001)
55
56
57def batch_norm_slow(tensor, mean, variance, beta, gamma, scale):
58  batch_norm = (tensor - mean) * math_ops.rsqrt(variance + 0.001)
59  if scale:
60    batch_norm *= gamma
61  return batch_norm + beta
62
63
64def build_graph(device, input_shape, axes, num_layers, mode, scale, train):
65  """Build a graph containing a sequence of batch normalizations.
66
67  Args:
68    device: string, the device to run on.
69    input_shape: shape of the input tensor.
70    axes: axes that are to be normalized across.
71    num_layers: number of batch normalization layers in the graph.
72    mode: "op", "py" or "slow" depending on the implementation.
73    scale: scale after normalization.
74    train: if true, also run backprop.
75
76  Returns:
77    An array of tensors to run()
78  """
79  moment_shape = []
80  keep_dims = mode == "py" or mode == "slow"
81  if keep_dims:
82    for axis in range(len(input_shape)):
83      if axis in axes:
84        moment_shape.append(1)
85      else:
86        moment_shape.append(input_shape[axis])
87  else:
88    for axis in range(len(input_shape)):
89      if axis not in axes:
90        moment_shape.append(input_shape[axis])
91  with ops.device("/%s:0" % device):
92    tensor = variables.Variable(random_ops.truncated_normal(input_shape))
93    for _ in range(num_layers):
94      if train:
95        mean, variance = nn_impl.moments(tensor, axes, keep_dims=keep_dims)
96      else:
97        mean = array_ops.zeros(moment_shape)
98        variance = array_ops.ones(moment_shape)
99      beta = variables.Variable(array_ops.zeros(moment_shape))
100      gamma = variables.Variable(constant_op.constant(1.0, shape=moment_shape))
101      if mode == "py":
102        tensor = batch_norm_py(tensor, mean, variance, beta, gamma, scale)
103      elif mode == "op":
104        tensor = batch_norm_op(tensor, mean, variance, beta, gamma, scale)
105      elif mode == "slow":
106        tensor = batch_norm_slow(tensor, mean, variance, beta, gamma, scale)
107    if train:
108      return gradients_impl.gradients([tensor], variables.trainable_variables())
109    else:
110      return [tensor]
111
112
113def print_difference(mode, t1, t2):
114  """Print the difference in timing between two runs."""
115  difference = (t2 - t1) / t1 * 100.0
116  print("=== %s: %.1f%% ===" % (mode, difference))
117
118
119class BatchNormBenchmark(test.Benchmark):
120  """Benchmark batch normalization."""
121
122  def _run_graph(self, device, input_shape, axes, num_layers, mode, scale,
123                 train, num_iters):
124    """Run the graph and print its execution time.
125
126    Args:
127      device: string, the device to run on.
128      input_shape: shape of the input tensor.
129      axes: axes that are to be normalized across.
130      num_layers: number of batch normalization layers in the graph.
131      mode: "op", "py" or "slow" depending on the implementation.
132      scale: scale after normalization.
133      train: if true, also run backprop.
134      num_iters: number of steps to run.
135
136    Returns:
137      The duration of the run in seconds.
138    """
139    graph = ops.Graph()
140    with graph.as_default():
141      outputs = build_graph(device, input_shape, axes, num_layers, mode, scale,
142                            train)
143    with session_lib.Session(graph=graph) as session:
144      variables.global_variables_initializer().run()
145      _ = session.run([out.op for out in outputs])  # warm up.
146      start_time = time.time()
147      for _ in range(num_iters):
148        _ = session.run([out.op for out in outputs])
149      duration = time.time() - start_time
150    print("%s shape:%d/%d #layers:%d mode:%s scale:%r train:%r - %f secs" %
151          (device, len(input_shape), len(axes), num_layers, mode, scale, train,
152           duration / num_iters))
153
154    name_template = (
155        "batch_norm_{device}_input_shape_{shape}_axes_{axes}_mode_{mode}_"
156        "layers_{num_layers}_scale_{scale}_"
157        "train_{train}")
158
159    self.report_benchmark(
160        name=name_template.format(
161            device=device,
162            mode=mode,
163            num_layers=num_layers,
164            scale=scale,
165            train=train,
166            shape=str(input_shape).replace(" ", ""),
167            axes=str(axes)).replace(" ", ""),
168        iters=num_iters,
169        wall_time=duration / num_iters)
170
171    return duration
172
173  def benchmark_batch_norm(self):
174    print("Forward convolution (lower layers).")
175    shape = [8, 128, 128, 32]
176    axes = [0, 1, 2]
177    t1 = self._run_graph("cpu", shape, axes, 10, "op", True, False, 5)
178    t2 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5)
179    t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5)
180    print_difference("op vs py", t1, t2)
181    print_difference("py vs slow", t2, t3)
182    if FLAGS.use_gpu:
183      t1 = self._run_graph("gpu", shape, axes, 10, "op", True, False, 50)
184      t2 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50)
185      t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50)
186      print_difference("op vs py", t1, t2)
187      print_difference("py vs slow", t2, t3)
188    print("Forward/backward convolution (lower layers).")
189    t1 = self._run_graph("cpu", shape, axes, 10, "op", True, True, 5)
190    t2 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 5)
191    t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 5)
192    print_difference("op vs py", t1, t2)
193    print_difference("py vs slow", t2, t3)
194    if FLAGS.use_gpu:
195      t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50)
196      t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50)
197      t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
198      print_difference("op vs py", t1, t2)
199      print_difference("py vs slow", t2, t3)
200    print("Forward convolution (higher layers).")
201    shape = [256, 17, 17, 32]
202    axes = [0, 1, 2]
203    t1 = self._run_graph("cpu", shape, axes, 10, "op", True, False, 5)
204    t2 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5)
205    t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5)
206    print_difference("op vs py", t1, t2)
207    print_difference("py vs slow", t2, t3)
208    if FLAGS.use_gpu:
209      t1 = self._run_graph("gpu", shape, axes, 10, "op", True, False, 50)
210      t2 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50)
211      t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50)
212      print_difference("op vs py", t1, t2)
213      print_difference("py vs slow", t2, t3)
214    print("Forward/backward convolution (higher layers).")
215    t1 = self._run_graph("cpu", shape, axes, 10, "op", True, True, 5)
216    t2 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 5)
217    t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 5)
218    print_difference("op vs py", t1, t2)
219    print_difference("py vs slow", t2, t3)
220    if FLAGS.use_gpu:
221      t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50)
222      t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50)
223      t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
224      print_difference("op vs py", t1, t2)
225      print_difference("py vs slow", t2, t3)
226    print("Forward fully-connected.")
227    shape = [1024, 32]
228    axes = [0]
229    t1 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5)
230    t2 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5)
231    print_difference("py vs slow", t1, t2)
232    if FLAGS.use_gpu:
233      t1 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50)
234      t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50)
235      print_difference("py vs slow", t1, t2)
236    print("Forward/backward fully-connected.")
237    t1 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 50)
238    t2 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 50)
239    print_difference("py vs slow", t1, t2)
240    if FLAGS.use_gpu:
241      t1 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 5)
242      t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 5)
243      print_difference("py vs slow", t1, t2)
244
245
246if __name__ == "__main__":
247  parser = argparse.ArgumentParser()
248  parser.register("type", "bool", lambda v: v.lower() == "true")
249  parser.add_argument(
250      "--use_gpu",
251      type="bool",
252      nargs="?",
253      const=True,
254      default=True,
255      help="Run GPU benchmarks."
256  )
257  global FLAGS  # pylint:disable=global-at-module-level
258  FLAGS, unparsed = parser.parse_known_args()
259  test.main(argv=[sys.argv[0]] + unparsed)
260