1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_PLATFORM_TRACING_H_
17 #define TENSORFLOW_CORE_PLATFORM_TRACING_H_
18
19 // Tracing interface
20
21 #include <array>
22 #include <atomic>
23 #include <map>
24 #include <memory>
25
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/platform.h"
31 #include "tensorflow/core/platform/types.h"
32
33 namespace tensorflow {
34 namespace tracing {
35
36 // This enumeration contains the identifiers of all TensorFlow CPU profiler
37 // events. It must be kept in sync with the code in GetEventCategoryName().
38 enum struct EventCategory : unsigned {
39 kScheduleClosure = 0,
40 kRunClosure = 1,
41 kCompute = 2,
42 kNumCategories = 3 // sentinel - keep last
43 };
GetNumEventCategories()44 constexpr unsigned GetNumEventCategories() {
45 return static_cast<unsigned>(EventCategory::kNumCategories);
46 }
47 const char* GetEventCategoryName(EventCategory);
48
49 // Interface for CPU profiler events.
50 class EventCollector {
51 public:
~EventCollector()52 virtual ~EventCollector() {}
53 virtual void RecordEvent(uint64 arg) const = 0;
54 virtual void StartRegion(uint64 arg) const = 0;
55 virtual void StopRegion() const = 0;
56
57 // Annotates the current thread with a name.
58 static void SetCurrentThreadName(const char* name);
59 // Returns whether event collection is enabled.
60 static bool IsEnabled();
61
62 private:
63 friend void SetEventCollector(EventCategory, const EventCollector*);
64 friend const EventCollector* GetEventCollector(EventCategory);
65
66 static std::array<const EventCollector*, GetNumEventCategories()> instances_;
67 };
68 // Set the callback for RecordEvent and ScopedRegion of category.
69 // Not thread safe. Only call while EventCollector::IsEnabled returns false.
70 void SetEventCollector(EventCategory category, const EventCollector* collector);
71
72 // Returns the callback for RecordEvent and ScopedRegion of category if
73 // EventCollector::IsEnabled(), otherwise returns null.
GetEventCollector(EventCategory category)74 inline const EventCollector* GetEventCollector(EventCategory category) {
75 if (EventCollector::IsEnabled()) {
76 return EventCollector::instances_[static_cast<unsigned>(category)];
77 }
78 return nullptr;
79 }
80
81 // Returns a unique id to pass to RecordEvent/ScopedRegion. Never returns zero.
82 uint64 GetUniqueArg();
83
84 // Returns an id for name to pass to RecordEvent/ScopedRegion.
85 uint64 GetArgForName(StringPiece name);
86
87 // Records an atomic event through the currently registered EventCollector.
RecordEvent(EventCategory category,uint64 arg)88 inline void RecordEvent(EventCategory category, uint64 arg) {
89 if (auto collector = GetEventCollector(category)) {
90 collector->RecordEvent(arg);
91 }
92 }
93
94 // Records an event for the duration of the instance lifetime through the
95 // currently registered EventCollector.
96 class ScopedRegion {
97 ScopedRegion(ScopedRegion&) = delete; // Not copy-constructible.
98 ScopedRegion& operator=(ScopedRegion&) = delete; // Not assignable.
99
100 public:
ScopedRegion(ScopedRegion && other)101 ScopedRegion(ScopedRegion&& other) noexcept // Move-constructible.
102 : collector_(other.collector_) {
103 other.collector_ = nullptr;
104 }
105
ScopedRegion(EventCategory category,uint64 arg)106 ScopedRegion(EventCategory category, uint64 arg)
107 : collector_(GetEventCollector(category)) {
108 if (collector_) {
109 collector_->StartRegion(arg);
110 }
111 }
112
113 // Same as ScopedRegion(category, GetUniqueArg()), but faster if
114 // EventCollector::IsEnaled() returns false.
ScopedRegion(EventCategory category)115 ScopedRegion(EventCategory category)
116 : collector_(GetEventCollector(category)) {
117 if (collector_) {
118 collector_->StartRegion(GetUniqueArg());
119 }
120 }
121
122 // Same as ScopedRegion(category, GetArgForName(name)), but faster if
123 // EventCollector::IsEnaled() returns false.
ScopedRegion(EventCategory category,StringPiece name)124 ScopedRegion(EventCategory category, StringPiece name)
125 : collector_(GetEventCollector(category)) {
126 if (collector_) {
127 collector_->StartRegion(GetArgForName(name));
128 }
129 }
130
~ScopedRegion()131 ~ScopedRegion() {
132 if (collector_) {
133 collector_->StopRegion();
134 }
135 }
136
IsEnabled()137 bool IsEnabled() const { return collector_ != nullptr; }
138
139 private:
140 const EventCollector* collector_;
141 };
142
143 // Interface for accelerator profiler annotations.
144 class TraceCollector {
145 public:
146 class Handle {
147 public:
~Handle()148 virtual ~Handle() {}
149 };
150
~TraceCollector()151 virtual ~TraceCollector() {}
152 virtual std::unique_ptr<Handle> CreateAnnotationHandle(
153 StringPiece name_part1, StringPiece name_part2) const = 0;
154 virtual std::unique_ptr<Handle> CreateActivityHandle(
155 StringPiece name_part1, StringPiece name_part2,
156 bool is_expensive) const = 0;
157
158 // Returns true if this annotation tracing is enabled for any op.
159 virtual bool IsEnabledForAnnotations() const = 0;
160
161 // Returns true if this activity handle tracking is enabled for an op of the
162 // given expensiveness.
163 virtual bool IsEnabledForActivities(bool is_expensive) const = 0;
164
165 protected:
166 static string ConcatenateNames(StringPiece first, StringPiece second);
167
168 private:
169 friend void SetTraceCollector(const TraceCollector*);
170 friend const TraceCollector* GetTraceCollector();
171 };
172 // Set the callback for ScopedAnnotation and ScopedActivity.
173 void SetTraceCollector(const TraceCollector* collector);
174 // Returns the callback for ScopedAnnotation and ScopedActivity.
175 const TraceCollector* GetTraceCollector();
176
177 // Adds an annotation to all activities for the duration of the instance
178 // lifetime through the currently registered TraceCollector.
179 //
180 // Usage: {
181 // ScopedAnnotation annotation("my kernels");
182 // Kernel1<<<x,y>>>;
183 // LaunchKernel2(); // Launches a CUDA kernel.
184 // }
185 // This will add 'my kernels' to both kernels in the profiler UI
186 class ScopedAnnotation {
187 public:
ScopedAnnotation(StringPiece name)188 explicit ScopedAnnotation(StringPiece name)
189 : ScopedAnnotation(name, StringPiece()) {}
190
191 // If tracing is enabled, add a name scope of
192 // "<name_part1>:<name_part2>". This can be cheaper than the
193 // single-argument constructor because the concatenation of the
194 // label string is only done if tracing is enabled.
ScopedAnnotation(StringPiece name_part1,StringPiece name_part2)195 ScopedAnnotation(StringPiece name_part1, StringPiece name_part2)
196 : handle_([&] {
197 auto trace_collector = GetTraceCollector();
198 return trace_collector ? trace_collector->CreateAnnotationHandle(
199 name_part1, name_part2)
200 : nullptr;
201 }()) {}
202
IsEnabled()203 bool IsEnabled() const { return static_cast<bool>(handle_); }
204
205 private:
206 std::unique_ptr<TraceCollector::Handle> handle_;
207 };
208
209 // Adds an activity through the currently registered TraceCollector.
210 // The activity starts when an object of this class is created and stops when
211 // the object is destroyed.
212 class ScopedActivity {
213 public:
214 explicit ScopedActivity(StringPiece name, bool is_expensive = true)
ScopedActivity(name,StringPiece (),is_expensive)215 : ScopedActivity(name, StringPiece(), is_expensive) {}
216
217 // If tracing is enabled, set up an activity with a label of
218 // "<name_part1>:<name_part2>". This can be cheaper than the
219 // single-argument constructor because the concatenation of the
220 // label string is only done if tracing is enabled.
221 ScopedActivity(StringPiece name_part1, StringPiece name_part2,
222 bool is_expensive = true)
223 : handle_([&] {
224 auto trace_collector = GetTraceCollector();
225 return trace_collector ? trace_collector->CreateActivityHandle(
226 name_part1, name_part2, is_expensive)
227 : nullptr;
228 }()) {}
229
IsEnabled()230 bool IsEnabled() const { return static_cast<bool>(handle_); }
231
232 private:
233 std::unique_ptr<TraceCollector::Handle> handle_;
234 };
235
236 // Return the pathname of the directory where we are writing log files.
237 const char* GetLogDir();
238
239 } // namespace tracing
240 } // namespace tensorflow
241
242 #if defined(PLATFORM_GOOGLE)
243 #include "tensorflow/core/platform/google/tracing_impl.h"
244 #else
245 #include "tensorflow/core/platform/default/tracing_impl.h"
246 #endif
247
248 #endif // TENSORFLOW_CORE_PLATFORM_TRACING_H_
249