• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Register flops statistics for various TensorFlow operations.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import graph_util
22from tensorflow.python.framework import ops
23
24
25# List of all ops which have implemented flops statistics.
26IMPLEMENTED_OPS = set([
27    # Unary ops
28    "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
29    "L2Loss", "Softmax",
30    # Binary ops
31    "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
32    "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
33    "SquaredDifference",
34    # Reduction ops
35    "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
36    # Convolution and pooling
37    "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
38    "Conv2DBackpropFilter",
39    # Other ops
40    "AddN",
41    # Ops implemented in core tensorflow:
42    "MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
43])
44
45
46def _zero_flops(graph, node):
47  """Returns zero flops."""
48  del graph, node  # graph and node are unused
49  return ops.OpStats("flops", 0)
50
51
52def _list_product(lst):
53  """Computes product of element of the list."""
54  result = 1
55  for item in lst:
56    result *= item
57  return result
58
59################################################################################
60# Unary operations
61################################################################################
62
63
64def _unary_op_flops(graph, node, ops_per_element=1):
65  """Common code which compute flops for unary operations."""
66  in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
67  in_shape.assert_is_fully_defined()
68  return ops.OpStats("flops", in_shape.num_elements() * ops_per_element)
69
70
71@ops.RegisterStatistics("Reciprocal", "flops")
72def _reciprocal_flops(graph, node):
73  """Compute flops for Reciprocal operation."""
74  return _unary_op_flops(graph, node)
75
76
77@ops.RegisterStatistics("Square", "flops")
78def _square_flops(graph, node):
79  """Compute flops for Square operation."""
80  return _unary_op_flops(graph, node)
81
82
83@ops.RegisterStatistics("Rsqrt", "flops")
84def _rsqrt_flops(graph, node):
85  """Compute flops for Rsqrt operation."""
86  # Rsqrt(x) = 1 / sqrt(x)
87  return _unary_op_flops(graph, node, ops_per_element=2)
88
89
90@ops.RegisterStatistics("Log", "flops")
91def _log_flops(graph, node):
92  """Compute flops for Log operation."""
93  return _unary_op_flops(graph, node)
94
95
96@ops.RegisterStatistics("Neg", "flops")
97def _neg_flops(graph, node):
98  """Compute flops for Neg operation."""
99  return _unary_op_flops(graph, node)
100
101
102@ops.RegisterStatistics("AssignSub", "flops")
103def _assign_sub_flops(graph, node):
104  """Compute flops for AssignSub operation."""
105  return _unary_op_flops(graph, node)
106
107
108@ops.RegisterStatistics("AssignAdd", "flops")
109def _assign_add_flops(graph, node):
110  """Compute flops for AssignAdd operation."""
111  return _unary_op_flops(graph, node)
112
113
114@ops.RegisterStatistics("L2Loss", "flops")
115def _l2_loss_flops(graph, node):
116  """Compute flops for L2Loss operation."""
117  in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
118  in_shape.assert_is_fully_defined()
119  # Tensorflow uses inefficient implementation, with (3*N-1) flops:
120  # Optimal implementation is 2*N flops
121  return ops.OpStats("flops", in_shape.num_elements() * 3 - 1)
122
123
124@ops.RegisterStatistics("Softmax", "flops")
125def _softmax_flops(graph, node):
126  """Compute flops for Softmax operation."""
127  # Softmax implenetation:
128  #
129  # Approximate flops breakdown:
130  #   2*n          -- compute shifted logits
131  #   n            -- exp of shifted logits
132  #   2*n          -- compute softmax from exp of shifted logits
133  return _unary_op_flops(graph, node, ops_per_element=5)
134
135################################################################################
136# Binary operations
137################################################################################
138
139
140def _binary_per_element_op_flops(graph, node, ops_per_element=1):
141  """Common code which compute flops for binary operations."""
142  out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
143  out_shape.assert_is_fully_defined()
144  return ops.OpStats("flops", out_shape.num_elements() * ops_per_element)
145
146
147@ops.RegisterStatistics("Add", "flops")
148def _add_flops(graph, node):
149  """Compute flops for Add operation."""
150  return _binary_per_element_op_flops(graph, node)
151
152
153@ops.RegisterStatistics("Sub", "flops")
154def _sub_flops(graph, node):
155  """Compute flops for Sub operation."""
156  return _binary_per_element_op_flops(graph, node)
157
158
159@ops.RegisterStatistics("Mul", "flops")
160def _mul_flops(graph, node):
161  """Compute flops for Mul operation."""
162  return _binary_per_element_op_flops(graph, node)
163
164
165@ops.RegisterStatistics("RealDiv", "flops")
166def _real_div_flops(graph, node):
167  """Compute flops for RealDiv operation."""
168  return _binary_per_element_op_flops(graph, node)
169
170
171@ops.RegisterStatistics("Maximum", "flops")
172def _maximum_flops(graph, node):
173  """Compute flops for Maximum operation."""
174  return _binary_per_element_op_flops(graph, node)
175
176
177@ops.RegisterStatistics("Minimum", "flops")
178def _minimum_flops(graph, node):
179  """Compute flops for Minimum operation."""
180  return _binary_per_element_op_flops(graph, node)
181
182
183@ops.RegisterStatistics("Pow", "flops")
184def _pow_flops(graph, node):
185  """Compute flops for Pow operation."""
186  return _binary_per_element_op_flops(graph, node)
187
188
189@ops.RegisterStatistics("RsqrtGrad", "flops")
190def _rsqrt_grad_flops(graph, node):
191  """Compute flops for RsqrtGrad operation."""
192  return _binary_per_element_op_flops(graph, node, ops_per_element=4)
193
194
195@ops.RegisterStatistics("GreaterEqual", "flops")
196def _greater_equal_flops(graph, node):
197  """Compute flops for GreaterEqual operation."""
198  return _binary_per_element_op_flops(graph, node)
199
200
201@ops.RegisterStatistics("Greater", "flops")
202def _greater_flops(graph, node):
203  """Compute flops for Greater operation."""
204  return _binary_per_element_op_flops(graph, node)
205
206
207@ops.RegisterStatistics("LessEqual", "flops")
208def _less_equal_flops(graph, node):
209  """Compute flops for LessEqual operation."""
210  return _binary_per_element_op_flops(graph, node)
211
212
213@ops.RegisterStatistics("Less", "flops")
214def _less_flops(graph, node):
215  """Compute flops for Less operation."""
216  return _binary_per_element_op_flops(graph, node)
217
218
219@ops.RegisterStatistics("Equal", "flops")
220def _equal_flops(graph, node):
221  """Compute flops for Equal operation."""
222  return _binary_per_element_op_flops(graph, node)
223
224
225@ops.RegisterStatistics("NotEqual", "flops")
226def _not_equal_flops(graph, node):
227  """Compute flops for NotEqual operation."""
228  return _binary_per_element_op_flops(graph, node)
229
230
231@ops.RegisterStatistics("SquaredDifference", "flops")
232def _squared_difference_flops(graph, node):
233  """Compute flops for SquaredDifference operation."""
234  return _binary_per_element_op_flops(graph, node, ops_per_element=2)
235
236################################################################################
237# Reduction ops
238################################################################################
239
240
241def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0):
242  """Common code which compute flops for reduction operations."""
243  in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
244  in_shape.assert_is_fully_defined()
245  out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
246  out_shape.assert_is_fully_defined()
247  num_flops = (in_shape.num_elements() * reduce_flops
248               + out_shape.num_elements() * (finalize_flops - reduce_flops))
249  return ops.OpStats("flops", num_flops)
250
251
252@ops.RegisterStatistics("Mean", "flops")
253def _mean_flops(graph, node):
254  """Compute flops for Mean operation."""
255  # reduction - sum, finalization - divide
256  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1)
257
258
259@ops.RegisterStatistics("Sum", "flops")
260def _sum_flops(graph, node):
261  """Compute flops for Sum operation."""
262  # reduction - sum, no finalization
263  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
264
265
266@ops.RegisterStatistics("ArgMax", "flops")
267def _arg_max_flops(graph, node):
268  """Compute flops for ArgMax operation."""
269  # reduction - comparison, no finalization
270  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
271
272
273@ops.RegisterStatistics("ArgMin", "flops")
274def _arg_min_flops(graph, node):
275  """Compute flops for ArgMin operation."""
276  # reduction - comparison, no finalization
277  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
278
279
280@ops.RegisterStatistics("BiasAddGrad", "flops")
281def _bias_add_grad_flops(graph, node):
282  """Compute flops for BiasAddGrad operation."""
283  # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping:
284  # So computing flops same way as for "Sum"
285  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
286
287################################################################################
288# Convolution and pooling
289# Note: all flops statistics are implemented only for NHWC data format
290################################################################################
291
292
293def _verify_conv_data_format(node):
294  """Verifies data format for pooling and convolutional operations."""
295  # TODO(xpan): P1: Support NCHW
296  if node.attr["data_format"].s != b"NHWC":
297    raise ValueError("Only NHWC format is supported in flops computations")
298
299
300def _pool_flops(graph, node):
301  """Common code which compute flops for pooling operations."""
302  # compute flops for average and max pooling
303  _verify_conv_data_format(node)
304  #
305  # Pooling declaration:
306  #   Inputs:
307  #     - value
308  #   Outputs:
309  #     - output
310  #   Attributes:
311  #     - ksize
312  #     - strides
313  #     - padding
314  #     - data_format
315  #
316  # Pooling implenetation:
317  out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
318  out_shape.assert_is_fully_defined()
319  kernel_shape = list(node.attr["ksize"].list.i)
320  kernel_area = _list_product(kernel_shape)
321  return ops.OpStats("flops", kernel_area * out_shape.num_elements())
322
323
324@ops.RegisterStatistics("AvgPool", "flops")
325def _avg_pool_flops(graph, node):
326  """Compute flops for AvgPool operation."""
327  return _pool_flops(graph, node)
328
329
330@ops.RegisterStatistics("MaxPool", "flops")
331def _max_pool_flops(graph, node):
332  """Compute flops for MaxPool operation."""
333  return _pool_flops(graph, node)
334
335
336@ops.RegisterStatistics("AvgPoolGrad", "flops")
337def _avg_pool_grad_flops(graph, node):
338  """Compute flops for AvgPoolGrad operation."""
339  _verify_conv_data_format(node)
340  # Pooling gradient implementation:
341  out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph,
342                                                                  node.input[1])
343  out_backprop_shape.assert_is_fully_defined()
344  kernel_shape = list(node.attr["ksize"].list.i)
345  kernel_area = _list_product(kernel_shape)
346  # TensorFlow multiply each element of pooling window by coefficient,
347  # then sum up all of them, thus we have 2 flops per element:
348  # More optimal implementation - if division is done after.
349  return ops.OpStats("flops",
350                     kernel_area * out_backprop_shape.num_elements() * 2)
351
352
353@ops.RegisterStatistics("MaxPoolGrad", "flops")
354def _max_pool_grad_flops(graph, node):
355  """Compute flops for MaxPoolGrad operation."""
356  _verify_conv_data_format(node)
357  #
358  # MaxPoolGrad declaration:
359  #   Inputs:
360  #     - orig_input  -- original input tensor (of max_pool)
361  #     - orig_output  -- original output tensor (of max_pool)
362  #     - grad --  gradient with respect to output of max_pool
363  #   Outputs:
364  #     - output -- gradient with respect to input of max_pool
365  #   Attributes:
366  #     - ksize
367  #     - strides
368  #     - padding
369  #     - data_format
370  # It computes MaxPool first, then one flop per each element of original output
371  #
372  kernel_shape = list(node.attr["ksize"].list.i)
373  kernel_area = _list_product(kernel_shape)
374  orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
375                                                              node.input[1])
376  orig_out_shape.assert_is_fully_defined()
377  max_pool_ops = kernel_area * orig_out_shape.num_elements()
378  return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
379
380
381@ops.RegisterStatistics("Conv2DBackpropInput", "flops")
382def _conv_2d_backprop_input_flops(graph, node):
383  """Compute flops for Conv2DBackpropInput operation."""
384  # Formula:
385  #  batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
386  #  * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
387  #
388  # Where:
389  # image_x_dim, image_y_dim and input_depth --- size of input to source (no
390  #   backprop) convolution, in other words they are sizes of backprop output.
391  # output_depth --- number of filters in the original convolution, thus
392  #   depth of backprop input.
393  # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension
394  # image_x_stride and image_x_stride --- strides of the convolution
395  #
396  _verify_conv_data_format(node)
397  # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
398  out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
399  out_shape.assert_is_fully_defined()
400  # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
401  kernel_shape = graph_util.tensor_shape_from_node_def_name(graph,
402                                                            node.input[1])
403  kernel_shape.assert_is_fully_defined()
404  # strides
405  strides_shape = list(node.attr["strides"].list.i)
406  strides_product = strides_shape[1] * strides_shape[2]
407  return ops.OpStats("flops",
408                     (2 * out_shape.num_elements()
409                      * kernel_shape.num_elements()
410                      / (out_shape.dims[-1].value * strides_product)))
411
412
413@ops.RegisterStatistics("Conv2DBackpropFilter", "flops")
414def _conv_2d_backprop_filter_flops(graph, node):
415  """Compute flops for Conv2DBackpropFilter operation."""
416  # Formula same as for Conv2DBackpropInput:
417  #  batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
418  #  * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
419  #
420  _verify_conv_data_format(node)
421  # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
422  image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
423  image_shape.assert_is_fully_defined()
424  # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
425  kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
426  kernel_shape.assert_is_fully_defined()
427  # strides
428  strides_shape = list(node.attr["strides"].list.i)
429  strides_product = strides_shape[1] * strides_shape[2]
430  return ops.OpStats("flops",
431                     (2 * image_shape.num_elements()
432                      * kernel_shape.num_elements()
433                      / (image_shape.dims[-1].value * strides_product)))
434
435################################################################################
436# Other ops
437################################################################################
438
439
440@ops.RegisterStatistics("AddN", "flops")
441def _add_n_flops(graph, node):
442  """Compute flops for AddN operation."""
443  if not node.input:
444    return _zero_flops(graph, node)
445  in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
446  in_shape.assert_is_fully_defined()
447  return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1))
448