• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Timeline visualization for TensorFlow using Chrome Trace Format."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import json
24import re
25
26# The timeline target is usually imported as part of BUILD target
27# "platform_test", which includes also includes the "platform"
28# dependency.  This is why the logging import here is okay.
29from tensorflow.python.platform import build_info
30from tensorflow.python.platform import tf_logging as logging
31
32
33class AllocationMaximum(collections.namedtuple(
34    'AllocationMaximum', ('timestamp', 'num_bytes', 'tensors'))):
35  """Stores the maximum allocation for a given allocator within the timelne.
36
37  Parameters:
38    timestamp: `tensorflow::Env::NowMicros()` when this maximum was reached.
39    num_bytes: the total memory used at this time.
40    tensors: the set of tensors allocated at this time.
41  """
42  pass
43
44
45class StepStatsAnalysis(collections.namedtuple(
46    'StepStatsAnalysis', ('chrome_trace', 'allocator_maximums'))):
47  """Stores the step stats analysis output.
48
49  Parameters:
50    chrome_trace: A dict containing the chrome trace analysis.
51    allocator_maximums: A dict mapping allocator names to AllocationMaximum.
52  """
53  pass
54
55
56class _ChromeTraceFormatter(object):
57  """A helper class for generating traces in Chrome Trace Format."""
58
59  def __init__(self, show_memory=False):
60    """Constructs a new Chrome Trace formatter."""
61    self._show_memory = show_memory
62    self._events = []
63    self._metadata = []
64
65  def _create_event(self, ph, category, name, pid, tid, timestamp):
66    """Creates a new Chrome Trace event.
67
68    For details of the file format, see:
69    https://github.com/catapult-project/catapult/blob/master/tracing/README.md
70
71    Args:
72      ph:  The type of event - usually a single character.
73      category: The event category as a string.
74      name:  The event name as a string.
75      pid:  Identifier of the process generating this event as an integer.
76      tid:  Identifier of the thread generating this event as an integer.
77      timestamp:  The timestamp of this event as a long integer.
78
79    Returns:
80      A JSON compatible event object.
81    """
82    event = {}
83    event['ph'] = ph
84    event['cat'] = category
85    event['name'] = name
86    event['pid'] = pid
87    event['tid'] = tid
88    event['ts'] = timestamp
89    return event
90
91  def emit_pid(self, name, pid):
92    """Adds a process metadata event to the trace.
93
94    Args:
95      name:  The process name as a string.
96      pid:  Identifier of the process as an integer.
97    """
98    event = {}
99    event['name'] = 'process_name'
100    event['ph'] = 'M'
101    event['pid'] = pid
102    event['args'] = {'name': name}
103    self._metadata.append(event)
104
105  def emit_tid(self, name, pid, tid):
106    """Adds a thread metadata event to the trace.
107
108    Args:
109      name:  The thread name as a string.
110      pid:  Identifier of the process as an integer.
111      tid:  Identifier of the thread as an integer.
112    """
113    event = {}
114    event['name'] = 'thread_name'
115    event['ph'] = 'M'
116    event['pid'] = pid
117    event['tid'] = tid
118    event['args'] = {'name': name}
119    self._metadata.append(event)
120
121  def emit_region(self, timestamp, duration, pid, tid, category, name, args):
122    """Adds a region event to the trace.
123
124    Args:
125      timestamp:  The start timestamp of this region as a long integer.
126      duration:  The duration of this region as a long integer.
127      pid:  Identifier of the process generating this event as an integer.
128      tid:  Identifier of the thread generating this event as an integer.
129      category: The event category as a string.
130      name:  The event name as a string.
131      args:  A JSON-compatible dictionary of event arguments.
132    """
133    event = self._create_event('X', category, name, pid, tid, timestamp)
134    event['dur'] = duration
135    event['args'] = args
136    self._events.append(event)
137
138  def emit_obj_create(self, category, name, timestamp, pid, tid, object_id):
139    """Adds an object creation event to the trace.
140
141    Args:
142      category: The event category as a string.
143      name:  The event name as a string.
144      timestamp:  The timestamp of this event as a long integer.
145      pid:  Identifier of the process generating this event as an integer.
146      tid:  Identifier of the thread generating this event as an integer.
147      object_id: Identifier of the object as an integer.
148    """
149    event = self._create_event('N', category, name, pid, tid, timestamp)
150    event['id'] = object_id
151    self._events.append(event)
152
153  def emit_obj_delete(self, category, name, timestamp, pid, tid, object_id):
154    """Adds an object deletion event to the trace.
155
156    Args:
157      category: The event category as a string.
158      name:  The event name as a string.
159      timestamp:  The timestamp of this event as a long integer.
160      pid:  Identifier of the process generating this event as an integer.
161      tid:  Identifier of the thread generating this event as an integer.
162      object_id: Identifier of the object as an integer.
163    """
164    event = self._create_event('D', category, name, pid, tid, timestamp)
165    event['id'] = object_id
166    self._events.append(event)
167
168  def emit_obj_snapshot(self, category, name, timestamp, pid, tid, object_id,
169                        snapshot):
170    """Adds an object snapshot event to the trace.
171
172    Args:
173      category: The event category as a string.
174      name:  The event name as a string.
175      timestamp:  The timestamp of this event as a long integer.
176      pid:  Identifier of the process generating this event as an integer.
177      tid:  Identifier of the thread generating this event as an integer.
178      object_id: Identifier of the object as an integer.
179      snapshot:  A JSON-compatible representation of the object.
180    """
181    event = self._create_event('O', category, name, pid, tid, timestamp)
182    event['id'] = object_id
183    event['args'] = {'snapshot': snapshot}
184    self._events.append(event)
185
186  def emit_flow_start(self, name, timestamp, pid, tid, flow_id):
187    """Adds a flow start event to the trace.
188
189    When matched with a flow end event (with the same 'flow_id') this will
190    cause the trace viewer to draw an arrow between the start and end events.
191
192    Args:
193      name:  The event name as a string.
194      timestamp:  The timestamp of this event as a long integer.
195      pid:  Identifier of the process generating this event as an integer.
196      tid:  Identifier of the thread generating this event as an integer.
197      flow_id: Identifier of the flow as an integer.
198    """
199    event = self._create_event('s', 'DataFlow', name, pid, tid, timestamp)
200    event['id'] = flow_id
201    self._events.append(event)
202
203  def emit_flow_end(self, name, timestamp, pid, tid, flow_id):
204    """Adds a flow end event to the trace.
205
206    When matched with a flow start event (with the same 'flow_id') this will
207    cause the trace viewer to draw an arrow between the start and end events.
208
209    Args:
210      name:  The event name as a string.
211      timestamp:  The timestamp of this event as a long integer.
212      pid:  Identifier of the process generating this event as an integer.
213      tid:  Identifier of the thread generating this event as an integer.
214      flow_id: Identifier of the flow as an integer.
215    """
216    event = self._create_event('t', 'DataFlow', name, pid, tid, timestamp)
217    event['id'] = flow_id
218    self._events.append(event)
219
220  def emit_counter(self, category, name, pid, timestamp, counter, value):
221    """Emits a record for a single counter.
222
223    Args:
224      category: The event category as a string.
225      name:  The event name as a string.
226      pid:  Identifier of the process generating this event as an integer.
227      timestamp:  The timestamp of this event as a long integer.
228      counter: Name of the counter as a string.
229      value:  Value of the counter as an integer.
230    """
231    event = self._create_event('C', category, name, pid, 0, timestamp)
232    event['args'] = {counter: value}
233    self._events.append(event)
234
235  def emit_counters(self, category, name, pid, timestamp, counters):
236    """Emits a counter record for the dictionary 'counters'.
237
238    Args:
239      category: The event category as a string.
240      name:  The event name as a string.
241      pid:  Identifier of the process generating this event as an integer.
242      timestamp:  The timestamp of this event as a long integer.
243      counters: Dictionary of counter values.
244    """
245    event = self._create_event('C', category, name, pid, 0, timestamp)
246    event['args'] = counters.copy()
247    self._events.append(event)
248
249  def format_to_string(self, pretty=False):
250    """Formats the chrome trace to a string.
251
252    Args:
253      pretty: (Optional.)  If True, produce human-readable JSON output.
254
255    Returns:
256      A JSON-formatted string in Chrome Trace format.
257    """
258    trace = {}
259    trace['traceEvents'] = self._metadata + self._events
260    if pretty:
261      return json.dumps(trace, indent=4, separators=(',', ': '))
262    else:
263      return json.dumps(trace, separators=(',', ':'))
264
265
266class _TensorTracker(object):
267  """An internal class to track the lifetime of a Tensor."""
268
269  def __init__(self, name, object_id, timestamp, pid, allocator, num_bytes):
270    """Creates an object to track tensor references.
271
272    This class is not thread safe and is intended only for internal use by
273    the 'Timeline' class in this file.
274
275    Args:
276      name:  The name of the Tensor as a string.
277      object_id:  Chrome Trace object identifier assigned for this Tensor.
278      timestamp:  The creation timestamp of this event as a long integer.
279      pid:  Process identifier of the associated device, as an integer.
280      allocator:  Name of the allocator used to create the Tensor.
281      num_bytes:  Number of bytes allocated (long integer).
282
283    Returns:
284      A 'TensorTracker' object.
285    """
286    self._name = name
287    self._pid = pid
288    self._object_id = object_id
289    self._create_time = timestamp
290    self._allocator = allocator
291    self._num_bytes = num_bytes
292    self._ref_times = []
293    self._unref_times = []
294
295  @property
296  def name(self):
297    """Name of this tensor."""
298    return self._name
299
300  @property
301  def pid(self):
302    """ID of the process which created this tensor (an integer)."""
303    return self._pid
304
305  @property
306  def create_time(self):
307    """Timestamp when this tensor was created (long integer)."""
308    return self._create_time
309
310  @property
311  def object_id(self):
312    """Returns the object identifier of this tensor (integer)."""
313    return self._object_id
314
315  @property
316  def num_bytes(self):
317    """Size of this tensor in bytes (long integer)."""
318    return self._num_bytes
319
320  @property
321  def allocator(self):
322    """Name of the allocator used to create this tensor (string)."""
323    return self._allocator
324
325  @property
326  def last_unref(self):
327    """Last unreference timestamp of this tensor (long integer)."""
328    return max(self._unref_times)
329
330  def add_ref(self, timestamp):
331    """Adds a reference to this tensor with the specified timestamp.
332
333    Args:
334      timestamp:  Timestamp of object reference as an integer.
335    """
336    self._ref_times.append(timestamp)
337
338  def add_unref(self, timestamp):
339    """Adds an unref to this tensor with the specified timestamp.
340
341    Args:
342      timestamp:  Timestamp of object unreference as an integer.
343    """
344    self._unref_times.append(timestamp)
345
346
347class Timeline(object):
348  """A class for visualizing execution timelines of TensorFlow steps."""
349
350  def __init__(self, step_stats, graph=None):
351    """Constructs a new Timeline.
352
353    A 'Timeline' is used for visualizing the execution of a TensorFlow
354    computation.  It shows the timings and concurrency of execution at
355    the granularity of TensorFlow Ops.
356    This class is not thread safe.
357
358    Args:
359      step_stats: The 'StepStats' proto recording execution times.
360      graph: (Optional) The 'Graph' that was executed.
361    """
362
363    self._origin_step_stats = step_stats
364    self._step_stats = None
365    self._graph = graph
366    self._chrome_trace = _ChromeTraceFormatter()
367    self._next_pid = 0
368    self._device_pids = {}  # device name -> pid for compute activity.
369    self._tensor_pids = {}  # device name -> pid for tensors.
370    self._tensors = {}  # tensor_name -> TensorTracker
371    self._next_flow_id = 0
372    self._flow_starts = {}  # tensor_name -> (timestamp, pid, tid)
373    self._alloc_times = {}  # tensor_name -> ( time, allocator, size )
374    self._allocator_maximums = {}  # allocator name => maximum bytes long
375
376  def _alloc_pid(self):
377    """Allocate a process Id."""
378    pid = self._next_pid
379    self._next_pid += 1
380    return pid
381
382  def _alloc_flow_id(self):
383    """Allocate a flow Id."""
384    flow_id = self._next_flow_id
385    self._next_flow_id += 1
386    return flow_id
387
388  def _parse_op_label(self, label):
389    """Parses the fields in a node timeline label."""
390    # Expects labels of the form: name = op(arg, arg, ...).
391    match = re.match(r'(.*) = (.*)\((.*)\)', label)
392    if match is None:
393      return 'unknown', 'unknown', []
394    nn, op, inputs = match.groups()
395    if not inputs:
396      inputs = []
397    else:
398      inputs = inputs.split(', ')
399    return nn, op, inputs
400
401  def _parse_kernel_label(self, label, node_name):
402    """Parses the fields in a node timeline label."""
403    # Expects labels of the form: retval (arg) detail @@annotation
404    start = label.find('@@')
405    end = label.find('#')
406    if start >= 0 and end >= 0 and start + 2 < end:
407      node_name = label[start + 2:end]
408    # Node names should always have the form 'name:op'.
409    fields = node_name.split(':') + ['unknown']
410    name, op = fields[:2]
411    return name, op
412
413  def _assign_lanes(self):
414    """Assigns non-overlapping lanes for the activities on each device."""
415    for device_stats in self._step_stats.dev_stats:
416      # TODO(pbar): Genuine thread IDs in NodeExecStats might be helpful.
417      lanes = [0]
418      for ns in device_stats.node_stats:
419        l = -1
420        for (i, lts) in enumerate(lanes):
421          if ns.all_start_micros > lts:
422            l = i
423            lanes[l] = ns.all_start_micros + ns.all_end_rel_micros
424            break
425        if l < 0:
426          l = len(lanes)
427          lanes.append(ns.all_start_micros + ns.all_end_rel_micros)
428        ns.thread_id = l
429
430  def _emit_op(self, nodestats, pid, is_gputrace):
431    """Generates a Chrome Trace event to show Op execution.
432
433    Args:
434      nodestats: The 'NodeExecStats' proto recording op execution.
435      pid: The pid assigned for the device where this op ran.
436      is_gputrace: If True then this op came from the GPUTracer.
437    """
438    node_name = nodestats.node_name
439    start = nodestats.all_start_micros
440    duration = nodestats.all_end_rel_micros
441    tid = nodestats.thread_id
442    inputs = []
443    if is_gputrace:
444      node_name, op = self._parse_kernel_label(nodestats.timeline_label,
445                                               node_name)
446    elif node_name == 'RecvTensor':
447      # RPC tracing does not use the standard timeline_label format.
448      op = 'RecvTensor'
449    else:
450      _, op, inputs = self._parse_op_label(nodestats.timeline_label)
451    args = {'name': node_name, 'op': op}
452    if build_info.build_info['is_rocm_build']:
453      args['kernel'] = nodestats.timeline_label.split('@@')[0]
454    for i, iname in enumerate(inputs):
455      args['input%d' % i] = iname
456    self._chrome_trace.emit_region(start, duration, pid, tid, 'Op', op, args)
457
458  def _emit_tensor_snapshot(self, tensor, timestamp, pid, tid, value):
459    """Generate Chrome Trace snapshot event for a computed Tensor.
460
461    Args:
462      tensor: A 'TensorTracker' object.
463      timestamp:  The timestamp of this snapshot as a long integer.
464      pid: The pid assigned for showing the device where this op ran.
465      tid: The tid of the thread computing the tensor snapshot.
466      value: A JSON-compliant snapshot of the object.
467    """
468    desc = str(value.tensor_description).replace('"', '')
469    snapshot = {'tensor_description': desc}
470    self._chrome_trace.emit_obj_snapshot('Tensor', tensor.name, timestamp, pid,
471                                         tid, tensor.object_id, snapshot)
472
473  def _produce_tensor(self, name, timestamp, tensors_pid, allocator, num_bytes):
474    object_id = len(self._tensors)
475    tensor = _TensorTracker(name, object_id, timestamp, tensors_pid, allocator,
476                            num_bytes)
477    self._tensors[name] = tensor
478    return tensor
479
480  def _is_gputrace_device(self, device_name):
481    """Returns true if this device is part of the GPUTracer logging."""
482    return '/stream:' in device_name or '/memcpy' in device_name
483
484  def _allocate_pids(self):
485    """Allocate fake process ids for each device in the StepStats."""
486    self._allocators_pid = self._alloc_pid()
487    self._chrome_trace.emit_pid('Allocators', self._allocators_pid)
488
489    # Add processes in the Chrome trace to show compute and data activity.
490    for dev_stats in self._step_stats.dev_stats:
491      device_pid = self._alloc_pid()
492      self._device_pids[dev_stats.device] = device_pid
493      tensors_pid = self._alloc_pid()
494      self._tensor_pids[dev_stats.device] = tensors_pid
495      self._chrome_trace.emit_pid(dev_stats.device + ' Compute', device_pid)
496      self._chrome_trace.emit_pid(dev_stats.device + ' Tensors', tensors_pid)
497
498  def _analyze_tensors(self, show_memory):
499    """Analyze tensor references to track dataflow."""
500    for dev_stats in self._step_stats.dev_stats:
501      device_pid = self._device_pids[dev_stats.device]
502      tensors_pid = self._tensor_pids[dev_stats.device]
503      for node_stats in dev_stats.node_stats:
504        tid = node_stats.thread_id
505        node_name = node_stats.node_name
506        start_time = node_stats.all_start_micros
507        end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
508        for index, output in enumerate(node_stats.output):
509          if index:
510            output_name = '%s:%d' % (node_name, index)
511          else:
512            output_name = node_name
513
514          allocation = output.tensor_description.allocation_description
515          num_bytes = allocation.requested_bytes
516          allocator_name = allocation.allocator_name
517          tensor = self._produce_tensor(output_name, start_time, tensors_pid,
518                                        allocator_name, num_bytes)
519          tensor.add_ref(start_time)
520          tensor.add_unref(end_time)
521          self._flow_starts[output_name] = (end_time, device_pid, tid)
522
523          if show_memory:
524            self._chrome_trace.emit_obj_create('Tensor', output_name,
525                                               start_time, tensors_pid, tid,
526                                               tensor.object_id)
527            self._emit_tensor_snapshot(tensor, end_time - 1, tensors_pid, tid,
528                                       output)
529
530  def _show_compute(self, show_dataflow):
531    """Visualize the computation activity."""
532    for dev_stats in self._step_stats.dev_stats:
533      device_name = dev_stats.device
534      device_pid = self._device_pids[device_name]
535      is_gputrace = self._is_gputrace_device(device_name)
536
537      for node_stats in dev_stats.node_stats:
538        tid = node_stats.thread_id
539        start_time = node_stats.all_start_micros
540        end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
541        self._emit_op(node_stats, device_pid, is_gputrace)
542
543        if is_gputrace or node_stats.node_name == 'RecvTensor':
544          continue
545
546        _, _, inputs = self._parse_op_label(node_stats.timeline_label)
547        for input_name in inputs:
548          if input_name not in self._tensors:
549            # This can happen when partitioning has inserted a Send/Recv.
550            # We remove the numeric suffix so that the dataflow appears to
551            # come from the original node.  Ideally, the StepStats would
552            # contain logging for the Send and Recv nodes.
553            index = input_name.rfind('/_')
554            if index > 0:
555              input_name = input_name[:index]
556
557          if input_name in self._tensors:
558            tensor = self._tensors[input_name]
559            tensor.add_ref(start_time)
560            tensor.add_unref(end_time - 1)
561
562            if show_dataflow:
563              # We use a different flow ID for every graph edge.
564              create_time, create_pid, create_tid = self._flow_starts[
565                  input_name]
566              # Don't add flows when producer and consumer ops are on the same
567              # pid/tid since the horizontal arrows clutter the visualization.
568              if create_pid != device_pid or create_tid != tid:
569                flow_id = self._alloc_flow_id()
570                self._chrome_trace.emit_flow_start(input_name, create_time,
571                                                   create_pid, create_tid,
572                                                   flow_id)
573                self._chrome_trace.emit_flow_end(input_name, start_time,
574                                                 device_pid, tid, flow_id)
575          else:
576            logging.vlog(1, 'Can\'t find tensor %s - removed by CSE?',
577                         input_name)
578
579  def _show_memory_counters(self):
580    """Produce a counter series for each memory allocator."""
581    # Iterate over all tensor trackers to build a list of allocations and
582    # frees for each allocator. Then sort the lists and emit a cumulative
583    # counter series for each allocator.
584    allocations = {}
585    for name in self._tensors:
586      tensor = self._tensors[name]
587      self._chrome_trace.emit_obj_delete('Tensor', name, tensor.last_unref,
588                                         tensor.pid, 0, tensor.object_id)
589      allocator = tensor.allocator
590      if allocator not in allocations:
591        allocations[allocator] = []
592      num_bytes = tensor.num_bytes
593      allocations[allocator].append((tensor.create_time, num_bytes, name))
594      allocations[allocator].append((tensor.last_unref, -num_bytes, name))
595
596    alloc_maxes = {}
597
598    # Generate a counter series showing total allocations for each allocator.
599    for allocator in allocations:
600      alloc_list = allocations[allocator]
601      alloc_list.sort()
602      total_bytes = 0
603      alloc_tensor_set = set()
604      alloc_maxes[allocator] = AllocationMaximum(
605          timestamp=0, num_bytes=0, tensors=set())
606      for time, num_bytes, name in sorted(
607          alloc_list, key=lambda allocation: allocation[0]):
608        total_bytes += num_bytes
609        if num_bytes < 0:
610          alloc_tensor_set.discard(name)
611        else:
612          alloc_tensor_set.add(name)
613
614        if total_bytes > alloc_maxes[allocator].num_bytes:
615          alloc_maxes[allocator] = AllocationMaximum(
616              timestamp=time,
617              num_bytes=total_bytes,
618              tensors=copy.deepcopy(alloc_tensor_set))
619
620        self._chrome_trace.emit_counter('Memory', allocator,
621                                        self._allocators_pid, time, allocator,
622                                        total_bytes)
623    self._allocator_maximums = alloc_maxes
624
625  def _preprocess_op_time(self, op_time):
626    """Update the start and end time of ops in step stats.
627
628    Args:
629    op_time: How the execution time of op is shown in timeline. Possible values
630      are "schedule", "gpu" and "all". "schedule" will show op from the time it
631      is scheduled to the end of the scheduling. Notice by the end of its
632      scheduling its async kernels may not start yet. It is shown using the
633      default value from step_stats. "gpu" will show op with the execution time
634      of its kernels on GPU. "all" will show op from the start of its scheduling
635      to the end of its last kernel.
636    """
637    if op_time == 'schedule':
638      self._step_stats = self._origin_step_stats
639      return
640    self._step_stats = copy.deepcopy(self._origin_step_stats)
641    # Separate job task and gpu tracer stream
642    stream_all_stats = []
643    job_stats = []
644    for stats in self._step_stats.dev_stats:
645      if '/stream:all' in stats.device:
646        stream_all_stats.append(stats)
647      elif '/job' in stats.device:
648        job_stats.append(stats)
649
650    # Record the start time of the first kernel and the end time of
651    # the last gpu kernel for all ops.
652    op_gpu_start = {}
653    op_gpu_end = {}
654    for stats in stream_all_stats:
655      for kernel in stats.node_stats:
656        name, _ = self._parse_kernel_label(kernel.timeline_label,
657                                           kernel.node_name)
658        start = kernel.all_start_micros
659        end = kernel.all_start_micros + kernel.all_end_rel_micros
660        if name in op_gpu_start:
661          op_gpu_start[name] = min(op_gpu_start[name], start)
662          op_gpu_end[name] = max(op_gpu_end[name], end)
663        else:
664          op_gpu_start[name] = start
665          op_gpu_end[name] = end
666
667    # Update the start and end time of each op according to the op_time
668    for stats in job_stats:
669      for op in stats.node_stats:
670        if op.node_name in op_gpu_start:
671          end = max(op_gpu_end[op.node_name],
672                    op.all_start_micros + op.all_end_rel_micros)
673          if op_time == 'gpu':
674            op.all_start_micros = op_gpu_start[op.node_name]
675          op.all_end_rel_micros = end - op.all_start_micros
676
677  def analyze_step_stats(self,
678                         show_dataflow=True,
679                         show_memory=True,
680                         op_time='schedule'):
681    """Analyze the step stats and format it into Chrome Trace Format.
682
683    Args:
684      show_dataflow: (Optional.) If True, add flow events to the trace
685        connecting producers and consumers of tensors.
686      show_memory: (Optional.) If True, add object snapshot events to the trace
687        showing the sizes and lifetimes of tensors.
688      op_time: (Optional.) How the execution time of op is shown in timeline.
689        Possible values are "schedule", "gpu" and "all". "schedule" will show op
690        from the time it is scheduled to the end of the scheduling. Notice by
691        the end of its scheduling its async kernels may not start yet. It is
692        shown using the default value from step_stats. "gpu" will show op with
693        the execution time of its kernels on GPU. "all" will show op from the
694        start of its scheduling to the end of its last kernel.
695
696    Returns:
697      A 'StepStatsAnalysis' object.
698    """
699    self._preprocess_op_time(op_time)
700    self._allocate_pids()
701    self._assign_lanes()
702    self._analyze_tensors(show_memory)
703    self._show_compute(show_dataflow)
704    if show_memory:
705      self._show_memory_counters()
706    return StepStatsAnalysis(
707        chrome_trace=self._chrome_trace,
708        allocator_maximums=self._allocator_maximums)
709
710  def generate_chrome_trace_format(self,
711                                   show_dataflow=True,
712                                   show_memory=False,
713                                   op_time='schedule'):
714    """Produces a trace in Chrome Trace Format.
715
716    Args:
717      show_dataflow: (Optional.) If True, add flow events to the trace
718        connecting producers and consumers of tensors.
719      show_memory: (Optional.) If True, add object snapshot events to the trace
720        showing the sizes and lifetimes of tensors.
721      op_time: (Optional.) How the execution time of op is shown in timeline.
722        Possible values are "schedule", "gpu" and "all".
723        "schedule" will show op from the time it is scheduled to the end of
724          the scheduling.
725          Notice by the end of its scheduling its async kernels may not start
726          yet. It is shown using the default value from step_stats.
727        "gpu" will show op with the execution time of its kernels on GPU.
728        "all" will show op from the start of its scheduling to the end of
729          its last kernel.
730
731    Returns:
732      A JSON formatted string in Chrome Trace format.
733    """
734    step_stats_analysis = self.analyze_step_stats(
735        show_dataflow=show_dataflow, show_memory=show_memory, op_time=op_time)
736
737    return step_stats_analysis.chrome_trace.format_to_string(pretty=True)
738