• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import os
17
18from tensorflow.python.client import session
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import test_util
21from tensorflow.python.ops import variables
22from tensorflow.python.platform import gfile
23from tensorflow.python.platform import test
24from tensorflow.python.profiler import option_builder
25
26# pylint: disable=g-bad-import-order
27from tensorflow.python.profiler import profile_context
28from tensorflow.python.profiler.internal import model_analyzer_testlib as lib
29
30builder = option_builder.ProfileOptionBuilder
31
32
33class ProfilerContextTest(test.TestCase):
34
35  @test_util.run_deprecated_v1
36  def testBasics(self):
37    ops.reset_default_graph()
38    outfile = os.path.join(test.get_temp_dir(), "dump")
39    opts = builder(builder.time_and_memory()).with_file_output(outfile).build()
40
41    x = lib.BuildFullModel()
42
43    profile_str = None
44    profile_step100 = os.path.join(test.get_temp_dir(), "profile_100")
45    with profile_context.ProfileContext(test.get_temp_dir()) as pctx:
46      pctx.add_auto_profiling("op", options=opts, profile_steps=[15, 50, 100])
47      with session.Session() as sess:
48        self.evaluate(variables.global_variables_initializer())
49        total_steps = 101
50        for i in range(total_steps):
51          self.evaluate(x)
52          if i == 14 or i == 49:
53            self.assertTrue(gfile.Exists(outfile))
54            gfile.Remove(outfile)
55          if i == 99:
56            self.assertTrue(gfile.Exists(profile_step100))
57            with gfile.Open(outfile, "r") as f:
58              profile_str = f.read()
59            gfile.Remove(outfile)
60
61      self.assertEqual(set([15, 50, 100]), set(pctx.get_profiles("op").keys()))
62
63    with lib.ProfilerFromFile(os.path.join(test.get_temp_dir(),
64                                           "profile_100")) as profiler:
65      profiler.profile_operations(options=opts)
66      with gfile.Open(outfile, "r") as f:
67        self.assertEqual(profile_str, f.read())
68
69  @test_util.run_deprecated_v1
70  def testAutoTracingInDeubMode(self):
71    ops.reset_default_graph()
72    x = lib.BuildFullModel()
73
74    with profile_context.ProfileContext(test.get_temp_dir(), debug=True):
75      with session.Session() as sess:
76        self.evaluate(variables.global_variables_initializer())
77        for _ in range(10):
78          self.evaluate(x)
79          for f in gfile.ListDirectory(test.get_temp_dir()):
80            # Warm up, no tracing.
81            self.assertFalse("run_meta" in f)
82        self.evaluate(x)
83        self.assertTrue(
84            gfile.Exists(os.path.join(test.get_temp_dir(), "run_meta_11")))
85        gfile.Remove(os.path.join(test.get_temp_dir(), "run_meta_11"))
86        # fetched already.
87        self.evaluate(x)
88        for f in gfile.ListDirectory(test.get_temp_dir()):
89          self.assertFalse("run_meta" in f)
90
91  @test_util.run_deprecated_v1
92  def testDisabled(self):
93    ops.reset_default_graph()
94    x = lib.BuildFullModel()
95    with profile_context.ProfileContext(
96        test.get_temp_dir(), enabled=False) as pctx:
97      with session.Session() as sess:
98        self.evaluate(variables.global_variables_initializer())
99        for _ in range(10):
100          self.evaluate(x)
101      self.assertTrue(pctx.profiler is None)
102      self.assertTrue(
103          getattr(session.BaseSession, "profile_context", None) is None)
104
105    with profile_context.ProfileContext(test.get_temp_dir()) as pctx:
106      with session.Session() as sess:
107        self.evaluate(variables.global_variables_initializer())
108        for _ in range(10):
109          self.evaluate(x)
110      self.assertFalse(pctx.profiler is None)
111      self.assertFalse(
112          getattr(session.BaseSession, "profile_context", None) is None)
113
114
115if __name__ == "__main__":
116  test.main()
117