1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf 2.x profiler.""" 16 17import os 18import socket 19 20from tensorflow.python.eager import test 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import errors 23from tensorflow.python.framework import test_util 24from tensorflow.python.platform import gfile 25from tensorflow.python.profiler import profiler_v2 as profiler 26from tensorflow.python.profiler import trace 27 28 29class ProfilerTest(test_util.TensorFlowTestCase): 30 31 def test_profile_exceptions(self): 32 logdir = self.get_temp_dir() 33 profiler.start(logdir) 34 with self.assertRaises(errors.AlreadyExistsError): 35 profiler.start(logdir) 36 37 profiler.stop() 38 with self.assertRaises(errors.UnavailableError): 39 profiler.stop() 40 41 # Test with a bad logdir, and it correctly raises exception and deletes 42 # profiler. 43 # pylint: disable=anomalous-backslash-in-string 44 profiler.start('/dev/null/\/\/:123') 45 # pylint: enable=anomalous-backslash-in-string 46 with self.assertRaises(Exception): 47 profiler.stop() 48 profiler.start(logdir) 49 profiler.stop() 50 51 def test_save_profile(self): 52 logdir = self.get_temp_dir() 53 profiler.start(logdir) 54 with trace.Trace('three_times_five'): 55 three = constant_op.constant(3) 56 five = constant_op.constant(5) 57 product = three * five 58 self.assertAllEqual(15, product) 59 60 profiler.stop() 61 file_list = gfile.ListDirectory(logdir) 62 self.assertEqual(len(file_list), 2) 63 for file_name in gfile.ListDirectory(logdir): 64 if gfile.IsDirectory(os.path.join(logdir, file_name)): 65 self.assertEqual(file_name, 'plugins') 66 else: 67 self.assertTrue(file_name.endswith('.profile-empty')) 68 profile_dir = os.path.join(logdir, 'plugins', 'profile') 69 run = gfile.ListDirectory(profile_dir)[0] 70 hostname = socket.gethostname() 71 overview_page = os.path.join(profile_dir, run, 72 hostname + '.overview_page.pb') 73 self.assertTrue(gfile.Exists(overview_page)) 74 input_pipeline = os.path.join(profile_dir, run, 75 hostname + '.input_pipeline.pb') 76 self.assertTrue(gfile.Exists(input_pipeline)) 77 tensorflow_stats = os.path.join(profile_dir, run, 78 hostname + '.tensorflow_stats.pb') 79 self.assertTrue(gfile.Exists(tensorflow_stats)) 80 kernel_stats = os.path.join(profile_dir, run, hostname + '.kernel_stats.pb') 81 self.assertTrue(gfile.Exists(kernel_stats)) 82 trace_file = os.path.join(profile_dir, run, hostname + '.trace.json.gz') 83 self.assertTrue(gfile.Exists(trace_file)) 84 85 def test_profile_with_options(self): 86 logdir = self.get_temp_dir() 87 options = profiler.ProfilerOptions( 88 host_tracer_level=3, python_tracer_level=1) 89 profiler.start(logdir, options) 90 with trace.Trace('three_times_five'): 91 three = constant_op.constant(3) 92 five = constant_op.constant(5) 93 product = three * five 94 self.assertAllEqual(15, product) 95 96 profiler.stop() 97 file_list = gfile.ListDirectory(logdir) 98 self.assertEqual(len(file_list), 2) 99 100 def test_context_manager_with_options(self): 101 logdir = self.get_temp_dir() 102 options = profiler.ProfilerOptions( 103 host_tracer_level=3, python_tracer_level=1) 104 with profiler.Profile(logdir, options): 105 with trace.Trace('three_times_five'): 106 three = constant_op.constant(3) 107 five = constant_op.constant(5) 108 product = three * five 109 self.assertAllEqual(15, product) 110 111 file_list = gfile.ListDirectory(logdir) 112 self.assertEqual(len(file_list), 2) 113 114 115if __name__ == '__main__': 116 test.main() 117