• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19 
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21 
22 #include <atomic>
23 #include <utility>
24 
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/core/util/env_var.h"
27 #include "tensorflow/stream_executor/blas.h"
28 #include "tensorflow/stream_executor/fft.h"
29 #include "tensorflow/stream_executor/lib/env.h"
30 #include "tensorflow/stream_executor/lib/error.h"
31 #include "tensorflow/stream_executor/lib/notification.h"
32 #include "tensorflow/stream_executor/lib/stacktrace.h"
33 #include "tensorflow/stream_executor/lib/str_util.h"
34 #include "tensorflow/stream_executor/lib/stringprintf.h"
35 #include "tensorflow/stream_executor/lib/threadpool.h"
36 #include "tensorflow/stream_executor/platform/port.h"
37 #include "tensorflow/stream_executor/rng.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 
40 namespace {
41 bool FLAGS_check_device_leaks = false;
42 }  // namespace
43 
44 namespace stream_executor {
45 namespace {
46 
StackTraceIfVLOG10()47 string StackTraceIfVLOG10() {
48   if (VLOG_IS_ON(10)) {
49     return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
50   } else {
51     return "";
52   }
53 }
54 
55 // Make sure the executor is done with its work; we know (because this isn't
56 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)57 void BlockOnThreadExecutor(port::ThreadPool *executor) {
58   port::Notification n;
59   executor->Schedule([&n]() { n.Notify(); });
60   n.WaitForNotification();
61 }
62 
StreamExecutorImplementationFromPlatformKind(PlatformKind platform_kind,const PluginConfig & plugin_config)63 internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind(
64     PlatformKind platform_kind, const PluginConfig &plugin_config) {
65   // Note: we use this factory-assignment-in-switch pattern instead of just
66   // invoking the callable in case linkage is messed up -- instead of invoking a
67   // nullptr std::function (due to failed registration) we give a nice
68   // LOG(FATAL) message.
69   internal::StreamExecutorFactory factory;
70   switch (platform_kind) {
71     case PlatformKind::kCuda:
72       factory = *internal::MakeCUDAExecutorImplementation();
73       break;
74     case PlatformKind::kROCm:
75       factory = *internal::MakeROCMExecutorImplementation();
76       break;
77     case PlatformKind::kOpenCL:
78       factory = *internal::MakeOpenCLExecutorImplementation();
79       break;
80     case PlatformKind::kHost:
81       factory = internal::MakeHostExecutorImplementation;
82       break;
83     default:
84       factory = nullptr;
85   }
86   if (factory == nullptr) {
87     LOG(FATAL)
88         << "cannot create StreamExecutor implementation for platform kind: "
89         << PlatformKindString(platform_kind);
90   }
91   return factory(plugin_config);
92 }
93 
94 std::atomic_int_fast64_t correlation_id_generator(0);
95 
96 }  // namespace
97 
98 template <typename BeginCallT, typename CompleteCallT,
99           typename ReturnT, typename... BeginArgsT>
100 class ScopedTracer {
101  public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)102   ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
103                CompleteCallT complete_call, const ReturnT *result,
104                BeginArgsT... begin_args)
105       : stream_exec_(stream_exec),
106         complete_call_(complete_call),
107         result_(result) {
108     if (stream_exec_->tracing_enabled_) {
109       correlation_id_ =
110           correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
111       Trace(begin_call, begin_args...);
112     }
113   }
114 
~ScopedTracer()115   ~ScopedTracer() {
116     if (stream_exec_->tracing_enabled_) {
117       Trace(complete_call_, result_);
118     }
119   }
120 
121  private:
122   template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)123   void Trace(CallbackT callback, TraceArgsT... args) {
124     {
125       // Instance tracers held in a block to limit the lock lifetime.
126       tf_shared_lock lock{stream_exec_->mu_};
127       for (TraceListener *listener : stream_exec_->listeners_) {
128         (listener->*callback)(correlation_id_,
129                               std::forward<TraceArgsT>(args)...);
130       }
131     }
132   }
133 
134   StreamExecutor *stream_exec_;
135   CompleteCallT complete_call_;
136   const ReturnT* result_;
137   int64 correlation_id_;
138 };
139 
140 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
141           typename... BeginArgsT>
142 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)143 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
144                  CompleteCallT complete_call, ReturnT *result,
145                  BeginArgsT... begin_args) {
146   return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
147       stream_exec, begin_call, complete_call, result,
148       std::forward<BeginArgsT>(begin_args)...);
149 }
150 
151 #define SCOPED_TRACE(LOC, ...)                                      \
152   auto tracer = MakeScopedTracer(this, &LOC ## Begin,               \
153                                  &LOC ## Complete, ## __VA_ARGS__);
154 
155 /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED};
156 
StreamExecutor(PlatformKind platform_kind,const PluginConfig & plugin_config)157 StreamExecutor::StreamExecutor(PlatformKind platform_kind,
158                                const PluginConfig &plugin_config)
159     : platform_(nullptr),
160       implementation_(StreamExecutorImplementationFromPlatformKind(
161           platform_kind, plugin_config)),
162       platform_kind_(platform_kind),
163       device_ordinal_(-1),
164       background_threads_(new port::ThreadPool(
165           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
166       live_stream_count_(0),
167       tracing_enabled_(false) {
168   CheckPlatformKindIsValid(platform_kind);
169 }
170 
171 // Get per-device memory limit in bytes. Returns 0 if
172 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()173 static int64 GetMemoryLimitBytes() {
174   int64 value;
175   SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
176                                               0, &value));
177   return value * (1ll << 20);
178 }
179 
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation)180 StreamExecutor::StreamExecutor(
181     const Platform *platform,
182     std::unique_ptr<internal::StreamExecutorInterface> implementation)
183     : platform_(platform),
184       implementation_(std::move(implementation)),
185       device_ordinal_(-1),
186       background_threads_(new port::ThreadPool(
187           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
188       live_stream_count_(0),
189       tracing_enabled_(false),
190       mem_alloc_bytes_(0),
191       memory_limit_bytes_(GetMemoryLimitBytes()) {
192   if (port::Lowercase(platform_->Name()) == "cuda") {
193     platform_kind_ = PlatformKind::kCuda;
194   } else if (port::Lowercase(platform_->Name()) == "rocm") {
195     platform_kind_ = PlatformKind::kROCm;
196   } else if (port::Lowercase(platform_->Name()) == "opencl") {
197     platform_kind_ = PlatformKind::kOpenCL;
198   } else if (port::Lowercase(platform_->Name()) == "host") {
199     platform_kind_ = PlatformKind::kHost;
200   } else {
201     platform_kind_ = PlatformKind::kInvalid;
202   }
203 }
204 
~StreamExecutor()205 StreamExecutor::~StreamExecutor() {
206   BlockOnThreadExecutor(background_threads_.get());
207 
208   if (live_stream_count_.load() != 0) {
209     LOG(WARNING) << "Not all streams were deallocated at executor destruction "
210                  << "time. This may lead to unexpected/bad behavior - "
211                  << "especially if any stream is still active!";
212   }
213 
214   if (FLAGS_check_device_leaks) {
215     for (auto it : mem_allocs_) {
216       LOG(INFO) << "Memory alloced at executor exit: addr: "
217                 << port::Printf("%p", it.first)
218                 << ", bytes: " << it.second.bytes << ", trace: \n"
219                 << it.second.stack_trace;
220     }
221   }
222 }
223 
Init(int device_ordinal,DeviceOptions device_options)224 port::Status StreamExecutor::Init(int device_ordinal,
225                                   DeviceOptions device_options) {
226   device_ordinal_ = device_ordinal;
227   return implementation_->Init(device_ordinal, std::move(device_options));
228 }
229 
Init()230 port::Status StreamExecutor::Init() {
231   return Init(0, DeviceOptions::Default());
232 }
233 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)234 bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
235                                KernelBase *kernel) {
236   return implementation_->GetKernel(spec, kernel);
237 }
238 
UnloadKernel(const KernelBase * kernel)239 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
240   implementation_->UnloadKernel(kernel);
241 }
242 
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)243 bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
244                                 ModuleHandle *module_handle) {
245   return implementation_->LoadModule(spec, module_handle);
246 }
247 
UnloadModule(ModuleHandle module_handle)248 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
249   return implementation_->UnloadModule(module_handle);
250 }
251 
Deallocate(DeviceMemoryBase * mem)252 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
253   VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
254           << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
255 
256   if (mem->opaque() != nullptr) {
257     EraseAllocRecord(mem->opaque());
258   }
259   implementation_->Deallocate(mem);
260   mem->Reset(nullptr, 0);
261 }
262 
GetMemAllocs(std::map<void *,AllocRecord> * records_out)263 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
264   tf_shared_lock lock(mu_);
265   *records_out = mem_allocs_;
266 }
267 
CanEnablePeerAccessTo(StreamExecutor * other)268 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
269   return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
270 }
271 
EnablePeerAccessTo(StreamExecutor * other)272 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
273   return implementation_->EnablePeerAccessTo(other->implementation_.get());
274 }
275 
GetDeviceSharedMemoryConfig()276 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
277   return implementation_->GetDeviceSharedMemoryConfig();
278 }
279 
SetDeviceSharedMemoryConfig(SharedMemoryConfig config)280 port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
281     SharedMemoryConfig config) {
282   if (config != SharedMemoryConfig::kDefault &&
283       config != SharedMemoryConfig::kFourByte &&
284       config != SharedMemoryConfig::kEightByte) {
285     string error_msg = port::Printf(
286         "Invalid shared memory config specified: %d", static_cast<int>(config));
287     LOG(ERROR) << error_msg;
288     return port::Status(port::error::INVALID_ARGUMENT, error_msg);
289   }
290   return implementation_->SetDeviceSharedMemoryConfig(config);
291 }
292 
GetDeviceDescription() const293 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
294   mutex_lock lock(mu_);
295   if (device_description_ != nullptr) {
296     return *device_description_;
297   }
298 
299   device_description_.reset(PopulateDeviceDescription());
300   return *device_description_;
301 }
302 
GetDeviceLoad() const303 int64 StreamExecutor::GetDeviceLoad() const {
304   return implementation_->GetDeviceLoad();
305 }
306 
PlatformDeviceCount() const307 int StreamExecutor::PlatformDeviceCount() const {
308   return implementation_->PlatformDeviceCount();
309 }
310 
SupportsBlas() const311 bool StreamExecutor::SupportsBlas() const {
312   return implementation_->SupportsBlas();
313 }
314 
SupportsRng() const315 bool StreamExecutor::SupportsRng() const {
316   return implementation_->SupportsRng();
317 }
318 
SupportsDnn() const319 bool StreamExecutor::SupportsDnn() const {
320   return implementation_->SupportsDnn();
321 }
322 
GetConvolveAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)323 bool StreamExecutor::GetConvolveAlgorithms(
324     bool with_winograd_nonfused,
325     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
326   dnn::DnnSupport *dnn_support = AsDnn();
327   if (!dnn_support) {
328     return false;
329   }
330   int cc_major, cc_minor;
331   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
332   return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
333                                             cc_minor, out_algorithms);
334 }
335 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)336 bool StreamExecutor::GetRnnAlgorithms(
337     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
338   dnn::DnnSupport *dnn_support = AsDnn();
339   if (!dnn_support) {
340     return false;
341   }
342   return dnn_support->GetRnnAlgorithms(out_algorithms);
343 }
344 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)345 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
346     bool with_winograd_nonfused,
347     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
348   dnn::DnnSupport *dnn_support = AsDnn();
349   if (!dnn_support) {
350     return false;
351   }
352   int cc_major, cc_minor;
353   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
354   return dnn_support->GetConvolveBackwardDataAlgorithms(
355       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
356 }
357 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)358 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
359     bool with_winograd_nonfused,
360     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
361   dnn::DnnSupport *dnn_support = AsDnn();
362   if (!dnn_support) {
363     return false;
364   }
365   int cc_major, cc_minor;
366   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
367   return dnn_support->GetConvolveBackwardFilterAlgorithms(
368       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
369 }
370 
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)371 bool StreamExecutor::GetBlasGemmAlgorithms(
372     std::vector<blas::AlgorithmType> *out_algorithms) {
373   blas::BlasSupport *blas_support = AsBlas();
374   if (!blas_support) {
375     return false;
376   }
377   return blas_support->GetBlasGemmAlgorithms(out_algorithms);
378 }
379 
380 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator)381 StreamExecutor::createRnnDescriptor(
382     int num_layers, int hidden_size, int input_size, int batch_size,
383     dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
384     dnn::RnnMode rnn_mode, dnn::DataType data_type,
385     const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
386     ScratchAllocator *state_allocator) {
387   dnn::DnnSupport *dnn_support = AsDnn();
388   if (!dnn_support) {
389     return port::Status(port::error::UNKNOWN,
390                         "Fail to find the dnn implementation.");
391   }
392   return dnn_support->createRnnDescriptor(
393       num_layers, hidden_size, input_size, batch_size, input_mode,
394       direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
395       state_allocator);
396 }
397 
398 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)399 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
400                                                   int batch_size, int data_size,
401                                                   dnn::DataType data_type) {
402   dnn::DnnSupport *dnn_support = AsDnn();
403   if (!dnn_support) {
404     return port::Status(port::error::UNKNOWN,
405                         "Fail to find the dnn implementation.");
406   }
407   return dnn_support->createRnnSequenceTensorDescriptor(
408       max_seq_length, batch_size, data_size, data_type);
409 }
410 
411 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)412 StreamExecutor::createRnnSequenceTensorDescriptor(
413     int max_seq_length, int batch_size, int data_size,
414     const absl::Span<const int> &seq_lengths, bool time_major,
415     dnn::DataType data_type) {
416   dnn::DnnSupport *dnn_support = AsDnn();
417   if (!dnn_support) {
418     return port::Status(port::error::UNKNOWN,
419                         "Fail to find the dnn implementation.");
420   }
421   return dnn_support->createRnnSequenceTensorDescriptor(
422       max_seq_length, batch_size, data_size, seq_lengths, time_major,
423       data_type);
424 }
425 
426 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)427 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
428                                                int data_size,
429                                                dnn::DataType data_type) {
430   dnn::DnnSupport *dnn_support = AsDnn();
431   if (!dnn_support) {
432     return port::Status(port::error::UNKNOWN,
433                         "Fail to find the dnn implementation.");
434   }
435   return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
436                                                      data_size, data_type);
437 }
438 
AsDnn()439 dnn::DnnSupport *StreamExecutor::AsDnn() {
440   mutex_lock lock(mu_);
441   if (dnn_ != nullptr) {
442     return dnn_.get();
443   }
444 
445   dnn_.reset(implementation_->CreateDnn());
446   return dnn_.get();
447 }
448 
AsBlas()449 blas::BlasSupport *StreamExecutor::AsBlas() {
450   mutex_lock lock(mu_);
451   if (blas_ != nullptr) {
452     return blas_.get();
453   }
454 
455   blas_.reset(implementation_->CreateBlas());
456   return blas_.get();
457 }
458 
AsFft()459 fft::FftSupport *StreamExecutor::AsFft() {
460   mutex_lock lock(mu_);
461   if (fft_ != nullptr) {
462     return fft_.get();
463   }
464 
465   fft_.reset(implementation_->CreateFft());
466   return fft_.get();
467 }
468 
AsRng()469 rng::RngSupport *StreamExecutor::AsRng() {
470   mutex_lock lock(mu_);
471   if (rng_ != nullptr) {
472     return rng_.get();
473   }
474 
475   rng_.reset(implementation_->CreateRng());
476   return rng_.get();
477 }
478 
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)479 bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
480                             const BlockDim &block_dims,
481                             const KernelBase &kernel,
482                             const KernelArgsArrayBase &args) {
483   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
484               kernel, args);
485 
486   return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
487 }
488 
BlockHostUntilDone(Stream * stream)489 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
490   port::Status result;
491   SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
492 
493   result = implementation_->BlockHostUntilDone(stream);
494   return result;
495 }
496 
GetStatus(Stream * stream)497 port::Status StreamExecutor::GetStatus(Stream *stream) {
498   return implementation_->GetStatus(stream);
499 }
500 
Allocate(uint64 size)501 void *StreamExecutor::Allocate(uint64 size) {
502   if (memory_limit_bytes_ > 0 &&
503       mem_alloc_bytes_ + size > memory_limit_bytes_) {
504     LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
505                  << device_ordinal_
506                  << " within provided limit. [used=" << mem_alloc_bytes_
507                  << ", limit=" << memory_limit_bytes_ << "]";
508     return nullptr;
509   }
510   void *buf = implementation_->Allocate(size);
511   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
512           << buf << StackTraceIfVLOG10();
513   CreateAllocRecord(buf, size);
514 
515   return buf;
516 }
517 
GetUntypedSymbol(const string & symbol_name,ModuleHandle module_handle)518 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
519     const string &symbol_name, ModuleHandle module_handle) {
520   // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
521   // be nullptr/0 for consistency with DeviceMemory semantics.
522   void *opaque = nullptr;
523   size_t bytes = 0;
524   if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
525     return DeviceMemoryBase(opaque, bytes);
526   }
527 
528   if (static_cast<bool>(module_handle)) {
529     return port::Status(
530         port::error::NOT_FOUND,
531         absl::StrCat("Check if module containing symbol ", symbol_name,
532                      " is loaded (module_handle = ",
533                      reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
534   } else {
535     return port::Status(
536         port::error::NOT_FOUND,
537         absl::StrCat("Check if kernel using the symbol is loaded: ",
538                      symbol_name));
539   }
540 }
541 
GetSymbol(const string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)542 bool StreamExecutor::GetSymbol(const string &symbol_name,
543                                ModuleHandle module_handle, void **mem,
544                                size_t *bytes) {
545   return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
546 }
547 
UnifiedMemoryAllocate(uint64 bytes)548 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
549   void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
550   VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
551           << ") returns " << buffer << StackTraceIfVLOG10();
552   return buffer;
553 }
554 
UnifiedMemoryDeallocate(void * location)555 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
556   VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
557           << location << ")" << StackTraceIfVLOG10();
558 
559   return implementation_->UnifiedMemoryDeallocate(location);
560 }
561 
HostMemoryAllocate(uint64 size)562 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
563   void *buffer = implementation_->HostMemoryAllocate(size);
564   VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
565           << ") returns " << buffer << StackTraceIfVLOG10();
566   return buffer;
567 }
568 
HostMemoryDeallocate(void * location)569 void StreamExecutor::HostMemoryDeallocate(void *location) {
570   VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
571           << ")" << StackTraceIfVLOG10();
572 
573   return implementation_->HostMemoryDeallocate(location);
574 }
575 
HostMemoryRegister(void * location,uint64 size)576 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
577   VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
578           << ", size=" << size << ")" << StackTraceIfVLOG10();
579   if (location == nullptr || size == 0) {
580     LOG(WARNING) << "attempting to register null or zero-sized memory: "
581                  << location << "; size " << size;
582   }
583   return implementation_->HostMemoryRegister(location, size);
584 }
585 
HostMemoryUnregister(void * location)586 bool StreamExecutor::HostMemoryUnregister(void *location) {
587   VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
588           << ")" << StackTraceIfVLOG10();
589   return implementation_->HostMemoryUnregister(location);
590 }
591 
SynchronizeAllActivity()592 bool StreamExecutor::SynchronizeAllActivity() {
593   VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
594           << StackTraceIfVLOG10();
595   bool ok = implementation_->SynchronizeAllActivity();
596 
597   // This should all be quick and infallible work, so we can perform the
598   // synchronization even in the case of failure.
599   BlockOnThreadExecutor(background_threads_.get());
600 
601   return ok;
602 }
603 
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)604 bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
605                                         uint64 size) {
606   VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
607           << ", size=" << size << ")" << StackTraceIfVLOG10();
608 
609   return implementation_->SynchronousMemZero(location, size);
610 }
611 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)612 bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
613                                        uint64 size) {
614   VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
615           << ", value=" << value << ", size=" << size << ")"
616           << StackTraceIfVLOG10();
617 
618   return implementation_->SynchronousMemSet(location, value, size);
619 }
620 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)621 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
622                                        const void *host_src, uint64 size) {
623   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
624           << device_dst->opaque() << ", host_src=" << host_src
625           << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
626 
627   // Tracing overloaded methods is very difficult due to issues with type
628   // inference on template args. Since use of these overloaded methods is
629   // discouraged anyway, this isn't a huge deal.
630   port::Status status =
631       implementation_->SynchronousMemcpy(device_dst, host_src, size);
632   if (!status.ok()) {
633     LOG(ERROR) << "synchronous memcpy: " << status;
634   }
635   return status.ok();
636 }
637 
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)638 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
639                                        const DeviceMemoryBase &device_src,
640                                        uint64 size) {
641   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
642           << ", device_src=" << device_src.opaque() << ", size=" << size
643           << ") D2H" << StackTraceIfVLOG10();
644 
645   port::Status status =
646       implementation_->SynchronousMemcpy(host_dst, device_src, size);
647   if (!status.ok()) {
648     LOG(ERROR) << "synchronous memcpy: " << status;
649   }
650   return status.ok();
651 }
652 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)653 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
654                                        const DeviceMemoryBase &device_src,
655                                        uint64 size) {
656   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
657           << device_dst->opaque() << ", device_src=" << device_src.opaque()
658           << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
659 
660   port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
661       device_dst, device_src, size);
662   if (!status.ok()) {
663     LOG(ERROR) << "synchronous memcpy: " << status;
664   }
665   return status.ok();
666 }
667 
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64 size,void * host_dst)668 port::Status StreamExecutor::SynchronousMemcpyD2H(
669     const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
670   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
671           << device_src.opaque() << ", size=" << size
672           << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
673 
674   port::Status result;
675   SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
676                host_dst);
677 
678   result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
679   if (!result.ok()) {
680     result = port::Status(port::error::INTERNAL,
681                           port::Printf("failed to synchronously memcpy "
682                                        "device-to-host: device %p to host %p "
683                                        "size %lld: %s",
684                                        device_src.opaque(), host_dst, size,
685                                        result.ToString().c_str()));
686   }
687 
688   return result;
689 }
690 
SynchronousMemcpyH2D(const void * host_src,int64 size,DeviceMemoryBase * device_dst)691 port::Status StreamExecutor::SynchronousMemcpyH2D(
692     const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
693   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
694           << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
695           << StackTraceIfVLOG10();
696 
697   port::Status result;
698   SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
699                device_dst);
700 
701   result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
702   if (!result.ok()) {
703     result = port::Status(
704         port::error::INTERNAL,
705         port::Printf("failed to synchronously memcpy host-to-device: host "
706                      "%p to device %p size %lld: %s",
707                      host_src, device_dst->opaque(), size,
708                      result.ToString().c_str()));
709   }
710 
711   return result;
712 }
713 
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)714 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
715                             const DeviceMemoryBase &device_src, uint64 size) {
716   return implementation_->Memcpy(stream, host_dst, device_src, size);
717 }
718 
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)719 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
720                             const void *host_src, uint64 size) {
721   return implementation_->Memcpy(stream, device_dst, host_src, size);
722 }
723 
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)724 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
725                                           DeviceMemoryBase *device_dst,
726                                           const DeviceMemoryBase &device_src,
727                                           uint64 size) {
728   return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
729                                                size);
730 }
731 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)732 bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
733                              uint64 size) {
734   return implementation_->MemZero(stream, location, size);
735 }
736 
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)737 bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
738                               uint32 pattern, uint64 size) {
739   CHECK_EQ(0, size % 4)
740       << "need 32-bit multiple size to fill with 32-bit pattern";
741   return implementation_->Memset32(stream, location, pattern, size);
742 }
743 
HostCallback(Stream * stream,std::function<void ()> callback)744 bool StreamExecutor::HostCallback(Stream *stream,
745                                   std::function<void()> callback) {
746   return implementation_->HostCallback(stream, std::move(callback));
747 }
748 
HostCallback(Stream * stream,std::function<port::Status ()> callback)749 bool StreamExecutor::HostCallback(Stream *stream,
750                                   std::function<port::Status()> callback) {
751   return implementation_->HostCallback(stream, std::move(callback));
752 }
753 
AllocateEvent(Event * event)754 port::Status StreamExecutor::AllocateEvent(Event *event) {
755   return implementation_->AllocateEvent(event);
756 }
757 
DeallocateEvent(Event * event)758 port::Status StreamExecutor::DeallocateEvent(Event *event) {
759   return implementation_->DeallocateEvent(event);
760 }
761 
RecordEvent(Stream * stream,Event * event)762 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
763   return implementation_->RecordEvent(stream, event);
764 }
765 
WaitForEvent(Stream * stream,Event * event)766 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
767   return implementation_->WaitForEvent(stream, event);
768 }
769 
PollForEventStatus(Event * event)770 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
771   return implementation_->PollForEventStatus(event);
772 }
773 
AllocateStream(Stream * stream)774 bool StreamExecutor::AllocateStream(Stream *stream) {
775   live_stream_count_.fetch_add(1, std::memory_order_relaxed);
776   if (!implementation_->AllocateStream(stream)) {
777     auto count = live_stream_count_.fetch_sub(1);
778     CHECK_GE(count, 0) << "live stream count should not dip below zero";
779     LOG(INFO) << "failed to allocate stream; live stream count: " << count;
780     return false;
781   }
782 
783   return true;
784 }
785 
DeallocateStream(Stream * stream)786 void StreamExecutor::DeallocateStream(Stream *stream) {
787   implementation_->DeallocateStream(stream);
788   CHECK_GE(live_stream_count_.fetch_sub(1), 0)
789       << "live stream count should not dip below zero";
790 }
791 
CreateStreamDependency(Stream * dependent,Stream * other)792 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
793   return implementation_->CreateStreamDependency(dependent, other);
794 }
795 
AllocateTimer(Timer * timer)796 bool StreamExecutor::AllocateTimer(Timer *timer) {
797   return implementation_->AllocateTimer(timer);
798 }
799 
DeallocateTimer(Timer * timer)800 void StreamExecutor::DeallocateTimer(Timer *timer) {
801   return implementation_->DeallocateTimer(timer);
802 }
803 
StartTimer(Stream * stream,Timer * timer)804 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
805   return implementation_->StartTimer(stream, timer);
806 }
807 
StopTimer(Stream * stream,Timer * timer)808 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
809   return implementation_->StopTimer(stream, timer);
810 }
811 
PopulateDeviceDescription() const812 DeviceDescription *StreamExecutor::PopulateDeviceDescription() const {
813   return implementation_->PopulateDeviceDescription();
814 }
815 
DeviceMemoryUsage(int64 * free,int64 * total) const816 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
817   return implementation_->DeviceMemoryUsage(free, total);
818 }
819 
EnqueueOnBackgroundThread(std::function<void ()> task)820 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
821   background_threads_->Schedule(std::move(task));
822 }
823 
CreateAllocRecord(void * opaque,uint64 bytes)824 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
825   if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
826     mutex_lock lock(mu_);
827     mem_allocs_[opaque] = AllocRecord{
828         bytes, ""};
829     mem_alloc_bytes_ += bytes;
830   }
831 }
832 
EraseAllocRecord(void * opaque)833 void StreamExecutor::EraseAllocRecord(void *opaque) {
834   if (FLAGS_check_device_leaks && opaque != nullptr) {
835     mutex_lock lock(mu_);
836     if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
837       LOG(ERROR) << "Deallocating unknown pointer: "
838                  << port::Printf("0x%p", opaque);
839     } else {
840       mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
841       mem_allocs_.erase(opaque);
842     }
843   }
844 }
845 
EnableTracing(bool enabled)846 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
847 
RegisterTraceListener(TraceListener * listener)848 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
849   {
850     mutex_lock lock(mu_);
851     if (listeners_.find(listener) != listeners_.end()) {
852       LOG(INFO) << "Attempt to register already-registered listener, "
853                 << listener;
854     } else {
855       listeners_.insert(listener);
856     }
857   }
858 
859   implementation_->RegisterTraceListener(listener);
860 }
861 
UnregisterTraceListener(TraceListener * listener)862 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
863   {
864     mutex_lock lock(mu_);
865     if (listeners_.find(listener) == listeners_.end()) {
866       LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
867       return false;
868     }
869     listeners_.erase(listener);
870   }
871 
872   implementation_->UnregisterTraceListener(listener);
873   return true;
874 }
875 
GetAllocatorStats()876 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
877   return implementation_->GetAllocatorStats();
878 }
879 
880 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)881 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
882   if (tracing_enabled_) {
883     {
884       // instance tracers held in a block to limit the lock lifetime.
885       tf_shared_lock lock(mu_);
886       for (TraceListener *listener : listeners_) {
887         (listener->*trace_call)(std::forward<ArgsT>(args)...);
888       }
889     }
890   }
891 }
892 
implementation()893 internal::StreamExecutorInterface *StreamExecutor::implementation() {
894   return implementation_->GetUnderlyingExecutor();
895 }
896 
897 }  // namespace stream_executor
898