1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Benchmark for control flow ops.""" 16 17import time 18 19from tensorflow.python.client import session 20from tensorflow.python.eager import context 21from tensorflow.python.eager import function 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import control_flow_util 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import random_ops 29from tensorflow.python.platform import test 30 31 32class CondWithManyIntermediatesBenchmark(test.Benchmark): 33 """Checks the runtime performance of outputting all intermediates.""" 34 35 NUM_INTERMEDIATES = 1000 36 NUM_ITERS = 500 37 NUM_WARM_UP_ITERS = 50 38 39 def _create_cond(self, x): 40 41 def branch_fn(): 42 # Use a random value so the adds can't be constant folded. 43 return x + sum(random_ops.random_normal([]) 44 for _ in range(self.NUM_INTERMEDIATES)) 45 46 # Use a dynamic predicate to make sure the cond isn't constant folded. 47 return control_flow_ops.cond(math_ops.not_equal(x, -1), 48 branch_fn, lambda: 0.0) 49 50 def _benchmark_defun(self): 51 """Benchmarks cond in a defun.""" 52 53 @function.defun 54 def cond_fn(x): 55 return self._create_cond(x) 56 57 # Warm up 58 for _ in range(self.NUM_WARM_UP_ITERS): 59 cond_fn(0.0) 60 61 start_time = time.time() 62 63 for _ in range(self.NUM_ITERS): 64 cond_fn(0.0) 65 66 self.report_benchmark( 67 wall_time=time.time() - start_time, 68 iters=self.NUM_ITERS) 69 70 def _benchmark_graph(self): 71 """Benchmarks cond in legacy graph mode.""" 72 with context.graph_mode(): 73 with ops.Graph().as_default(): 74 x = array_ops.placeholder(dtypes.float32) 75 cond_val = self._create_cond(x) 76 77 with session.Session() as sess: 78 cond_fn = sess.make_callable(cond_val, [x]) 79 80 # Warm up 81 for _ in range(self.NUM_WARM_UP_ITERS): 82 cond_fn(0.0) 83 84 start_time = time.time() 85 86 for _ in range(self.NUM_ITERS): 87 cond_fn(0.0) 88 89 self.report_benchmark( 90 wall_time=time.time() - start_time, 91 iters=self.NUM_ITERS) 92 93 def benchmark_cond_v1_defun(self): 94 old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2 95 control_flow_util.ENABLE_CONTROL_FLOW_V2 = False 96 self._benchmark_defun() 97 control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val 98 99 def benchmark_cond_v2_defun(self): 100 old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2 101 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True 102 self._benchmark_defun() 103 control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val 104 105 def benchmark_cond_v1_graph(self): 106 old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2 107 control_flow_util.ENABLE_CONTROL_FLOW_V2 = False 108 self._benchmark_graph() 109 control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val 110 111 def benchmark_cond_v2_graph(self): 112 old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2 113 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True 114 self._benchmark_graph() 115 control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val 116 117if __name__ == "__main__": 118 ops.enable_eager_execution() 119 test.main() 120