• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19 
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21 
22 #include <atomic>
23 #include <memory>
24 #include <utility>
25 
26 #include "absl/base/const_init.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/synchronization/notification.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/fft.h"
34 #include "tensorflow/stream_executor/lib/env.h"
35 #include "tensorflow/stream_executor/lib/error.h"
36 #include "tensorflow/stream_executor/lib/stacktrace.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/lib/threadpool.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/rng.h"
41 #include "tensorflow/stream_executor/stream_executor_internal.h"
42 
43 namespace {
44 bool FLAGS_check_device_leaks = false;
45 }  // namespace
46 
47 namespace stream_executor {
48 namespace {
49 
StackTraceIfVLOG10()50 string StackTraceIfVLOG10() {
51   if (VLOG_IS_ON(10)) {
52     return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
53   } else {
54     return "";
55   }
56 }
57 
58 // Make sure the executor is done with its work; we know (because this isn't
59 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)60 void BlockOnThreadExecutor(port::ThreadPool *executor) {
61   absl::Notification n;
62   executor->Schedule([&n]() { n.Notify(); });
63   n.WaitForNotification();
64 }
65 
66 std::atomic_int_fast64_t correlation_id_generator(0);
67 
68 }  // namespace
69 
70 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
71           typename... BeginArgsT>
72 class ScopedTracer {
73  public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)74   ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
75                CompleteCallT complete_call, const ReturnT *result,
76                BeginArgsT... begin_args)
77       : stream_exec_(stream_exec),
78         complete_call_(complete_call),
79         result_(result) {
80     if (stream_exec_->tracing_enabled_) {
81       correlation_id_ =
82           correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
83       Trace(begin_call, begin_args...);
84     }
85   }
86 
~ScopedTracer()87   ~ScopedTracer() {
88     if (stream_exec_->tracing_enabled_) {
89       Trace(complete_call_, result_);
90     }
91   }
92 
93  private:
94   template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)95   void Trace(CallbackT callback, TraceArgsT... args) {
96     {
97       // Instance tracers held in a block to limit the lock lifetime.
98       absl::ReaderMutexLock lock{&stream_exec_->mu_};
99       for (TraceListener *listener : stream_exec_->listeners_) {
100         (listener->*callback)(correlation_id_,
101                               std::forward<TraceArgsT>(args)...);
102       }
103     }
104   }
105 
106   StreamExecutor *stream_exec_;
107   CompleteCallT complete_call_;
108   const ReturnT *result_;
109   int64 correlation_id_;
110 };
111 
112 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
113           typename... BeginArgsT>
114 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)115 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
116                  CompleteCallT complete_call, ReturnT *result,
117                  BeginArgsT... begin_args) {
118   return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
119       stream_exec, begin_call, complete_call, result,
120       std::forward<BeginArgsT>(begin_args)...);
121 }
122 
123 #define SCOPED_TRACE(LOC, ...) \
124   auto tracer =                \
125       MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);
126 
127 /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
128 
129 // Get per-device memory limit in bytes. Returns 0 if
130 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()131 static int64 GetMemoryLimitBytes() {
132   int64 value;
133   SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
134                                               0, &value));
135   return value * (1ll << 20);
136 }
137 
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation,int device_ordinal)138 StreamExecutor::StreamExecutor(
139     const Platform *platform,
140     std::unique_ptr<internal::StreamExecutorInterface> implementation,
141     int device_ordinal)
142     : platform_(platform),
143       implementation_(std::move(implementation)),
144       device_ordinal_(device_ordinal),
145       background_threads_(new port::ThreadPool(
146           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
147       live_stream_count_(0),
148       tracing_enabled_(false),
149       mem_alloc_bytes_(0),
150       memory_limit_bytes_(GetMemoryLimitBytes()),
151       allocator_(this) {
152   string name = absl::AsciiStrToLower(platform_->Name());
153   if (name == "cuda") {
154     platform_kind_ = PlatformKind::kCuda;
155   } else if (name == "rocm") {
156     platform_kind_ = PlatformKind::kROCm;
157   } else if (name == "opencl") {
158     platform_kind_ = PlatformKind::kOpenCL;
159   } else if (name == "host") {
160     platform_kind_ = PlatformKind::kHost;
161   } else {
162     platform_kind_ = PlatformKind::kInvalid;
163   }
164 }
165 
~StreamExecutor()166 StreamExecutor::~StreamExecutor() {
167   BlockOnThreadExecutor(background_threads_.get());
168 
169   if (live_stream_count_.load() != 0) {
170     LOG(WARNING) << "Not all streams were deallocated at executor destruction "
171                  << "time. This may lead to unexpected/bad behavior - "
172                  << "especially if any stream is still active!";
173   }
174 
175   if (FLAGS_check_device_leaks) {
176     for (auto it : mem_allocs_) {
177       LOG(INFO) << "Memory alloced at executor exit: addr: "
178                 << absl::StrFormat("%p", it.first)
179                 << ", bytes: " << it.second.bytes << ", trace: \n"
180                 << it.second.stack_trace;
181     }
182   }
183 }
184 
Init(DeviceOptions device_options)185 port::Status StreamExecutor::Init(DeviceOptions device_options) {
186   return implementation_->Init(device_ordinal_, std::move(device_options));
187 }
188 
Init()189 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
190 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)191 port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
192                                        KernelBase *kernel) {
193   return implementation_->GetKernel(spec, kernel);
194 }
195 
UnloadKernel(const KernelBase * kernel)196 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
197   implementation_->UnloadKernel(kernel);
198 }
199 
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)200 port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
201                                         ModuleHandle *module_handle) {
202   return implementation_->LoadModule(spec, module_handle);
203 }
204 
UnloadModule(ModuleHandle module_handle)205 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
206   return implementation_->UnloadModule(module_handle);
207 }
208 
Deallocate(DeviceMemoryBase * mem)209 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
210   VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
211           << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
212 
213   if (mem->opaque() != nullptr) {
214     EraseAllocRecord(mem->opaque());
215   }
216   implementation_->Deallocate(mem);
217   mem->Reset(nullptr, 0);
218 }
219 
GetMemAllocs(std::map<void *,AllocRecord> * records_out)220 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
221   absl::ReaderMutexLock lock(&mu_);
222   *records_out = mem_allocs_;
223 }
224 
CanEnablePeerAccessTo(StreamExecutor * other)225 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
226   return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
227 }
228 
EnablePeerAccessTo(StreamExecutor * other)229 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
230   return implementation_->EnablePeerAccessTo(other->implementation_.get());
231 }
232 
GetDeviceSharedMemoryConfig()233 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
234   return implementation_->GetDeviceSharedMemoryConfig();
235 }
236 
SetDeviceSharedMemoryConfig(SharedMemoryConfig config)237 port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
238     SharedMemoryConfig config) {
239   if (config != SharedMemoryConfig::kDefault &&
240       config != SharedMemoryConfig::kFourByte &&
241       config != SharedMemoryConfig::kEightByte) {
242     string error_msg = absl::StrFormat(
243         "Invalid shared memory config specified: %d", static_cast<int>(config));
244     LOG(ERROR) << error_msg;
245     return port::Status(port::error::INVALID_ARGUMENT, error_msg);
246   }
247   return implementation_->SetDeviceSharedMemoryConfig(config);
248 }
249 
GetDeviceDescription() const250 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
251   absl::MutexLock lock(&mu_);
252   if (device_description_ != nullptr) {
253     return *device_description_;
254   }
255 
256   device_description_ = CreateDeviceDescription();
257   return *device_description_;
258 }
259 
GetDeviceLoad() const260 int64 StreamExecutor::GetDeviceLoad() const {
261   return implementation_->GetDeviceLoad();
262 }
263 
PlatformDeviceCount() const264 int StreamExecutor::PlatformDeviceCount() const {
265   return implementation_->PlatformDeviceCount();
266 }
267 
SupportsBlas() const268 bool StreamExecutor::SupportsBlas() const {
269   return implementation_->SupportsBlas();
270 }
271 
SupportsRng() const272 bool StreamExecutor::SupportsRng() const {
273   return implementation_->SupportsRng();
274 }
275 
SupportsDnn() const276 bool StreamExecutor::SupportsDnn() const {
277   return implementation_->SupportsDnn();
278 }
279 
GetConvolveAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)280 bool StreamExecutor::GetConvolveAlgorithms(
281     bool with_winograd_nonfused,
282     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
283   dnn::DnnSupport *dnn_support = AsDnn();
284   if (!dnn_support) {
285     return false;
286   }
287   int cc_major, cc_minor;
288   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
289   return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
290                                             cc_minor, out_algorithms);
291 }
292 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,Stream * stream,dnn::DataType element_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,std::vector<dnn::ProfileResult> * out_algorithms)293 bool StreamExecutor::GetMIOpenConvolveAlgorithms(
294     dnn::ConvolutionKind kind, Stream *stream, dnn::DataType element_type,
295     const dnn::BatchDescriptor &input_descriptor,
296     const dnn::FilterDescriptor &filter_descriptor,
297     const dnn::ConvolutionDescriptor &convolution_descriptor,
298     const dnn::BatchDescriptor &output_descriptor,
299     std::vector<dnn::ProfileResult> *out_algorithms) {
300   dnn::DnnSupport *dnn_support = AsDnn();
301   if (!dnn_support) {
302     return false;
303   }
304   return dnn_support->GetMIOpenConvolveAlgorithms(
305       kind, stream, element_type, input_descriptor, filter_descriptor,
306       convolution_descriptor, output_descriptor, out_algorithms);
307 }
308 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)309 bool StreamExecutor::GetRnnAlgorithms(
310     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
311   dnn::DnnSupport *dnn_support = AsDnn();
312   if (!dnn_support) {
313     return false;
314   }
315   return dnn_support->GetRnnAlgorithms(out_algorithms);
316 }
317 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)318 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
319     bool with_winograd_nonfused,
320     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
321   dnn::DnnSupport *dnn_support = AsDnn();
322   if (!dnn_support) {
323     return false;
324   }
325   int cc_major, cc_minor;
326   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
327   return dnn_support->GetConvolveBackwardDataAlgorithms(
328       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
329 }
330 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)331 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
332     bool with_winograd_nonfused,
333     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
334   dnn::DnnSupport *dnn_support = AsDnn();
335   if (!dnn_support) {
336     return false;
337   }
338   int cc_major, cc_minor;
339   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
340   return dnn_support->GetConvolveBackwardFilterAlgorithms(
341       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
342 }
343 
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)344 bool StreamExecutor::GetBlasGemmAlgorithms(
345     std::vector<blas::AlgorithmType> *out_algorithms) {
346   blas::BlasSupport *blas_support = AsBlas();
347   if (!blas_support) {
348     return false;
349   }
350   return blas_support->GetBlasGemmAlgorithms(out_algorithms);
351 }
352 
353 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)354 StreamExecutor::createRnnDescriptor(
355     int num_layers, int hidden_size, int input_size, int cell_size,
356     int batch_size, dnn::RnnInputMode input_mode,
357     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
358     dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
359     float dropout, uint64 seed, ScratchAllocator *state_allocator,
360     bool use_padded_io) {
361   dnn::DnnSupport *dnn_support = AsDnn();
362   if (!dnn_support) {
363     return port::Status(port::error::UNKNOWN,
364                         "Fail to find the dnn implementation.");
365   }
366   return dnn_support->createRnnDescriptor(
367       num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
368       direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
369       state_allocator, use_padded_io);
370 }
371 
372 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)373 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
374                                                   int batch_size, int data_size,
375                                                   dnn::DataType data_type) {
376   dnn::DnnSupport *dnn_support = AsDnn();
377   if (!dnn_support) {
378     return port::Status(port::error::UNKNOWN,
379                         "Fail to find the dnn implementation.");
380   }
381   return dnn_support->createRnnSequenceTensorDescriptor(
382       max_seq_length, batch_size, data_size, data_type);
383 }
384 
385 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)386 StreamExecutor::createRnnSequenceTensorDescriptor(
387     int max_seq_length, int batch_size, int data_size,
388     const absl::Span<const int> &seq_lengths, bool time_major,
389     dnn::DataType data_type) {
390   dnn::DnnSupport *dnn_support = AsDnn();
391   if (!dnn_support) {
392     return port::Status(port::error::UNKNOWN,
393                         "Fail to find the dnn implementation.");
394   }
395   return dnn_support->createRnnSequenceTensorDescriptor(
396       max_seq_length, batch_size, data_size, seq_lengths, time_major,
397       data_type);
398 }
399 
400 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)401 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
402                                                int data_size,
403                                                dnn::DataType data_type) {
404   dnn::DnnSupport *dnn_support = AsDnn();
405   if (!dnn_support) {
406     return port::Status(port::error::UNKNOWN,
407                         "Fail to find the dnn implementation.");
408   }
409   return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
410                                                      data_size, data_type);
411 }
412 
AsDnn()413 dnn::DnnSupport *StreamExecutor::AsDnn() {
414   absl::MutexLock lock(&mu_);
415   if (dnn_ != nullptr) {
416     return dnn_.get();
417   }
418 
419   dnn_.reset(implementation_->CreateDnn());
420   return dnn_.get();
421 }
422 
AsBlas()423 blas::BlasSupport *StreamExecutor::AsBlas() {
424   absl::MutexLock lock(&mu_);
425   if (blas_ != nullptr) {
426     return blas_.get();
427   }
428 
429   blas_.reset(implementation_->CreateBlas());
430   return blas_.get();
431 }
432 
AsFft()433 fft::FftSupport *StreamExecutor::AsFft() {
434   absl::MutexLock lock(&mu_);
435   if (fft_ != nullptr) {
436     return fft_.get();
437   }
438 
439   fft_.reset(implementation_->CreateFft());
440   return fft_.get();
441 }
442 
AsRng()443 rng::RngSupport *StreamExecutor::AsRng() {
444   absl::MutexLock lock(&mu_);
445   if (rng_ != nullptr) {
446     return rng_.get();
447   }
448 
449   rng_.reset(implementation_->CreateRng());
450   return rng_.get();
451 }
452 
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)453 port::Status StreamExecutor::Launch(Stream *stream,
454                                     const ThreadDim &thread_dims,
455                                     const BlockDim &block_dims,
456                                     const KernelBase &kernel,
457                                     const KernelArgsArrayBase &args) {
458   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
459               kernel, args);
460 
461   return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
462 }
463 
BlockHostUntilDone(Stream * stream)464 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
465   port::Status result;
466   SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
467 
468   result = implementation_->BlockHostUntilDone(stream);
469   return result;
470 }
471 
GetStatus(Stream * stream)472 port::Status StreamExecutor::GetStatus(Stream *stream) {
473   return implementation_->GetStatus(stream);
474 }
475 
Allocate(uint64 size,int64 memory_space)476 DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64 memory_space) {
477   if (memory_limit_bytes_ > 0 &&
478       mem_alloc_bytes_ + size > memory_limit_bytes_) {
479     LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
480                  << device_ordinal_
481                  << " within provided limit. [used=" << mem_alloc_bytes_
482                  << ", limit=" << memory_limit_bytes_ << "]";
483     return DeviceMemoryBase();
484   }
485   DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
486   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
487           << ", memory_space=" << memory_space << ") returns " << buf.opaque()
488           << StackTraceIfVLOG10();
489   CreateAllocRecord(buf.opaque(), size);
490 
491   return buf;
492 }
493 
GetUntypedSymbol(const string & symbol_name,ModuleHandle module_handle)494 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
495     const string &symbol_name, ModuleHandle module_handle) {
496   // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
497   // be nullptr/0 for consistency with DeviceMemory semantics.
498   void *opaque = nullptr;
499   size_t bytes = 0;
500   if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
501     return DeviceMemoryBase(opaque, bytes);
502   }
503 
504   if (static_cast<bool>(module_handle)) {
505     return port::Status(
506         port::error::NOT_FOUND,
507         absl::StrCat("Check if module containing symbol ", symbol_name,
508                      " is loaded (module_handle = ",
509                      reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
510   } else {
511     return port::Status(
512         port::error::NOT_FOUND,
513         absl::StrCat("Check if kernel using the symbol is loaded: ",
514                      symbol_name));
515   }
516 }
517 
GetSymbol(const string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)518 bool StreamExecutor::GetSymbol(const string &symbol_name,
519                                ModuleHandle module_handle, void **mem,
520                                size_t *bytes) {
521   return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
522 }
523 
UnifiedMemoryAllocate(uint64 bytes)524 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
525   void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
526   VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
527           << ") returns " << buffer << StackTraceIfVLOG10();
528   return buffer;
529 }
530 
UnifiedMemoryDeallocate(void * location)531 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
532   VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
533           << location << ")" << StackTraceIfVLOG10();
534 
535   return implementation_->UnifiedMemoryDeallocate(location);
536 }
537 
HostMemoryAllocate(uint64 size)538 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
539   void *buffer = implementation_->HostMemoryAllocate(size);
540   VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
541           << ") returns " << buffer << StackTraceIfVLOG10();
542   return buffer;
543 }
544 
HostMemoryDeallocate(void * location)545 void StreamExecutor::HostMemoryDeallocate(void *location) {
546   VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
547           << ")" << StackTraceIfVLOG10();
548 
549   return implementation_->HostMemoryDeallocate(location);
550 }
551 
HostMemoryRegister(void * location,uint64 size)552 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
553   VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
554           << ", size=" << size << ")" << StackTraceIfVLOG10();
555   if (location == nullptr || size == 0) {
556     LOG(WARNING) << "attempting to register null or zero-sized memory: "
557                  << location << "; size " << size;
558   }
559   return implementation_->HostMemoryRegister(location, size);
560 }
561 
HostMemoryUnregister(void * location)562 bool StreamExecutor::HostMemoryUnregister(void *location) {
563   VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
564           << ")" << StackTraceIfVLOG10();
565   return implementation_->HostMemoryUnregister(location);
566 }
567 
SynchronizeAllActivity()568 bool StreamExecutor::SynchronizeAllActivity() {
569   VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
570           << StackTraceIfVLOG10();
571   bool ok = implementation_->SynchronizeAllActivity();
572 
573   // This should all be quick and infallible work, so we can perform the
574   // synchronization even in the case of failure.
575   BlockOnThreadExecutor(background_threads_.get());
576 
577   return ok;
578 }
579 
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)580 port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
581                                                 uint64 size) {
582   VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
583           << ", size=" << size << ")" << StackTraceIfVLOG10();
584 
585   return implementation_->SynchronousMemZero(location, size);
586 }
587 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)588 port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location,
589                                                int value, uint64 size) {
590   VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
591           << ", value=" << value << ", size=" << size << ")"
592           << StackTraceIfVLOG10();
593 
594   return implementation_->SynchronousMemSet(location, value, size);
595 }
596 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)597 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
598                                        const void *host_src, uint64 size) {
599   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
600           << device_dst->opaque() << ", host_src=" << host_src
601           << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
602 
603   // Tracing overloaded methods is very difficult due to issues with type
604   // inference on template args. Since use of these overloaded methods is
605   // discouraged anyway, this isn't a huge deal.
606   port::Status status =
607       implementation_->SynchronousMemcpy(device_dst, host_src, size);
608   if (!status.ok()) {
609     LOG(ERROR) << "synchronous memcpy: " << status;
610   }
611   return status.ok();
612 }
613 
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)614 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
615                                        const DeviceMemoryBase &device_src,
616                                        uint64 size) {
617   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
618           << ", device_src=" << device_src.opaque() << ", size=" << size
619           << ") D2H" << StackTraceIfVLOG10();
620 
621   port::Status status =
622       implementation_->SynchronousMemcpy(host_dst, device_src, size);
623   if (!status.ok()) {
624     LOG(ERROR) << "synchronous memcpy: " << status;
625   }
626   return status.ok();
627 }
628 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)629 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
630                                        const DeviceMemoryBase &device_src,
631                                        uint64 size) {
632   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
633           << device_dst->opaque() << ", device_src=" << device_src.opaque()
634           << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
635 
636   port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
637       device_dst, device_src, size);
638   if (!status.ok()) {
639     LOG(ERROR) << "synchronous memcpy: " << status;
640   }
641   return status.ok();
642 }
643 
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64 size,void * host_dst)644 port::Status StreamExecutor::SynchronousMemcpyD2H(
645     const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
646   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
647           << device_src.opaque() << ", size=" << size
648           << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
649 
650   port::Status result;
651   SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
652                host_dst);
653 
654   result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
655   if (!result.ok()) {
656     result = port::Status(
657         port::error::INTERNAL,
658         absl::StrFormat("failed to synchronously memcpy device-to-host: device "
659                         "%p to host %p size %d: %s",
660                         device_src.opaque(), host_dst, size,
661                         result.ToString()));
662   }
663 
664   return result;
665 }
666 
SynchronousMemcpyH2D(const void * host_src,int64 size,DeviceMemoryBase * device_dst)667 port::Status StreamExecutor::SynchronousMemcpyH2D(
668     const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
669   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
670           << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
671           << StackTraceIfVLOG10();
672 
673   port::Status result;
674   SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
675                device_dst);
676 
677   result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
678   if (!result.ok()) {
679     result = port::Status(
680         port::error::INTERNAL,
681         absl::StrFormat("failed to synchronously memcpy host-to-device: host "
682                         "%p to device %p size %d: %s",
683                         host_src, device_dst->opaque(), size,
684                         result.ToString()));
685   }
686 
687   return result;
688 }
689 
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)690 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
691                             const DeviceMemoryBase &device_src, uint64 size) {
692   return implementation_->Memcpy(stream, host_dst, device_src, size);
693 }
694 
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)695 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
696                             const void *host_src, uint64 size) {
697   return implementation_->Memcpy(stream, device_dst, host_src, size);
698 }
699 
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)700 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
701                                           DeviceMemoryBase *device_dst,
702                                           const DeviceMemoryBase &device_src,
703                                           uint64 size) {
704   return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
705                                                size);
706 }
707 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)708 port::Status StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
709                                      uint64 size) {
710   return implementation_->MemZero(stream, location, size);
711 }
712 
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)713 port::Status StreamExecutor::Memset32(Stream *stream,
714                                       DeviceMemoryBase *location,
715                                       uint32 pattern, uint64 size) {
716   CHECK_EQ(0, size % 4)
717       << "need 32-bit multiple size to fill with 32-bit pattern";
718   return implementation_->Memset32(stream, location, pattern, size);
719 }
720 
HostCallback(Stream * stream,std::function<void ()> callback)721 bool StreamExecutor::HostCallback(Stream *stream,
722                                   std::function<void()> callback) {
723   return implementation_->HostCallback(stream, std::move(callback));
724 }
725 
HostCallback(Stream * stream,std::function<port::Status ()> callback)726 bool StreamExecutor::HostCallback(Stream *stream,
727                                   std::function<port::Status()> callback) {
728   return implementation_->HostCallback(stream, std::move(callback));
729 }
730 
AllocateEvent(Event * event)731 port::Status StreamExecutor::AllocateEvent(Event *event) {
732   return implementation_->AllocateEvent(event);
733 }
734 
DeallocateEvent(Event * event)735 port::Status StreamExecutor::DeallocateEvent(Event *event) {
736   return implementation_->DeallocateEvent(event);
737 }
738 
RecordEvent(Stream * stream,Event * event)739 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
740   return implementation_->RecordEvent(stream, event);
741 }
742 
WaitForEvent(Stream * stream,Event * event)743 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
744   return implementation_->WaitForEvent(stream, event);
745 }
746 
PollForEventStatus(Event * event)747 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
748   return implementation_->PollForEventStatus(event);
749 }
750 
AllocateStream(Stream * stream)751 bool StreamExecutor::AllocateStream(Stream *stream) {
752   live_stream_count_.fetch_add(1, std::memory_order_relaxed);
753   if (!implementation_->AllocateStream(stream)) {
754     auto count = live_stream_count_.fetch_sub(1);
755     CHECK_GE(count, 0) << "live stream count should not dip below zero";
756     LOG(INFO) << "failed to allocate stream; live stream count: " << count;
757     return false;
758   }
759 
760   return true;
761 }
762 
DeallocateStream(Stream * stream)763 void StreamExecutor::DeallocateStream(Stream *stream) {
764   implementation_->DeallocateStream(stream);
765   CHECK_GE(live_stream_count_.fetch_sub(1), 0)
766       << "live stream count should not dip below zero";
767 }
768 
CreateStreamDependency(Stream * dependent,Stream * other)769 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
770   return implementation_->CreateStreamDependency(dependent, other);
771 }
772 
AllocateTimer(Timer * timer)773 bool StreamExecutor::AllocateTimer(Timer *timer) {
774   return implementation_->AllocateTimer(timer);
775 }
776 
DeallocateTimer(Timer * timer)777 void StreamExecutor::DeallocateTimer(Timer *timer) {
778   return implementation_->DeallocateTimer(timer);
779 }
780 
StartTimer(Stream * stream,Timer * timer)781 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
782   return implementation_->StartTimer(stream, timer);
783 }
784 
StopTimer(Stream * stream,Timer * timer)785 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
786   return implementation_->StopTimer(stream, timer);
787 }
788 
CreateDeviceDescription() const789 std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
790     const {
791   auto desc_status = implementation_->CreateDeviceDescription();
792   return desc_status.ConsumeValueOrDie();
793 }
794 
DeviceMemoryUsage(int64 * free,int64 * total) const795 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
796   return implementation_->DeviceMemoryUsage(free, total);
797 }
798 
EnqueueOnBackgroundThread(std::function<void ()> task)799 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
800   background_threads_->Schedule(std::move(task));
801 }
802 
CreateAllocRecord(void * opaque,uint64 bytes)803 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
804   if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
805     absl::MutexLock lock(&mu_);
806     mem_allocs_[opaque] = AllocRecord{bytes, ""};
807     mem_alloc_bytes_ += bytes;
808   }
809 }
810 
EraseAllocRecord(void * opaque)811 void StreamExecutor::EraseAllocRecord(void *opaque) {
812   if (FLAGS_check_device_leaks && opaque != nullptr) {
813     absl::MutexLock lock(&mu_);
814     if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
815       LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
816     } else {
817       mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
818       mem_allocs_.erase(opaque);
819     }
820   }
821 }
822 
EnableTracing(bool enabled)823 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
824 
RegisterTraceListener(TraceListener * listener)825 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
826   {
827     absl::MutexLock lock(&mu_);
828     if (listeners_.find(listener) != listeners_.end()) {
829       LOG(INFO) << "Attempt to register already-registered listener, "
830                 << listener;
831     } else {
832       listeners_.insert(listener);
833     }
834   }
835 
836   implementation_->RegisterTraceListener(listener);
837 }
838 
UnregisterTraceListener(TraceListener * listener)839 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
840   {
841     absl::MutexLock lock(&mu_);
842     if (listeners_.find(listener) == listeners_.end()) {
843       LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
844       return false;
845     }
846     listeners_.erase(listener);
847   }
848 
849   implementation_->UnregisterTraceListener(listener);
850   return true;
851 }
852 
GetAllocatorStats()853 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
854   return implementation_->GetAllocatorStats();
855 }
856 
857 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)858 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
859   if (tracing_enabled_) {
860     {
861       // instance tracers held in a block to limit the lock lifetime.
862       absl::ReaderMutexLock lock(&mu_);
863       for (TraceListener *listener : listeners_) {
864         (listener->*trace_call)(std::forward<ArgsT>(args)...);
865       }
866     }
867   }
868 }
869 
implementation()870 internal::StreamExecutorInterface *StreamExecutor::implementation() {
871   return implementation_->GetUnderlyingExecutor();
872 }
873 
StreamExecutorMemoryAllocator(StreamExecutor * executor)874 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
875     StreamExecutor *executor)
876     : DeviceMemoryAllocator(executor->platform()) {
877   stream_executors_ = {executor};
878 }
879 
StreamExecutorMemoryAllocator(const Platform * platform,absl::Span<StreamExecutor * const> stream_executors)880 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
881     const Platform *platform,
882     absl::Span<StreamExecutor *const> stream_executors)
883     : DeviceMemoryAllocator(platform),
884       stream_executors_(stream_executors.begin(), stream_executors.end()) {}
885 
Allocate(int device_ordinal,uint64 size,bool retry_on_failure,int64 memory_space)886 port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
887     int device_ordinal, uint64 size, bool retry_on_failure,
888     int64 memory_space) {
889   TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
890                       GetStreamExecutor(device_ordinal));
891   DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
892   if (size > 0 && result == nullptr) {
893     return tensorflow::errors::ResourceExhausted(absl::StrFormat(
894         "Failed to allocate request for %s (%uB) on device ordinal %d",
895         tensorflow::strings::HumanReadableNumBytes(size), size,
896         device_ordinal));
897   }
898   VLOG(3) << absl::StreamFormat(
899       "Allocated %s (%uB) on device ordinal %d: %p",
900       tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
901       result.opaque());
902   return OwningDeviceMemory(result, device_ordinal, this);
903 }
904 
Deallocate(int device_ordinal,DeviceMemoryBase mem)905 port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
906                                                        DeviceMemoryBase mem) {
907   if (!mem.is_null()) {
908     TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
909                         GetStreamExecutor(device_ordinal));
910     VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
911                                   mem.opaque(), device_ordinal);
912     executor->Deallocate(&mem);
913   }
914   return port::Status::OK();
915 }
916 
917 port::StatusOr<StreamExecutor *>
GetStreamExecutor(int device_ordinal) const918 StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
919   if (device_ordinal < 0) {
920     return tensorflow::errors::InvalidArgument(absl::StrFormat(
921         "device ordinal value (%d) must be non-negative", device_ordinal));
922   }
923   for (StreamExecutor *se : stream_executors_) {
924     if (se->device_ordinal() == device_ordinal) {
925       return se;
926     }
927   }
928   return tensorflow::errors::NotFound(
929       absl::StrFormat("Device %s:%d present but not supported",
930                       platform()->Name(), device_ordinal));
931 }
932 
AllowsAsynchronousDeallocation() const933 bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
934   return false;
935 }
936 
GetStream(int device_ordinal)937 port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
938     int device_ordinal) {
939   CHECK(!AllowsAsynchronousDeallocation())
940       << "The logic below only works for synchronous allocators";
941   TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
942                       GetStreamExecutor(device_ordinal));
943   Stream *out = [&] {
944     absl::MutexLock lock(&mutex_);
945     if (!streams_.count(device_ordinal)) {
946       auto p = streams_.emplace(std::piecewise_construct,
947                                 std::forward_as_tuple(device_ordinal),
948                                 std::forward_as_tuple(executor));
949       p.first->second.Init();
950       return &p.first->second;
951     }
952     return &streams_.at(device_ordinal);
953   }();
954   return out;
955 }
956 
957 }  // namespace stream_executor
958