• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19 
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21 
22 #include <atomic>
23 #include <utility>
24 
25 #include "tensorflow/stream_executor/blas.h"
26 #include "tensorflow/stream_executor/fft.h"
27 #include "tensorflow/stream_executor/lib/env.h"
28 #include "tensorflow/stream_executor/lib/error.h"
29 #include "tensorflow/stream_executor/lib/notification.h"
30 #include "tensorflow/stream_executor/lib/stacktrace.h"
31 #include "tensorflow/stream_executor/lib/str_util.h"
32 #include "tensorflow/stream_executor/lib/stringprintf.h"
33 #include "tensorflow/stream_executor/lib/threadpool.h"
34 #include "tensorflow/stream_executor/platform/port.h"
35 #include "tensorflow/stream_executor/rng.h"
36 #include "tensorflow/stream_executor/stream_executor_internal.h"
37 
38 namespace {
39 bool FLAGS_check_device_leaks = false;
40 }  // namespace
41 
42 namespace perftools {
43 namespace gputools {
44 namespace {
45 
StackTraceIfVLOG10()46 string StackTraceIfVLOG10() {
47   if (VLOG_IS_ON(10)) {
48     return port::StrCat(" ", port::CurrentStackTrace(), "\n");
49   } else {
50     return "";
51   }
52 }
53 
54 // Make sure the executor is done with its work; we know (because this isn't
55 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)56 void BlockOnThreadExecutor(port::ThreadPool *executor) {
57   port::Notification n;
58   executor->Schedule([&n]() { n.Notify(); });
59   n.WaitForNotification();
60 }
61 
StreamExecutorImplementationFromPlatformKind(PlatformKind platform_kind,const PluginConfig & plugin_config)62 internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind(
63     PlatformKind platform_kind, const PluginConfig &plugin_config) {
64   // Note: we use this factory-assignment-in-switch pattern instead of just
65   // invoking the callable in case linkage is messed up -- instead of invoking a
66   // nullptr std::function (due to failed registration) we give a nice
67   // LOG(FATAL) message.
68   internal::StreamExecutorFactory factory;
69   switch (platform_kind) {
70     case PlatformKind::kCuda:
71       factory = *internal::MakeCUDAExecutorImplementation();
72       break;
73     case PlatformKind::kOpenCL:
74       factory = *internal::MakeOpenCLExecutorImplementation();
75       break;
76     case PlatformKind::kHost:
77       factory = internal::MakeHostExecutorImplementation;
78       break;
79     default:
80       factory = nullptr;
81   }
82   if (factory == nullptr) {
83     LOG(FATAL)
84         << "cannot create StreamExecutor implementation for platform kind: "
85         << PlatformKindString(platform_kind);
86   }
87   return factory(plugin_config);
88 }
89 
90 std::atomic_int_fast64_t correlation_id_generator(0);
91 
92 }  // namespace
93 
94 template <typename BeginCallT, typename CompleteCallT,
95           typename ReturnT, typename... BeginArgsT>
96 class ScopedTracer {
97  public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)98   ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
99                CompleteCallT complete_call, const ReturnT *result,
100                BeginArgsT... begin_args)
101       : stream_exec_(stream_exec),
102         complete_call_(complete_call),
103         result_(result) {
104     if (stream_exec_->tracing_enabled_) {
105       correlation_id_ =
106           correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
107       Trace(begin_call, begin_args...);
108     }
109   }
110 
~ScopedTracer()111   ~ScopedTracer() {
112     if (stream_exec_->tracing_enabled_) {
113       Trace(complete_call_, result_);
114     }
115   }
116 
117  private:
118   template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)119   void Trace(CallbackT callback, TraceArgsT... args) {
120     {
121       // Instance tracers held in a block to limit the lock lifetime.
122       tf_shared_lock lock{stream_exec_->mu_};
123       for (TraceListener *listener : stream_exec_->listeners_) {
124         (listener->*callback)(correlation_id_,
125                               std::forward<TraceArgsT>(args)...);
126       }
127     }
128   }
129 
130   StreamExecutor *stream_exec_;
131   CompleteCallT complete_call_;
132   const ReturnT* result_;
133   int64 correlation_id_;
134 };
135 
136 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
137           typename... BeginArgsT>
138 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)139 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
140                  CompleteCallT complete_call, ReturnT *result,
141                  BeginArgsT... begin_args) {
142   return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
143       stream_exec, begin_call, complete_call, result,
144       std::forward<BeginArgsT>(begin_args)...);
145 }
146 
147 #define SCOPED_TRACE(LOC, ...)                                      \
148   auto tracer = MakeScopedTracer(this, &LOC ## Begin,               \
149                                  &LOC ## Complete, ## __VA_ARGS__);
150 
151 /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED};
152 
StreamExecutor(PlatformKind platform_kind,const PluginConfig & plugin_config)153 StreamExecutor::StreamExecutor(PlatformKind platform_kind,
154                                const PluginConfig &plugin_config)
155     : platform_(nullptr),
156       implementation_(StreamExecutorImplementationFromPlatformKind(
157           platform_kind, plugin_config)),
158       platform_kind_(platform_kind),
159       device_ordinal_(-1),
160       background_threads_(new port::ThreadPool(
161           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
162       live_stream_count_(0),
163       tracing_enabled_(false) {
164   CheckPlatformKindIsValid(platform_kind);
165 }
166 
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation)167 StreamExecutor::StreamExecutor(
168     const Platform *platform,
169     std::unique_ptr<internal::StreamExecutorInterface> implementation)
170     : platform_(platform),
171       implementation_(std::move(implementation)),
172       device_ordinal_(-1),
173       background_threads_(new port::ThreadPool(
174           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
175       live_stream_count_(0),
176       tracing_enabled_(false) {
177   if (port::Lowercase(platform_->Name()) == "cuda") {
178     platform_kind_ = PlatformKind::kCuda;
179   } else if (port::Lowercase(platform_->Name()) == "opencl") {
180     platform_kind_ = PlatformKind::kOpenCL;
181   } else if (port::Lowercase(platform_->Name()) == "host") {
182     platform_kind_ = PlatformKind::kHost;
183   }
184 }
185 
~StreamExecutor()186 StreamExecutor::~StreamExecutor() {
187   BlockOnThreadExecutor(background_threads_.get());
188 
189   if (live_stream_count_.load() != 0) {
190     LOG(WARNING) << "Not all streams were deallocated at executor destruction "
191                  << "time. This may lead to unexpected/bad behavior - "
192                  << "especially if any stream is still active!";
193   }
194 
195   if (FLAGS_check_device_leaks) {
196     for (auto it : mem_allocs_) {
197       LOG(INFO) << "Memory alloced at executor exit: addr: "
198                 << port::Printf("%p", it.first)
199                 << ", bytes: " << it.second.bytes << ", trace: \n"
200                 << it.second.stack_trace;
201     }
202   }
203 }
204 
Init(int device_ordinal,DeviceOptions device_options)205 port::Status StreamExecutor::Init(int device_ordinal,
206                                   DeviceOptions device_options) {
207   device_ordinal_ = device_ordinal;
208   return implementation_->Init(device_ordinal, std::move(device_options));
209 }
210 
Init()211 port::Status StreamExecutor::Init() {
212   return Init(0, DeviceOptions::Default());
213 }
214 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)215 bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
216                                KernelBase *kernel) {
217   return implementation_->GetKernel(spec, kernel);
218 }
219 
UnloadKernel(const KernelBase * kernel)220 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
221   implementation_->UnloadKernel(kernel);
222 }
223 
Deallocate(DeviceMemoryBase * mem)224 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
225   VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
226           << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
227 
228   if (mem->opaque() != nullptr) {
229     EraseAllocRecord(mem->opaque());
230   }
231   implementation_->Deallocate(mem);
232   mem->Reset(nullptr, 0);
233 }
234 
GetMemAllocs(std::map<void *,AllocRecord> * records_out)235 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
236   tf_shared_lock lock{mu_};
237   *records_out = mem_allocs_;
238 }
239 
CanEnablePeerAccessTo(StreamExecutor * other)240 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
241   return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
242 }
243 
EnablePeerAccessTo(StreamExecutor * other)244 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
245   return implementation_->EnablePeerAccessTo(other->implementation_.get());
246 }
247 
GetDeviceSharedMemoryConfig()248 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
249   return implementation_->GetDeviceSharedMemoryConfig();
250 }
251 
SetDeviceSharedMemoryConfig(SharedMemoryConfig config)252 port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
253     SharedMemoryConfig config) {
254   if (config != SharedMemoryConfig::kDefault &&
255       config != SharedMemoryConfig::kFourByte &&
256       config != SharedMemoryConfig::kEightByte) {
257     string error_msg = port::Printf(
258         "Invalid shared memory config specified: %d", static_cast<int>(config));
259     LOG(ERROR) << error_msg;
260     return port::Status{port::error::INVALID_ARGUMENT, error_msg};
261   }
262   return implementation_->SetDeviceSharedMemoryConfig(config);
263 }
264 
GetDeviceDescription() const265 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
266   mutex_lock lock{mu_};
267   if (device_description_ != nullptr) {
268     return *device_description_;
269   }
270 
271   device_description_.reset(PopulateDeviceDescription());
272   return *device_description_;
273 }
274 
GetDeviceLoad() const275 int64 StreamExecutor::GetDeviceLoad() const {
276   return implementation_->GetDeviceLoad();
277 }
278 
PlatformDeviceCount() const279 int StreamExecutor::PlatformDeviceCount() const {
280   return implementation_->PlatformDeviceCount();
281 }
282 
SupportsBlas() const283 bool StreamExecutor::SupportsBlas() const {
284   return implementation_->SupportsBlas();
285 }
286 
SupportsRng() const287 bool StreamExecutor::SupportsRng() const {
288   return implementation_->SupportsRng();
289 }
290 
SupportsDnn() const291 bool StreamExecutor::SupportsDnn() const {
292   return implementation_->SupportsDnn();
293 }
294 
GetConvolveAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)295 bool StreamExecutor::GetConvolveAlgorithms(
296     bool with_winograd_nonfused,
297     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
298   dnn::DnnSupport *dnn_support = AsDnn();
299   if (!dnn_support) {
300     return false;
301   }
302   int cc_major, cc_minor;
303   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
304   return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
305                                             cc_minor, out_algorithms);
306 }
307 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)308 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
309     bool with_winograd_nonfused,
310     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
311   dnn::DnnSupport *dnn_support = AsDnn();
312   if (!dnn_support) {
313     return false;
314   }
315   int cc_major, cc_minor;
316   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
317   return dnn_support->GetConvolveBackwardDataAlgorithms(
318       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
319 }
320 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)321 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
322     bool with_winograd_nonfused,
323     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
324   dnn::DnnSupport *dnn_support = AsDnn();
325   if (!dnn_support) {
326     return false;
327   }
328   int cc_major, cc_minor;
329   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
330   return dnn_support->GetConvolveBackwardFilterAlgorithms(
331       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
332 }
333 
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)334 bool StreamExecutor::GetBlasGemmAlgorithms(
335     std::vector<blas::AlgorithmType> *out_algorithms) {
336   blas::BlasSupport *blas_support = AsBlas();
337   if (!blas_support) {
338     return false;
339   }
340   return blas_support->GetBlasGemmAlgorithms(out_algorithms);
341 }
342 
343 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,float dropout,uint64 seed,ScratchAllocator * state_allocator)344 StreamExecutor::createRnnDescriptor(
345     int num_layers, int hidden_size, int input_size,
346     dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
347     dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed,
348     ScratchAllocator *state_allocator) {
349   dnn::DnnSupport *dnn_support = AsDnn();
350   if (!dnn_support) {
351     return port::Status(port::error::UNKNOWN,
352                         "Fail to find the dnn implementation.");
353   }
354   return dnn_support->createRnnDescriptor(
355       num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
356       data_type, dropout, seed, state_allocator);
357 }
358 
359 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,dnn::DataType data_type)360 StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length,
361                                                   int batch_size, int data_size,
362                                                   dnn::DataType data_type) {
363   dnn::DnnSupport *dnn_support = AsDnn();
364   if (!dnn_support) {
365     return port::Status(port::error::UNKNOWN,
366                         "Fail to find the dnn implementation.");
367   }
368   return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size,
369                                                         data_size, data_type);
370 }
371 
372 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)373 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
374                                                int data_size,
375                                                dnn::DataType data_type) {
376   dnn::DnnSupport *dnn_support = AsDnn();
377   if (!dnn_support) {
378     return port::Status(port::error::UNKNOWN,
379                         "Fail to find the dnn implementation.");
380   }
381   return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
382                                                      data_size, data_type);
383 }
384 
AsDnn()385 dnn::DnnSupport *StreamExecutor::AsDnn() {
386   mutex_lock lock{mu_};
387   if (dnn_ != nullptr) {
388     return dnn_.get();
389   }
390 
391   dnn_.reset(implementation_->CreateDnn());
392   return dnn_.get();
393 }
394 
AsBlas()395 blas::BlasSupport *StreamExecutor::AsBlas() {
396   mutex_lock lock{mu_};
397   if (blas_ != nullptr) {
398     return blas_.get();
399   }
400 
401   blas_.reset(implementation_->CreateBlas());
402   return blas_.get();
403 }
404 
AsFft()405 fft::FftSupport *StreamExecutor::AsFft() {
406   mutex_lock lock{mu_};
407   if (fft_ != nullptr) {
408     return fft_.get();
409   }
410 
411   fft_.reset(implementation_->CreateFft());
412   return fft_.get();
413 }
414 
AsRng()415 rng::RngSupport *StreamExecutor::AsRng() {
416   mutex_lock lock{mu_};
417   if (rng_ != nullptr) {
418     return rng_.get();
419   }
420 
421   rng_.reset(implementation_->CreateRng());
422   return rng_.get();
423 }
424 
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)425 bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
426                             const BlockDim &block_dims,
427                             const KernelBase &kernel,
428                             const KernelArgsArrayBase &args) {
429   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
430               kernel, args);
431 
432   return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
433 }
434 
BlockHostUntilDone(Stream * stream)435 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
436   port::Status result;
437   SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
438 
439   result = implementation_->BlockHostUntilDone(stream);
440   return result;
441 }
442 
Allocate(uint64 size)443 void *StreamExecutor::Allocate(uint64 size) {
444   void *buf = implementation_->Allocate(size);
445   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
446           << buf << StackTraceIfVLOG10();
447   CreateAllocRecord(buf, size);
448 
449   return buf;
450 }
451 
GetSymbol(const string & symbol_name,void ** mem,size_t * bytes)452 bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem,
453                                size_t *bytes) {
454   return implementation_->GetSymbol(symbol_name, mem, bytes);
455 }
456 
HostMemoryAllocate(uint64 size)457 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
458   void *buffer = implementation_->HostMemoryAllocate(size);
459   VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
460           << ") returns " << buffer << StackTraceIfVLOG10();
461   return buffer;
462 }
463 
HostMemoryDeallocate(void * location)464 void StreamExecutor::HostMemoryDeallocate(void *location) {
465   VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
466           << ")" << StackTraceIfVLOG10();
467 
468   return implementation_->HostMemoryDeallocate(location);
469 }
470 
HostMemoryRegister(void * location,uint64 size)471 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
472   VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
473           << ", size=" << size << ")" << StackTraceIfVLOG10();
474   if (location == nullptr || size == 0) {
475     LOG(WARNING) << "attempting to register null or zero-sized memory: "
476                  << location << "; size " << size;
477   }
478   return implementation_->HostMemoryRegister(location, size);
479 }
480 
HostMemoryUnregister(void * location)481 bool StreamExecutor::HostMemoryUnregister(void *location) {
482   VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
483           << ")" << StackTraceIfVLOG10();
484   return implementation_->HostMemoryUnregister(location);
485 }
486 
SynchronizeAllActivity()487 bool StreamExecutor::SynchronizeAllActivity() {
488   VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
489           << StackTraceIfVLOG10();
490   bool ok = implementation_->SynchronizeAllActivity();
491 
492   // This should all be quick and infallible work, so we can perform the
493   // synchronization even in the case of failure.
494   BlockOnThreadExecutor(background_threads_.get());
495 
496   return ok;
497 }
498 
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)499 bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
500                                         uint64 size) {
501   VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
502           << ", size=" << size << ")" << StackTraceIfVLOG10();
503 
504   return implementation_->SynchronousMemZero(location, size);
505 }
506 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)507 bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
508                                        uint64 size) {
509   VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
510           << ", value=" << value << ", size=" << size << ")"
511           << StackTraceIfVLOG10();
512 
513   return implementation_->SynchronousMemSet(location, value, size);
514 }
515 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)516 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
517                                        const void *host_src, uint64 size) {
518   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
519           << device_dst->opaque() << ", host_src=" << host_src
520           << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
521 
522   // Tracing overloaded methods is very difficult due to issues with type
523   // inference on template args. Since use of these overloaded methods is
524   // discouraged anyway, this isn't a huge deal.
525   port::Status status =
526       implementation_->SynchronousMemcpy(device_dst, host_src, size);
527   if (!status.ok()) {
528     LOG(ERROR) << "synchronous memcpy: " << status;
529   }
530   return status.ok();
531 }
532 
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)533 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
534                                        const DeviceMemoryBase &device_src,
535                                        uint64 size) {
536   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
537           << ", device_src=" << device_src.opaque() << ", size=" << size
538           << ") D2H" << StackTraceIfVLOG10();
539 
540   port::Status status =
541       implementation_->SynchronousMemcpy(host_dst, device_src, size);
542   if (!status.ok()) {
543     LOG(ERROR) << "synchronous memcpy: " << status;
544   }
545   return status.ok();
546 }
547 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)548 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
549                                        const DeviceMemoryBase &device_src,
550                                        uint64 size) {
551   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
552           << device_dst->opaque() << ", device_src=" << device_src.opaque()
553           << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
554 
555   port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
556       device_dst, device_src, size);
557   if (!status.ok()) {
558     LOG(ERROR) << "synchronous memcpy: " << status;
559   }
560   return status.ok();
561 }
562 
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64 size,void * host_dst)563 port::Status StreamExecutor::SynchronousMemcpyD2H(
564     const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
565   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
566           << device_src.opaque() << ", size=" << size
567           << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
568 
569   port::Status result;
570   SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
571                host_dst);
572 
573   result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
574   if (!result.ok()) {
575     result = port::Status{port::error::INTERNAL,
576                           port::Printf("failed to synchronously memcpy "
577                                        "device-to-host: device %p to host %p "
578                                        "size %lld: %s",
579                                        device_src.opaque(), host_dst, size,
580                                        result.ToString().c_str())};
581   }
582 
583   return result;
584 }
585 
SynchronousMemcpyH2D(const void * host_src,int64 size,DeviceMemoryBase * device_dst)586 port::Status StreamExecutor::SynchronousMemcpyH2D(
587     const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
588   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
589           << ", size=" << size << ", device_dst" << device_dst->opaque() << ")"
590           << StackTraceIfVLOG10();
591 
592   port::Status result;
593   SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
594                device_dst);
595 
596   result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
597   if (!result.ok()) {
598     result = port::Status{
599         port::error::INTERNAL,
600         port::Printf("failed to synchronously memcpy host-to-device: host "
601                      "%p to device %p size %lld: %s",
602                      host_src, device_dst->opaque(), size,
603                      result.ToString().c_str())};
604   }
605 
606   return result;
607 }
608 
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)609 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
610                             const DeviceMemoryBase &device_src, uint64 size) {
611   return implementation_->Memcpy(stream, host_dst, device_src, size);
612 }
613 
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)614 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
615                             const void *host_src, uint64 size) {
616   return implementation_->Memcpy(stream, device_dst, host_src, size);
617 }
618 
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)619 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
620                                           DeviceMemoryBase *device_dst,
621                                           const DeviceMemoryBase &device_src,
622                                           uint64 size) {
623   return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
624                                                size);
625 }
626 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)627 bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
628                              uint64 size) {
629   return implementation_->MemZero(stream, location, size);
630 }
631 
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)632 bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
633                               uint32 pattern, uint64 size) {
634   CHECK_EQ(0, size % 4)
635       << "need 32-bit multiple size to fill with 32-bit pattern";
636   return implementation_->Memset32(stream, location, pattern, size);
637 }
638 
HostCallback(Stream * stream,std::function<void ()> callback)639 bool StreamExecutor::HostCallback(Stream *stream,
640                                   std::function<void()> callback) {
641   return implementation_->HostCallback(stream, std::move(callback));
642 }
643 
AllocateEvent(Event * event)644 port::Status StreamExecutor::AllocateEvent(Event *event) {
645   return implementation_->AllocateEvent(event);
646 }
647 
DeallocateEvent(Event * event)648 port::Status StreamExecutor::DeallocateEvent(Event *event) {
649   return implementation_->DeallocateEvent(event);
650 }
651 
RecordEvent(Stream * stream,Event * event)652 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
653   return implementation_->RecordEvent(stream, event);
654 }
655 
WaitForEvent(Stream * stream,Event * event)656 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
657   return implementation_->WaitForEvent(stream, event);
658 }
659 
PollForEventStatus(Event * event)660 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
661   return implementation_->PollForEventStatus(event);
662 }
663 
AllocateStream(Stream * stream)664 bool StreamExecutor::AllocateStream(Stream *stream) {
665   live_stream_count_.fetch_add(1, std::memory_order_relaxed);
666   if (!implementation_->AllocateStream(stream)) {
667     auto count = live_stream_count_.fetch_sub(1);
668     CHECK_GE(count, 0) << "live stream count should not dip below zero";
669     LOG(INFO) << "failed to allocate stream; live stream count: " << count;
670     return false;
671   }
672 
673   return true;
674 }
675 
DeallocateStream(Stream * stream)676 void StreamExecutor::DeallocateStream(Stream *stream) {
677   implementation_->DeallocateStream(stream);
678   CHECK_GE(live_stream_count_.fetch_sub(1), 0)
679       << "live stream count should not dip below zero";
680 }
681 
CreateStreamDependency(Stream * dependent,Stream * other)682 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
683   return implementation_->CreateStreamDependency(dependent, other);
684 }
685 
AllocateTimer(Timer * timer)686 bool StreamExecutor::AllocateTimer(Timer *timer) {
687   return implementation_->AllocateTimer(timer);
688 }
689 
DeallocateTimer(Timer * timer)690 void StreamExecutor::DeallocateTimer(Timer *timer) {
691   return implementation_->DeallocateTimer(timer);
692 }
693 
StartTimer(Stream * stream,Timer * timer)694 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
695   return implementation_->StartTimer(stream, timer);
696 }
697 
StopTimer(Stream * stream,Timer * timer)698 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
699   return implementation_->StopTimer(stream, timer);
700 }
701 
PopulateDeviceDescription() const702 DeviceDescription *StreamExecutor::PopulateDeviceDescription() const {
703   return implementation_->PopulateDeviceDescription();
704 }
705 
DeviceMemoryUsage(int64 * free,int64 * total) const706 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
707   return implementation_->DeviceMemoryUsage(free, total);
708 }
709 
EnqueueOnBackgroundThread(std::function<void ()> task)710 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
711   background_threads_->Schedule(std::move(task));
712 }
713 
CreateAllocRecord(void * opaque,uint64 bytes)714 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
715   if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
716     mutex_lock lock{mu_};
717     mem_allocs_[opaque] = AllocRecord{
718         bytes, ""};
719   }
720 }
721 
EraseAllocRecord(void * opaque)722 void StreamExecutor::EraseAllocRecord(void *opaque) {
723   if (FLAGS_check_device_leaks && opaque != nullptr) {
724     mutex_lock lock{mu_};
725     if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
726       LOG(ERROR) << "Deallocating unknown pointer: "
727                  << port::Printf("0x%p", opaque);
728     } else {
729       mem_allocs_.erase(opaque);
730     }
731   }
732 }
733 
EnableTracing(bool enabled)734 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
735 
RegisterTraceListener(TraceListener * listener)736 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
737   {
738     mutex_lock lock{mu_};
739     if (listeners_.find(listener) != listeners_.end()) {
740       LOG(INFO) << "Attempt to register already-registered listener, "
741                 << listener;
742     } else {
743       listeners_.insert(listener);
744     }
745   }
746 
747   implementation_->RegisterTraceListener(listener);
748 }
749 
UnregisterTraceListener(TraceListener * listener)750 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
751   {
752     mutex_lock lock{mu_};
753     if (listeners_.find(listener) == listeners_.end()) {
754       LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
755       return false;
756     }
757     listeners_.erase(listener);
758   }
759 
760   implementation_->UnregisterTraceListener(listener);
761   return true;
762 }
763 
764 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)765 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
766   if (tracing_enabled_) {
767     {
768       // instance tracers held in a block to limit the lock lifetime.
769       tf_shared_lock lock{mu_};
770       for (TraceListener *listener : listeners_) {
771         (listener->*trace_call)(std::forward<ArgsT>(args)...);
772       }
773     }
774   }
775 }
776 
implementation()777 internal::StreamExecutorInterface *StreamExecutor::implementation() {
778   return implementation_->GetUnderlyingExecutor();
779 }
780 
781 }  // namespace gputools
782 }  // namespace perftools
783