• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Benchmarks for remote worker eager execution.
16
17To run CPU benchmarks:
18  bazel run -c opt remote_benchmarks_test -- --benchmarks=.
19
20To run GPU benchmarks:
21  bazel run --config=cuda -c opt --copt="-mavx" remote_benchmarks_test -- \
22    --benchmarks=.
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28
29import gc
30import time
31
32from six.moves import xrange  # pylint: disable=redefined-builtin
33
34from tensorflow.python.eager import context
35from tensorflow.python.eager import def_function
36from tensorflow.python.eager import remote
37from tensorflow.python.eager import test
38from tensorflow.python.framework import ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import random_ops
41from tensorflow.python.ops import variables
42from tensorflow.python.training import server_lib
43
44
45def run_benchmark(func, num_iters, execution_mode=None):
46  ctx = context.context()
47  with context.execution_mode(execution_mode):
48    # call func to maybe warm up the GPU
49    func()
50    if execution_mode == context.ASYNC:
51      ctx.executor.wait()
52    start = time.time()
53    for _ in xrange(num_iters):
54      func()
55    if execution_mode == context.ASYNC:
56      ctx.executor.wait()
57    end = time.time()
58
59    return end - start
60
61
62class Foo(object):
63
64  def __init__(self, num_vars):
65    self._num_vars = num_vars
66    self._v = []
67
68  def __call__(self, inputs):
69    if not self._v:
70      for _ in range(self._num_vars):
71        self._v.append(variables.Variable(
72            random_ops.random_uniform([]), shape=[]))
73    for v in self._v:
74      inputs = inputs * v
75    return inputs
76
77
78class RemoteWorkerMicroBenchmarks(test.Benchmark):
79
80  def __init__(self):
81    # used for remote benchmarks
82    self._cached_server1 = server_lib.Server.create_local_server()
83    self._cached_server_target1 = self._cached_server1.target[len("grpc://"):]
84    self._cached_server2 = server_lib.Server.create_local_server()
85    self._cached_server_target2 = self._cached_server2.target[len("grpc://"):]
86
87  def _run(self, func, num_iters=1000, execution_mode=context.ASYNC):
88    total_time = run_benchmark(func, num_iters, execution_mode)
89    mean_us = total_time * 1e6 / num_iters
90    self.report_benchmark(
91        iters=num_iters,
92        wall_time=mean_us,
93        extras={"examples_per_sec": num_iters / total_time})
94
95  def benchmark_send_mirroring_off(self):
96    remote.connect_to_remote_host(self._cached_server_target1)
97
98    x = random_ops.random_uniform((2, 2)).cpu()
99
100    @def_function.function
101    def remote_func(m):
102      return math_ops.matmul(m, m)
103
104    def func(m):
105      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
106        return remote_func(m)
107
108    context.context().mirroring_policy = context.MIRRORING_NONE
109    self._run(lambda: func(x))
110    # NOTE(b/136184459): Force garbage collecting hanging resources before
111    # subsequent calls to set_server_def, to ensure the destroy resource ops are
112    # executed when their corresponding device and manager are still available.
113    gc.collect()
114
115  def benchmark_send_mirroring_on(self):
116    remote.connect_to_remote_host(self._cached_server_target1)
117
118    x = random_ops.random_uniform((2, 2)).cpu()
119
120    @def_function.function
121    def remote_func(m):
122      return math_ops.matmul(m, m)
123
124    def func(m):
125      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
126        return remote_func(m)
127
128    context.context().mirroring_policy = context.MIRRORING_ALL
129    self._run(lambda: func(x))
130    # NOTE(b/136184459): Force garbage collecting hanging resources before
131    # subsequent calls to set_server_def, to ensure the destroy resource ops are
132    # executed when their corresponding device and manager are still available.
133    gc.collect()
134
135  def benchmark_worker_mirroring_off(self):
136    remote.connect_to_remote_host(
137        [self._cached_server_target1, self._cached_server_target2])
138
139    with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
140      v = variables.Variable(1.0)
141
142    @def_function.function
143    def remote_func():
144      return 1.0 + v
145
146    def func():
147      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
148        return remote_func()
149
150    context.context().mirroring_policy = context.MIRRORING_NONE
151    self._run(func)
152    # NOTE(b/136184459): Force garbage collecting hanging resources before
153    # subsequent calls to set_server_def, to ensure the destroy resource ops are
154    # executed when their corresponding device and manager are still available.
155    gc.collect()
156
157  def benchmark_worker_mirroring_on(self):
158    remote.connect_to_remote_host(
159        [self._cached_server_target1, self._cached_server_target2])
160
161    with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
162      v = variables.Variable(1.0)
163
164    @def_function.function
165    def remote_func():
166      return 1.0 + v
167
168    def func():
169      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
170        return remote_func()
171
172    context.context().mirroring_policy = context.MIRRORING_ALL
173    self._run(func)
174    # NOTE(b/136184459): Force garbage collecting hanging resources before
175    # subsequent calls to set_server_def, to ensure the destroy resource ops are
176    # executed when their corresponding device and manager are still available.
177    gc.collect()
178
179  def benchmark_create_vars_inside_function(self):
180    remote.connect_to_remote_host(self._cached_server_target1)
181
182    def func():
183      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
184        layer = Foo(50)
185
186        @def_function.function
187        def remote_func():
188          with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
189            return layer(random_ops.random_uniform([]))
190
191        return remote_func()
192
193    self._run(func, execution_mode=context.ASYNC, num_iters=100)
194    # NOTE(b/136184459): Force garbage collecting hanging resources before
195    # subsequent calls to set_server_def, to ensure the destroy resource ops are
196    # executed when their corresponding device and manager are still available.
197    gc.collect()
198
199
200if __name__ == "__main__":
201  test.main()
202