• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Profiler for TensorFlow models that outputs data in pprof format.
16
17See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof
18profile format.
19The following needs to be set for profiler to work:
20  * trace_level needs to be set to FULL_TRACE
21  * run_metadata object should be passed in to session.run call
22
23Sample usage:
24  options = tf.compat.v1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
25  run_metadata = tf.compat.v1.RunMetadata()
26
27  with tf.compat.v1.Session as sess:
28    ...
29    sess.run(computation, run_metadata=run_metadata, options=options)
30  pprof_profiler.profile(sess.graph, run_metadata, output_dir)
31
32
33  The code above would output a pprof profile to separate output_dir/.*.pb.gz
34  file for each device. These files can be passed to pprof for formatting.
35  For e.g.:
36     pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz
37"""
38from collections import defaultdict
39from collections import namedtuple
40import gzip
41import os
42import string
43import sys
44import time
45
46from proto import profile_pb2
47
48
49if sys.version_info < (3,):
50  maketrans = string.maketrans
51else:
52  maketrans = str.maketrans
53
54
55ProfileDatum = namedtuple('ProfileDatum', [
56    'node_exec_stats', 'op_type', 'traceback'])
57
58
59class StringTable(object):
60  """Keeps track of strings to add to string_table in pprof proto."""
61
62  def __init__(self):
63    # Pprof requires first entry in string_table to be ''.
64    self._string_table = ['']
65    self._string_to_index = {'': 0}
66
67  def index_of(self, value_str):
68    """Get index of value_str in the string table.
69
70    If value_str is not in the string table, we will add it at the end
71    and then return the new index.
72    Args:
73      value_str: (string) Value to lookup/add in/to the string table.
74
75    Returns:
76      Index of value_str in the string table.
77    """
78    if value_str is None:
79      value_str = ''
80    if value_str in self._string_to_index:
81      return self._string_to_index[value_str]
82    index = len(self._string_table)
83    self._string_table.append(value_str)
84    self._string_to_index[value_str] = index
85    return index
86
87  def next_index(self):
88    """Gets index that would be assigned to the next added string.
89
90    Returns:
91      Index of the next string if it was added.
92    """
93    return len(self._string_table)
94
95  def string_table(self):
96    """Returns a list of strings to store in pprof's string_table."""
97    return self._string_table
98
99
100class Functions(object):
101  """Keeps track of `Function` protos for pprof profile."""
102
103  def __init__(self, string_table):
104    """Constructor.
105
106    Args:
107      string_table: A `StringTable` object.
108    """
109    self._string_table = string_table
110    # Maps tuples in the form (file_path, function_name, start_line_number)
111    # to `Function` protos.
112    self._function_key_to_function = {}
113
114  def index_of(self, file_path, function_name, function_start_line):
115    """Returns index of the function, adding the function if needed.
116
117    Args:
118      file_path: (string) Path to file where the function is defined.
119      function_name: (string) Function name.
120      function_start_line: (integer) Start line number of function definition.
121
122    Returns:
123      Function index.
124    """
125    function_key = (file_path, function_name, function_start_line)
126    if function_key in self._function_key_to_function:
127      return self._function_key_to_function[function_key].id
128    else:
129      # Function indexes should start from 1
130      function_index = len(self._function_key_to_function) + 1
131      function = profile_pb2.Function()
132      function.id = function_index
133      function.name = self._string_table.index_of(function_name)
134      function.filename = self._string_table.index_of(file_path)
135      function.start_line = function_start_line
136      self._function_key_to_function[function_key] = function
137      return function_index
138
139  def function_protos(self):
140    """Returns list of `profile_pb2.Function` protos."""
141    return self._function_key_to_function.values()
142
143
144class Locations(object):
145  """Keeps track of `Location` protos for pprof profile.
146
147  `Locations` store information about function call locations.
148  """
149
150  def __init__(self, functions):
151    """Constructor.
152
153    Args:
154      functions: A `Functions` object.
155    """
156    self._functions = functions
157    # Maps tuples in the form (file_path, called_function_name, line_number)
158    # to `Location` protos.
159    self._location_key_to_location = {}
160
161  def index_of(
162      self, file_path, line_number, called_function_name, called_file_path,
163      called_function_start_line):
164    """Returns index of the location, adding the location if needed.
165
166    Args:
167      file_path: (string) Path to file that makes the call.
168      line_number: (integer) Call line number.
169      called_function_name: (string) Function name of the function called at
170        `file_path` and `line_number`.
171      called_file_path: (string) Path to file where the called function is
172        defined.
173      called_function_start_line: (integer) Start line number of called
174        function definition in `called_file_path` file.
175
176    Returns:
177      Index of location.
178    """
179    location_key = (file_path, called_function_name, line_number)
180    if location_key in self._location_key_to_location:
181      location = self._location_key_to_location[location_key]
182      return location.id
183    else:
184      # Location indexes should start from 1
185      location_index = len(self._location_key_to_location) + 1
186      location = profile_pb2.Location()
187      location.id = location_index
188      self._location_key_to_location[location_key] = location
189
190      line = location.line.add()
191      line.function_id = self._functions.index_of(
192          called_file_path, called_function_name, called_function_start_line)
193      line.line = line_number
194      return location_index
195
196  def location_protos(self):
197    """Returns list of `profile_pb2.Location` protos."""
198    return self._location_key_to_location.values()
199
200
201class Samples(object):
202  """Keeps track of `Sample` protos for pprof profile.
203
204  Samples store the following statistics in order:
205  count, all_time, op_time
206  """
207
208  def __init__(self, string_table):
209    """Constructor.
210
211    Args:
212      string_table: A `StringTable` object.
213    """
214    self._string_table = string_table
215    # TODO(annarev): figure out if location is unique for each node name.
216    # If not, also key this dictionary based on location ids.
217    self._node_name_to_sample = {}
218
219  def add(self, datum, location_ids):
220    """Adds a sample data point.
221
222    Args:
223      datum: `ProfileDatum` to add a sample for.
224      location_ids: List of numberic location ids for this
225        sample.
226    """
227    node_name = datum.node_exec_stats.node_name
228    if node_name in self._node_name_to_sample:
229      sample = self._node_name_to_sample[node_name]
230      sample.location_id.extend(location_ids)
231    else:
232      sample = profile_pb2.Sample()
233      # Sample stores 3 values: count, all_time, op_time
234      sample.value.extend([0, 0, 0])
235
236      label = sample.label.add()
237      label.key = self._string_table.index_of('node_name')
238      label.str = self._string_table.index_of(node_name)
239      label = sample.label.add()
240      label.key = self._string_table.index_of('op_type')
241      label.str = self._string_table.index_of(datum.op_type)
242      self._node_name_to_sample[node_name] = sample
243    sample.value[0] += 1
244    sample.value[1] += datum.node_exec_stats.all_end_rel_micros
245    sample.value[2] += (
246        datum.node_exec_stats.op_end_rel_micros -
247        datum.node_exec_stats.op_start_rel_micros)
248
249  def get_sample_protos(self):
250    """Returns list of `Sample` protos for pprof profile."""
251    return self._node_name_to_sample.values()
252
253
254class PprofProfiler(object):
255  """Creates profiles in pprof format."""
256
257  def __init__(self, graph, run_metadata):
258    """Constructor.
259
260    Args:
261      graph: A `Graph` instance.
262      run_metadata: A list of `RunMetadata` objects.
263    """
264    self._graph = graph
265    self._run_metadata = run_metadata
266    self._string_table = StringTable()
267    self._functions = Functions(self._string_table)
268    self._locations = Locations(self._functions)
269
270  def profile(self):
271    """Generates pprof profiles.
272
273    Returns:
274      Dictionary mapping from device name to proto in `profile_pb2.Profile`
275      format.
276    """
277    profiles = {}
278    data_generator_func = self._get_profile_data_generator()
279    for device_index, device_stats in enumerate(
280        self._run_metadata.step_stats.dev_stats):
281      # Create profile
282      pprof_proto = self._get_pprof_proto(data_generator_func(device_stats))
283      if not pprof_proto.sample:
284        print(
285            'Not enough data to create profile for device %s. Did you pass '
286            'RunMetadata to session.run call?' % device_stats.device)
287        continue
288      # Add device name comment
289      device_count = len(self._run_metadata.step_stats.dev_stats)
290      device_description = (
291          'Device %d of %d: %s' %
292          (device_index + 1, device_count, device_stats.device))
293      device_description_str_index = self._string_table.next_index()
294      pprof_proto.string_table.append(device_description)
295      pprof_proto.comment.append(device_description_str_index)
296      profiles[device_stats.device] = pprof_proto
297    return profiles
298
299  def _get_pprof_proto(self, profile_datum_generator):
300    """Returns profile data in pprof proto format.
301
302    Args:
303      profile_datum_generator: Generator outputting `ProfileDatum` objects.
304
305    Returns:
306      A proto in pprof format.
307    """
308    pprof_profile = profile_pb2.Profile()
309    samples = Samples(self._string_table)
310
311    for datum in profile_datum_generator:
312      if not datum.traceback:
313        continue
314
315      stack_frame = datum.traceback[-1]
316      after_apply_op = False
317      location_ids = []
318
319      # We add locations from stack trace in bottom-up order.
320      for stack_frame_index in reversed(range(len(datum.traceback) - 1)):
321        prev_stack_frame = stack_frame
322        stack_frame = datum.traceback[stack_frame_index]
323
324        # Call at current frame calls function at previous frame.
325        prev_file_path = prev_stack_frame[0]
326        prev_function = prev_stack_frame[2]
327        prev_function_start_line = -1
328        curr_file_path = stack_frame[0]
329        curr_line_number = stack_frame[1]
330
331        # Skip all calls up to apply_op since they are the same for all ops.
332        if not after_apply_op:
333          if prev_function == 'apply_op':
334            after_apply_op = True
335          continue
336        location_index = self._locations.index_of(
337            curr_file_path, curr_line_number,
338            prev_function, prev_file_path, prev_function_start_line)
339        location_ids.append(location_index)
340      samples.add(datum, location_ids)
341
342    sample_type_description = 'count'
343    sample_type = pprof_profile.sample_type.add()
344    sample_type.type = self._string_table.index_of(sample_type_description)
345    sample_type.unit = self._string_table.index_of('count')
346    sample_type_description = 'all_time'
347    sample_type = pprof_profile.sample_type.add()
348    sample_type.type = self._string_table.index_of(sample_type_description)
349    sample_type.unit = self._string_table.index_of('nanoseconds')
350    sample_type_description = 'op_time'
351    sample_type = pprof_profile.sample_type.add()
352    sample_type.type = self._string_table.index_of(sample_type_description)
353    sample_type.unit = self._string_table.index_of('nanoseconds')
354
355    pprof_profile.string_table.extend(self._string_table.string_table())
356    pprof_profile.sample.extend(samples.get_sample_protos())
357    pprof_profile.function.extend(self._functions.function_protos())
358    pprof_profile.location.extend(self._locations.location_protos())
359    return pprof_profile
360
361  def _get_profile_data_generator(self):
362    """Get function that generates `ProfileDatum` objects.
363
364    Returns:
365      A function that generates `ProfileDatum` objects.
366    """
367    node_to_traceback = defaultdict(list)
368    node_to_op_type = defaultdict(str)
369    for op in self._graph.get_operations():
370      node_to_traceback[op.name] = op.traceback
371      node_to_op_type[op.name] = op.type
372
373    def profile_data_generator(device_step_stats):
374      for node_stats in device_step_stats.node_stats:
375        if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK':
376          continue
377        yield ProfileDatum(
378            node_stats,
379            node_to_op_type[node_stats.node_name],
380            node_to_traceback[node_stats.node_name])
381
382    return profile_data_generator
383
384
385def get_profiles(graph, run_metadata):
386  """Generate profiles in pprof format.
387
388  See https://github.com/google/pprof/blob/master/proto/profile.proto
389  for pprof proto format.
390
391  Args:
392    graph: A `Graph` object.
393    run_metadata: A `RunMetadata` proto.
394
395  Returns:
396    A dictionary mapping from device name to pprof proto for that device.
397  """
398  return PprofProfiler(graph, run_metadata).profile()
399
400
401def profile(graph, run_metadata, output_dir=None):
402  """Generate profiles in pprof format.
403
404  See https://github.com/google/pprof/blob/master/proto/profile.proto
405  for pprof proto format.
406
407  Args:
408    graph: A `Graph` object.
409    run_metadata: A `RunMetadata` proto.
410    output_dir: (string) Directory to output pprof profile to.
411      Profile files for each device will be stored in compressed
412      serialized proto format. If output_dir is None, profile protos
413      will be printed to stdout instead.
414
415  Returns:
416    List of output files created by this profile call.
417    (Note: this list will be empty if output_dir is None)
418  """
419  profiles = get_profiles(graph, run_metadata)
420  output_file_template = None
421  if output_dir:
422    if not os.path.isdir(output_dir):
423      os.makedirs(output_dir)
424    time_suffix = time.strftime('%Y%m%d%H%M%S')
425    output_file_template = os.path.join(
426        output_dir, '%s_' + time_suffix + '.pb.gz')
427
428  profile_files = []
429  for device, pprof_proto in profiles.items():
430    if output_file_template is None:
431      print('No output directory specified, printing to stdout instead.')
432      print(pprof_proto)
433    else:
434      device_name = str(device).strip('/').translate(
435          maketrans('/:', '__'))
436      profile_file = output_file_template % device_name
437      profile_files.append(profile_file)
438      with gzip.open(profile_file, 'w') as output_file:
439        print('Writing profile to %s...' % profile_file)
440        output_file.write(pprof_proto.SerializeToString())
441  return profile_files
442