1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Register flops statistics for various TensorFlow operations. 16""" 17from tensorflow.python.framework import graph_util 18from tensorflow.python.framework import ops 19 20 21# List of all ops which have implemented flops statistics. 22IMPLEMENTED_OPS = set([ 23 # Unary ops 24 "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd", 25 "L2Loss", "Softmax", 26 # Binary ops 27 "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad", 28 "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual", 29 "SquaredDifference", 30 # Reduction ops 31 "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad", 32 # Convolution and pooling 33 "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput", 34 "Conv2DBackpropFilter", 35 # Other ops 36 "AddN", 37 # Ops implemented in core tensorflow: 38 "MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D", 39]) 40 41 42def _zero_flops(graph, node): 43 """Returns zero flops.""" 44 del graph, node # graph and node are unused 45 return ops.OpStats("flops", 0) 46 47 48def _list_product(lst): 49 """Computes product of element of the list.""" 50 result = 1 51 for item in lst: 52 result *= item 53 return result 54 55################################################################################ 56# Unary operations 57################################################################################ 58 59 60def _unary_op_flops(graph, node, ops_per_element=1): 61 """Common code which compute flops for unary operations.""" 62 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 63 in_shape.assert_is_fully_defined() 64 return ops.OpStats("flops", in_shape.num_elements() * ops_per_element) 65 66 67@ops.RegisterStatistics("Reciprocal", "flops") 68def _reciprocal_flops(graph, node): 69 """Compute flops for Reciprocal operation.""" 70 return _unary_op_flops(graph, node) 71 72 73@ops.RegisterStatistics("Square", "flops") 74def _square_flops(graph, node): 75 """Compute flops for Square operation.""" 76 return _unary_op_flops(graph, node) 77 78 79@ops.RegisterStatistics("Rsqrt", "flops") 80def _rsqrt_flops(graph, node): 81 """Compute flops for Rsqrt operation.""" 82 # Rsqrt(x) = 1 / sqrt(x) 83 return _unary_op_flops(graph, node, ops_per_element=2) 84 85 86@ops.RegisterStatistics("Log", "flops") 87def _log_flops(graph, node): 88 """Compute flops for Log operation.""" 89 return _unary_op_flops(graph, node) 90 91 92@ops.RegisterStatistics("Neg", "flops") 93def _neg_flops(graph, node): 94 """Compute flops for Neg operation.""" 95 return _unary_op_flops(graph, node) 96 97 98@ops.RegisterStatistics("AssignSub", "flops") 99def _assign_sub_flops(graph, node): 100 """Compute flops for AssignSub operation.""" 101 return _unary_op_flops(graph, node) 102 103 104@ops.RegisterStatistics("AssignAdd", "flops") 105def _assign_add_flops(graph, node): 106 """Compute flops for AssignAdd operation.""" 107 return _unary_op_flops(graph, node) 108 109 110@ops.RegisterStatistics("L2Loss", "flops") 111def _l2_loss_flops(graph, node): 112 """Compute flops for L2Loss operation.""" 113 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 114 in_shape.assert_is_fully_defined() 115 # Tensorflow uses inefficient implementation, with (3*N-1) flops: 116 # Optimal implementation is 2*N flops 117 return ops.OpStats("flops", in_shape.num_elements() * 3 - 1) 118 119 120@ops.RegisterStatistics("Softmax", "flops") 121def _softmax_flops(graph, node): 122 """Compute flops for Softmax operation.""" 123 # Softmax implemetation: 124 # 125 # Approximate flops breakdown: 126 # 2*n -- compute shifted logits 127 # n -- exp of shifted logits 128 # 2*n -- compute softmax from exp of shifted logits 129 return _unary_op_flops(graph, node, ops_per_element=5) 130 131################################################################################ 132# Binary operations 133################################################################################ 134 135 136def _binary_per_element_op_flops(graph, node, ops_per_element=1): 137 """Common code which compute flops for binary operations.""" 138 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 139 out_shape.assert_is_fully_defined() 140 return ops.OpStats("flops", out_shape.num_elements() * ops_per_element) 141 142 143@ops.RegisterStatistics("Add", "flops") 144def _add_flops(graph, node): 145 """Compute flops for Add operation.""" 146 return _binary_per_element_op_flops(graph, node) 147 148 149@ops.RegisterStatistics("Sub", "flops") 150def _sub_flops(graph, node): 151 """Compute flops for Sub operation.""" 152 return _binary_per_element_op_flops(graph, node) 153 154 155@ops.RegisterStatistics("Mul", "flops") 156def _mul_flops(graph, node): 157 """Compute flops for Mul operation.""" 158 return _binary_per_element_op_flops(graph, node) 159 160 161@ops.RegisterStatistics("RealDiv", "flops") 162def _real_div_flops(graph, node): 163 """Compute flops for RealDiv operation.""" 164 return _binary_per_element_op_flops(graph, node) 165 166 167@ops.RegisterStatistics("Maximum", "flops") 168def _maximum_flops(graph, node): 169 """Compute flops for Maximum operation.""" 170 return _binary_per_element_op_flops(graph, node) 171 172 173@ops.RegisterStatistics("Minimum", "flops") 174def _minimum_flops(graph, node): 175 """Compute flops for Minimum operation.""" 176 return _binary_per_element_op_flops(graph, node) 177 178 179@ops.RegisterStatistics("Pow", "flops") 180def _pow_flops(graph, node): 181 """Compute flops for Pow operation.""" 182 return _binary_per_element_op_flops(graph, node) 183 184 185@ops.RegisterStatistics("RsqrtGrad", "flops") 186def _rsqrt_grad_flops(graph, node): 187 """Compute flops for RsqrtGrad operation.""" 188 return _binary_per_element_op_flops(graph, node, ops_per_element=4) 189 190 191@ops.RegisterStatistics("GreaterEqual", "flops") 192def _greater_equal_flops(graph, node): 193 """Compute flops for GreaterEqual operation.""" 194 return _binary_per_element_op_flops(graph, node) 195 196 197@ops.RegisterStatistics("Greater", "flops") 198def _greater_flops(graph, node): 199 """Compute flops for Greater operation.""" 200 return _binary_per_element_op_flops(graph, node) 201 202 203@ops.RegisterStatistics("LessEqual", "flops") 204def _less_equal_flops(graph, node): 205 """Compute flops for LessEqual operation.""" 206 return _binary_per_element_op_flops(graph, node) 207 208 209@ops.RegisterStatistics("Less", "flops") 210def _less_flops(graph, node): 211 """Compute flops for Less operation.""" 212 return _binary_per_element_op_flops(graph, node) 213 214 215@ops.RegisterStatistics("Equal", "flops") 216def _equal_flops(graph, node): 217 """Compute flops for Equal operation.""" 218 return _binary_per_element_op_flops(graph, node) 219 220 221@ops.RegisterStatistics("NotEqual", "flops") 222def _not_equal_flops(graph, node): 223 """Compute flops for NotEqual operation.""" 224 return _binary_per_element_op_flops(graph, node) 225 226 227@ops.RegisterStatistics("SquaredDifference", "flops") 228def _squared_difference_flops(graph, node): 229 """Compute flops for SquaredDifference operation.""" 230 return _binary_per_element_op_flops(graph, node, ops_per_element=2) 231 232################################################################################ 233# Reduction ops 234################################################################################ 235 236 237def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0): 238 """Common code which compute flops for reduction operations.""" 239 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 240 in_shape.assert_is_fully_defined() 241 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 242 out_shape.assert_is_fully_defined() 243 num_flops = (in_shape.num_elements() * reduce_flops 244 + out_shape.num_elements() * (finalize_flops - reduce_flops)) 245 return ops.OpStats("flops", num_flops) 246 247 248@ops.RegisterStatistics("Mean", "flops") 249def _mean_flops(graph, node): 250 """Compute flops for Mean operation.""" 251 # reduction - sum, finalization - divide 252 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1) 253 254 255@ops.RegisterStatistics("Sum", "flops") 256def _sum_flops(graph, node): 257 """Compute flops for Sum operation.""" 258 # reduction - sum, no finalization 259 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 260 261 262@ops.RegisterStatistics("ArgMax", "flops") 263def _arg_max_flops(graph, node): 264 """Compute flops for ArgMax operation.""" 265 # reduction - comparison, no finalization 266 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 267 268 269@ops.RegisterStatistics("ArgMin", "flops") 270def _arg_min_flops(graph, node): 271 """Compute flops for ArgMin operation.""" 272 # reduction - comparison, no finalization 273 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 274 275 276@ops.RegisterStatistics("BiasAddGrad", "flops") 277def _bias_add_grad_flops(graph, node): 278 """Compute flops for BiasAddGrad operation.""" 279 # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping: 280 # So computing flops same way as for "Sum" 281 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 282 283################################################################################ 284# Convolution and pooling 285# Note: all flops statistics are implemented only for NHWC data format 286################################################################################ 287 288 289def _verify_conv_data_format(node): 290 """Verifies data format for pooling and convolutional operations.""" 291 # TODO(xpan): P1: Support NCHW 292 if node.attr["data_format"].s != b"NHWC": 293 raise ValueError("Only NHWC format is supported in flops computations") 294 295 296def _pool_flops(graph, node): 297 """Common code which compute flops for pooling operations.""" 298 # compute flops for average and max pooling 299 _verify_conv_data_format(node) 300 # 301 # Pooling declaration: 302 # Inputs: 303 # - value 304 # Outputs: 305 # - output 306 # Attributes: 307 # - ksize 308 # - strides 309 # - padding 310 # - data_format 311 # 312 # Pooling implemetation: 313 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 314 out_shape.assert_is_fully_defined() 315 kernel_shape = list(node.attr["ksize"].list.i) 316 kernel_area = _list_product(kernel_shape) 317 return ops.OpStats("flops", kernel_area * out_shape.num_elements()) 318 319 320@ops.RegisterStatistics("AvgPool", "flops") 321def _avg_pool_flops(graph, node): 322 """Compute flops for AvgPool operation.""" 323 return _pool_flops(graph, node) 324 325 326@ops.RegisterStatistics("MaxPool", "flops") 327def _max_pool_flops(graph, node): 328 """Compute flops for MaxPool operation.""" 329 return _pool_flops(graph, node) 330 331 332@ops.RegisterStatistics("AvgPoolGrad", "flops") 333def _avg_pool_grad_flops(graph, node): 334 """Compute flops for AvgPoolGrad operation.""" 335 _verify_conv_data_format(node) 336 # Pooling gradient implementation: 337 out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph, 338 node.input[1]) 339 out_backprop_shape.assert_is_fully_defined() 340 kernel_shape = list(node.attr["ksize"].list.i) 341 kernel_area = _list_product(kernel_shape) 342 # TensorFlow multiply each element of pooling window by coefficient, 343 # then sum up all of them, thus we have 2 flops per element: 344 # More optimal implementation - if division is done after. 345 return ops.OpStats("flops", 346 kernel_area * out_backprop_shape.num_elements() * 2) 347 348 349@ops.RegisterStatistics("MaxPoolGrad", "flops") 350def _max_pool_grad_flops(graph, node): 351 """Compute flops for MaxPoolGrad operation.""" 352 _verify_conv_data_format(node) 353 # 354 # MaxPoolGrad declaration: 355 # Inputs: 356 # - orig_input -- original input tensor (of max_pool) 357 # - orig_output -- original output tensor (of max_pool) 358 # - grad -- gradient with respect to output of max_pool 359 # Outputs: 360 # - output -- gradient with respect to input of max_pool 361 # Attributes: 362 # - ksize 363 # - strides 364 # - padding 365 # - data_format 366 # It computes MaxPool first, then one flop per each element of original output 367 # 368 kernel_shape = list(node.attr["ksize"].list.i) 369 kernel_area = _list_product(kernel_shape) 370 orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph, 371 node.input[1]) 372 orig_out_shape.assert_is_fully_defined() 373 max_pool_ops = kernel_area * orig_out_shape.num_elements() 374 return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements()) 375 376 377@ops.RegisterStatistics("Conv2DBackpropInput", "flops") 378def _conv_2d_backprop_input_flops(graph, node): 379 """Compute flops for Conv2DBackpropInput operation.""" 380 # Formula: 381 # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim 382 # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride) 383 # 384 # Where: 385 # image_x_dim, image_y_dim and input_depth --- size of input to source (no 386 # backprop) convolution, in other words they are sizes of backprop output. 387 # output_depth --- number of filters in the original convolution, thus 388 # depth of backprop input. 389 # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension 390 # image_x_stride and image_x_stride --- strides of the convolution 391 # 392 _verify_conv_data_format(node) 393 # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth] 394 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 395 out_shape.assert_is_fully_defined() 396 # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth] 397 kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, 398 node.input[1]) 399 kernel_shape.assert_is_fully_defined() 400 # strides 401 strides_shape = list(node.attr["strides"].list.i) 402 strides_product = strides_shape[1] * strides_shape[2] 403 return ops.OpStats("flops", 404 (2 * out_shape.num_elements() 405 * kernel_shape.num_elements() 406 / (out_shape.dims[-1].value * strides_product))) 407 408 409@ops.RegisterStatistics("Conv2DBackpropFilter", "flops") 410def _conv_2d_backprop_filter_flops(graph, node): 411 """Compute flops for Conv2DBackpropFilter operation.""" 412 # Formula same as for Conv2DBackpropInput: 413 # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim 414 # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride) 415 # 416 _verify_conv_data_format(node) 417 # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth] 418 image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 419 image_shape.assert_is_fully_defined() 420 # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth] 421 kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 422 kernel_shape.assert_is_fully_defined() 423 # strides 424 strides_shape = list(node.attr["strides"].list.i) 425 strides_product = strides_shape[1] * strides_shape[2] 426 return ops.OpStats("flops", 427 (2 * image_shape.num_elements() 428 * kernel_shape.num_elements() 429 / (image_shape.dims[-1].value * strides_product))) 430 431################################################################################ 432# Other ops 433################################################################################ 434 435 436@ops.RegisterStatistics("AddN", "flops") 437def _add_n_flops(graph, node): 438 """Compute flops for AddN operation.""" 439 if not node.input: 440 return _zero_flops(graph, node) 441 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 442 in_shape.assert_is_fully_defined() 443 return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1)) 444