1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import os 17 18from tensorflow.python.client import session 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import test_util 21from tensorflow.python.ops import variables 22from tensorflow.python.platform import gfile 23from tensorflow.python.platform import test 24from tensorflow.python.profiler import option_builder 25 26# pylint: disable=g-bad-import-order 27from tensorflow.python.profiler import profile_context 28from tensorflow.python.profiler.internal import model_analyzer_testlib as lib 29 30builder = option_builder.ProfileOptionBuilder 31 32 33class ProfilerContextTest(test.TestCase): 34 35 @test_util.run_deprecated_v1 36 def testBasics(self): 37 ops.reset_default_graph() 38 outfile = os.path.join(test.get_temp_dir(), "dump") 39 opts = builder(builder.time_and_memory()).with_file_output(outfile).build() 40 41 x = lib.BuildFullModel() 42 43 profile_str = None 44 profile_step100 = os.path.join(test.get_temp_dir(), "profile_100") 45 with profile_context.ProfileContext(test.get_temp_dir()) as pctx: 46 pctx.add_auto_profiling("op", options=opts, profile_steps=[15, 50, 100]) 47 with session.Session() as sess: 48 self.evaluate(variables.global_variables_initializer()) 49 total_steps = 101 50 for i in range(total_steps): 51 self.evaluate(x) 52 if i == 14 or i == 49: 53 self.assertTrue(gfile.Exists(outfile)) 54 gfile.Remove(outfile) 55 if i == 99: 56 self.assertTrue(gfile.Exists(profile_step100)) 57 with gfile.Open(outfile, "r") as f: 58 profile_str = f.read() 59 gfile.Remove(outfile) 60 61 self.assertEqual(set([15, 50, 100]), set(pctx.get_profiles("op").keys())) 62 63 with lib.ProfilerFromFile(os.path.join(test.get_temp_dir(), 64 "profile_100")) as profiler: 65 profiler.profile_operations(options=opts) 66 with gfile.Open(outfile, "r") as f: 67 self.assertEqual(profile_str, f.read()) 68 69 @test_util.run_deprecated_v1 70 def testAutoTracingInDeubMode(self): 71 ops.reset_default_graph() 72 x = lib.BuildFullModel() 73 74 with profile_context.ProfileContext(test.get_temp_dir(), debug=True): 75 with session.Session() as sess: 76 self.evaluate(variables.global_variables_initializer()) 77 for _ in range(10): 78 self.evaluate(x) 79 for f in gfile.ListDirectory(test.get_temp_dir()): 80 # Warm up, no tracing. 81 self.assertFalse("run_meta" in f) 82 self.evaluate(x) 83 self.assertTrue( 84 gfile.Exists(os.path.join(test.get_temp_dir(), "run_meta_11"))) 85 gfile.Remove(os.path.join(test.get_temp_dir(), "run_meta_11")) 86 # fetched already. 87 self.evaluate(x) 88 for f in gfile.ListDirectory(test.get_temp_dir()): 89 self.assertFalse("run_meta" in f) 90 91 @test_util.run_deprecated_v1 92 def testDisabled(self): 93 ops.reset_default_graph() 94 x = lib.BuildFullModel() 95 with profile_context.ProfileContext( 96 test.get_temp_dir(), enabled=False) as pctx: 97 with session.Session() as sess: 98 self.evaluate(variables.global_variables_initializer()) 99 for _ in range(10): 100 self.evaluate(x) 101 self.assertTrue(pctx.profiler is None) 102 self.assertTrue( 103 getattr(session.BaseSession, "profile_context", None) is None) 104 105 with profile_context.ProfileContext(test.get_temp_dir()) as pctx: 106 with session.Session() as sess: 107 self.evaluate(variables.global_variables_initializer()) 108 for _ in range(10): 109 self.evaluate(x) 110 self.assertFalse(pctx.profiler is None) 111 self.assertFalse( 112 getattr(session.BaseSession, "profile_context", None) is None) 113 114 115if __name__ == "__main__": 116 test.main() 117