• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Register flops statistics for various TensorFlow operations.
16"""
17from tensorflow.python.framework import graph_util
18from tensorflow.python.framework import ops
19
20
21# List of all ops which have implemented flops statistics.
22IMPLEMENTED_OPS = set([
23    # Unary ops
24    "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
25    "L2Loss", "Softmax",
26    # Binary ops
27    "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
28    "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
29    "SquaredDifference",
30    # Reduction ops
31    "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
32    # Convolution and pooling
33    "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
34    "Conv2DBackpropFilter",
35    # Other ops
36    "AddN",
37    # Ops implemented in core tensorflow:
38    "MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
39])
40
41
42def _zero_flops(graph, node):
43  """Returns zero flops."""
44  del graph, node  # graph and node are unused
45  return ops.OpStats("flops", 0)
46
47
48def _list_product(lst):
49  """Computes product of element of the list."""
50  result = 1
51  for item in lst:
52    result *= item
53  return result
54
55################################################################################
56# Unary operations
57################################################################################
58
59
60def _unary_op_flops(graph, node, ops_per_element=1):
61  """Common code which compute flops for unary operations."""
62  in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
63  in_shape.assert_is_fully_defined()
64  return ops.OpStats("flops", in_shape.num_elements() * ops_per_element)
65
66
67@ops.RegisterStatistics("Reciprocal", "flops")
68def _reciprocal_flops(graph, node):
69  """Compute flops for Reciprocal operation."""
70  return _unary_op_flops(graph, node)
71
72
73@ops.RegisterStatistics("Square", "flops")
74def _square_flops(graph, node):
75  """Compute flops for Square operation."""
76  return _unary_op_flops(graph, node)
77
78
79@ops.RegisterStatistics("Rsqrt", "flops")
80def _rsqrt_flops(graph, node):
81  """Compute flops for Rsqrt operation."""
82  # Rsqrt(x) = 1 / sqrt(x)
83  return _unary_op_flops(graph, node, ops_per_element=2)
84
85
86@ops.RegisterStatistics("Log", "flops")
87def _log_flops(graph, node):
88  """Compute flops for Log operation."""
89  return _unary_op_flops(graph, node)
90
91
92@ops.RegisterStatistics("Neg", "flops")
93def _neg_flops(graph, node):
94  """Compute flops for Neg operation."""
95  return _unary_op_flops(graph, node)
96
97
98@ops.RegisterStatistics("AssignSub", "flops")
99def _assign_sub_flops(graph, node):
100  """Compute flops for AssignSub operation."""
101  return _unary_op_flops(graph, node)
102
103
104@ops.RegisterStatistics("AssignAdd", "flops")
105def _assign_add_flops(graph, node):
106  """Compute flops for AssignAdd operation."""
107  return _unary_op_flops(graph, node)
108
109
110@ops.RegisterStatistics("L2Loss", "flops")
111def _l2_loss_flops(graph, node):
112  """Compute flops for L2Loss operation."""
113  in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
114  in_shape.assert_is_fully_defined()
115  # Tensorflow uses inefficient implementation, with (3*N-1) flops:
116  # Optimal implementation is 2*N flops
117  return ops.OpStats("flops", in_shape.num_elements() * 3 - 1)
118
119
120@ops.RegisterStatistics("Softmax", "flops")
121def _softmax_flops(graph, node):
122  """Compute flops for Softmax operation."""
123  # Softmax implemetation:
124  #
125  # Approximate flops breakdown:
126  #   2*n          -- compute shifted logits
127  #   n            -- exp of shifted logits
128  #   2*n          -- compute softmax from exp of shifted logits
129  return _unary_op_flops(graph, node, ops_per_element=5)
130
131################################################################################
132# Binary operations
133################################################################################
134
135
136def _binary_per_element_op_flops(graph, node, ops_per_element=1):
137  """Common code which compute flops for binary operations."""
138  out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
139  out_shape.assert_is_fully_defined()
140  return ops.OpStats("flops", out_shape.num_elements() * ops_per_element)
141
142
143@ops.RegisterStatistics("Add", "flops")
144def _add_flops(graph, node):
145  """Compute flops for Add operation."""
146  return _binary_per_element_op_flops(graph, node)
147
148
149@ops.RegisterStatistics("Sub", "flops")
150def _sub_flops(graph, node):
151  """Compute flops for Sub operation."""
152  return _binary_per_element_op_flops(graph, node)
153
154
155@ops.RegisterStatistics("Mul", "flops")
156def _mul_flops(graph, node):
157  """Compute flops for Mul operation."""
158  return _binary_per_element_op_flops(graph, node)
159
160
161@ops.RegisterStatistics("RealDiv", "flops")
162def _real_div_flops(graph, node):
163  """Compute flops for RealDiv operation."""
164  return _binary_per_element_op_flops(graph, node)
165
166
167@ops.RegisterStatistics("Maximum", "flops")
168def _maximum_flops(graph, node):
169  """Compute flops for Maximum operation."""
170  return _binary_per_element_op_flops(graph, node)
171
172
173@ops.RegisterStatistics("Minimum", "flops")
174def _minimum_flops(graph, node):
175  """Compute flops for Minimum operation."""
176  return _binary_per_element_op_flops(graph, node)
177
178
179@ops.RegisterStatistics("Pow", "flops")
180def _pow_flops(graph, node):
181  """Compute flops for Pow operation."""
182  return _binary_per_element_op_flops(graph, node)
183
184
185@ops.RegisterStatistics("RsqrtGrad", "flops")
186def _rsqrt_grad_flops(graph, node):
187  """Compute flops for RsqrtGrad operation."""
188  return _binary_per_element_op_flops(graph, node, ops_per_element=4)
189
190
191@ops.RegisterStatistics("GreaterEqual", "flops")
192def _greater_equal_flops(graph, node):
193  """Compute flops for GreaterEqual operation."""
194  return _binary_per_element_op_flops(graph, node)
195
196
197@ops.RegisterStatistics("Greater", "flops")
198def _greater_flops(graph, node):
199  """Compute flops for Greater operation."""
200  return _binary_per_element_op_flops(graph, node)
201
202
203@ops.RegisterStatistics("LessEqual", "flops")
204def _less_equal_flops(graph, node):
205  """Compute flops for LessEqual operation."""
206  return _binary_per_element_op_flops(graph, node)
207
208
209@ops.RegisterStatistics("Less", "flops")
210def _less_flops(graph, node):
211  """Compute flops for Less operation."""
212  return _binary_per_element_op_flops(graph, node)
213
214
215@ops.RegisterStatistics("Equal", "flops")
216def _equal_flops(graph, node):
217  """Compute flops for Equal operation."""
218  return _binary_per_element_op_flops(graph, node)
219
220
221@ops.RegisterStatistics("NotEqual", "flops")
222def _not_equal_flops(graph, node):
223  """Compute flops for NotEqual operation."""
224  return _binary_per_element_op_flops(graph, node)
225
226
227@ops.RegisterStatistics("SquaredDifference", "flops")
228def _squared_difference_flops(graph, node):
229  """Compute flops for SquaredDifference operation."""
230  return _binary_per_element_op_flops(graph, node, ops_per_element=2)
231
232################################################################################
233# Reduction ops
234################################################################################
235
236
237def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0):
238  """Common code which compute flops for reduction operations."""
239  in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
240  in_shape.assert_is_fully_defined()
241  out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
242  out_shape.assert_is_fully_defined()
243  num_flops = (in_shape.num_elements() * reduce_flops
244               + out_shape.num_elements() * (finalize_flops - reduce_flops))
245  return ops.OpStats("flops", num_flops)
246
247
248@ops.RegisterStatistics("Mean", "flops")
249def _mean_flops(graph, node):
250  """Compute flops for Mean operation."""
251  # reduction - sum, finalization - divide
252  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1)
253
254
255@ops.RegisterStatistics("Sum", "flops")
256def _sum_flops(graph, node):
257  """Compute flops for Sum operation."""
258  # reduction - sum, no finalization
259  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
260
261
262@ops.RegisterStatistics("ArgMax", "flops")
263def _arg_max_flops(graph, node):
264  """Compute flops for ArgMax operation."""
265  # reduction - comparison, no finalization
266  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
267
268
269@ops.RegisterStatistics("ArgMin", "flops")
270def _arg_min_flops(graph, node):
271  """Compute flops for ArgMin operation."""
272  # reduction - comparison, no finalization
273  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
274
275
276@ops.RegisterStatistics("BiasAddGrad", "flops")
277def _bias_add_grad_flops(graph, node):
278  """Compute flops for BiasAddGrad operation."""
279  # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping:
280  # So computing flops same way as for "Sum"
281  return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
282
283################################################################################
284# Convolution and pooling
285# Note: all flops statistics are implemented only for NHWC data format
286################################################################################
287
288
289def _verify_conv_data_format(node):
290  """Verifies data format for pooling and convolutional operations."""
291  # TODO(xpan): P1: Support NCHW
292  if node.attr["data_format"].s != b"NHWC":
293    raise ValueError("Only NHWC format is supported in flops computations")
294
295
296def _pool_flops(graph, node):
297  """Common code which compute flops for pooling operations."""
298  # compute flops for average and max pooling
299  _verify_conv_data_format(node)
300  #
301  # Pooling declaration:
302  #   Inputs:
303  #     - value
304  #   Outputs:
305  #     - output
306  #   Attributes:
307  #     - ksize
308  #     - strides
309  #     - padding
310  #     - data_format
311  #
312  # Pooling implemetation:
313  out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
314  out_shape.assert_is_fully_defined()
315  kernel_shape = list(node.attr["ksize"].list.i)
316  kernel_area = _list_product(kernel_shape)
317  return ops.OpStats("flops", kernel_area * out_shape.num_elements())
318
319
320@ops.RegisterStatistics("AvgPool", "flops")
321def _avg_pool_flops(graph, node):
322  """Compute flops for AvgPool operation."""
323  return _pool_flops(graph, node)
324
325
326@ops.RegisterStatistics("MaxPool", "flops")
327def _max_pool_flops(graph, node):
328  """Compute flops for MaxPool operation."""
329  return _pool_flops(graph, node)
330
331
332@ops.RegisterStatistics("AvgPoolGrad", "flops")
333def _avg_pool_grad_flops(graph, node):
334  """Compute flops for AvgPoolGrad operation."""
335  _verify_conv_data_format(node)
336  # Pooling gradient implementation:
337  out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph,
338                                                                  node.input[1])
339  out_backprop_shape.assert_is_fully_defined()
340  kernel_shape = list(node.attr["ksize"].list.i)
341  kernel_area = _list_product(kernel_shape)
342  # TensorFlow multiply each element of pooling window by coefficient,
343  # then sum up all of them, thus we have 2 flops per element:
344  # More optimal implementation - if division is done after.
345  return ops.OpStats("flops",
346                     kernel_area * out_backprop_shape.num_elements() * 2)
347
348
349@ops.RegisterStatistics("MaxPoolGrad", "flops")
350def _max_pool_grad_flops(graph, node):
351  """Compute flops for MaxPoolGrad operation."""
352  _verify_conv_data_format(node)
353  #
354  # MaxPoolGrad declaration:
355  #   Inputs:
356  #     - orig_input  -- original input tensor (of max_pool)
357  #     - orig_output  -- original output tensor (of max_pool)
358  #     - grad --  gradient with respect to output of max_pool
359  #   Outputs:
360  #     - output -- gradient with respect to input of max_pool
361  #   Attributes:
362  #     - ksize
363  #     - strides
364  #     - padding
365  #     - data_format
366  # It computes MaxPool first, then one flop per each element of original output
367  #
368  kernel_shape = list(node.attr["ksize"].list.i)
369  kernel_area = _list_product(kernel_shape)
370  orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
371                                                              node.input[1])
372  orig_out_shape.assert_is_fully_defined()
373  max_pool_ops = kernel_area * orig_out_shape.num_elements()
374  return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
375
376
377@ops.RegisterStatistics("Conv2DBackpropInput", "flops")
378def _conv_2d_backprop_input_flops(graph, node):
379  """Compute flops for Conv2DBackpropInput operation."""
380  # Formula:
381  #  batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
382  #  * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
383  #
384  # Where:
385  # image_x_dim, image_y_dim and input_depth --- size of input to source (no
386  #   backprop) convolution, in other words they are sizes of backprop output.
387  # output_depth --- number of filters in the original convolution, thus
388  #   depth of backprop input.
389  # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension
390  # image_x_stride and image_x_stride --- strides of the convolution
391  #
392  _verify_conv_data_format(node)
393  # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
394  out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
395  out_shape.assert_is_fully_defined()
396  # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
397  kernel_shape = graph_util.tensor_shape_from_node_def_name(graph,
398                                                            node.input[1])
399  kernel_shape.assert_is_fully_defined()
400  # strides
401  strides_shape = list(node.attr["strides"].list.i)
402  strides_product = strides_shape[1] * strides_shape[2]
403  return ops.OpStats("flops",
404                     (2 * out_shape.num_elements()
405                      * kernel_shape.num_elements()
406                      / (out_shape.dims[-1].value * strides_product)))
407
408
409@ops.RegisterStatistics("Conv2DBackpropFilter", "flops")
410def _conv_2d_backprop_filter_flops(graph, node):
411  """Compute flops for Conv2DBackpropFilter operation."""
412  # Formula same as for Conv2DBackpropInput:
413  #  batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
414  #  * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
415  #
416  _verify_conv_data_format(node)
417  # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
418  image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
419  image_shape.assert_is_fully_defined()
420  # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
421  kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
422  kernel_shape.assert_is_fully_defined()
423  # strides
424  strides_shape = list(node.attr["strides"].list.i)
425  strides_product = strides_shape[1] * strides_shape[2]
426  return ops.OpStats("flops",
427                     (2 * image_shape.num_elements()
428                      * kernel_shape.num_elements()
429                      / (image_shape.dims[-1].value * strides_product)))
430
431################################################################################
432# Other ops
433################################################################################
434
435
436@ops.RegisterStatistics("AddN", "flops")
437def _add_n_flops(graph, node):
438  """Compute flops for AddN operation."""
439  if not node.input:
440    return _zero_flops(graph, node)
441  in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
442  in_shape.assert_is_fully_defined()
443  return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1))
444