1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Profiler for TensorFlow models that outputs data in pprof format. 16 17See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof 18profile format. 19The following needs to be set for profiler to work: 20 * trace_level needs to be set to FULL_TRACE 21 * run_metadata object should be passed in to session.run call 22 23Sample usage: 24 options = tf.compat.v1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 25 run_metadata = tf.compat.v1.RunMetadata() 26 27 with tf.compat.v1.Session as sess: 28 ... 29 sess.run(computation, run_metadata=run_metadata, options=options) 30 pprof_profiler.profile(sess.graph, run_metadata, output_dir) 31 32 33 The code above would output a pprof profile to separate output_dir/.*.pb.gz 34 file for each device. These files can be passed to pprof for formatting. 35 For e.g.: 36 pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz 37""" 38from collections import defaultdict 39from collections import namedtuple 40import gzip 41import os 42import string 43import sys 44import time 45 46from proto import profile_pb2 47 48 49if sys.version_info < (3,): 50 maketrans = string.maketrans 51else: 52 maketrans = str.maketrans 53 54 55ProfileDatum = namedtuple('ProfileDatum', [ 56 'node_exec_stats', 'op_type', 'traceback']) 57 58 59class StringTable(object): 60 """Keeps track of strings to add to string_table in pprof proto.""" 61 62 def __init__(self): 63 # Pprof requires first entry in string_table to be ''. 64 self._string_table = [''] 65 self._string_to_index = {'': 0} 66 67 def index_of(self, value_str): 68 """Get index of value_str in the string table. 69 70 If value_str is not in the string table, we will add it at the end 71 and then return the new index. 72 Args: 73 value_str: (string) Value to lookup/add in/to the string table. 74 75 Returns: 76 Index of value_str in the string table. 77 """ 78 if value_str is None: 79 value_str = '' 80 if value_str in self._string_to_index: 81 return self._string_to_index[value_str] 82 index = len(self._string_table) 83 self._string_table.append(value_str) 84 self._string_to_index[value_str] = index 85 return index 86 87 def next_index(self): 88 """Gets index that would be assigned to the next added string. 89 90 Returns: 91 Index of the next string if it was added. 92 """ 93 return len(self._string_table) 94 95 def string_table(self): 96 """Returns a list of strings to store in pprof's string_table.""" 97 return self._string_table 98 99 100class Functions(object): 101 """Keeps track of `Function` protos for pprof profile.""" 102 103 def __init__(self, string_table): 104 """Constructor. 105 106 Args: 107 string_table: A `StringTable` object. 108 """ 109 self._string_table = string_table 110 # Maps tuples in the form (file_path, function_name, start_line_number) 111 # to `Function` protos. 112 self._function_key_to_function = {} 113 114 def index_of(self, file_path, function_name, function_start_line): 115 """Returns index of the function, adding the function if needed. 116 117 Args: 118 file_path: (string) Path to file where the function is defined. 119 function_name: (string) Function name. 120 function_start_line: (integer) Start line number of function definition. 121 122 Returns: 123 Function index. 124 """ 125 function_key = (file_path, function_name, function_start_line) 126 if function_key in self._function_key_to_function: 127 return self._function_key_to_function[function_key].id 128 else: 129 # Function indexes should start from 1 130 function_index = len(self._function_key_to_function) + 1 131 function = profile_pb2.Function() 132 function.id = function_index 133 function.name = self._string_table.index_of(function_name) 134 function.filename = self._string_table.index_of(file_path) 135 function.start_line = function_start_line 136 self._function_key_to_function[function_key] = function 137 return function_index 138 139 def function_protos(self): 140 """Returns list of `profile_pb2.Function` protos.""" 141 return self._function_key_to_function.values() 142 143 144class Locations(object): 145 """Keeps track of `Location` protos for pprof profile. 146 147 `Locations` store information about function call locations. 148 """ 149 150 def __init__(self, functions): 151 """Constructor. 152 153 Args: 154 functions: A `Functions` object. 155 """ 156 self._functions = functions 157 # Maps tuples in the form (file_path, called_function_name, line_number) 158 # to `Location` protos. 159 self._location_key_to_location = {} 160 161 def index_of( 162 self, file_path, line_number, called_function_name, called_file_path, 163 called_function_start_line): 164 """Returns index of the location, adding the location if needed. 165 166 Args: 167 file_path: (string) Path to file that makes the call. 168 line_number: (integer) Call line number. 169 called_function_name: (string) Function name of the function called at 170 `file_path` and `line_number`. 171 called_file_path: (string) Path to file where the called function is 172 defined. 173 called_function_start_line: (integer) Start line number of called 174 function definition in `called_file_path` file. 175 176 Returns: 177 Index of location. 178 """ 179 location_key = (file_path, called_function_name, line_number) 180 if location_key in self._location_key_to_location: 181 location = self._location_key_to_location[location_key] 182 return location.id 183 else: 184 # Location indexes should start from 1 185 location_index = len(self._location_key_to_location) + 1 186 location = profile_pb2.Location() 187 location.id = location_index 188 self._location_key_to_location[location_key] = location 189 190 line = location.line.add() 191 line.function_id = self._functions.index_of( 192 called_file_path, called_function_name, called_function_start_line) 193 line.line = line_number 194 return location_index 195 196 def location_protos(self): 197 """Returns list of `profile_pb2.Location` protos.""" 198 return self._location_key_to_location.values() 199 200 201class Samples(object): 202 """Keeps track of `Sample` protos for pprof profile. 203 204 Samples store the following statistics in order: 205 count, all_time, op_time 206 """ 207 208 def __init__(self, string_table): 209 """Constructor. 210 211 Args: 212 string_table: A `StringTable` object. 213 """ 214 self._string_table = string_table 215 # TODO(annarev): figure out if location is unique for each node name. 216 # If not, also key this dictionary based on location ids. 217 self._node_name_to_sample = {} 218 219 def add(self, datum, location_ids): 220 """Adds a sample data point. 221 222 Args: 223 datum: `ProfileDatum` to add a sample for. 224 location_ids: List of numberic location ids for this 225 sample. 226 """ 227 node_name = datum.node_exec_stats.node_name 228 if node_name in self._node_name_to_sample: 229 sample = self._node_name_to_sample[node_name] 230 sample.location_id.extend(location_ids) 231 else: 232 sample = profile_pb2.Sample() 233 # Sample stores 3 values: count, all_time, op_time 234 sample.value.extend([0, 0, 0]) 235 236 label = sample.label.add() 237 label.key = self._string_table.index_of('node_name') 238 label.str = self._string_table.index_of(node_name) 239 label = sample.label.add() 240 label.key = self._string_table.index_of('op_type') 241 label.str = self._string_table.index_of(datum.op_type) 242 self._node_name_to_sample[node_name] = sample 243 sample.value[0] += 1 244 sample.value[1] += datum.node_exec_stats.all_end_rel_micros 245 sample.value[2] += ( 246 datum.node_exec_stats.op_end_rel_micros - 247 datum.node_exec_stats.op_start_rel_micros) 248 249 def get_sample_protos(self): 250 """Returns list of `Sample` protos for pprof profile.""" 251 return self._node_name_to_sample.values() 252 253 254class PprofProfiler(object): 255 """Creates profiles in pprof format.""" 256 257 def __init__(self, graph, run_metadata): 258 """Constructor. 259 260 Args: 261 graph: A `Graph` instance. 262 run_metadata: A list of `RunMetadata` objects. 263 """ 264 self._graph = graph 265 self._run_metadata = run_metadata 266 self._string_table = StringTable() 267 self._functions = Functions(self._string_table) 268 self._locations = Locations(self._functions) 269 270 def profile(self): 271 """Generates pprof profiles. 272 273 Returns: 274 Dictionary mapping from device name to proto in `profile_pb2.Profile` 275 format. 276 """ 277 profiles = {} 278 data_generator_func = self._get_profile_data_generator() 279 for device_index, device_stats in enumerate( 280 self._run_metadata.step_stats.dev_stats): 281 # Create profile 282 pprof_proto = self._get_pprof_proto(data_generator_func(device_stats)) 283 if not pprof_proto.sample: 284 print( 285 'Not enough data to create profile for device %s. Did you pass ' 286 'RunMetadata to session.run call?' % device_stats.device) 287 continue 288 # Add device name comment 289 device_count = len(self._run_metadata.step_stats.dev_stats) 290 device_description = ( 291 'Device %d of %d: %s' % 292 (device_index + 1, device_count, device_stats.device)) 293 device_description_str_index = self._string_table.next_index() 294 pprof_proto.string_table.append(device_description) 295 pprof_proto.comment.append(device_description_str_index) 296 profiles[device_stats.device] = pprof_proto 297 return profiles 298 299 def _get_pprof_proto(self, profile_datum_generator): 300 """Returns profile data in pprof proto format. 301 302 Args: 303 profile_datum_generator: Generator outputting `ProfileDatum` objects. 304 305 Returns: 306 A proto in pprof format. 307 """ 308 pprof_profile = profile_pb2.Profile() 309 samples = Samples(self._string_table) 310 311 for datum in profile_datum_generator: 312 if not datum.traceback: 313 continue 314 315 stack_frame = datum.traceback[-1] 316 after_apply_op = False 317 location_ids = [] 318 319 # We add locations from stack trace in bottom-up order. 320 for stack_frame_index in reversed(range(len(datum.traceback) - 1)): 321 prev_stack_frame = stack_frame 322 stack_frame = datum.traceback[stack_frame_index] 323 324 # Call at current frame calls function at previous frame. 325 prev_file_path = prev_stack_frame[0] 326 prev_function = prev_stack_frame[2] 327 prev_function_start_line = -1 328 curr_file_path = stack_frame[0] 329 curr_line_number = stack_frame[1] 330 331 # Skip all calls up to apply_op since they are the same for all ops. 332 if not after_apply_op: 333 if prev_function == 'apply_op': 334 after_apply_op = True 335 continue 336 location_index = self._locations.index_of( 337 curr_file_path, curr_line_number, 338 prev_function, prev_file_path, prev_function_start_line) 339 location_ids.append(location_index) 340 samples.add(datum, location_ids) 341 342 sample_type_description = 'count' 343 sample_type = pprof_profile.sample_type.add() 344 sample_type.type = self._string_table.index_of(sample_type_description) 345 sample_type.unit = self._string_table.index_of('count') 346 sample_type_description = 'all_time' 347 sample_type = pprof_profile.sample_type.add() 348 sample_type.type = self._string_table.index_of(sample_type_description) 349 sample_type.unit = self._string_table.index_of('nanoseconds') 350 sample_type_description = 'op_time' 351 sample_type = pprof_profile.sample_type.add() 352 sample_type.type = self._string_table.index_of(sample_type_description) 353 sample_type.unit = self._string_table.index_of('nanoseconds') 354 355 pprof_profile.string_table.extend(self._string_table.string_table()) 356 pprof_profile.sample.extend(samples.get_sample_protos()) 357 pprof_profile.function.extend(self._functions.function_protos()) 358 pprof_profile.location.extend(self._locations.location_protos()) 359 return pprof_profile 360 361 def _get_profile_data_generator(self): 362 """Get function that generates `ProfileDatum` objects. 363 364 Returns: 365 A function that generates `ProfileDatum` objects. 366 """ 367 node_to_traceback = defaultdict(list) 368 node_to_op_type = defaultdict(str) 369 for op in self._graph.get_operations(): 370 node_to_traceback[op.name] = op.traceback 371 node_to_op_type[op.name] = op.type 372 373 def profile_data_generator(device_step_stats): 374 for node_stats in device_step_stats.node_stats: 375 if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK': 376 continue 377 yield ProfileDatum( 378 node_stats, 379 node_to_op_type[node_stats.node_name], 380 node_to_traceback[node_stats.node_name]) 381 382 return profile_data_generator 383 384 385def get_profiles(graph, run_metadata): 386 """Generate profiles in pprof format. 387 388 See https://github.com/google/pprof/blob/master/proto/profile.proto 389 for pprof proto format. 390 391 Args: 392 graph: A `Graph` object. 393 run_metadata: A `RunMetadata` proto. 394 395 Returns: 396 A dictionary mapping from device name to pprof proto for that device. 397 """ 398 return PprofProfiler(graph, run_metadata).profile() 399 400 401def profile(graph, run_metadata, output_dir=None): 402 """Generate profiles in pprof format. 403 404 See https://github.com/google/pprof/blob/master/proto/profile.proto 405 for pprof proto format. 406 407 Args: 408 graph: A `Graph` object. 409 run_metadata: A `RunMetadata` proto. 410 output_dir: (string) Directory to output pprof profile to. 411 Profile files for each device will be stored in compressed 412 serialized proto format. If output_dir is None, profile protos 413 will be printed to stdout instead. 414 415 Returns: 416 List of output files created by this profile call. 417 (Note: this list will be empty if output_dir is None) 418 """ 419 profiles = get_profiles(graph, run_metadata) 420 output_file_template = None 421 if output_dir: 422 if not os.path.isdir(output_dir): 423 os.makedirs(output_dir) 424 time_suffix = time.strftime('%Y%m%d%H%M%S') 425 output_file_template = os.path.join( 426 output_dir, '%s_' + time_suffix + '.pb.gz') 427 428 profile_files = [] 429 for device, pprof_proto in profiles.items(): 430 if output_file_template is None: 431 print('No output directory specified, printing to stdout instead.') 432 print(pprof_proto) 433 else: 434 device_name = str(device).strip('/').translate( 435 maketrans('/:', '__')) 436 profile_file = output_file_template % device_name 437 profile_files.append(profile_file) 438 with gzip.open(profile_file, 'w') as output_file: 439 print('Writing profile to %s...' % profile_file) 440 output_file.write(pprof_proto.SerializeToString()) 441 return profile_files 442