1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for pprof_profiler.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gzip 22 23from proto import profile_pb2 24from tensorflow.core.framework import step_stats_pb2 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31from tensorflow.python.profiler import pprof_profiler 32 33 34class PprofProfilerTest(test.TestCase): 35 36 def testDataEmpty(self): 37 output_dir = test.get_temp_dir() 38 run_metadata = config_pb2.RunMetadata() 39 graph = test.mock.MagicMock() 40 graph.get_operations.return_value = [] 41 42 profiles = pprof_profiler.get_profiles(graph, run_metadata) 43 self.assertEqual(0, len(profiles)) 44 profile_files = pprof_profiler.profile( 45 graph, run_metadata, output_dir) 46 self.assertEqual(0, len(profile_files)) 47 48 def testRunMetadataEmpty(self): 49 output_dir = test.get_temp_dir() 50 run_metadata = config_pb2.RunMetadata() 51 graph = test.mock.MagicMock() 52 op1 = test.mock.MagicMock() 53 op1.name = 'Add/123' 54 op1.traceback = [('a/b/file1', 10, 'some_var')] 55 op1.type = 'add' 56 graph.get_operations.return_value = [op1] 57 58 profiles = pprof_profiler.get_profiles(graph, run_metadata) 59 self.assertEqual(0, len(profiles)) 60 profile_files = pprof_profiler.profile( 61 graph, run_metadata, output_dir) 62 self.assertEqual(0, len(profile_files)) 63 64 def testValidProfile(self): 65 output_dir = test.get_temp_dir() 66 run_metadata = config_pb2.RunMetadata() 67 68 node1 = step_stats_pb2.NodeExecStats( 69 node_name='Add/123', 70 op_start_rel_micros=3, 71 op_end_rel_micros=5, 72 all_end_rel_micros=4) 73 74 run_metadata = config_pb2.RunMetadata() 75 device1 = run_metadata.step_stats.dev_stats.add() 76 device1.device = 'deviceA' 77 device1.node_stats.extend([node1]) 78 79 graph = test.mock.MagicMock() 80 op1 = test.mock.MagicMock() 81 op1.name = 'Add/123' 82 op1.traceback = [ 83 ('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')] 84 op1.type = 'add' 85 graph.get_operations.return_value = [op1] 86 87 expected_proto = """sample_type { 88 type: 5 89 unit: 5 90} 91sample_type { 92 type: 6 93 unit: 7 94} 95sample_type { 96 type: 8 97 unit: 7 98} 99sample { 100 value: 1 101 value: 4 102 value: 2 103 label { 104 key: 1 105 str: 2 106 } 107 label { 108 key: 3 109 str: 4 110 } 111} 112string_table: "" 113string_table: "node_name" 114string_table: "Add/123" 115string_table: "op_type" 116string_table: "add" 117string_table: "count" 118string_table: "all_time" 119string_table: "nanoseconds" 120string_table: "op_time" 121string_table: "Device 1 of 1: deviceA" 122comment: 9 123""" 124 # Test with protos 125 profiles = pprof_profiler.get_profiles(graph, run_metadata) 126 self.assertEqual(1, len(profiles)) 127 self.assertTrue('deviceA' in profiles) 128 self.assertEqual(expected_proto, str(profiles['deviceA'])) 129 # Test with files 130 profile_files = pprof_profiler.profile( 131 graph, run_metadata, output_dir) 132 self.assertEqual(1, len(profile_files)) 133 with gzip.open(profile_files[0]) as profile_file: 134 profile_contents = profile_file.read() 135 profile = profile_pb2.Profile() 136 profile.ParseFromString(profile_contents) 137 self.assertEqual(expected_proto, str(profile)) 138 139 @test_util.run_v1_only('b/120545219') 140 def testProfileWithWhileLoop(self): 141 options = config_pb2.RunOptions() 142 options.trace_level = config_pb2.RunOptions.FULL_TRACE 143 run_metadata = config_pb2.RunMetadata() 144 145 num_iters = 5 146 with self.cached_session() as sess: 147 i = constant_op.constant(0) 148 c = lambda i: math_ops.less(i, num_iters) 149 b = lambda i: math_ops.add(i, 1) 150 r = control_flow_ops.while_loop(c, b, [i]) 151 sess.run(r, options=options, run_metadata=run_metadata) 152 profiles = pprof_profiler.get_profiles(sess.graph, run_metadata) 153 self.assertEqual(1, len(profiles)) 154 profile = next(iter(profiles.values())) 155 add_samples = [] # Samples for the while/Add node 156 for sample in profile.sample: 157 if profile.string_table[sample.label[0].str] == 'while/Add': 158 add_samples.append(sample) 159 # Values for same nodes are aggregated. 160 self.assertEqual(1, len(add_samples)) 161 # Value of "count" should be equal to number of iterations. 162 self.assertEqual(num_iters, add_samples[0].value[0]) 163 164 165if __name__ == '__main__': 166 test.main() 167