• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Benchmark for control flow ops."""
16
17import time
18
19from tensorflow.python.client import session
20from tensorflow.python.eager import context
21from tensorflow.python.eager import function
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import control_flow_util
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import random_ops
29from tensorflow.python.platform import test
30
31
32class CondWithManyIntermediatesBenchmark(test.Benchmark):
33  """Checks the runtime performance of outputting all intermediates."""
34
35  NUM_INTERMEDIATES = 1000
36  NUM_ITERS = 500
37  NUM_WARM_UP_ITERS = 50
38
39  def _create_cond(self, x):
40
41    def branch_fn():
42      # Use a random value so the adds can't be constant folded.
43      return x + sum(random_ops.random_normal([])
44                     for _ in range(self.NUM_INTERMEDIATES))
45
46    # Use a dynamic predicate to make sure the cond isn't constant folded.
47    return control_flow_ops.cond(math_ops.not_equal(x, -1),
48                                 branch_fn, lambda: 0.0)
49
50  def _benchmark_defun(self):
51    """Benchmarks cond in a defun."""
52
53    @function.defun
54    def cond_fn(x):
55      return self._create_cond(x)
56
57    # Warm up
58    for _ in range(self.NUM_WARM_UP_ITERS):
59      cond_fn(0.0)
60
61    start_time = time.time()
62
63    for _ in range(self.NUM_ITERS):
64      cond_fn(0.0)
65
66    self.report_benchmark(
67        wall_time=time.time() - start_time,
68        iters=self.NUM_ITERS)
69
70  def _benchmark_graph(self):
71    """Benchmarks cond in legacy graph mode."""
72    with context.graph_mode():
73      with ops.Graph().as_default():
74        x = array_ops.placeholder(dtypes.float32)
75        cond_val = self._create_cond(x)
76
77        with session.Session() as sess:
78          cond_fn = sess.make_callable(cond_val, [x])
79
80          # Warm up
81          for _ in range(self.NUM_WARM_UP_ITERS):
82            cond_fn(0.0)
83
84          start_time = time.time()
85
86          for _ in range(self.NUM_ITERS):
87            cond_fn(0.0)
88
89          self.report_benchmark(
90              wall_time=time.time() - start_time,
91              iters=self.NUM_ITERS)
92
93  def benchmark_cond_v1_defun(self):
94    old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
95    control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
96    self._benchmark_defun()
97    control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
98
99  def benchmark_cond_v2_defun(self):
100    old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
101    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
102    self._benchmark_defun()
103    control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
104
105  def benchmark_cond_v1_graph(self):
106    old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
107    control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
108    self._benchmark_graph()
109    control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
110
111  def benchmark_cond_v2_graph(self):
112    old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
113    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
114    self._benchmark_graph()
115    control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
116
117if __name__ == "__main__":
118  ops.enable_eager_execution()
119  test.main()
120