• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Trace allows the profiler to trace Python events."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23from tensorflow.python.profiler.internal import _pywrap_traceme
24from tensorflow.python.util.tf_export import tf_export
25
26# This variable is modified by PythonHooks::Start/Stop() in C++. Such
27# arrangement will reduce the number of calls through pybind11.
28enabled = False
29
30
31@tf_export('profiler.experimental.Trace', v1=[])
32class Trace(object):
33  """Context manager that generates a trace event in the profiler.
34
35  A trace event will start when entering the context, and stop and save the
36  result to the profiler when exiting the context. Open TensorBoard Profile tab
37  and choose trace viewer to view the trace event in the timeline.
38
39  Trace events are created only when the profiler is enabled. More information
40  on how to use the profiler can be found at
41  https://tensorflow.org/guide/profiler
42
43  Example usage:
44  ```python
45  tf.profiler.experimental.start('logdir')
46  for step in range(num_steps):
47    # Creates a trace event for each training step with the step number.
48    with tf.profiler.experimental.Trace("Train", step_num=step):
49      train_fn()
50  tf.profiler.experimental.stop()
51  ```
52  """
53
54  def __init__(self, name, **kwargs):
55    """Creates a trace event in the profiler.
56
57    Args:
58      name: The name of the trace event.
59      **kwargs: Keyword arguments added to the trace event.
60                Both the key and value are of types that
61                can be converted to strings, which will be
62                interpreted by the profiler according to the
63                traceme name.
64
65      Example usage:
66
67      ```python
68
69        tf.profiler.experimental.start('logdir')
70        for step in range(num_steps):
71          # Creates a trace event for each training step with the
72          # step number.
73          with tf.profiler.experimental.Trace("Train", step_num=step):
74            train_fn()
75        tf.profiler.experimental.stop()
76
77      ```
78      The example above uses the keyword argument "step_num" to specify the
79      training step being traced.
80    """
81    if enabled:
82      # Creating _pywrap_traceme.TraceMe starts the clock.
83      self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
84    else:
85      self._traceme = None
86
87  def __enter__(self):
88    # Starting the TraceMe clock here would require an extra Python->C++ call.
89    return self
90
91  def set_metadata(self, **kwargs):
92    """Sets metadata in this trace event.
93
94    Args:
95      **kwargs: metadata in key-value pairs.
96
97    This method enables setting metadata in a trace event after it is
98    created.
99
100    Example usage:
101
102    ```python
103
104      def call(function):
105        with tf.profiler.experimental.Trace("call",
106             function_name=function.name) as tm:
107          binary, in_cache = jit_compile(function)
108          tm.set_metadata(in_cache=in_cache)
109          execute(binary)
110
111    ```
112    In this example, we want to trace how much time spent on
113    calling a function, which includes compilation and execution.
114    The compilation can be either getting a cached copy of the
115    binary or actually generating the binary, which is indicated
116    by the boolean "in_cache" returned by jit_compile(). We need
117    to use set_metadata() to pass in_cache because we did not know
118    the in_cache value when the trace was created (and we cannot
119    create the trace after jit_compile(), because we want
120    to measure the entire duration of call()).
121    """
122    if self._traceme and kwargs:
123      self._traceme.SetMetadata(**kwargs)
124
125  def __exit__(self, exc_type, exc_val, exc_tb):
126    if self._traceme:
127      self._traceme.Stop()
128
129
130def trace_wrapper(trace_name, **trace_kwargs):
131  """Decorator alternative to `with Trace(): ...`.  It's faster.
132
133  Args:
134    trace_name: The name of the trace event.
135    **trace_kwargs: Keyword arguments added to the trace event. Both the key and
136      value are of types that can be converted to strings, which will be
137      interpreted by the profiler according to the traceme name.
138
139  Returns:
140    A decorator that can wrap a function and apply `Trace` scope if needed.
141
142  Example usage:
143    ```python
144
145    @trace_wrapper('trace_name')
146    def func(x, y, z):
147      pass  # code to execute and apply `Trace` if needed.
148
149    # Equivalent to
150    # with Trace('trace_name'):
151    #   func(1, 2, 3)
152    func(1, 2, 3)
153    ```
154  """
155
156  def inner_wrapper(func):
157
158    @functools.wraps(func)
159    def wrapped(*args, **kwargs):
160      if enabled:
161        with Trace(trace_name, **trace_kwargs):
162          return func(*args, **kwargs)
163      return func(*args, **kwargs)
164
165    return wrapped
166
167  return inner_wrapper
168