1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Benchmarks for remote worker eager execution. 16 17To run CPU benchmarks: 18 bazel run -c opt remote_benchmarks_test -- --benchmarks=. 19 20To run GPU benchmarks: 21 bazel run --config=cuda -c opt --copt="-mavx" remote_benchmarks_test -- \ 22 --benchmarks=. 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28 29import gc 30import time 31 32from six.moves import xrange # pylint: disable=redefined-builtin 33 34from tensorflow.python.eager import context 35from tensorflow.python.eager import def_function 36from tensorflow.python.eager import remote 37from tensorflow.python.eager import test 38from tensorflow.python.framework import ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import random_ops 41from tensorflow.python.ops import variables 42from tensorflow.python.training import server_lib 43 44 45def run_benchmark(func, num_iters, execution_mode=None): 46 ctx = context.context() 47 with context.execution_mode(execution_mode): 48 # call func to maybe warm up the GPU 49 func() 50 if execution_mode == context.ASYNC: 51 ctx.executor.wait() 52 start = time.time() 53 for _ in xrange(num_iters): 54 func() 55 if execution_mode == context.ASYNC: 56 ctx.executor.wait() 57 end = time.time() 58 59 return end - start 60 61 62class Foo(object): 63 64 def __init__(self, num_vars): 65 self._num_vars = num_vars 66 self._v = [] 67 68 def __call__(self, inputs): 69 if not self._v: 70 for _ in range(self._num_vars): 71 self._v.append(variables.Variable( 72 random_ops.random_uniform([]), shape=[])) 73 for v in self._v: 74 inputs = inputs * v 75 return inputs 76 77 78class RemoteWorkerMicroBenchmarks(test.Benchmark): 79 80 def __init__(self): 81 # used for remote benchmarks 82 self._cached_server1 = server_lib.Server.create_local_server() 83 self._cached_server_target1 = self._cached_server1.target[len("grpc://"):] 84 self._cached_server2 = server_lib.Server.create_local_server() 85 self._cached_server_target2 = self._cached_server2.target[len("grpc://"):] 86 87 def _run(self, func, num_iters=1000, execution_mode=context.ASYNC): 88 total_time = run_benchmark(func, num_iters, execution_mode) 89 mean_us = total_time * 1e6 / num_iters 90 self.report_benchmark( 91 iters=num_iters, 92 wall_time=mean_us, 93 extras={"examples_per_sec": num_iters / total_time}) 94 95 def benchmark_send_mirroring_off(self): 96 remote.connect_to_remote_host(self._cached_server_target1) 97 98 x = random_ops.random_uniform((2, 2)).cpu() 99 100 @def_function.function 101 def remote_func(m): 102 return math_ops.matmul(m, m) 103 104 def func(m): 105 with ops.device("job:worker/replica:0/task:0/device:CPU:0"): 106 return remote_func(m) 107 108 context.context().mirroring_policy = context.MIRRORING_NONE 109 self._run(lambda: func(x)) 110 # NOTE(b/136184459): Force garbage collecting hanging resources before 111 # subsequent calls to set_server_def, to ensure the destroy resource ops are 112 # executed when their corresponding device and manager are still available. 113 gc.collect() 114 115 def benchmark_send_mirroring_on(self): 116 remote.connect_to_remote_host(self._cached_server_target1) 117 118 x = random_ops.random_uniform((2, 2)).cpu() 119 120 @def_function.function 121 def remote_func(m): 122 return math_ops.matmul(m, m) 123 124 def func(m): 125 with ops.device("job:worker/replica:0/task:0/device:CPU:0"): 126 return remote_func(m) 127 128 context.context().mirroring_policy = context.MIRRORING_ALL 129 self._run(lambda: func(x)) 130 # NOTE(b/136184459): Force garbage collecting hanging resources before 131 # subsequent calls to set_server_def, to ensure the destroy resource ops are 132 # executed when their corresponding device and manager are still available. 133 gc.collect() 134 135 def benchmark_worker_mirroring_off(self): 136 remote.connect_to_remote_host( 137 [self._cached_server_target1, self._cached_server_target2]) 138 139 with ops.device("job:worker/replica:0/task:1/device:CPU:0"): 140 v = variables.Variable(1.0) 141 142 @def_function.function 143 def remote_func(): 144 return 1.0 + v 145 146 def func(): 147 with ops.device("job:worker/replica:0/task:0/device:CPU:0"): 148 return remote_func() 149 150 context.context().mirroring_policy = context.MIRRORING_NONE 151 self._run(func) 152 # NOTE(b/136184459): Force garbage collecting hanging resources before 153 # subsequent calls to set_server_def, to ensure the destroy resource ops are 154 # executed when their corresponding device and manager are still available. 155 gc.collect() 156 157 def benchmark_worker_mirroring_on(self): 158 remote.connect_to_remote_host( 159 [self._cached_server_target1, self._cached_server_target2]) 160 161 with ops.device("job:worker/replica:0/task:1/device:CPU:0"): 162 v = variables.Variable(1.0) 163 164 @def_function.function 165 def remote_func(): 166 return 1.0 + v 167 168 def func(): 169 with ops.device("job:worker/replica:0/task:0/device:CPU:0"): 170 return remote_func() 171 172 context.context().mirroring_policy = context.MIRRORING_ALL 173 self._run(func) 174 # NOTE(b/136184459): Force garbage collecting hanging resources before 175 # subsequent calls to set_server_def, to ensure the destroy resource ops are 176 # executed when their corresponding device and manager are still available. 177 gc.collect() 178 179 def benchmark_create_vars_inside_function(self): 180 remote.connect_to_remote_host(self._cached_server_target1) 181 182 def func(): 183 with ops.device("job:worker/replica:0/task:0/device:CPU:0"): 184 layer = Foo(50) 185 186 @def_function.function 187 def remote_func(): 188 with ops.device("job:worker/replica:0/task:0/device:CPU:0"): 189 return layer(random_ops.random_uniform([])) 190 191 return remote_func() 192 193 self._run(func, execution_mode=context.ASYNC, num_iters=100) 194 # NOTE(b/136184459): Force garbage collecting hanging resources before 195 # subsequent calls to set_server_def, to ensure the destroy resource ops are 196 # executed when their corresponding device and manager are still available. 197 gc.collect() 198 199 200if __name__ == "__main__": 201 test.main() 202