• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/device_tracer.h"
17 
18 #if GOOGLE_CUDA
19 
20 #include <stdlib.h>
21 #include <memory>
22 
23 #include "absl/base/casts.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/node_hash_map.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/core/common_runtime/step_stats_collector.h"
29 #include "tensorflow/core/framework/step_stats.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/cupti_wrapper.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mem.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/tracing.h"
40 #include "tensorflow/core/profiler/internal/cpu/host_tracer.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42 
43 namespace tensorflow {
44 namespace {
ToStatus(CUptiResult result)45 Status ToStatus(CUptiResult result) {
46   if (result == CUPTI_SUCCESS) {
47     return Status::OK();
48   }
49   const char* str = nullptr;
50   if (auto wrapper =
51           absl::make_unique<perftools::gputools::profiler::CuptiWrapper>()) {
52     wrapper->GetResultString(result, &str);
53   }
54   return errors::Unavailable("CUPTI error: ", str ? str : "<unknown>");
55 }
56 
ToStatus(CUresult result)57 Status ToStatus(CUresult result) {
58   if (result == CUDA_SUCCESS) {
59     return Status::OK();
60   }
61   const char* str = nullptr;
62   cuGetErrorName(result, &str);
63   return errors::Unavailable("CUDA error: ", str ? str : "<unknown>");
64 }
65 
LogIfError(const Status & status)66 void LogIfError(const Status& status) {
67   if (status.ok()) {
68     return;
69   }
70   LOG(ERROR) << status.error_message();
71 }
72 
73 struct KernelRecord {
74   const char* kernel_name;
75   // TODO(csigg): cuStreamGetCtx introduced in CUDA 9.2 would allow us to only
76   // record the stream and infer the context during collection.
77   CUcontext context;
78   CUstream stream;
79   CUevent start_event;
80   CUevent stop_event;
81   const std::string* annotation;
82 };
83 
84 struct MemcpyRecord {
85   CUmemorytype src_type;
86   CUmemorytype dst_type;
87   size_t size_bytes;
88   CUcontext context;
89   CUstream stream;
90   CUevent start_event;
91   CUevent stop_event;
92   const std::string* annotation;
93 };
94 
CreateAndRecordEvent(CUevent * event,CUstream stream)95 Status CreateAndRecordEvent(CUevent* event, CUstream stream) {
96   TF_RETURN_IF_ERROR(ToStatus(cuEventCreate(event, CU_EVENT_DEFAULT)));
97   return ToStatus(cuEventRecord(*event, stream));
98 }
99 
100 // Thread-local state recording the most recent annotation (if any).
101 // When non-null, this points to a string in the active annotation
102 // of the current thread.  The annotation is guaranteed to remain live
103 // for the duration of the CUPTI API callback.
104 static thread_local const char* tls_current_annotation;
105 
106 // Stores a series of kernel and memcpy records.
107 class CudaEventRecorder {
108  public:
109   // Registers the start of a kernel launch. The returned index should be passed
110   // to StopKernel() after the kernel launch has completed.
StartKernel(const char * kernel_name,CUcontext context,CUstream stream)111   size_t StartKernel(const char* kernel_name, CUcontext context,
112                      CUstream stream) {
113     KernelRecord record = {kernel_name, context, stream};
114     LogIfError(CreateAndRecordEvent(&record.start_event, stream));
115     mutex_lock lock(mutex_);
116     if (tls_current_annotation) {
117       record.annotation = &*annotations_.emplace(tls_current_annotation).first;
118     }
119     kernel_records_.push_back(record);
120     return kernel_records_.size() - 1;
121   }
StopKernel(size_t index)122   void StopKernel(size_t index) {
123     mutex_lock lock(mutex_);
124     auto& record = kernel_records_[index];
125     LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream));
126   }
127 
128   // Registers the start of a copy operation. The returned index should be
129   // passed to StopMemcpy() after the kernel launch has completed.
StartMemcpy(CUmemorytype src_type,CUmemorytype dst_type,size_t size_bytes,CUcontext context,CUstream stream)130   size_t StartMemcpy(CUmemorytype src_type, CUmemorytype dst_type,
131                      size_t size_bytes, CUcontext context, CUstream stream) {
132     MemcpyRecord record = {src_type, dst_type, size_bytes, context, stream};
133     LogIfError(CreateAndRecordEvent(&record.start_event, stream));
134     mutex_lock lock(mutex_);
135     if (tls_current_annotation) {
136       record.annotation = &*annotations_.emplace(tls_current_annotation).first;
137     }
138     memcpy_records_.push_back(record);
139     return memcpy_records_.size() - 1;
140   }
StopMemcpy(size_t index)141   void StopMemcpy(size_t index) {
142     mutex_lock lock(mutex_);
143     auto& record = memcpy_records_[index];
144     LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream));
145   }
146 
ConsumeKernelRecords()147   std::vector<KernelRecord> ConsumeKernelRecords() {
148     mutex_lock lock(mutex_);
149     return std::move(kernel_records_);
150   }
ConsumeMemcpyRecords()151   std::vector<MemcpyRecord> ConsumeMemcpyRecords() {
152     mutex_lock lock(mutex_);
153     return std::move(memcpy_records_);
154   }
155 
156  private:
157   mutex mutex_;
158   std::unordered_set<std::string> annotations_ GUARDED_BY(mutex_);
159   std::vector<KernelRecord> kernel_records_ GUARDED_BY(mutex_);
160   std::vector<MemcpyRecord> memcpy_records_ GUARDED_BY(mutex_);
161 };
162 
163 // Instances register callbacks with CUPTI to notify the event recorder before
164 // and after kernel launches and memory copies.
165 class CuptiCallbackHook {
166  public:
CuptiCallbackHook()167   CuptiCallbackHook()
168       : cupti_wrapper_(new perftools::gputools::profiler::CuptiWrapper()),
169         subscriber_(nullptr) {}
170 
Enable(CudaEventRecorder * recorder)171   Status Enable(CudaEventRecorder* recorder) {
172     TF_RETURN_IF_ERROR(ToStatus(
173         cupti_wrapper_->Subscribe(&subscriber_, &CuptiCallback, recorder)));
174     for (auto cbid : {CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel,
175                       CUPTI_DRIVER_TRACE_CBID_cuMemcpy,
176                       CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync,
177                       CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2,
178                       CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2,
179                       CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2,
180                       CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2,
181                       CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2,
182                       CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2}) {
183       TF_RETURN_IF_ERROR(ToStatus(cupti_wrapper_->EnableCallback(
184           /*enable=*/1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)));
185     }
186     return Status::OK();
187   }
188 
~CuptiCallbackHook()189   ~CuptiCallbackHook() {
190     LogIfError(ToStatus(cupti_wrapper_->Unsubscribe(subscriber_)));
191   }
192 
193  private:
CuptiCallback(void * userdata,CUpti_CallbackDomain domain,CUpti_CallbackId cbid,const void * cbdata)194   static void CUPTIAPI CuptiCallback(void* userdata,
195                                      CUpti_CallbackDomain domain,
196                                      CUpti_CallbackId cbid,
197                                      const void* cbdata) {
198     auto recorder = static_cast<CudaEventRecorder*>(userdata);
199     auto data = static_cast<const CUpti_CallbackData*>(cbdata);
200     DCHECK_EQ(domain, CUPTI_CB_DOMAIN_DRIVER_API);
201 
202     if (data->callbackSite == CUPTI_API_ENTER) {
203       DriverApiEnterCallback(cbid, *data, recorder);
204     } else {
205       DriverApiExitCallback(cbid, *data, recorder);
206     }
207   }
208 
GetMemoryType(CUdeviceptr ptr)209   static CUmemorytype GetMemoryType(CUdeviceptr ptr) {
210     CUmemorytype mem_type;
211     auto status =
212         cuPointerGetAttribute(&mem_type, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, ptr);
213     if (status == CUDA_ERROR_INVALID_VALUE) {
214       // Pointer not registered with CUDA, must be host memory.
215       return CU_MEMORYTYPE_HOST;
216     }
217     LogIfError(ToStatus(status));
218     return mem_type;
219   }
220 
221   template <typename T>
StartMemcpy(CUmemorytype src_type,CUmemorytype dst_type,const CUpti_CallbackData & cbdata,CudaEventRecorder * recorder)222   static void StartMemcpy(CUmemorytype src_type, CUmemorytype dst_type,
223                           const CUpti_CallbackData& cbdata,
224                           CudaEventRecorder* recorder) {
225     auto params = static_cast<const T*>(cbdata.functionParams);
226     *cbdata.correlationData = recorder->StartMemcpy(
227         src_type, dst_type, params->ByteCount, cbdata.context, nullptr);
228   }
229   template <typename T>
StartMemcpyAsync(CUmemorytype dst_type,CUmemorytype src_type,const CUpti_CallbackData & cbdata,CudaEventRecorder * recorder)230   static void StartMemcpyAsync(CUmemorytype dst_type, CUmemorytype src_type,
231                                const CUpti_CallbackData& cbdata,
232                                CudaEventRecorder* recorder) {
233     auto params = static_cast<const T*>(cbdata.functionParams);
234     *cbdata.correlationData = recorder->StartMemcpy(
235         src_type, dst_type, params->ByteCount, cbdata.context, params->hStream);
236   }
237 
DriverApiEnterCallback(CUpti_CallbackId cbid,const CUpti_CallbackData & cbdata,CudaEventRecorder * recorder)238   static void DriverApiEnterCallback(CUpti_CallbackId cbid,
239                                      const CUpti_CallbackData& cbdata,
240                                      CudaEventRecorder* recorder) {
241     switch (cbid) {
242       case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: {
243         DCHECK_NE(cbdata.symbolName, nullptr);
244         auto params =
245             static_cast<const cuLaunchKernel_params*>(cbdata.functionParams);
246         *cbdata.correlationData = recorder->StartKernel(
247             cbdata.symbolName, cbdata.context, params->hStream);
248         return;
249       }
250 
251       case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: {
252         auto params =
253             static_cast<const cuMemcpy_params*>(cbdata.functionParams);
254         return StartMemcpy<cuMemcpy_params>(GetMemoryType(params->src),
255                                             GetMemoryType(params->dst), cbdata,
256                                             recorder);
257       }
258       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: {
259         auto params =
260             static_cast<const cuMemcpyAsync_params*>(cbdata.functionParams);
261         return StartMemcpyAsync<cuMemcpyAsync_params>(
262             GetMemoryType(params->src), GetMemoryType(params->dst), cbdata,
263             recorder);
264       }
265 
266       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2:
267         return StartMemcpy<cuMemcpyHtoD_v2_params>(
268             CU_MEMORYTYPE_HOST, CU_MEMORYTYPE_DEVICE, cbdata, recorder);
269 
270       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2:
271         return StartMemcpyAsync<cuMemcpyHtoDAsync_v2_params>(
272             CU_MEMORYTYPE_HOST, CU_MEMORYTYPE_DEVICE, cbdata, recorder);
273 
274       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2:
275         return StartMemcpy<cuMemcpyDtoH_v2_params>(
276             CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_HOST, cbdata, recorder);
277       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2:
278         return StartMemcpyAsync<cuMemcpyDtoHAsync_v2_params>(
279             CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_HOST, cbdata, recorder);
280 
281       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2:
282         return StartMemcpy<cuMemcpyDtoD_v2_params>(
283             CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_DEVICE, cbdata, recorder);
284       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2:
285         return StartMemcpyAsync<cuMemcpyDtoDAsync_v2_params>(
286             CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_DEVICE, cbdata, recorder);
287 
288       default:
289         LOG(ERROR) << "Unexpected callback id: " << cbid;
290     }
291   }
292 
DriverApiExitCallback(CUpti_CallbackId cbid,const CUpti_CallbackData & cbdata,CudaEventRecorder * recorder)293   static void DriverApiExitCallback(CUpti_CallbackId cbid,
294                                     const CUpti_CallbackData& cbdata,
295                                     CudaEventRecorder* recorder) {
296     switch (cbid) {
297       case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel:
298         recorder->StopKernel(*cbdata.correlationData);
299         break;
300       case CUPTI_DRIVER_TRACE_CBID_cuMemcpy:
301       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync:
302       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2:
303       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2:
304       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2:
305       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2:
306       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2:
307       case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2:
308         recorder->StopMemcpy(*cbdata.correlationData);
309         break;
310       default:
311         LOG(ERROR) << "Unexpected callback id: " << cbid;
312     }
313   }
314 
315   std::unique_ptr<perftools::gputools::profiler::CuptiWrapper> cupti_wrapper_;
316   CUpti_SubscriberHandle subscriber_;
317 };
318 }  // namespace
319 
320 class TraceCollectorImpl : public tracing::TraceCollector {
321  public:
322   class ActivityHandle : public Handle {
323    public:
ActivityHandle(std::string && name,int level)324     ActivityHandle(std::string&& name, int level)
325         : trace_me_(std::move(name), level) {}
326 
327    private:
328     profiler::TraceMe trace_me_;
329   };
TraceCollectorImpl()330   TraceCollectorImpl() { tracing::SetTraceCollector(this); }
331 
~TraceCollectorImpl()332   ~TraceCollectorImpl() override {
333     DCHECK(!active_trace_session_)
334         << "Unexpected active trace session detected. ";
335   }
336 
337   // Note the method can be called after a call to Stop().
CreateAnnotationHandle(StringPiece name_part1,StringPiece name_part2) const338   virtual std::unique_ptr<Handle> CreateAnnotationHandle(
339       StringPiece name_part1, StringPiece name_part2) const {
340     struct Impl : public tracing::TraceCollector::Handle {
341       std::string annotation;
342       explicit Impl(std::string&& name_scope) : annotation(name_scope) {
343         VLOG(2) << "CreateAnnotationHandle " << annotation;
344         // Remember the most recent ScopedAnnotation for each thread.
345         tls_current_annotation = annotation.c_str();
346       }
347       ~Impl() override { tls_current_annotation = nullptr; }
348     };
349     return absl::make_unique<Impl>(ConcatenateNames(name_part1, name_part2));
350   }
351 
CreateActivityHandle(StringPiece name_part1,StringPiece name_part2,bool is_expensive) const352   virtual std::unique_ptr<Handle> CreateActivityHandle(
353       StringPiece name_part1, StringPiece name_part2, bool is_expensive) const {
354     if (!IsEnabledForActivities(is_expensive)) {
355       return nullptr;
356     }
357     return absl::make_unique<ActivityHandle>(
358         ConcatenateNames(name_part1, name_part2), GetLevel(is_expensive));
359   }
360 
IsEnabledForAnnotations() const361   bool IsEnabledForAnnotations() const override {
362     return active_trace_session_.load(std::memory_order_relaxed);
363   }
364 
IsEnabledForActivities(bool is_expensive) const365   bool IsEnabledForActivities(bool is_expensive) const override {
366     return profiler::TraceMeRecorder::Active(GetLevel(is_expensive));
367   }
368 
Start()369   void Start() {
370     DCHECK(!active_trace_session_)
371         << "Unexpected active trace session detected. ";
372     active_trace_session_ = true;
373   }
374 
Stop()375   void Stop() {
376     DCHECK(active_trace_session_) << "No active trace session detected. ";
377     active_trace_session_ = false;
378   }
379 
380  private:
GetLevel(bool is_expensive)381   static int GetLevel(bool is_expensive) {
382     return profiler::GetTFTraceMeLevel(is_expensive);
383   }
384 
385   std::atomic<bool> active_trace_session_;
386 };
387 
GlobalDefaultTraceCollector()388 TraceCollectorImpl* GlobalDefaultTraceCollector() {
389   static auto* instance = new TraceCollectorImpl();
390   return instance;
391 }
392 
393 class DeviceTracerImpl : public DeviceTracer {
394  public:
395   DeviceTracerImpl();
396   ~DeviceTracerImpl() override;
397 
398   // DeviceTracer interface:
399   Status Start() override;
400   Status Stop() override;
401   Status Collect(StepStatsCollector* collector) override;
402 
403  private:
404   std::unique_ptr<CudaEventRecorder> recorder_;
405   std::unique_ptr<CuptiCallbackHook> cupti_hook_;
406 
407   mutex mu_;
408   bool enabled_ GUARDED_BY(mu_);
409   std::unique_ptr<profiler::cpu::HostTracer> host_tracer_ GUARDED_BY(mu_);
410 };
411 
DeviceTracerImpl()412 DeviceTracerImpl::DeviceTracerImpl() : recorder_(new CudaEventRecorder()) {
413   VLOG(1) << "DeviceTracer created.";
414   host_tracer_ = profiler::cpu::HostTracer::Create(2);
415   enabled_ = false;
416 }
417 
~DeviceTracerImpl()418 DeviceTracerImpl::~DeviceTracerImpl() {
419   // Unregister the CUPTI callbacks if needed to prevent them from accessing
420   // freed memory.
421   Stop().IgnoreError();
422 }
423 
Start()424 Status DeviceTracerImpl::Start() {
425   VLOG(1) << "DeviceTracer::Start";
426   mutex_lock l(mu_);
427   if (enabled_) {
428     return errors::FailedPrecondition("DeviceTracer is already enabled.");
429   }
430   cupti_hook_.reset(new CuptiCallbackHook());
431   TF_RETURN_IF_ERROR(cupti_hook_->Enable(recorder_.get()));
432 
433   // Register as a TraceEngine to receive ScopedAnnotations.
434   GlobalDefaultTraceCollector()->Start();
435 
436   host_tracer_->Start().IgnoreError();
437   enabled_ = true;
438   return Status::OK();
439 }
440 
Stop()441 Status DeviceTracerImpl::Stop() {
442   VLOG(1) << "DeviceTracer::Stop";
443   mutex_lock l(mu_);
444   if (!enabled_) {
445     return Status::OK();
446   }
447   cupti_hook_.reset();
448   GlobalDefaultTraceCollector()->Stop();
449 
450   enabled_ = false;
451   host_tracer_->Stop().IgnoreError();
452   return Status::OK();
453 }
454 
455 namespace {
456 class CudaEventCollector {
457   struct DeviceInfo {
458     int ordinal;
459     std::string name;
460     int num_contexts;
461   };
462 
463   struct ContextInfo {
464     int index;
465     const DeviceInfo* dev_info;
466     int num_streams;
467     CUevent end_event;
468   };
469 
470   struct StreamInfo {
471     std::string name;
472     int index;  // 0 is reserved for null stream.
473     const ContextInfo* ctx_info;
474   };
475 
476   // Include context in key to distinguish null streams.
477   using StreamKey = std::pair<CUcontext, CUstream>;
478 
CudaEventCollector(CudaEventRecorder * recorder,StepStatsCollector * collector)479   CudaEventCollector(CudaEventRecorder* recorder, StepStatsCollector* collector)
480       : recorder_(recorder), collector_(collector) {
481     DCHECK(recorder != nullptr);
482     DCHECK(collector != nullptr);
483   }
484 
485   // Populates device_infos_ from all devices.
InitializeDeviceInfos()486   Status InitializeDeviceInfos() {
487     int count;
488     TF_RETURN_IF_ERROR(ToStatus(cuDeviceGetCount(&count)));
489     for (int ordinal = 0; ordinal < count; ++ordinal) {
490       CUdevice device;
491       TF_RETURN_IF_ERROR(ToStatus(cuDeviceGet(&device, ordinal)));
492       char name[100];
493       TF_RETURN_IF_ERROR(ToStatus(cuDeviceGetName(name, sizeof(name), device)));
494       device_infos_[device] = {ordinal, name};
495     }
496     return Status::OK();
497   }
498 
499   // Returns element from context_infos_, adding it if not yet present.
GetContextInfo(CUcontext context,ContextInfo ** ctx_info_ptr)500   Status GetContextInfo(CUcontext context, ContextInfo** ctx_info_ptr) {
501     auto it = context_infos_.find(context);
502 
503     if (it == context_infos_.end()) {
504       TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(context)));
505       CUdevice device;
506       TF_RETURN_IF_ERROR(ToStatus(cuCtxGetDevice(&device)));
507 
508       auto& dev_info = device_infos_[device];
509       ContextInfo ctx_info = {dev_info.num_contexts++, &dev_info};
510       it = context_infos_.emplace(context, ctx_info).first;
511     }
512 
513     *ctx_info_ptr = &it->second;
514     return Status::OK();
515   }
516 
517   // Adds element to stream_infos_ if not yet present. If present, clear name
518   // if it doesn't match parameter.
AddStreamInfo(CUcontext context,CUstream stream,absl::string_view name)519   Status AddStreamInfo(CUcontext context, CUstream stream,
520                        absl::string_view name) {
521     StreamKey key(context, stream);
522     auto it = stream_infos_.find(key);
523     if (it != stream_infos_.end()) {
524       if (it->second.name != name) {
525         it->second.name.clear();  // Stream with inconsistent names, clear it.
526       }
527       return Status::OK();
528     }
529 
530     ContextInfo* ctx_info;
531     TF_RETURN_IF_ERROR(GetContextInfo(context, &ctx_info));
532     int index = stream ? ++ctx_info->num_streams : 0;
533     StreamInfo stream_info = {static_cast<std::string>(name), index, ctx_info};
534     stream_infos_.emplace(key, stream_info);
535     return Status::OK();
536   }
537 
538   // Returns string describing source and destination memory types.
GetMemcpyName(const MemcpyRecord & record)539   static std::string GetMemcpyName(const MemcpyRecord& record) {
540     auto get_memory_type = [](CUmemorytype mem_type) {
541       switch (mem_type) {
542         case CU_MEMORYTYPE_HOST:
543           return 'H';
544         case CU_MEMORYTYPE_DEVICE:
545           return 'D';
546         case CU_MEMORYTYPE_ARRAY:
547           return 'A';
548         case CU_MEMORYTYPE_UNIFIED:
549           return 'U';
550         default:
551           LOG(ERROR) << "Unknown memory type: " << mem_type;
552           return '?';
553       }
554     };
555     return absl::StrFormat("Memcpy%cto%c", get_memory_type(record.src_type),
556                            get_memory_type(record.dst_type));
557   }
558 
559   // Returns time in microseconds between events recorded on the GPU.
GetElasedTimeUs(CUevent start,CUevent stop)560   static uint64_t GetElasedTimeUs(CUevent start, CUevent stop) {
561     float elapsed_ms = 0.0f;
562     LogIfError(ToStatus(cuEventElapsedTime(&elapsed_ms, start, stop)));
563     return static_cast<uint64>(
564         std::llroundf(1000 * std::max(elapsed_ms, 0.0f)));
565   }
566 
567   // Synchronizes all contexts.
Synchronize() const568   Status Synchronize() const {
569     for (const auto& pair : context_infos_) {
570       TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first)));
571       TF_RETURN_IF_ERROR(ToStatus(cuCtxSynchronize()));
572     }
573     return Status::OK();
574   }
575 
576   // Save stats to collector;
SaveStats(std::unique_ptr<NodeExecStats> stats,const StreamInfo & stream_info) const577   Status SaveStats(std::unique_ptr<NodeExecStats> stats,
578                    const StreamInfo& stream_info) const {
579     auto ctx_info = stream_info.ctx_info;
580     auto dev_info = ctx_info->dev_info;
581     // TODO(csigg): tfprof_node.cc, run_metadata_test.py, and timeline_test.py
582     // currently require this particular formatting.
583     collector_->Save(
584         absl::StrFormat("/device:GPU:%d/stream:all", dev_info->ordinal),
585         new NodeExecStats(*stats));
586     auto name = absl::StrFormat("/gpu:%d (%s)/context#%d/", dev_info->ordinal,
587                                 dev_info->name, ctx_info->index);
588     if (stream_info.index) {
589       absl::StrAppend(&name, "stream#", std::to_string(stream_info.index));
590     } else {
591       absl::StrAppend(&name, "null stream");
592     }
593     if (!stream_info.name.empty()) {
594       absl::StrAppend(&name, ":", stream_info.name);
595     }
596     collector_->Save(name, stats.release());
597     return Status::OK();
598   }
599 
SaveRecord(const KernelRecord & record) const600   Status SaveRecord(const KernelRecord& record) const {
601     if (!record.start_event || !record.stop_event) {
602       return Status::OK();
603     }
604     const auto& stream_info =
605         stream_infos_.at(StreamKey(record.context, record.stream));
606     auto start_us =
607         GetElasedTimeUs(record.start_event, stream_info.ctx_info->end_event);
608     auto elapsed_us = GetElasedTimeUs(record.start_event, record.stop_event);
609 
610     auto stats = absl::make_unique<NodeExecStats>();
611     std::string node_name = record.kernel_name;
612     if (record.annotation) {
613       node_name = absl::StrCat(*record.annotation, "::", node_name);
614     }
615     stats->set_node_name(node_name);
616     // TODO(csigg): Report grid size?
617     std::string node_label;
618     stats->set_timeline_label(node_label);
619     stats->set_all_start_micros(end_walltime_us_ - start_us);
620     stats->set_op_end_rel_micros(elapsed_us);
621     stats->set_all_end_rel_micros(elapsed_us);
622     return SaveStats(std::move(stats), stream_info);
623   }
624 
SaveRecord(const MemcpyRecord & record) const625   Status SaveRecord(const MemcpyRecord& record) const {
626     if (!record.start_event || !record.stop_event) {
627       return Status::OK();
628     }
629     const auto& stream_info =
630         stream_infos_.at(StreamKey(record.context, record.stream));
631     auto start_us =
632         GetElasedTimeUs(record.start_event, stream_info.ctx_info->end_event);
633     auto elapsed_us = GetElasedTimeUs(record.start_event, record.stop_event);
634 
635     auto stats = absl::make_unique<NodeExecStats>();
636     std::string node_name = GetMemcpyName(record);
637     if (record.annotation) {
638       node_name = absl::StrCat(*record.annotation, "::", node_name);
639     }
640     stats->set_node_name(node_name);
641     // TODO(csigg): Show label in Chrome trace viewer.
642     std::string node_label = absl::StrFormat("%d bytes", record.size_bytes);
643     stats->set_timeline_label(node_label);
644     stats->set_all_start_micros(end_walltime_us_ - start_us);
645     stats->set_op_end_rel_micros(elapsed_us);
646     stats->set_all_end_rel_micros(elapsed_us);
647     return SaveStats(std::move(stats), stream_info);
648   }
649 
Collect()650   Status Collect() {
651     TF_RETURN_IF_ERROR(InitializeDeviceInfos());
652 
653     auto kernel_records = recorder_->ConsumeKernelRecords();
654     auto memcpy_records = recorder_->ConsumeMemcpyRecords();
655     LOG(INFO) << "Collecting " << kernel_records.size() << " kernel records, "
656               << memcpy_records.size() << " memcpy records.";
657 
658     // Gather all profiled streams and contexts.
659     for (const auto& record : kernel_records) {
660       TF_RETURN_IF_ERROR(
661           AddStreamInfo(record.context, record.stream, "Kernel"));
662     }
663     for (const auto& record : memcpy_records) {
664       TF_RETURN_IF_ERROR(
665           AddStreamInfo(record.context, record.stream, GetMemcpyName(record)));
666     }
667 
668     // Synchronize all contexts, record end events, synchronize again.
669     TF_RETURN_IF_ERROR(Synchronize());
670     for (auto& pair : context_infos_) {
671       TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first)));
672       TF_RETURN_IF_ERROR(CreateAndRecordEvent(&pair.second.end_event, nullptr));
673     }
674     TF_RETURN_IF_ERROR(Synchronize());
675     end_walltime_us_ = Env::Default()->NowMicros();
676 
677     for (const auto& record : kernel_records) {
678       TF_RETURN_IF_ERROR(SaveRecord(record));
679     }
680     for (const auto& record : memcpy_records) {
681       TF_RETURN_IF_ERROR(SaveRecord(record));
682     }
683 
684     return Status::OK();
685   }
686 
687  public:
688   // Consumes the records in recorder and saves them to the collector.
Collect(CudaEventRecorder * recorder,StepStatsCollector * collector)689   static Status Collect(CudaEventRecorder* recorder,
690                         StepStatsCollector* collector) {
691     CUcontext context;
692     TF_RETURN_IF_ERROR(ToStatus(cuCtxGetCurrent(&context)));
693     auto status = CudaEventCollector(recorder, collector).Collect();
694     TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(context)));
695     return status;
696   }
697 
698  private:
699   CudaEventRecorder* recorder_;
700   StepStatsCollector* collector_;
701 
702   absl::node_hash_map<CUdevice, DeviceInfo> device_infos_;
703   absl::node_hash_map<CUcontext, ContextInfo> context_infos_;
704   absl::flat_hash_map<StreamKey, StreamInfo, hash<StreamKey>> stream_infos_;
705   int64 end_walltime_us_;
706 };
707 }  // namespace
708 
Collect(StepStatsCollector * collector)709 Status DeviceTracerImpl::Collect(StepStatsCollector* collector) {
710   mutex_lock l(mu_);
711   if (enabled_) {
712     return errors::FailedPrecondition("DeviceTracer is still enabled.");
713   }
714 
715   TF_RETURN_IF_ERROR(CudaEventCollector::Collect(recorder_.get(), collector));
716   host_tracer_->CollectDataToCollector(collector).IgnoreError();
717   return Status::OK();
718 }
719 
CreateDeviceTracer()720 std::unique_ptr<DeviceTracer> CreateDeviceTracer() {
721   auto status = cuInit(0);
722   if (status != CUDA_SUCCESS) {
723     LogIfError(ToStatus(status));
724     return nullptr;
725   }
726   return absl::make_unique<DeviceTracerImpl>();
727 }
728 }  // namespace tensorflow
729 #else  // GOOGLE_CUDA
730 
731 namespace tensorflow {
732 
CreateDeviceTracer()733 std::unique_ptr<DeviceTracer> CreateDeviceTracer() { return nullptr; }
734 
735 }  // namespace tensorflow
736 
737 #endif  // GOOGLE_CUDA
738