1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""End-to-end benchmark for batch normalization.""" 16 17import argparse 18import sys 19import time 20 21from tensorflow.python.client import session as session_lib 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import gen_nn_ops 27from tensorflow.python.ops import gradients_impl 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import nn_impl 30from tensorflow.python.ops import random_ops 31from tensorflow.python.ops import variables 32import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 33from tensorflow.python.platform import test 34 35 36def batch_norm_op(tensor, mean, variance, beta, gamma, scale): 37 """Fused kernel for batch normalization.""" 38 # _batch_norm_with_global_normalization is deprecated in v9 39 test_util.set_producer_version(ops.get_default_graph(), 8) 40 # pylint: disable=protected-access 41 return gen_nn_ops._batch_norm_with_global_normalization( 42 tensor, mean, variance, beta, gamma, 0.001, scale) 43 # pylint: enable=protected-access 44 45 46# Note that the naive implementation is much slower: 47# batch_norm = (tensor - mean) * tf.math.rsqrt(variance + 0.001) 48# if scale: 49# batch_norm *= gamma 50# return batch_norm + beta 51def batch_norm_py(tensor, mean, variance, beta, gamma, scale): 52 """Python implementation of batch normalization.""" 53 return nn_impl.batch_normalization(tensor, mean, variance, beta, gamma if 54 scale else None, 0.001) 55 56 57def batch_norm_slow(tensor, mean, variance, beta, gamma, scale): 58 batch_norm = (tensor - mean) * math_ops.rsqrt(variance + 0.001) 59 if scale: 60 batch_norm *= gamma 61 return batch_norm + beta 62 63 64def build_graph(device, input_shape, axes, num_layers, mode, scale, train): 65 """Build a graph containing a sequence of batch normalizations. 66 67 Args: 68 device: string, the device to run on. 69 input_shape: shape of the input tensor. 70 axes: axes that are to be normalized across. 71 num_layers: number of batch normalization layers in the graph. 72 mode: "op", "py" or "slow" depending on the implementation. 73 scale: scale after normalization. 74 train: if true, also run backprop. 75 76 Returns: 77 An array of tensors to run() 78 """ 79 moment_shape = [] 80 keep_dims = mode == "py" or mode == "slow" 81 if keep_dims: 82 for axis in range(len(input_shape)): 83 if axis in axes: 84 moment_shape.append(1) 85 else: 86 moment_shape.append(input_shape[axis]) 87 else: 88 for axis in range(len(input_shape)): 89 if axis not in axes: 90 moment_shape.append(input_shape[axis]) 91 with ops.device("/%s:0" % device): 92 tensor = variables.Variable(random_ops.truncated_normal(input_shape)) 93 for _ in range(num_layers): 94 if train: 95 mean, variance = nn_impl.moments(tensor, axes, keep_dims=keep_dims) 96 else: 97 mean = array_ops.zeros(moment_shape) 98 variance = array_ops.ones(moment_shape) 99 beta = variables.Variable(array_ops.zeros(moment_shape)) 100 gamma = variables.Variable(constant_op.constant(1.0, shape=moment_shape)) 101 if mode == "py": 102 tensor = batch_norm_py(tensor, mean, variance, beta, gamma, scale) 103 elif mode == "op": 104 tensor = batch_norm_op(tensor, mean, variance, beta, gamma, scale) 105 elif mode == "slow": 106 tensor = batch_norm_slow(tensor, mean, variance, beta, gamma, scale) 107 if train: 108 return gradients_impl.gradients([tensor], variables.trainable_variables()) 109 else: 110 return [tensor] 111 112 113def print_difference(mode, t1, t2): 114 """Print the difference in timing between two runs.""" 115 difference = (t2 - t1) / t1 * 100.0 116 print("=== %s: %.1f%% ===" % (mode, difference)) 117 118 119class BatchNormBenchmark(test.Benchmark): 120 """Benchmark batch normalization.""" 121 122 def _run_graph(self, device, input_shape, axes, num_layers, mode, scale, 123 train, num_iters): 124 """Run the graph and print its execution time. 125 126 Args: 127 device: string, the device to run on. 128 input_shape: shape of the input tensor. 129 axes: axes that are to be normalized across. 130 num_layers: number of batch normalization layers in the graph. 131 mode: "op", "py" or "slow" depending on the implementation. 132 scale: scale after normalization. 133 train: if true, also run backprop. 134 num_iters: number of steps to run. 135 136 Returns: 137 The duration of the run in seconds. 138 """ 139 graph = ops.Graph() 140 with graph.as_default(): 141 outputs = build_graph(device, input_shape, axes, num_layers, mode, scale, 142 train) 143 with session_lib.Session(graph=graph) as session: 144 variables.global_variables_initializer().run() 145 _ = session.run([out.op for out in outputs]) # warm up. 146 start_time = time.time() 147 for _ in range(num_iters): 148 _ = session.run([out.op for out in outputs]) 149 duration = time.time() - start_time 150 print("%s shape:%d/%d #layers:%d mode:%s scale:%r train:%r - %f secs" % 151 (device, len(input_shape), len(axes), num_layers, mode, scale, train, 152 duration / num_iters)) 153 154 name_template = ( 155 "batch_norm_{device}_input_shape_{shape}_axes_{axes}_mode_{mode}_" 156 "layers_{num_layers}_scale_{scale}_" 157 "train_{train}") 158 159 self.report_benchmark( 160 name=name_template.format( 161 device=device, 162 mode=mode, 163 num_layers=num_layers, 164 scale=scale, 165 train=train, 166 shape=str(input_shape).replace(" ", ""), 167 axes=str(axes)).replace(" ", ""), 168 iters=num_iters, 169 wall_time=duration / num_iters) 170 171 return duration 172 173 def benchmark_batch_norm(self): 174 print("Forward convolution (lower layers).") 175 shape = [8, 128, 128, 32] 176 axes = [0, 1, 2] 177 t1 = self._run_graph("cpu", shape, axes, 10, "op", True, False, 5) 178 t2 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5) 179 t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5) 180 print_difference("op vs py", t1, t2) 181 print_difference("py vs slow", t2, t3) 182 if FLAGS.use_gpu: 183 t1 = self._run_graph("gpu", shape, axes, 10, "op", True, False, 50) 184 t2 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50) 185 t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50) 186 print_difference("op vs py", t1, t2) 187 print_difference("py vs slow", t2, t3) 188 print("Forward/backward convolution (lower layers).") 189 t1 = self._run_graph("cpu", shape, axes, 10, "op", True, True, 5) 190 t2 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 5) 191 t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 5) 192 print_difference("op vs py", t1, t2) 193 print_difference("py vs slow", t2, t3) 194 if FLAGS.use_gpu: 195 t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50) 196 t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50) 197 t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50) 198 print_difference("op vs py", t1, t2) 199 print_difference("py vs slow", t2, t3) 200 print("Forward convolution (higher layers).") 201 shape = [256, 17, 17, 32] 202 axes = [0, 1, 2] 203 t1 = self._run_graph("cpu", shape, axes, 10, "op", True, False, 5) 204 t2 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5) 205 t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5) 206 print_difference("op vs py", t1, t2) 207 print_difference("py vs slow", t2, t3) 208 if FLAGS.use_gpu: 209 t1 = self._run_graph("gpu", shape, axes, 10, "op", True, False, 50) 210 t2 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50) 211 t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50) 212 print_difference("op vs py", t1, t2) 213 print_difference("py vs slow", t2, t3) 214 print("Forward/backward convolution (higher layers).") 215 t1 = self._run_graph("cpu", shape, axes, 10, "op", True, True, 5) 216 t2 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 5) 217 t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 5) 218 print_difference("op vs py", t1, t2) 219 print_difference("py vs slow", t2, t3) 220 if FLAGS.use_gpu: 221 t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50) 222 t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50) 223 t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50) 224 print_difference("op vs py", t1, t2) 225 print_difference("py vs slow", t2, t3) 226 print("Forward fully-connected.") 227 shape = [1024, 32] 228 axes = [0] 229 t1 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5) 230 t2 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5) 231 print_difference("py vs slow", t1, t2) 232 if FLAGS.use_gpu: 233 t1 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50) 234 t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50) 235 print_difference("py vs slow", t1, t2) 236 print("Forward/backward fully-connected.") 237 t1 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 50) 238 t2 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 50) 239 print_difference("py vs slow", t1, t2) 240 if FLAGS.use_gpu: 241 t1 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 5) 242 t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 5) 243 print_difference("py vs slow", t1, t2) 244 245 246if __name__ == "__main__": 247 parser = argparse.ArgumentParser() 248 parser.register("type", "bool", lambda v: v.lower() == "true") 249 parser.add_argument( 250 "--use_gpu", 251 type="bool", 252 nargs="?", 253 const=True, 254 default=True, 255 help="Run GPU benchmarks." 256 ) 257 global FLAGS # pylint:disable=global-at-module-level 258 FLAGS, unparsed = parser.parse_known_args() 259 test.main(argv=[sys.argv[0]] + unparsed) 260