• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Profiler for TensorFlow models that outputs data in pprof format.
16
17See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof
18profile format.
19The following needs to be set for profiler to work:
20  * trace_level needs to be set to FULL_TRACE
21  * run_metadata object should be passed in to session.run call
22
23Sample usage:
24  options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
25  run_metadata = tf.RunMetadata()
26
27  with tf.Session as sess:
28    ...
29    sess.run(computation, run_metadata=run_metadata, options=options)
30  pprof_profiler.profile(sess.graph, run_metadata, output_dir)
31
32
33  The code above would output a pprof profile to separate output_dir/.*.pb.gz
34  file for each device. These files can be passed to pprof for formatting.
35  For e.g.:
36     pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz
37"""
38from __future__ import absolute_import
39from __future__ import division
40from __future__ import print_function
41
42from collections import defaultdict
43from collections import namedtuple
44import gzip
45import os
46import string
47import sys
48import time
49
50from proto import profile_pb2
51
52
53if sys.version_info < (3,):
54  maketrans = string.maketrans
55else:
56  maketrans = str.maketrans
57
58
59ProfileDatum = namedtuple('ProfileDatum', [
60    'node_exec_stats', 'op_type', 'traceback'])
61
62
63class StringTable(object):
64  """Keeps track of strings to add to string_table in pprof proto."""
65
66  def __init__(self):
67    # Pprof requires first entry in string_table to be ''.
68    self._string_table = ['']
69    self._string_to_index = {'': 0}
70
71  def index_of(self, value_str):
72    """Get index of value_str in the string table.
73
74    If value_str is not in the string table, we will add it at the end
75    and then return the new index.
76    Args:
77      value_str: (string) Value to lookup/add in/to the string table.
78
79    Returns:
80      Index of value_str in the string table.
81    """
82    if value_str is None:
83      value_str = ''
84    if value_str in self._string_to_index:
85      return self._string_to_index[value_str]
86    index = len(self._string_table)
87    self._string_table.append(value_str)
88    self._string_to_index[value_str] = index
89    return index
90
91  def next_index(self):
92    """Gets index that would be assigned to the next added string.
93
94    Returns:
95      Index of the next string if it was added.
96    """
97    return len(self._string_table)
98
99  def string_table(self):
100    """Returns a list of strings to store in pprof's string_table."""
101    return self._string_table
102
103
104class Functions(object):
105  """Keeps track of `Function` protos for pprof profile."""
106
107  def __init__(self, string_table):
108    """Constructor.
109
110    Args:
111      string_table: A `StringTable` object.
112    """
113    self._string_table = string_table
114    # Maps tuples in the form (file_path, function_name, start_line_number)
115    # to `Function` protos.
116    self._function_key_to_function = {}
117
118  def index_of(self, file_path, function_name, function_start_line):
119    """Returns index of the function, adding the function if needed.
120
121    Args:
122      file_path: (string) Path to file where the function is defined.
123      function_name: (string) Function name.
124      function_start_line: (integer) Start line number of function definition.
125
126    Returns:
127      Function index.
128    """
129    function_key = (file_path, function_name, function_start_line)
130    if function_key in self._function_key_to_function:
131      return self._function_key_to_function[function_key].id
132    else:
133      # Function indexes should start from 1
134      function_index = len(self._function_key_to_function) + 1
135      function = profile_pb2.Function()
136      function.id = function_index
137      function.name = self._string_table.index_of(function_name)
138      function.filename = self._string_table.index_of(file_path)
139      function.start_line = function_start_line
140      self._function_key_to_function[function_key] = function
141      return function_index
142
143  def function_protos(self):
144    """Returns list of `profile_pb2.Function` protos."""
145    return self._function_key_to_function.values()
146
147
148class Locations(object):
149  """Keeps track of `Location` protos for pprof profile.
150
151  `Locations` store information about function call locations.
152  """
153
154  def __init__(self, functions):
155    """Constructor.
156
157    Args:
158      functions: A `Functions` object.
159    """
160    self._functions = functions
161    # Maps tuples in the form (file_path, called_function_name, line_number)
162    # to `Location` protos.
163    self._location_key_to_location = {}
164
165  def index_of(
166      self, file_path, line_number, called_function_name, called_file_path,
167      called_function_start_line):
168    """Returns index of the location, adding the location if needed.
169
170    Args:
171      file_path: (string) Path to file that makes the call.
172      line_number: (integer) Call line number.
173      called_function_name: (string) Function name of the function called at
174        `file_path` and `line_number`.
175      called_file_path: (string) Path to file where the called function is
176        defined.
177      called_function_start_line: (integer) Start line number of called
178        function definition in `called_file_path` file.
179
180    Returns:
181      Index of location.
182    """
183    location_key = (file_path, called_function_name, line_number)
184    if location_key in self._location_key_to_location:
185      location = self._location_key_to_location[location_key]
186      return location.id
187    else:
188      # Location indexes should start from 1
189      location_index = len(self._location_key_to_location) + 1
190      location = profile_pb2.Location()
191      location.id = location_index
192      self._location_key_to_location[location_key] = location
193
194      line = location.line.add()
195      line.function_id = self._functions.index_of(
196          called_file_path, called_function_name, called_function_start_line)
197      line.line = line_number
198      return location_index
199
200  def location_protos(self):
201    """Returns list of `profile_pb2.Location` protos."""
202    return self._location_key_to_location.values()
203
204
205class Samples(object):
206  """Keeps track of `Sample` protos for pprof profile.
207
208  Samples store the following statistics in order:
209  count, all_time, op_time
210  """
211
212  def __init__(self, string_table):
213    """Constructor.
214
215    Args:
216      string_table: A `StringTable` object.
217    """
218    self._string_table = string_table
219    # TODO(annarev): figure out if location is unique for each node name.
220    # If not, also key this dictionary based on location ids.
221    self._node_name_to_sample = {}
222
223  def add(self, datum, location_ids):
224    """Adds a sample data point.
225
226    Args:
227      datum: `ProfileDatum` to add a sample for.
228      location_ids: List of numberic location ids for this
229        sample.
230    """
231    node_name = datum.node_exec_stats.node_name
232    if node_name in self._node_name_to_sample:
233      sample = self._node_name_to_sample[node_name]
234      sample.location_id.extend(location_ids)
235    else:
236      sample = profile_pb2.Sample()
237      # Sample stores 3 values: count, all_time, op_time
238      sample.value.extend([0, 0, 0])
239
240      label = sample.label.add()
241      label.key = self._string_table.index_of('node_name')
242      label.str = self._string_table.index_of(node_name)
243      label = sample.label.add()
244      label.key = self._string_table.index_of('op_type')
245      label.str = self._string_table.index_of(datum.op_type)
246      self._node_name_to_sample[node_name] = sample
247    sample.value[0] += 1
248    sample.value[1] += datum.node_exec_stats.all_end_rel_micros
249    sample.value[2] += (
250        datum.node_exec_stats.op_end_rel_micros -
251        datum.node_exec_stats.op_start_rel_micros)
252
253  def get_sample_protos(self):
254    """Returns list of `Sample` protos for pprof profile."""
255    return self._node_name_to_sample.values()
256
257
258class PprofProfiler(object):
259  """Creates profiles in pprof format."""
260
261  def __init__(self, graph, run_metadata):
262    """Constructor.
263
264    Args:
265      graph: A `Graph` instance.
266      run_metadata: A list of `RunMetadata` objects.
267    """
268    self._graph = graph
269    self._run_metadata = run_metadata
270    self._string_table = StringTable()
271    self._functions = Functions(self._string_table)
272    self._locations = Locations(self._functions)
273
274  def profile(self):
275    """Generates pprof profiles.
276
277    Returns:
278      Dictionary mapping from device name to proto in `profile_pb2.Profile`
279      format.
280    """
281    profiles = {}
282    data_generator_func = self._get_profile_data_generator()
283    for device_index, device_stats in enumerate(
284        self._run_metadata.step_stats.dev_stats):
285      # Create profile
286      pprof_proto = self._get_pprof_proto(data_generator_func(device_stats))
287      if not pprof_proto.sample:
288        print(
289            'Not enough data to create profile for device %s. Did you pass '
290            'RunMetadata to session.run call?' % device_stats.device)
291        continue
292      # Add device name comment
293      device_count = len(self._run_metadata.step_stats.dev_stats)
294      device_description = (
295          'Device %d of %d: %s' %
296          (device_index + 1, device_count, device_stats.device))
297      device_description_str_index = self._string_table.next_index()
298      pprof_proto.string_table.append(device_description)
299      pprof_proto.comment.append(device_description_str_index)
300      profiles[device_stats.device] = pprof_proto
301    return profiles
302
303  def _get_pprof_proto(self, profile_datum_generator):
304    """Returns profile data in pprof proto format.
305
306    Args:
307      profile_datum_generator: Generator outputting `ProfileDatum` objects.
308
309    Returns:
310      A proto in pprof format.
311    """
312    pprof_profile = profile_pb2.Profile()
313    samples = Samples(self._string_table)
314
315    for datum in profile_datum_generator:
316      if not datum.traceback:
317        continue
318
319      stack_frame = datum.traceback[-1]
320      after_apply_op = False
321      location_ids = []
322
323      # We add locations from stack trace in bottom-up order.
324      for stack_frame_index in reversed(range(len(datum.traceback) - 1)):
325        prev_stack_frame = stack_frame
326        stack_frame = datum.traceback[stack_frame_index]
327
328        # Call at current frame calls function at previous frame.
329        prev_file_path = prev_stack_frame[0]
330        prev_function = prev_stack_frame[2]
331        prev_function_start_line = prev_stack_frame[4]
332        curr_file_path = stack_frame[0]
333        curr_line_number = stack_frame[1]
334
335        # Skip all calls up to apply_op since they are the same for all ops.
336        if not after_apply_op:
337          if prev_function == 'apply_op':
338            after_apply_op = True
339          continue
340        location_index = self._locations.index_of(
341            curr_file_path, curr_line_number,
342            prev_function, prev_file_path, prev_function_start_line)
343        location_ids.append(location_index)
344      samples.add(datum, location_ids)
345
346    sample_type_description = 'count'
347    sample_type = pprof_profile.sample_type.add()
348    sample_type.type = self._string_table.index_of(sample_type_description)
349    sample_type.unit = self._string_table.index_of('count')
350    sample_type_description = 'all_time'
351    sample_type = pprof_profile.sample_type.add()
352    sample_type.type = self._string_table.index_of(sample_type_description)
353    sample_type.unit = self._string_table.index_of('nanoseconds')
354    sample_type_description = 'op_time'
355    sample_type = pprof_profile.sample_type.add()
356    sample_type.type = self._string_table.index_of(sample_type_description)
357    sample_type.unit = self._string_table.index_of('nanoseconds')
358
359    pprof_profile.string_table.extend(self._string_table.string_table())
360    pprof_profile.sample.extend(samples.get_sample_protos())
361    pprof_profile.function.extend(self._functions.function_protos())
362    pprof_profile.location.extend(self._locations.location_protos())
363    return pprof_profile
364
365  def _get_profile_data_generator(self):
366    """Get function that generates `ProfileDatum` objects.
367
368    Returns:
369      A function that generates `ProfileDatum` objects.
370    """
371    node_to_traceback = defaultdict(list)
372    node_to_op_type = defaultdict(str)
373    for op in self._graph.get_operations():
374      node_to_traceback[op.name] = op.traceback_with_start_lines
375      node_to_op_type[op.name] = op.type
376
377    def profile_data_generator(device_step_stats):
378      for node_stats in device_step_stats.node_stats:
379        if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK':
380          continue
381        yield ProfileDatum(
382            node_stats,
383            node_to_op_type[node_stats.node_name],
384            node_to_traceback[node_stats.node_name])
385
386    return profile_data_generator
387
388
389def get_profiles(graph, run_metadata):
390  """Generate profiles in pprof format.
391
392  See https://github.com/google/pprof/blob/master/proto/profile.proto
393  for pprof proto format.
394
395  Args:
396    graph: A `Graph` object.
397    run_metadata: A `RunMetadata` proto.
398
399  Returns:
400    A dictionary mapping from device name to pprof proto for that device.
401  """
402  return PprofProfiler(graph, run_metadata).profile()
403
404
405def profile(graph, run_metadata, output_dir=None):
406  """Generate profiles in pprof format.
407
408  See https://github.com/google/pprof/blob/master/proto/profile.proto
409  for pprof proto format.
410
411  Args:
412    graph: A `Graph` object.
413    run_metadata: A `RunMetadata` proto.
414    output_dir: (string) Directory to output pprof profile to.
415      Profile files for each device will be stored in compressed
416      serialized proto format. If output_dir is None, profile protos
417      will be printed to stdout instead.
418
419  Returns:
420    List of output files created by this profile call.
421    (Note: this list will be empty if output_dir is None)
422  """
423  profiles = get_profiles(graph, run_metadata)
424  output_file_template = None
425  if output_dir:
426    if not os.path.isdir(output_dir):
427      os.makedirs(output_dir)
428    time_suffix = time.strftime('%Y%m%d%H%M%S')
429    output_file_template = os.path.join(
430        output_dir, '%s_' + time_suffix + '.pb.gz')
431
432  profile_files = []
433  for device, pprof_proto in profiles.items():
434    if output_file_template is None:
435      print('No output directory specified, printing to stdout instead.')
436      print(pprof_proto)
437    else:
438      device_name = str(device).strip('/').translate(
439          maketrans('/:', '__'))
440      profile_file = output_file_template % device_name
441      profile_files.append(profile_file)
442      with gzip.open(profile_file, 'w') as output_file:
443        print('Writing profile to %s...' % profile_file)
444        output_file.write(pprof_proto.SerializeToString())
445  return profile_files
446