1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/device_tracer.h"
17
18 #if GOOGLE_CUDA
19
20 #include <stdlib.h>
21 #include <memory>
22
23 #include "absl/base/casts.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/node_hash_map.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/core/common_runtime/step_stats_collector.h"
29 #include "tensorflow/core/framework/step_stats.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/cupti_wrapper.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mem.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/tracing.h"
40 #include "tensorflow/core/profiler/internal/cpu/host_tracer.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42
43 namespace tensorflow {
44 namespace {
ToStatus(CUptiResult result)45 Status ToStatus(CUptiResult result) {
46 if (result == CUPTI_SUCCESS) {
47 return Status::OK();
48 }
49 const char* str = nullptr;
50 if (auto wrapper =
51 absl::make_unique<perftools::gputools::profiler::CuptiWrapper>()) {
52 wrapper->GetResultString(result, &str);
53 }
54 return errors::Unavailable("CUPTI error: ", str ? str : "<unknown>");
55 }
56
ToStatus(CUresult result)57 Status ToStatus(CUresult result) {
58 if (result == CUDA_SUCCESS) {
59 return Status::OK();
60 }
61 const char* str = nullptr;
62 cuGetErrorName(result, &str);
63 return errors::Unavailable("CUDA error: ", str ? str : "<unknown>");
64 }
65
LogIfError(const Status & status)66 void LogIfError(const Status& status) {
67 if (status.ok()) {
68 return;
69 }
70 LOG(ERROR) << status.error_message();
71 }
72
73 struct KernelRecord {
74 const char* kernel_name;
75 // TODO(csigg): cuStreamGetCtx introduced in CUDA 9.2 would allow us to only
76 // record the stream and infer the context during collection.
77 CUcontext context;
78 CUstream stream;
79 CUevent start_event;
80 CUevent stop_event;
81 const std::string* annotation;
82 };
83
84 struct MemcpyRecord {
85 CUmemorytype src_type;
86 CUmemorytype dst_type;
87 size_t size_bytes;
88 CUcontext context;
89 CUstream stream;
90 CUevent start_event;
91 CUevent stop_event;
92 const std::string* annotation;
93 };
94
CreateAndRecordEvent(CUevent * event,CUstream stream)95 Status CreateAndRecordEvent(CUevent* event, CUstream stream) {
96 TF_RETURN_IF_ERROR(ToStatus(cuEventCreate(event, CU_EVENT_DEFAULT)));
97 return ToStatus(cuEventRecord(*event, stream));
98 }
99
100 // Thread-local state recording the most recent annotation (if any).
101 // When non-null, this points to a string in the active annotation
102 // of the current thread. The annotation is guaranteed to remain live
103 // for the duration of the CUPTI API callback.
104 static thread_local const char* tls_current_annotation;
105
106 // Stores a series of kernel and memcpy records.
107 class CudaEventRecorder {
108 public:
109 // Registers the start of a kernel launch. The returned index should be passed
110 // to StopKernel() after the kernel launch has completed.
StartKernel(const char * kernel_name,CUcontext context,CUstream stream)111 size_t StartKernel(const char* kernel_name, CUcontext context,
112 CUstream stream) {
113 KernelRecord record = {kernel_name, context, stream};
114 LogIfError(CreateAndRecordEvent(&record.start_event, stream));
115 mutex_lock lock(mutex_);
116 if (tls_current_annotation) {
117 record.annotation = &*annotations_.emplace(tls_current_annotation).first;
118 }
119 kernel_records_.push_back(record);
120 return kernel_records_.size() - 1;
121 }
StopKernel(size_t index)122 void StopKernel(size_t index) {
123 mutex_lock lock(mutex_);
124 auto& record = kernel_records_[index];
125 LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream));
126 }
127
128 // Registers the start of a copy operation. The returned index should be
129 // passed to StopMemcpy() after the kernel launch has completed.
StartMemcpy(CUmemorytype src_type,CUmemorytype dst_type,size_t size_bytes,CUcontext context,CUstream stream)130 size_t StartMemcpy(CUmemorytype src_type, CUmemorytype dst_type,
131 size_t size_bytes, CUcontext context, CUstream stream) {
132 MemcpyRecord record = {src_type, dst_type, size_bytes, context, stream};
133 LogIfError(CreateAndRecordEvent(&record.start_event, stream));
134 mutex_lock lock(mutex_);
135 if (tls_current_annotation) {
136 record.annotation = &*annotations_.emplace(tls_current_annotation).first;
137 }
138 memcpy_records_.push_back(record);
139 return memcpy_records_.size() - 1;
140 }
StopMemcpy(size_t index)141 void StopMemcpy(size_t index) {
142 mutex_lock lock(mutex_);
143 auto& record = memcpy_records_[index];
144 LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream));
145 }
146
ConsumeKernelRecords()147 std::vector<KernelRecord> ConsumeKernelRecords() {
148 mutex_lock lock(mutex_);
149 return std::move(kernel_records_);
150 }
ConsumeMemcpyRecords()151 std::vector<MemcpyRecord> ConsumeMemcpyRecords() {
152 mutex_lock lock(mutex_);
153 return std::move(memcpy_records_);
154 }
155
156 private:
157 mutex mutex_;
158 std::unordered_set<std::string> annotations_ GUARDED_BY(mutex_);
159 std::vector<KernelRecord> kernel_records_ GUARDED_BY(mutex_);
160 std::vector<MemcpyRecord> memcpy_records_ GUARDED_BY(mutex_);
161 };
162
163 // Instances register callbacks with CUPTI to notify the event recorder before
164 // and after kernel launches and memory copies.
165 class CuptiCallbackHook {
166 public:
CuptiCallbackHook()167 CuptiCallbackHook()
168 : cupti_wrapper_(new perftools::gputools::profiler::CuptiWrapper()),
169 subscriber_(nullptr) {}
170
Enable(CudaEventRecorder * recorder)171 Status Enable(CudaEventRecorder* recorder) {
172 TF_RETURN_IF_ERROR(ToStatus(
173 cupti_wrapper_->Subscribe(&subscriber_, &CuptiCallback, recorder)));
174 for (auto cbid : {CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel,
175 CUPTI_DRIVER_TRACE_CBID_cuMemcpy,
176 CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync,
177 CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2,
178 CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2,
179 CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2,
180 CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2,
181 CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2,
182 CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2}) {
183 TF_RETURN_IF_ERROR(ToStatus(cupti_wrapper_->EnableCallback(
184 /*enable=*/1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)));
185 }
186 return Status::OK();
187 }
188
~CuptiCallbackHook()189 ~CuptiCallbackHook() {
190 LogIfError(ToStatus(cupti_wrapper_->Unsubscribe(subscriber_)));
191 }
192
193 private:
CuptiCallback(void * userdata,CUpti_CallbackDomain domain,CUpti_CallbackId cbid,const void * cbdata)194 static void CUPTIAPI CuptiCallback(void* userdata,
195 CUpti_CallbackDomain domain,
196 CUpti_CallbackId cbid,
197 const void* cbdata) {
198 auto recorder = static_cast<CudaEventRecorder*>(userdata);
199 auto data = static_cast<const CUpti_CallbackData*>(cbdata);
200 DCHECK_EQ(domain, CUPTI_CB_DOMAIN_DRIVER_API);
201
202 if (data->callbackSite == CUPTI_API_ENTER) {
203 DriverApiEnterCallback(cbid, *data, recorder);
204 } else {
205 DriverApiExitCallback(cbid, *data, recorder);
206 }
207 }
208
GetMemoryType(CUdeviceptr ptr)209 static CUmemorytype GetMemoryType(CUdeviceptr ptr) {
210 CUmemorytype mem_type;
211 auto status =
212 cuPointerGetAttribute(&mem_type, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, ptr);
213 if (status == CUDA_ERROR_INVALID_VALUE) {
214 // Pointer not registered with CUDA, must be host memory.
215 return CU_MEMORYTYPE_HOST;
216 }
217 LogIfError(ToStatus(status));
218 return mem_type;
219 }
220
221 template <typename T>
StartMemcpy(CUmemorytype src_type,CUmemorytype dst_type,const CUpti_CallbackData & cbdata,CudaEventRecorder * recorder)222 static void StartMemcpy(CUmemorytype src_type, CUmemorytype dst_type,
223 const CUpti_CallbackData& cbdata,
224 CudaEventRecorder* recorder) {
225 auto params = static_cast<const T*>(cbdata.functionParams);
226 *cbdata.correlationData = recorder->StartMemcpy(
227 src_type, dst_type, params->ByteCount, cbdata.context, nullptr);
228 }
229 template <typename T>
StartMemcpyAsync(CUmemorytype dst_type,CUmemorytype src_type,const CUpti_CallbackData & cbdata,CudaEventRecorder * recorder)230 static void StartMemcpyAsync(CUmemorytype dst_type, CUmemorytype src_type,
231 const CUpti_CallbackData& cbdata,
232 CudaEventRecorder* recorder) {
233 auto params = static_cast<const T*>(cbdata.functionParams);
234 *cbdata.correlationData = recorder->StartMemcpy(
235 src_type, dst_type, params->ByteCount, cbdata.context, params->hStream);
236 }
237
DriverApiEnterCallback(CUpti_CallbackId cbid,const CUpti_CallbackData & cbdata,CudaEventRecorder * recorder)238 static void DriverApiEnterCallback(CUpti_CallbackId cbid,
239 const CUpti_CallbackData& cbdata,
240 CudaEventRecorder* recorder) {
241 switch (cbid) {
242 case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: {
243 DCHECK_NE(cbdata.symbolName, nullptr);
244 auto params =
245 static_cast<const cuLaunchKernel_params*>(cbdata.functionParams);
246 *cbdata.correlationData = recorder->StartKernel(
247 cbdata.symbolName, cbdata.context, params->hStream);
248 return;
249 }
250
251 case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: {
252 auto params =
253 static_cast<const cuMemcpy_params*>(cbdata.functionParams);
254 return StartMemcpy<cuMemcpy_params>(GetMemoryType(params->src),
255 GetMemoryType(params->dst), cbdata,
256 recorder);
257 }
258 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: {
259 auto params =
260 static_cast<const cuMemcpyAsync_params*>(cbdata.functionParams);
261 return StartMemcpyAsync<cuMemcpyAsync_params>(
262 GetMemoryType(params->src), GetMemoryType(params->dst), cbdata,
263 recorder);
264 }
265
266 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2:
267 return StartMemcpy<cuMemcpyHtoD_v2_params>(
268 CU_MEMORYTYPE_HOST, CU_MEMORYTYPE_DEVICE, cbdata, recorder);
269
270 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2:
271 return StartMemcpyAsync<cuMemcpyHtoDAsync_v2_params>(
272 CU_MEMORYTYPE_HOST, CU_MEMORYTYPE_DEVICE, cbdata, recorder);
273
274 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2:
275 return StartMemcpy<cuMemcpyDtoH_v2_params>(
276 CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_HOST, cbdata, recorder);
277 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2:
278 return StartMemcpyAsync<cuMemcpyDtoHAsync_v2_params>(
279 CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_HOST, cbdata, recorder);
280
281 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2:
282 return StartMemcpy<cuMemcpyDtoD_v2_params>(
283 CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_DEVICE, cbdata, recorder);
284 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2:
285 return StartMemcpyAsync<cuMemcpyDtoDAsync_v2_params>(
286 CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_DEVICE, cbdata, recorder);
287
288 default:
289 LOG(ERROR) << "Unexpected callback id: " << cbid;
290 }
291 }
292
DriverApiExitCallback(CUpti_CallbackId cbid,const CUpti_CallbackData & cbdata,CudaEventRecorder * recorder)293 static void DriverApiExitCallback(CUpti_CallbackId cbid,
294 const CUpti_CallbackData& cbdata,
295 CudaEventRecorder* recorder) {
296 switch (cbid) {
297 case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel:
298 recorder->StopKernel(*cbdata.correlationData);
299 break;
300 case CUPTI_DRIVER_TRACE_CBID_cuMemcpy:
301 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync:
302 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2:
303 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2:
304 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2:
305 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2:
306 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2:
307 case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2:
308 recorder->StopMemcpy(*cbdata.correlationData);
309 break;
310 default:
311 LOG(ERROR) << "Unexpected callback id: " << cbid;
312 }
313 }
314
315 std::unique_ptr<perftools::gputools::profiler::CuptiWrapper> cupti_wrapper_;
316 CUpti_SubscriberHandle subscriber_;
317 };
318 } // namespace
319
320 class TraceCollectorImpl : public tracing::TraceCollector {
321 public:
322 class ActivityHandle : public Handle {
323 public:
ActivityHandle(std::string && name,int level)324 ActivityHandle(std::string&& name, int level)
325 : trace_me_(std::move(name), level) {}
326
327 private:
328 profiler::TraceMe trace_me_;
329 };
TraceCollectorImpl()330 TraceCollectorImpl() { tracing::SetTraceCollector(this); }
331
~TraceCollectorImpl()332 ~TraceCollectorImpl() override {
333 DCHECK(!active_trace_session_)
334 << "Unexpected active trace session detected. ";
335 }
336
337 // Note the method can be called after a call to Stop().
CreateAnnotationHandle(StringPiece name_part1,StringPiece name_part2) const338 virtual std::unique_ptr<Handle> CreateAnnotationHandle(
339 StringPiece name_part1, StringPiece name_part2) const {
340 struct Impl : public tracing::TraceCollector::Handle {
341 std::string annotation;
342 explicit Impl(std::string&& name_scope) : annotation(name_scope) {
343 VLOG(2) << "CreateAnnotationHandle " << annotation;
344 // Remember the most recent ScopedAnnotation for each thread.
345 tls_current_annotation = annotation.c_str();
346 }
347 ~Impl() override { tls_current_annotation = nullptr; }
348 };
349 return absl::make_unique<Impl>(ConcatenateNames(name_part1, name_part2));
350 }
351
CreateActivityHandle(StringPiece name_part1,StringPiece name_part2,bool is_expensive) const352 virtual std::unique_ptr<Handle> CreateActivityHandle(
353 StringPiece name_part1, StringPiece name_part2, bool is_expensive) const {
354 if (!IsEnabledForActivities(is_expensive)) {
355 return nullptr;
356 }
357 return absl::make_unique<ActivityHandle>(
358 ConcatenateNames(name_part1, name_part2), GetLevel(is_expensive));
359 }
360
IsEnabledForAnnotations() const361 bool IsEnabledForAnnotations() const override {
362 return active_trace_session_.load(std::memory_order_relaxed);
363 }
364
IsEnabledForActivities(bool is_expensive) const365 bool IsEnabledForActivities(bool is_expensive) const override {
366 return profiler::TraceMeRecorder::Active(GetLevel(is_expensive));
367 }
368
Start()369 void Start() {
370 DCHECK(!active_trace_session_)
371 << "Unexpected active trace session detected. ";
372 active_trace_session_ = true;
373 }
374
Stop()375 void Stop() {
376 DCHECK(active_trace_session_) << "No active trace session detected. ";
377 active_trace_session_ = false;
378 }
379
380 private:
GetLevel(bool is_expensive)381 static int GetLevel(bool is_expensive) {
382 return profiler::GetTFTraceMeLevel(is_expensive);
383 }
384
385 std::atomic<bool> active_trace_session_;
386 };
387
GlobalDefaultTraceCollector()388 TraceCollectorImpl* GlobalDefaultTraceCollector() {
389 static auto* instance = new TraceCollectorImpl();
390 return instance;
391 }
392
393 class DeviceTracerImpl : public DeviceTracer {
394 public:
395 DeviceTracerImpl();
396 ~DeviceTracerImpl() override;
397
398 // DeviceTracer interface:
399 Status Start() override;
400 Status Stop() override;
401 Status Collect(StepStatsCollector* collector) override;
402
403 private:
404 std::unique_ptr<CudaEventRecorder> recorder_;
405 std::unique_ptr<CuptiCallbackHook> cupti_hook_;
406
407 mutex mu_;
408 bool enabled_ GUARDED_BY(mu_);
409 std::unique_ptr<profiler::cpu::HostTracer> host_tracer_ GUARDED_BY(mu_);
410 };
411
DeviceTracerImpl()412 DeviceTracerImpl::DeviceTracerImpl() : recorder_(new CudaEventRecorder()) {
413 VLOG(1) << "DeviceTracer created.";
414 host_tracer_ = profiler::cpu::HostTracer::Create(2);
415 enabled_ = false;
416 }
417
~DeviceTracerImpl()418 DeviceTracerImpl::~DeviceTracerImpl() {
419 // Unregister the CUPTI callbacks if needed to prevent them from accessing
420 // freed memory.
421 Stop().IgnoreError();
422 }
423
Start()424 Status DeviceTracerImpl::Start() {
425 VLOG(1) << "DeviceTracer::Start";
426 mutex_lock l(mu_);
427 if (enabled_) {
428 return errors::FailedPrecondition("DeviceTracer is already enabled.");
429 }
430 cupti_hook_.reset(new CuptiCallbackHook());
431 TF_RETURN_IF_ERROR(cupti_hook_->Enable(recorder_.get()));
432
433 // Register as a TraceEngine to receive ScopedAnnotations.
434 GlobalDefaultTraceCollector()->Start();
435
436 host_tracer_->Start().IgnoreError();
437 enabled_ = true;
438 return Status::OK();
439 }
440
Stop()441 Status DeviceTracerImpl::Stop() {
442 VLOG(1) << "DeviceTracer::Stop";
443 mutex_lock l(mu_);
444 if (!enabled_) {
445 return Status::OK();
446 }
447 cupti_hook_.reset();
448 GlobalDefaultTraceCollector()->Stop();
449
450 enabled_ = false;
451 host_tracer_->Stop().IgnoreError();
452 return Status::OK();
453 }
454
455 namespace {
456 class CudaEventCollector {
457 struct DeviceInfo {
458 int ordinal;
459 std::string name;
460 int num_contexts;
461 };
462
463 struct ContextInfo {
464 int index;
465 const DeviceInfo* dev_info;
466 int num_streams;
467 CUevent end_event;
468 };
469
470 struct StreamInfo {
471 std::string name;
472 int index; // 0 is reserved for null stream.
473 const ContextInfo* ctx_info;
474 };
475
476 // Include context in key to distinguish null streams.
477 using StreamKey = std::pair<CUcontext, CUstream>;
478
CudaEventCollector(CudaEventRecorder * recorder,StepStatsCollector * collector)479 CudaEventCollector(CudaEventRecorder* recorder, StepStatsCollector* collector)
480 : recorder_(recorder), collector_(collector) {
481 DCHECK(recorder != nullptr);
482 DCHECK(collector != nullptr);
483 }
484
485 // Populates device_infos_ from all devices.
InitializeDeviceInfos()486 Status InitializeDeviceInfos() {
487 int count;
488 TF_RETURN_IF_ERROR(ToStatus(cuDeviceGetCount(&count)));
489 for (int ordinal = 0; ordinal < count; ++ordinal) {
490 CUdevice device;
491 TF_RETURN_IF_ERROR(ToStatus(cuDeviceGet(&device, ordinal)));
492 char name[100];
493 TF_RETURN_IF_ERROR(ToStatus(cuDeviceGetName(name, sizeof(name), device)));
494 device_infos_[device] = {ordinal, name};
495 }
496 return Status::OK();
497 }
498
499 // Returns element from context_infos_, adding it if not yet present.
GetContextInfo(CUcontext context,ContextInfo ** ctx_info_ptr)500 Status GetContextInfo(CUcontext context, ContextInfo** ctx_info_ptr) {
501 auto it = context_infos_.find(context);
502
503 if (it == context_infos_.end()) {
504 TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(context)));
505 CUdevice device;
506 TF_RETURN_IF_ERROR(ToStatus(cuCtxGetDevice(&device)));
507
508 auto& dev_info = device_infos_[device];
509 ContextInfo ctx_info = {dev_info.num_contexts++, &dev_info};
510 it = context_infos_.emplace(context, ctx_info).first;
511 }
512
513 *ctx_info_ptr = &it->second;
514 return Status::OK();
515 }
516
517 // Adds element to stream_infos_ if not yet present. If present, clear name
518 // if it doesn't match parameter.
AddStreamInfo(CUcontext context,CUstream stream,absl::string_view name)519 Status AddStreamInfo(CUcontext context, CUstream stream,
520 absl::string_view name) {
521 StreamKey key(context, stream);
522 auto it = stream_infos_.find(key);
523 if (it != stream_infos_.end()) {
524 if (it->second.name != name) {
525 it->second.name.clear(); // Stream with inconsistent names, clear it.
526 }
527 return Status::OK();
528 }
529
530 ContextInfo* ctx_info;
531 TF_RETURN_IF_ERROR(GetContextInfo(context, &ctx_info));
532 int index = stream ? ++ctx_info->num_streams : 0;
533 StreamInfo stream_info = {static_cast<std::string>(name), index, ctx_info};
534 stream_infos_.emplace(key, stream_info);
535 return Status::OK();
536 }
537
538 // Returns string describing source and destination memory types.
GetMemcpyName(const MemcpyRecord & record)539 static std::string GetMemcpyName(const MemcpyRecord& record) {
540 auto get_memory_type = [](CUmemorytype mem_type) {
541 switch (mem_type) {
542 case CU_MEMORYTYPE_HOST:
543 return 'H';
544 case CU_MEMORYTYPE_DEVICE:
545 return 'D';
546 case CU_MEMORYTYPE_ARRAY:
547 return 'A';
548 case CU_MEMORYTYPE_UNIFIED:
549 return 'U';
550 default:
551 LOG(ERROR) << "Unknown memory type: " << mem_type;
552 return '?';
553 }
554 };
555 return absl::StrFormat("Memcpy%cto%c", get_memory_type(record.src_type),
556 get_memory_type(record.dst_type));
557 }
558
559 // Returns time in microseconds between events recorded on the GPU.
GetElasedTimeUs(CUevent start,CUevent stop)560 static uint64_t GetElasedTimeUs(CUevent start, CUevent stop) {
561 float elapsed_ms = 0.0f;
562 LogIfError(ToStatus(cuEventElapsedTime(&elapsed_ms, start, stop)));
563 return static_cast<uint64>(
564 std::llroundf(1000 * std::max(elapsed_ms, 0.0f)));
565 }
566
567 // Synchronizes all contexts.
Synchronize() const568 Status Synchronize() const {
569 for (const auto& pair : context_infos_) {
570 TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first)));
571 TF_RETURN_IF_ERROR(ToStatus(cuCtxSynchronize()));
572 }
573 return Status::OK();
574 }
575
576 // Save stats to collector;
SaveStats(std::unique_ptr<NodeExecStats> stats,const StreamInfo & stream_info) const577 Status SaveStats(std::unique_ptr<NodeExecStats> stats,
578 const StreamInfo& stream_info) const {
579 auto ctx_info = stream_info.ctx_info;
580 auto dev_info = ctx_info->dev_info;
581 // TODO(csigg): tfprof_node.cc, run_metadata_test.py, and timeline_test.py
582 // currently require this particular formatting.
583 collector_->Save(
584 absl::StrFormat("/device:GPU:%d/stream:all", dev_info->ordinal),
585 new NodeExecStats(*stats));
586 auto name = absl::StrFormat("/gpu:%d (%s)/context#%d/", dev_info->ordinal,
587 dev_info->name, ctx_info->index);
588 if (stream_info.index) {
589 absl::StrAppend(&name, "stream#", std::to_string(stream_info.index));
590 } else {
591 absl::StrAppend(&name, "null stream");
592 }
593 if (!stream_info.name.empty()) {
594 absl::StrAppend(&name, ":", stream_info.name);
595 }
596 collector_->Save(name, stats.release());
597 return Status::OK();
598 }
599
SaveRecord(const KernelRecord & record) const600 Status SaveRecord(const KernelRecord& record) const {
601 if (!record.start_event || !record.stop_event) {
602 return Status::OK();
603 }
604 const auto& stream_info =
605 stream_infos_.at(StreamKey(record.context, record.stream));
606 auto start_us =
607 GetElasedTimeUs(record.start_event, stream_info.ctx_info->end_event);
608 auto elapsed_us = GetElasedTimeUs(record.start_event, record.stop_event);
609
610 auto stats = absl::make_unique<NodeExecStats>();
611 std::string node_name = record.kernel_name;
612 if (record.annotation) {
613 node_name = absl::StrCat(*record.annotation, "::", node_name);
614 }
615 stats->set_node_name(node_name);
616 // TODO(csigg): Report grid size?
617 std::string node_label;
618 stats->set_timeline_label(node_label);
619 stats->set_all_start_micros(end_walltime_us_ - start_us);
620 stats->set_op_end_rel_micros(elapsed_us);
621 stats->set_all_end_rel_micros(elapsed_us);
622 return SaveStats(std::move(stats), stream_info);
623 }
624
SaveRecord(const MemcpyRecord & record) const625 Status SaveRecord(const MemcpyRecord& record) const {
626 if (!record.start_event || !record.stop_event) {
627 return Status::OK();
628 }
629 const auto& stream_info =
630 stream_infos_.at(StreamKey(record.context, record.stream));
631 auto start_us =
632 GetElasedTimeUs(record.start_event, stream_info.ctx_info->end_event);
633 auto elapsed_us = GetElasedTimeUs(record.start_event, record.stop_event);
634
635 auto stats = absl::make_unique<NodeExecStats>();
636 std::string node_name = GetMemcpyName(record);
637 if (record.annotation) {
638 node_name = absl::StrCat(*record.annotation, "::", node_name);
639 }
640 stats->set_node_name(node_name);
641 // TODO(csigg): Show label in Chrome trace viewer.
642 std::string node_label = absl::StrFormat("%d bytes", record.size_bytes);
643 stats->set_timeline_label(node_label);
644 stats->set_all_start_micros(end_walltime_us_ - start_us);
645 stats->set_op_end_rel_micros(elapsed_us);
646 stats->set_all_end_rel_micros(elapsed_us);
647 return SaveStats(std::move(stats), stream_info);
648 }
649
Collect()650 Status Collect() {
651 TF_RETURN_IF_ERROR(InitializeDeviceInfos());
652
653 auto kernel_records = recorder_->ConsumeKernelRecords();
654 auto memcpy_records = recorder_->ConsumeMemcpyRecords();
655 LOG(INFO) << "Collecting " << kernel_records.size() << " kernel records, "
656 << memcpy_records.size() << " memcpy records.";
657
658 // Gather all profiled streams and contexts.
659 for (const auto& record : kernel_records) {
660 TF_RETURN_IF_ERROR(
661 AddStreamInfo(record.context, record.stream, "Kernel"));
662 }
663 for (const auto& record : memcpy_records) {
664 TF_RETURN_IF_ERROR(
665 AddStreamInfo(record.context, record.stream, GetMemcpyName(record)));
666 }
667
668 // Synchronize all contexts, record end events, synchronize again.
669 TF_RETURN_IF_ERROR(Synchronize());
670 for (auto& pair : context_infos_) {
671 TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first)));
672 TF_RETURN_IF_ERROR(CreateAndRecordEvent(&pair.second.end_event, nullptr));
673 }
674 TF_RETURN_IF_ERROR(Synchronize());
675 end_walltime_us_ = Env::Default()->NowMicros();
676
677 for (const auto& record : kernel_records) {
678 TF_RETURN_IF_ERROR(SaveRecord(record));
679 }
680 for (const auto& record : memcpy_records) {
681 TF_RETURN_IF_ERROR(SaveRecord(record));
682 }
683
684 return Status::OK();
685 }
686
687 public:
688 // Consumes the records in recorder and saves them to the collector.
Collect(CudaEventRecorder * recorder,StepStatsCollector * collector)689 static Status Collect(CudaEventRecorder* recorder,
690 StepStatsCollector* collector) {
691 CUcontext context;
692 TF_RETURN_IF_ERROR(ToStatus(cuCtxGetCurrent(&context)));
693 auto status = CudaEventCollector(recorder, collector).Collect();
694 TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(context)));
695 return status;
696 }
697
698 private:
699 CudaEventRecorder* recorder_;
700 StepStatsCollector* collector_;
701
702 absl::node_hash_map<CUdevice, DeviceInfo> device_infos_;
703 absl::node_hash_map<CUcontext, ContextInfo> context_infos_;
704 absl::flat_hash_map<StreamKey, StreamInfo, hash<StreamKey>> stream_infos_;
705 int64 end_walltime_us_;
706 };
707 } // namespace
708
Collect(StepStatsCollector * collector)709 Status DeviceTracerImpl::Collect(StepStatsCollector* collector) {
710 mutex_lock l(mu_);
711 if (enabled_) {
712 return errors::FailedPrecondition("DeviceTracer is still enabled.");
713 }
714
715 TF_RETURN_IF_ERROR(CudaEventCollector::Collect(recorder_.get(), collector));
716 host_tracer_->CollectDataToCollector(collector).IgnoreError();
717 return Status::OK();
718 }
719
CreateDeviceTracer()720 std::unique_ptr<DeviceTracer> CreateDeviceTracer() {
721 auto status = cuInit(0);
722 if (status != CUDA_SUCCESS) {
723 LogIfError(ToStatus(status));
724 return nullptr;
725 }
726 return absl::make_unique<DeviceTracerImpl>();
727 }
728 } // namespace tensorflow
729 #else // GOOGLE_CUDA
730
731 namespace tensorflow {
732
CreateDeviceTracer()733 std::unique_ptr<DeviceTracer> CreateDeviceTracer() { return nullptr; }
734
735 } // namespace tensorflow
736
737 #endif // GOOGLE_CUDA
738