• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf 2.x profiler."""
16
17import glob
18import os
19import threading
20
21import portpicker
22
23from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
24from tensorflow.python.distribute import multi_process_runner
25from tensorflow.python.eager import context
26from tensorflow.python.framework import test_util
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.profiler import profiler_client
29from tensorflow.python.profiler import profiler_v2 as profiler
30from tensorflow.python.profiler.integration_test import mnist_testing_utils
31
32
33def _model_setup():
34  """Set up a MNIST Keras model for testing purposes.
35
36  Builds a MNIST Keras model and returns model information.
37
38  Returns:
39    A tuple of (batch_size, steps, train_dataset, mode)
40  """
41  context.set_log_device_placement(True)
42  batch_size = 64
43  steps = 2
44  with collective_strategy.CollectiveAllReduceStrategy().scope():
45    # TODO(b/142509827): In rare cases this errors out at C++ level with the
46    # "Connect failed" error message.
47    train_ds, _ = mnist_testing_utils.mnist_synthetic_dataset(batch_size, steps)
48    model = mnist_testing_utils.get_mnist_model((28, 28, 1))
49  return batch_size, steps, train_ds, model
50
51
52def _make_temp_log_dir(test_obj):
53  return test_obj.get_temp_dir()
54
55
56class ProfilerApiTest(test_util.TensorFlowTestCase):
57
58  def setUp(self):
59    super().setUp()
60    self.worker_start = threading.Event()
61    self.profile_done = False
62
63  def _check_tools_pb_exist(self, logdir):
64    expected_files = [
65        'overview_page.pb',
66        'input_pipeline.pb',
67        'tensorflow_stats.pb',
68        'kernel_stats.pb',
69    ]
70    for file in expected_files:
71      path = os.path.join(logdir, 'plugins', 'profile', '*', '*{}'.format(file))
72      self.assertEqual(1, len(glob.glob(path)),
73                       'Expected one path match: ' + path)
74
75  def _check_xspace_pb_exist(self, logdir):
76    path = os.path.join(logdir, 'plugins', 'profile', '*', '*.xplane.pb')
77    self.assertEqual(1, len(glob.glob(path)),
78                     'Expected one path match: ' + path)
79
80  def test_single_worker_no_profiling(self):
81    """Test single worker without profiling."""
82
83    _, steps, train_ds, model = _model_setup()
84
85    model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
86
87  def test_single_worker_sampling_mode(self, delay_ms=None):
88    """Test single worker sampling mode."""
89
90    def on_worker(port, worker_start):
91      logging.info('worker starting server on {}'.format(port))
92      profiler.start_server(port)
93      _, steps, train_ds, model = _model_setup()
94      worker_start.set()
95      while True:
96        model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
97        if self.profile_done:
98          break
99
100    def on_profile(port, logdir, worker_start):
101      # Request for 30 milliseconds of profile.
102      duration_ms = 30
103
104      worker_start.wait()
105      options = profiler.ProfilerOptions(
106          host_tracer_level=2,
107          python_tracer_level=0,
108          device_tracer_level=1,
109          delay_ms=delay_ms,
110      )
111
112      profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
113                            '', 100, options)
114
115      self.profile_done = True
116
117    logdir = self.get_temp_dir()
118    port = portpicker.pick_unused_port()
119    thread_profiler = threading.Thread(
120        target=on_profile, args=(port, logdir, self.worker_start))
121    thread_worker = threading.Thread(
122        target=on_worker, args=(port, self.worker_start))
123    thread_worker.start()
124    thread_profiler.start()
125    thread_profiler.join()
126    thread_worker.join(120)
127    self._check_xspace_pb_exist(logdir)
128
129  def test_single_worker_sampling_mode_short_delay(self):
130    """Test single worker sampling mode with a short delay.
131
132    Expect that requested delayed start time will arrive late, and a subsequent
133    retry will issue an immediate start.
134    """
135
136    self.test_single_worker_sampling_mode(delay_ms=1)
137
138  def test_single_worker_sampling_mode_long_delay(self):
139    """Test single worker sampling mode with a long delay."""
140
141    self.test_single_worker_sampling_mode(delay_ms=1000)
142
143  def test_single_worker_programmatic_mode(self):
144    """Test single worker programmatic mode."""
145    logdir = self.get_temp_dir()
146
147    options = profiler.ProfilerOptions(
148        host_tracer_level=2,
149        python_tracer_level=0,
150        device_tracer_level=1,
151    )
152    profiler.start(logdir, options)
153    _, steps, train_ds, model = _model_setup()
154    model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
155    profiler.stop()
156    self._check_xspace_pb_exist(logdir)
157    self._check_tools_pb_exist(logdir)
158
159
160if __name__ == '__main__':
161  multi_process_runner.test_main()
162