1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Profiler for TensorFlow models that outputs data in pprof format. 16 17See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof 18profile format. 19The following needs to be set for profiler to work: 20 * trace_level needs to be set to FULL_TRACE 21 * run_metadata object should be passed in to session.run call 22 23Sample usage: 24 options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 25 run_metadata = tf.RunMetadata() 26 27 with tf.Session as sess: 28 ... 29 sess.run(computation, run_metadata=run_metadata, options=options) 30 pprof_profiler.profile(sess.graph, run_metadata, output_dir) 31 32 33 The code above would output a pprof profile to separate output_dir/.*.pb.gz 34 file for each device. These files can be passed to pprof for formatting. 35 For e.g.: 36 pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz 37""" 38from __future__ import absolute_import 39from __future__ import division 40from __future__ import print_function 41 42from collections import defaultdict 43from collections import namedtuple 44import gzip 45import os 46import string 47import sys 48import time 49 50from proto import profile_pb2 51 52 53if sys.version_info < (3,): 54 maketrans = string.maketrans 55else: 56 maketrans = str.maketrans 57 58 59ProfileDatum = namedtuple('ProfileDatum', [ 60 'node_exec_stats', 'op_type', 'traceback']) 61 62 63class StringTable(object): 64 """Keeps track of strings to add to string_table in pprof proto.""" 65 66 def __init__(self): 67 # Pprof requires first entry in string_table to be ''. 68 self._string_table = [''] 69 self._string_to_index = {'': 0} 70 71 def index_of(self, value_str): 72 """Get index of value_str in the string table. 73 74 If value_str is not in the string table, we will add it at the end 75 and then return the new index. 76 Args: 77 value_str: (string) Value to lookup/add in/to the string table. 78 79 Returns: 80 Index of value_str in the string table. 81 """ 82 if value_str is None: 83 value_str = '' 84 if value_str in self._string_to_index: 85 return self._string_to_index[value_str] 86 index = len(self._string_table) 87 self._string_table.append(value_str) 88 self._string_to_index[value_str] = index 89 return index 90 91 def next_index(self): 92 """Gets index that would be assigned to the next added string. 93 94 Returns: 95 Index of the next string if it was added. 96 """ 97 return len(self._string_table) 98 99 def string_table(self): 100 """Returns a list of strings to store in pprof's string_table.""" 101 return self._string_table 102 103 104class Functions(object): 105 """Keeps track of `Function` protos for pprof profile.""" 106 107 def __init__(self, string_table): 108 """Constructor. 109 110 Args: 111 string_table: A `StringTable` object. 112 """ 113 self._string_table = string_table 114 # Maps tuples in the form (file_path, function_name, start_line_number) 115 # to `Function` protos. 116 self._function_key_to_function = {} 117 118 def index_of(self, file_path, function_name, function_start_line): 119 """Returns index of the function, adding the function if needed. 120 121 Args: 122 file_path: (string) Path to file where the function is defined. 123 function_name: (string) Function name. 124 function_start_line: (integer) Start line number of function definition. 125 126 Returns: 127 Function index. 128 """ 129 function_key = (file_path, function_name, function_start_line) 130 if function_key in self._function_key_to_function: 131 return self._function_key_to_function[function_key].id 132 else: 133 # Function indexes should start from 1 134 function_index = len(self._function_key_to_function) + 1 135 function = profile_pb2.Function() 136 function.id = function_index 137 function.name = self._string_table.index_of(function_name) 138 function.filename = self._string_table.index_of(file_path) 139 function.start_line = function_start_line 140 self._function_key_to_function[function_key] = function 141 return function_index 142 143 def function_protos(self): 144 """Returns list of `profile_pb2.Function` protos.""" 145 return self._function_key_to_function.values() 146 147 148class Locations(object): 149 """Keeps track of `Location` protos for pprof profile. 150 151 `Locations` store information about function call locations. 152 """ 153 154 def __init__(self, functions): 155 """Constructor. 156 157 Args: 158 functions: A `Functions` object. 159 """ 160 self._functions = functions 161 # Maps tuples in the form (file_path, called_function_name, line_number) 162 # to `Location` protos. 163 self._location_key_to_location = {} 164 165 def index_of( 166 self, file_path, line_number, called_function_name, called_file_path, 167 called_function_start_line): 168 """Returns index of the location, adding the location if needed. 169 170 Args: 171 file_path: (string) Path to file that makes the call. 172 line_number: (integer) Call line number. 173 called_function_name: (string) Function name of the function called at 174 `file_path` and `line_number`. 175 called_file_path: (string) Path to file where the called function is 176 defined. 177 called_function_start_line: (integer) Start line number of called 178 function definition in `called_file_path` file. 179 180 Returns: 181 Index of location. 182 """ 183 location_key = (file_path, called_function_name, line_number) 184 if location_key in self._location_key_to_location: 185 location = self._location_key_to_location[location_key] 186 return location.id 187 else: 188 # Location indexes should start from 1 189 location_index = len(self._location_key_to_location) + 1 190 location = profile_pb2.Location() 191 location.id = location_index 192 self._location_key_to_location[location_key] = location 193 194 line = location.line.add() 195 line.function_id = self._functions.index_of( 196 called_file_path, called_function_name, called_function_start_line) 197 line.line = line_number 198 return location_index 199 200 def location_protos(self): 201 """Returns list of `profile_pb2.Location` protos.""" 202 return self._location_key_to_location.values() 203 204 205class Samples(object): 206 """Keeps track of `Sample` protos for pprof profile. 207 208 Samples store the following statistics in order: 209 count, all_time, op_time 210 """ 211 212 def __init__(self, string_table): 213 """Constructor. 214 215 Args: 216 string_table: A `StringTable` object. 217 """ 218 self._string_table = string_table 219 # TODO(annarev): figure out if location is unique for each node name. 220 # If not, also key this dictionary based on location ids. 221 self._node_name_to_sample = {} 222 223 def add(self, datum, location_ids): 224 """Adds a sample data point. 225 226 Args: 227 datum: `ProfileDatum` to add a sample for. 228 location_ids: List of numberic location ids for this 229 sample. 230 """ 231 node_name = datum.node_exec_stats.node_name 232 if node_name in self._node_name_to_sample: 233 sample = self._node_name_to_sample[node_name] 234 sample.location_id.extend(location_ids) 235 else: 236 sample = profile_pb2.Sample() 237 # Sample stores 3 values: count, all_time, op_time 238 sample.value.extend([0, 0, 0]) 239 240 label = sample.label.add() 241 label.key = self._string_table.index_of('node_name') 242 label.str = self._string_table.index_of(node_name) 243 label = sample.label.add() 244 label.key = self._string_table.index_of('op_type') 245 label.str = self._string_table.index_of(datum.op_type) 246 self._node_name_to_sample[node_name] = sample 247 sample.value[0] += 1 248 sample.value[1] += datum.node_exec_stats.all_end_rel_micros 249 sample.value[2] += ( 250 datum.node_exec_stats.op_end_rel_micros - 251 datum.node_exec_stats.op_start_rel_micros) 252 253 def get_sample_protos(self): 254 """Returns list of `Sample` protos for pprof profile.""" 255 return self._node_name_to_sample.values() 256 257 258class PprofProfiler(object): 259 """Creates profiles in pprof format.""" 260 261 def __init__(self, graph, run_metadata): 262 """Constructor. 263 264 Args: 265 graph: A `Graph` instance. 266 run_metadata: A list of `RunMetadata` objects. 267 """ 268 self._graph = graph 269 self._run_metadata = run_metadata 270 self._string_table = StringTable() 271 self._functions = Functions(self._string_table) 272 self._locations = Locations(self._functions) 273 274 def profile(self): 275 """Generates pprof profiles. 276 277 Returns: 278 Dictionary mapping from device name to proto in `profile_pb2.Profile` 279 format. 280 """ 281 profiles = {} 282 data_generator_func = self._get_profile_data_generator() 283 for device_index, device_stats in enumerate( 284 self._run_metadata.step_stats.dev_stats): 285 # Create profile 286 pprof_proto = self._get_pprof_proto(data_generator_func(device_stats)) 287 if not pprof_proto.sample: 288 print( 289 'Not enough data to create profile for device %s. Did you pass ' 290 'RunMetadata to session.run call?' % device_stats.device) 291 continue 292 # Add device name comment 293 device_count = len(self._run_metadata.step_stats.dev_stats) 294 device_description = ( 295 'Device %d of %d: %s' % 296 (device_index + 1, device_count, device_stats.device)) 297 device_description_str_index = self._string_table.next_index() 298 pprof_proto.string_table.append(device_description) 299 pprof_proto.comment.append(device_description_str_index) 300 profiles[device_stats.device] = pprof_proto 301 return profiles 302 303 def _get_pprof_proto(self, profile_datum_generator): 304 """Returns profile data in pprof proto format. 305 306 Args: 307 profile_datum_generator: Generator outputting `ProfileDatum` objects. 308 309 Returns: 310 A proto in pprof format. 311 """ 312 pprof_profile = profile_pb2.Profile() 313 samples = Samples(self._string_table) 314 315 for datum in profile_datum_generator: 316 if not datum.traceback: 317 continue 318 319 stack_frame = datum.traceback[-1] 320 after_apply_op = False 321 location_ids = [] 322 323 # We add locations from stack trace in bottom-up order. 324 for stack_frame_index in reversed(range(len(datum.traceback) - 1)): 325 prev_stack_frame = stack_frame 326 stack_frame = datum.traceback[stack_frame_index] 327 328 # Call at current frame calls function at previous frame. 329 prev_file_path = prev_stack_frame[0] 330 prev_function = prev_stack_frame[2] 331 prev_function_start_line = prev_stack_frame[4] 332 curr_file_path = stack_frame[0] 333 curr_line_number = stack_frame[1] 334 335 # Skip all calls up to apply_op since they are the same for all ops. 336 if not after_apply_op: 337 if prev_function == 'apply_op': 338 after_apply_op = True 339 continue 340 location_index = self._locations.index_of( 341 curr_file_path, curr_line_number, 342 prev_function, prev_file_path, prev_function_start_line) 343 location_ids.append(location_index) 344 samples.add(datum, location_ids) 345 346 sample_type_description = 'count' 347 sample_type = pprof_profile.sample_type.add() 348 sample_type.type = self._string_table.index_of(sample_type_description) 349 sample_type.unit = self._string_table.index_of('count') 350 sample_type_description = 'all_time' 351 sample_type = pprof_profile.sample_type.add() 352 sample_type.type = self._string_table.index_of(sample_type_description) 353 sample_type.unit = self._string_table.index_of('nanoseconds') 354 sample_type_description = 'op_time' 355 sample_type = pprof_profile.sample_type.add() 356 sample_type.type = self._string_table.index_of(sample_type_description) 357 sample_type.unit = self._string_table.index_of('nanoseconds') 358 359 pprof_profile.string_table.extend(self._string_table.string_table()) 360 pprof_profile.sample.extend(samples.get_sample_protos()) 361 pprof_profile.function.extend(self._functions.function_protos()) 362 pprof_profile.location.extend(self._locations.location_protos()) 363 return pprof_profile 364 365 def _get_profile_data_generator(self): 366 """Get function that generates `ProfileDatum` objects. 367 368 Returns: 369 A function that generates `ProfileDatum` objects. 370 """ 371 node_to_traceback = defaultdict(list) 372 node_to_op_type = defaultdict(str) 373 for op in self._graph.get_operations(): 374 node_to_traceback[op.name] = op.traceback_with_start_lines 375 node_to_op_type[op.name] = op.type 376 377 def profile_data_generator(device_step_stats): 378 for node_stats in device_step_stats.node_stats: 379 if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK': 380 continue 381 yield ProfileDatum( 382 node_stats, 383 node_to_op_type[node_stats.node_name], 384 node_to_traceback[node_stats.node_name]) 385 386 return profile_data_generator 387 388 389def get_profiles(graph, run_metadata): 390 """Generate profiles in pprof format. 391 392 See https://github.com/google/pprof/blob/master/proto/profile.proto 393 for pprof proto format. 394 395 Args: 396 graph: A `Graph` object. 397 run_metadata: A `RunMetadata` proto. 398 399 Returns: 400 A dictionary mapping from device name to pprof proto for that device. 401 """ 402 return PprofProfiler(graph, run_metadata).profile() 403 404 405def profile(graph, run_metadata, output_dir=None): 406 """Generate profiles in pprof format. 407 408 See https://github.com/google/pprof/blob/master/proto/profile.proto 409 for pprof proto format. 410 411 Args: 412 graph: A `Graph` object. 413 run_metadata: A `RunMetadata` proto. 414 output_dir: (string) Directory to output pprof profile to. 415 Profile files for each device will be stored in compressed 416 serialized proto format. If output_dir is None, profile protos 417 will be printed to stdout instead. 418 419 Returns: 420 List of output files created by this profile call. 421 (Note: this list will be empty if output_dir is None) 422 """ 423 profiles = get_profiles(graph, run_metadata) 424 output_file_template = None 425 if output_dir: 426 if not os.path.isdir(output_dir): 427 os.makedirs(output_dir) 428 time_suffix = time.strftime('%Y%m%d%H%M%S') 429 output_file_template = os.path.join( 430 output_dir, '%s_' + time_suffix + '.pb.gz') 431 432 profile_files = [] 433 for device, pprof_proto in profiles.items(): 434 if output_file_template is None: 435 print('No output directory specified, printing to stdout instead.') 436 print(pprof_proto) 437 else: 438 device_name = str(device).strip('/').translate( 439 maketrans('/:', '__')) 440 profile_file = output_file_template % device_name 441 profile_files.append(profile_file) 442 with gzip.open(profile_file, 'w') as output_file: 443 print('Writing profile to %s...' % profile_file) 444 output_file.write(pprof_proto.SerializeToString()) 445 return profile_files 446