• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for pprof_profiler."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gzip
22
23from proto import profile_pb2
24from tensorflow.core.framework import step_stats_pb2
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31from tensorflow.python.profiler import pprof_profiler
32
33
34class PprofProfilerTest(test.TestCase):
35
36  def testDataEmpty(self):
37    output_dir = test.get_temp_dir()
38    run_metadata = config_pb2.RunMetadata()
39    graph = test.mock.MagicMock()
40    graph.get_operations.return_value = []
41
42    profiles = pprof_profiler.get_profiles(graph, run_metadata)
43    self.assertEqual(0, len(profiles))
44    profile_files = pprof_profiler.profile(
45        graph, run_metadata, output_dir)
46    self.assertEqual(0, len(profile_files))
47
48  def testRunMetadataEmpty(self):
49    output_dir = test.get_temp_dir()
50    run_metadata = config_pb2.RunMetadata()
51    graph = test.mock.MagicMock()
52    op1 = test.mock.MagicMock()
53    op1.name = 'Add/123'
54    op1.traceback = [('a/b/file1', 10, 'some_var')]
55    op1.type = 'add'
56    graph.get_operations.return_value = [op1]
57
58    profiles = pprof_profiler.get_profiles(graph, run_metadata)
59    self.assertEqual(0, len(profiles))
60    profile_files = pprof_profiler.profile(
61        graph, run_metadata, output_dir)
62    self.assertEqual(0, len(profile_files))
63
64  def testValidProfile(self):
65    output_dir = test.get_temp_dir()
66    run_metadata = config_pb2.RunMetadata()
67
68    node1 = step_stats_pb2.NodeExecStats(
69        node_name='Add/123',
70        op_start_rel_micros=3,
71        op_end_rel_micros=5,
72        all_end_rel_micros=4)
73
74    run_metadata = config_pb2.RunMetadata()
75    device1 = run_metadata.step_stats.dev_stats.add()
76    device1.device = 'deviceA'
77    device1.node_stats.extend([node1])
78
79    graph = test.mock.MagicMock()
80    op1 = test.mock.MagicMock()
81    op1.name = 'Add/123'
82    op1.traceback = [
83        ('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')]
84    op1.type = 'add'
85    graph.get_operations.return_value = [op1]
86
87    expected_proto = """sample_type {
88  type: 5
89  unit: 5
90}
91sample_type {
92  type: 6
93  unit: 7
94}
95sample_type {
96  type: 8
97  unit: 7
98}
99sample {
100  value: 1
101  value: 4
102  value: 2
103  label {
104    key: 1
105    str: 2
106  }
107  label {
108    key: 3
109    str: 4
110  }
111}
112string_table: ""
113string_table: "node_name"
114string_table: "Add/123"
115string_table: "op_type"
116string_table: "add"
117string_table: "count"
118string_table: "all_time"
119string_table: "nanoseconds"
120string_table: "op_time"
121string_table: "Device 1 of 1: deviceA"
122comment: 9
123"""
124    # Test with protos
125    profiles = pprof_profiler.get_profiles(graph, run_metadata)
126    self.assertEqual(1, len(profiles))
127    self.assertTrue('deviceA' in profiles)
128    self.assertEqual(expected_proto, str(profiles['deviceA']))
129    # Test with files
130    profile_files = pprof_profiler.profile(
131        graph, run_metadata, output_dir)
132    self.assertEqual(1, len(profile_files))
133    with gzip.open(profile_files[0]) as profile_file:
134      profile_contents = profile_file.read()
135      profile = profile_pb2.Profile()
136      profile.ParseFromString(profile_contents)
137      self.assertEqual(expected_proto, str(profile))
138
139  @test_util.run_v1_only('b/120545219')
140  def testProfileWithWhileLoop(self):
141    options = config_pb2.RunOptions()
142    options.trace_level = config_pb2.RunOptions.FULL_TRACE
143    run_metadata = config_pb2.RunMetadata()
144
145    num_iters = 5
146    with self.cached_session() as sess:
147      i = constant_op.constant(0)
148      c = lambda i: math_ops.less(i, num_iters)
149      b = lambda i: math_ops.add(i, 1)
150      r = control_flow_ops.while_loop(c, b, [i])
151      sess.run(r, options=options, run_metadata=run_metadata)
152      profiles = pprof_profiler.get_profiles(sess.graph, run_metadata)
153      self.assertEqual(1, len(profiles))
154      profile = next(iter(profiles.values()))
155      add_samples = []  # Samples for the while/Add node
156      for sample in profile.sample:
157        if profile.string_table[sample.label[0].str] == 'while/Add':
158          add_samples.append(sample)
159      # Values for same nodes are aggregated.
160      self.assertEqual(1, len(add_samples))
161      # Value of "count" should be equal to number of iterations.
162      self.assertEqual(num_iters, add_samples[0].value[0])
163
164
165if __name__ == '__main__':
166  test.main()
167