1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf 2.x profiler.""" 16 17import glob 18import os 19import threading 20 21import portpicker 22 23from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy 24from tensorflow.python.distribute import multi_process_runner 25from tensorflow.python.eager import context 26from tensorflow.python.framework import test_util 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.profiler import profiler_client 29from tensorflow.python.profiler import profiler_v2 as profiler 30from tensorflow.python.profiler.integration_test import mnist_testing_utils 31 32 33def _model_setup(): 34 """Set up a MNIST Keras model for testing purposes. 35 36 Builds a MNIST Keras model and returns model information. 37 38 Returns: 39 A tuple of (batch_size, steps, train_dataset, mode) 40 """ 41 context.set_log_device_placement(True) 42 batch_size = 64 43 steps = 2 44 with collective_strategy.CollectiveAllReduceStrategy().scope(): 45 # TODO(b/142509827): In rare cases this errors out at C++ level with the 46 # "Connect failed" error message. 47 train_ds, _ = mnist_testing_utils.mnist_synthetic_dataset(batch_size, steps) 48 model = mnist_testing_utils.get_mnist_model((28, 28, 1)) 49 return batch_size, steps, train_ds, model 50 51 52def _make_temp_log_dir(test_obj): 53 return test_obj.get_temp_dir() 54 55 56class ProfilerApiTest(test_util.TensorFlowTestCase): 57 58 def setUp(self): 59 super().setUp() 60 self.worker_start = threading.Event() 61 self.profile_done = False 62 63 def _check_tools_pb_exist(self, logdir): 64 expected_files = [ 65 'overview_page.pb', 66 'input_pipeline.pb', 67 'tensorflow_stats.pb', 68 'kernel_stats.pb', 69 ] 70 for file in expected_files: 71 path = os.path.join(logdir, 'plugins', 'profile', '*', '*{}'.format(file)) 72 self.assertEqual(1, len(glob.glob(path)), 73 'Expected one path match: ' + path) 74 75 def _check_xspace_pb_exist(self, logdir): 76 path = os.path.join(logdir, 'plugins', 'profile', '*', '*.xplane.pb') 77 self.assertEqual(1, len(glob.glob(path)), 78 'Expected one path match: ' + path) 79 80 def test_single_worker_no_profiling(self): 81 """Test single worker without profiling.""" 82 83 _, steps, train_ds, model = _model_setup() 84 85 model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) 86 87 def test_single_worker_sampling_mode(self, delay_ms=None): 88 """Test single worker sampling mode.""" 89 90 def on_worker(port, worker_start): 91 logging.info('worker starting server on {}'.format(port)) 92 profiler.start_server(port) 93 _, steps, train_ds, model = _model_setup() 94 worker_start.set() 95 while True: 96 model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) 97 if self.profile_done: 98 break 99 100 def on_profile(port, logdir, worker_start): 101 # Request for 30 milliseconds of profile. 102 duration_ms = 30 103 104 worker_start.wait() 105 options = profiler.ProfilerOptions( 106 host_tracer_level=2, 107 python_tracer_level=0, 108 device_tracer_level=1, 109 delay_ms=delay_ms, 110 ) 111 112 profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, 113 '', 100, options) 114 115 self.profile_done = True 116 117 logdir = self.get_temp_dir() 118 port = portpicker.pick_unused_port() 119 thread_profiler = threading.Thread( 120 target=on_profile, args=(port, logdir, self.worker_start)) 121 thread_worker = threading.Thread( 122 target=on_worker, args=(port, self.worker_start)) 123 thread_worker.start() 124 thread_profiler.start() 125 thread_profiler.join() 126 thread_worker.join(120) 127 self._check_xspace_pb_exist(logdir) 128 129 def test_single_worker_sampling_mode_short_delay(self): 130 """Test single worker sampling mode with a short delay. 131 132 Expect that requested delayed start time will arrive late, and a subsequent 133 retry will issue an immediate start. 134 """ 135 136 self.test_single_worker_sampling_mode(delay_ms=1) 137 138 def test_single_worker_sampling_mode_long_delay(self): 139 """Test single worker sampling mode with a long delay.""" 140 141 self.test_single_worker_sampling_mode(delay_ms=1000) 142 143 def test_single_worker_programmatic_mode(self): 144 """Test single worker programmatic mode.""" 145 logdir = self.get_temp_dir() 146 147 options = profiler.ProfilerOptions( 148 host_tracer_level=2, 149 python_tracer_level=0, 150 device_tracer_level=1, 151 ) 152 profiler.start(logdir, options) 153 _, steps, train_ds, model = _model_setup() 154 model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) 155 profiler.stop() 156 self._check_xspace_pb_exist(logdir) 157 self._check_tools_pb_exist(logdir) 158 159 160if __name__ == '__main__': 161 multi_process_runner.test_main() 162