• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf 2.x profiler."""
16
17import os
18import socket
19
20from tensorflow.python.eager import test
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import errors
23from tensorflow.python.framework import test_util
24from tensorflow.python.platform import gfile
25from tensorflow.python.profiler import profiler_v2 as profiler
26from tensorflow.python.profiler import trace
27
28
29class ProfilerTest(test_util.TensorFlowTestCase):
30
31  def test_profile_exceptions(self):
32    logdir = self.get_temp_dir()
33    profiler.start(logdir)
34    with self.assertRaises(errors.AlreadyExistsError):
35      profiler.start(logdir)
36
37    profiler.stop()
38    with self.assertRaises(errors.UnavailableError):
39      profiler.stop()
40
41    # Test with a bad logdir, and it correctly raises exception and deletes
42    # profiler.
43    # pylint: disable=anomalous-backslash-in-string
44    profiler.start('/dev/null/\/\/:123')
45    # pylint: enable=anomalous-backslash-in-string
46    with self.assertRaises(Exception):
47      profiler.stop()
48    profiler.start(logdir)
49    profiler.stop()
50
51  def test_save_profile(self):
52    logdir = self.get_temp_dir()
53    profiler.start(logdir)
54    with trace.Trace('three_times_five'):
55      three = constant_op.constant(3)
56      five = constant_op.constant(5)
57      product = three * five
58    self.assertAllEqual(15, product)
59
60    profiler.stop()
61    file_list = gfile.ListDirectory(logdir)
62    self.assertEqual(len(file_list), 2)
63    for file_name in gfile.ListDirectory(logdir):
64      if gfile.IsDirectory(os.path.join(logdir, file_name)):
65        self.assertEqual(file_name, 'plugins')
66      else:
67        self.assertTrue(file_name.endswith('.profile-empty'))
68    profile_dir = os.path.join(logdir, 'plugins', 'profile')
69    run = gfile.ListDirectory(profile_dir)[0]
70    hostname = socket.gethostname()
71    overview_page = os.path.join(profile_dir, run,
72                                 hostname + '.overview_page.pb')
73    self.assertTrue(gfile.Exists(overview_page))
74    input_pipeline = os.path.join(profile_dir, run,
75                                  hostname + '.input_pipeline.pb')
76    self.assertTrue(gfile.Exists(input_pipeline))
77    tensorflow_stats = os.path.join(profile_dir, run,
78                                    hostname + '.tensorflow_stats.pb')
79    self.assertTrue(gfile.Exists(tensorflow_stats))
80    kernel_stats = os.path.join(profile_dir, run, hostname + '.kernel_stats.pb')
81    self.assertTrue(gfile.Exists(kernel_stats))
82    trace_file = os.path.join(profile_dir, run, hostname + '.trace.json.gz')
83    self.assertTrue(gfile.Exists(trace_file))
84
85  def test_profile_with_options(self):
86    logdir = self.get_temp_dir()
87    options = profiler.ProfilerOptions(
88        host_tracer_level=3, python_tracer_level=1)
89    profiler.start(logdir, options)
90    with trace.Trace('three_times_five'):
91      three = constant_op.constant(3)
92      five = constant_op.constant(5)
93      product = three * five
94    self.assertAllEqual(15, product)
95
96    profiler.stop()
97    file_list = gfile.ListDirectory(logdir)
98    self.assertEqual(len(file_list), 2)
99
100  def test_context_manager_with_options(self):
101    logdir = self.get_temp_dir()
102    options = profiler.ProfilerOptions(
103        host_tracer_level=3, python_tracer_level=1)
104    with profiler.Profile(logdir, options):
105      with trace.Trace('three_times_five'):
106        three = constant_op.constant(3)
107        five = constant_op.constant(5)
108        product = three * five
109      self.assertAllEqual(15, product)
110
111    file_list = gfile.ListDirectory(logdir)
112    self.assertEqual(len(file_list), 2)
113
114
115if __name__ == '__main__':
116  test.main()
117