• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19 
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21 
22 #include <atomic>
23 #include <memory>
24 #include <utility>
25 
26 #include "absl/base/const_init.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/synchronization/notification.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/fft.h"
34 #include "tensorflow/stream_executor/lib/env.h"
35 #include "tensorflow/stream_executor/lib/error.h"
36 #include "tensorflow/stream_executor/lib/stacktrace.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/lib/threadpool.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/rng.h"
41 #include "tensorflow/stream_executor/stream.h"
42 #include "tensorflow/stream_executor/stream_executor_internal.h"
43 
44 namespace {
45 bool FLAGS_check_device_leaks = false;
46 }  // namespace
47 
48 namespace stream_executor {
49 namespace {
50 
StackTraceIfVLOG10()51 std::string StackTraceIfVLOG10() {
52   if (VLOG_IS_ON(10)) {
53     return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
54   } else {
55     return "";
56   }
57 }
58 
59 // Make sure the executor is done with its work; we know (because this isn't
60 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)61 void BlockOnThreadExecutor(port::ThreadPool *executor) {
62   absl::Notification n;
63   executor->Schedule([&n]() { n.Notify(); });
64   n.WaitForNotification();
65 }
66 
67 std::atomic_int_fast64_t correlation_id_generator(0);
68 
69 }  // namespace
70 
71 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
72           typename... BeginArgsT>
73 class ScopedTracer {
74  public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)75   ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
76                CompleteCallT complete_call, const ReturnT *result,
77                BeginArgsT... begin_args)
78       : stream_exec_(stream_exec),
79         complete_call_(complete_call),
80         result_(result) {
81     if (stream_exec_->tracing_enabled_) {
82       correlation_id_ =
83           correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
84       Trace(begin_call, begin_args...);
85     }
86   }
87 
~ScopedTracer()88   ~ScopedTracer() {
89     if (stream_exec_->tracing_enabled_) {
90       Trace(complete_call_, result_);
91     }
92   }
93 
94  private:
95   template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)96   void Trace(CallbackT callback, TraceArgsT... args) {
97     {
98       // Instance tracers held in a block to limit the lock lifetime.
99       absl::ReaderMutexLock lock{&stream_exec_->mu_};
100       for (TraceListener *listener : stream_exec_->listeners_) {
101         (listener->*callback)(correlation_id_,
102                               std::forward<TraceArgsT>(args)...);
103       }
104     }
105   }
106 
107   StreamExecutor *stream_exec_;
108   CompleteCallT complete_call_;
109   const ReturnT *result_;
110   int64 correlation_id_;
111 };
112 
113 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
114           typename... BeginArgsT>
115 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)116 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
117                  CompleteCallT complete_call, ReturnT *result,
118                  BeginArgsT... begin_args) {
119   return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
120       stream_exec, begin_call, complete_call, result,
121       std::forward<BeginArgsT>(begin_args)...);
122 }
123 
124 #define SCOPED_TRACE(LOC, ...) \
125   auto tracer =                \
126       MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);
127 
128 /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
129 
130 // Get per-device memory limit in bytes. Returns 0 if
131 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()132 static int64 GetMemoryLimitBytes() {
133   int64_t value;
134   SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
135                                               0, &value));
136   return value * (1ll << 20);
137 }
138 
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation,int device_ordinal)139 StreamExecutor::StreamExecutor(
140     const Platform *platform,
141     std::unique_ptr<internal::StreamExecutorInterface> implementation,
142     int device_ordinal)
143     : platform_(platform),
144       implementation_(std::move(implementation)),
145       device_ordinal_(device_ordinal),
146       background_threads_(new port::ThreadPool(
147           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
148       live_stream_count_(0),
149       tracing_enabled_(false),
150       mem_alloc_bytes_(0),
151       memory_limit_bytes_(GetMemoryLimitBytes()),
152       allocator_(this) {
153   std::string name = absl::AsciiStrToLower(platform_->Name());
154   if (name == "cuda") {
155     platform_kind_ = PlatformKind::kCuda;
156   } else if (name == "rocm") {
157     platform_kind_ = PlatformKind::kROCm;
158   } else if (name == "opencl") {
159     platform_kind_ = PlatformKind::kOpenCL;
160   } else if (name == "host") {
161     platform_kind_ = PlatformKind::kHost;
162   } else {
163     platform_kind_ = PlatformKind::kInvalid;
164   }
165 }
166 
~StreamExecutor()167 StreamExecutor::~StreamExecutor() {
168   BlockOnThreadExecutor(background_threads_.get());
169 
170   if (live_stream_count_.load() != 0) {
171     LOG(WARNING) << "Not all streams were deallocated at executor destruction "
172                  << "time. This may lead to unexpected/bad behavior - "
173                  << "especially if any stream is still active!";
174   }
175 
176   if (FLAGS_check_device_leaks) {
177     for (const auto &it : mem_allocs_) {
178       LOG(INFO) << "Memory alloced at executor exit: addr: "
179                 << absl::StrFormat("%p", it.first)
180                 << ", bytes: " << it.second.bytes << ", trace: \n"
181                 << it.second.stack_trace;
182     }
183   }
184 }
185 
Init(DeviceOptions device_options)186 port::Status StreamExecutor::Init(DeviceOptions device_options) {
187   return implementation_->Init(device_ordinal_, std::move(device_options));
188 }
189 
Init()190 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
191 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)192 port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
193                                        KernelBase *kernel) {
194   return implementation_->GetKernel(spec, kernel);
195 }
196 
UnloadKernel(const KernelBase * kernel)197 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
198   implementation_->UnloadKernel(kernel);
199 }
200 
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)201 port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
202                                         ModuleHandle *module_handle) {
203   return implementation_->LoadModule(spec, module_handle);
204 }
205 
UnloadModule(ModuleHandle module_handle)206 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
207   return implementation_->UnloadModule(module_handle);
208 }
209 
Deallocate(DeviceMemoryBase * mem)210 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
211   VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
212           << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
213 
214   if (mem->opaque() != nullptr) {
215     EraseAllocRecord(mem->opaque());
216   }
217   implementation_->Deallocate(mem);
218   mem->Reset(nullptr, 0);
219 }
220 
GetMemAllocs(std::map<void *,AllocRecord> * records_out)221 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
222   absl::ReaderMutexLock lock(&mu_);
223   *records_out = mem_allocs_;
224 }
225 
CanEnablePeerAccessTo(StreamExecutor * other)226 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
227   return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
228 }
229 
EnablePeerAccessTo(StreamExecutor * other)230 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
231   return implementation_->EnablePeerAccessTo(other->implementation_.get());
232 }
233 
GetDeviceDescription() const234 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
235   absl::MutexLock lock(&mu_);
236   if (device_description_ != nullptr) {
237     return *device_description_;
238   }
239 
240   device_description_ = CreateDeviceDescription();
241   return *device_description_;
242 }
243 
GetDeviceLoad() const244 int64 StreamExecutor::GetDeviceLoad() const {
245   return implementation_->GetDeviceLoad();
246 }
247 
PlatformDeviceCount() const248 int StreamExecutor::PlatformDeviceCount() const {
249   return implementation_->PlatformDeviceCount();
250 }
251 
SupportsBlas() const252 bool StreamExecutor::SupportsBlas() const {
253   return implementation_->SupportsBlas();
254 }
255 
SupportsRng() const256 bool StreamExecutor::SupportsRng() const {
257   return implementation_->SupportsRng();
258 }
259 
SupportsDnn() const260 bool StreamExecutor::SupportsDnn() const {
261   return implementation_->SupportsDnn();
262 }
263 
GetConvolveAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)264 bool StreamExecutor::GetConvolveAlgorithms(
265     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
266   dnn::DnnSupport *dnn_support = AsDnn();
267   if (!dnn_support) {
268     return false;
269   }
270   return dnn_support->GetConvolveAlgorithms(
271       GetDeviceDescription().cuda_compute_capability(), out_algorithms);
272 }
273 
GetConvolveExecutionPlans(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> * out_exec_plans)274 bool StreamExecutor::GetConvolveExecutionPlans(
275     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
276     const dnn::BatchDescriptor &input_descriptor,
277     const dnn::FilterDescriptor &filter_descriptor,
278     const dnn::BatchDescriptor &output_descriptor,
279     const dnn::ConvolutionDescriptor &convolution_descriptor,
280     std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> *out_exec_plans) {
281   dnn::DnnSupport *dnn_support = AsDnn();
282   if (!dnn_support) {
283     return false;
284   }
285   return dnn_support->GetConvolveExecutionPlans(
286       kind, element_type, stream, input_descriptor, filter_descriptor,
287       output_descriptor, convolution_descriptor, out_exec_plans);
288 }
289 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)290 bool StreamExecutor::GetMIOpenConvolveAlgorithms(
291     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
292     const dnn::BatchDescriptor &input_descriptor, DeviceMemoryBase input_data,
293     const dnn::FilterDescriptor &filter_descriptor,
294     DeviceMemoryBase filter_data, const dnn::BatchDescriptor &output_descriptor,
295     DeviceMemoryBase output_data,
296     const dnn::ConvolutionDescriptor &convolution_descriptor,
297     ScratchAllocator *scratch_allocator,
298     std::vector<dnn::ProfileResult> *out_algorithms) {
299   dnn::DnnSupport *dnn_support = AsDnn();
300   if (!dnn_support) {
301     return false;
302   }
303   return dnn_support->GetMIOpenConvolveAlgorithms(
304       kind, element_type, stream, input_descriptor, input_data,
305       filter_descriptor, filter_data, output_descriptor, output_data,
306       convolution_descriptor, scratch_allocator, out_algorithms);
307 }
308 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)309 bool StreamExecutor::GetRnnAlgorithms(
310     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
311   dnn::DnnSupport *dnn_support = AsDnn();
312   if (!dnn_support) {
313     return false;
314   }
315   return dnn_support->GetRnnAlgorithms(out_algorithms);
316 }
317 
GetConvolveBackwardDataAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)318 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
319     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
320   dnn::DnnSupport *dnn_support = AsDnn();
321   if (!dnn_support) {
322     return false;
323   }
324   return dnn_support->GetConvolveBackwardDataAlgorithms(
325       GetDeviceDescription().cuda_compute_capability(), out_algorithms);
326 }
327 
GetConvolveBackwardFilterAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)328 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
329     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
330   dnn::DnnSupport *dnn_support = AsDnn();
331   if (!dnn_support) {
332     return false;
333   }
334   return dnn_support->GetConvolveBackwardFilterAlgorithms(
335       GetDeviceDescription().cuda_compute_capability(), out_algorithms);
336 }
337 
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)338 bool StreamExecutor::GetBlasGemmAlgorithms(
339     std::vector<blas::AlgorithmType> *out_algorithms) {
340   blas::BlasSupport *blas_support = AsBlas();
341   if (!blas_support) {
342     return false;
343   }
344   return blas_support->GetBlasGemmAlgorithms(out_algorithms);
345 }
346 
347 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams & params)348 StreamExecutor::CreateBlasLtMatmulPlan(
349     const blas::BlasLtMatmulPlanParams &params) {
350   blas::BlasSupport *blas_support = AsBlas();
351   if (!blas_support) {
352     return port::Status(port::error::UNKNOWN,
353                         "Fail to find the blas implementation.");
354   }
355   return blas_support->CreateBlasLtMatmulPlan(params);
356 }
357 
358 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan * plan,size_t max_workspace_size,int max_algorithm_count)359 StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
360                                           size_t max_workspace_size,
361                                           int max_algorithm_count) {
362   blas::BlasSupport *blas_support = AsBlas();
363   if (!blas_support) {
364     return port::Status(port::error::UNKNOWN,
365                         "Fail to find the blas implementation.");
366   }
367   return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
368                                                  max_algorithm_count);
369 }
370 
371 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)372 StreamExecutor::createRnnDescriptor(
373     int num_layers, int hidden_size, int input_size, int cell_size,
374     int batch_size, dnn::RnnInputMode input_mode,
375     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
376     dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
377     float dropout, uint64 seed, ScratchAllocator *state_allocator,
378     bool use_padded_io) {
379   dnn::DnnSupport *dnn_support = AsDnn();
380   if (!dnn_support) {
381     return port::Status(port::error::UNKNOWN,
382                         "Fail to find the dnn implementation.");
383   }
384   return dnn_support->createRnnDescriptor(
385       num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
386       direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
387       state_allocator, use_padded_io);
388 }
389 
390 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)391 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
392                                                   int batch_size, int data_size,
393                                                   dnn::DataType data_type) {
394   dnn::DnnSupport *dnn_support = AsDnn();
395   if (!dnn_support) {
396     return port::Status(port::error::UNKNOWN,
397                         "Fail to find the dnn implementation.");
398   }
399   return dnn_support->createRnnSequenceTensorDescriptor(
400       max_seq_length, batch_size, data_size, data_type);
401 }
402 
403 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)404 StreamExecutor::createRnnSequenceTensorDescriptor(
405     int max_seq_length, int batch_size, int data_size,
406     const absl::Span<const int> &seq_lengths, bool time_major,
407     dnn::DataType data_type) {
408   dnn::DnnSupport *dnn_support = AsDnn();
409   if (!dnn_support) {
410     return port::Status(port::error::UNKNOWN,
411                         "Fail to find the dnn implementation.");
412   }
413   return dnn_support->createRnnSequenceTensorDescriptor(
414       max_seq_length, batch_size, data_size, seq_lengths, time_major,
415       data_type);
416 }
417 
418 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)419 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
420                                                int data_size,
421                                                dnn::DataType data_type) {
422   dnn::DnnSupport *dnn_support = AsDnn();
423   if (!dnn_support) {
424     return port::Status(port::error::UNKNOWN,
425                         "Fail to find the dnn implementation.");
426   }
427   return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
428                                                      data_size, data_type);
429 }
430 
AsDnn()431 dnn::DnnSupport *StreamExecutor::AsDnn() {
432   absl::MutexLock lock(&mu_);
433   if (dnn_ != nullptr) {
434     return dnn_.get();
435   }
436 
437   dnn_.reset(implementation_->CreateDnn());
438   return dnn_.get();
439 }
440 
AsBlas()441 blas::BlasSupport *StreamExecutor::AsBlas() {
442   absl::MutexLock lock(&mu_);
443   if (blas_ != nullptr) {
444     return blas_.get();
445   }
446 
447   blas_.reset(implementation_->CreateBlas());
448   return blas_.get();
449 }
450 
AsFft()451 fft::FftSupport *StreamExecutor::AsFft() {
452   absl::MutexLock lock(&mu_);
453   if (fft_ != nullptr) {
454     return fft_.get();
455   }
456 
457   fft_.reset(implementation_->CreateFft());
458   return fft_.get();
459 }
460 
AsRng()461 rng::RngSupport *StreamExecutor::AsRng() {
462   absl::MutexLock lock(&mu_);
463   if (rng_ != nullptr) {
464     return rng_.get();
465   }
466 
467   rng_.reset(implementation_->CreateRng());
468   return rng_.get();
469 }
470 
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)471 port::Status StreamExecutor::Launch(Stream *stream,
472                                     const ThreadDim &thread_dims,
473                                     const BlockDim &block_dims,
474                                     const KernelBase &kernel,
475                                     const KernelArgsArrayBase &args) {
476   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
477               kernel, args);
478 
479   return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
480 }
481 
BlockHostUntilDone(Stream * stream)482 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
483   port::Status result;
484   SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
485 
486   result = implementation_->BlockHostUntilDone(stream);
487   return result;
488 }
489 
GetStatus(Stream * stream)490 port::Status StreamExecutor::GetStatus(Stream *stream) {
491   return implementation_->GetStatus(stream);
492 }
493 
Allocate(uint64 size,int64_t memory_space)494 DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64_t memory_space) {
495   if (memory_limit_bytes_ > 0 &&
496       static_cast<int64>(mem_alloc_bytes_ + size) > memory_limit_bytes_) {
497     LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
498                  << device_ordinal_
499                  << " within provided limit. [used=" << mem_alloc_bytes_
500                  << ", limit=" << memory_limit_bytes_ << "]";
501     return DeviceMemoryBase();
502   }
503   DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
504   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
505           << ", memory_space=" << memory_space << ") returns " << buf.opaque()
506           << StackTraceIfVLOG10();
507   CreateAllocRecord(buf.opaque(), size);
508 
509   return buf;
510 }
511 
GetUntypedSymbol(const std::string & symbol_name,ModuleHandle module_handle)512 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
513     const std::string &symbol_name, ModuleHandle module_handle) {
514   // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
515   // be nullptr/0 for consistency with DeviceMemory semantics.
516   void *opaque = nullptr;
517   size_t bytes = 0;
518   if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
519     return DeviceMemoryBase(opaque, bytes);
520   }
521 
522   if (static_cast<bool>(module_handle)) {
523     return port::Status(
524         port::error::NOT_FOUND,
525         absl::StrCat("Check if module containing symbol ", symbol_name,
526                      " is loaded (module_handle = ",
527                      reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
528   } else {
529     return port::Status(
530         port::error::NOT_FOUND,
531         absl::StrCat("Check if kernel using the symbol is loaded: ",
532                      symbol_name));
533   }
534 }
535 
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)536 bool StreamExecutor::GetSymbol(const std::string &symbol_name,
537                                ModuleHandle module_handle, void **mem,
538                                size_t *bytes) {
539   return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
540 }
541 
UnifiedMemoryAllocate(uint64 bytes)542 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
543   void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
544   VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
545           << ") returns " << buffer << StackTraceIfVLOG10();
546   return buffer;
547 }
548 
UnifiedMemoryDeallocate(void * location)549 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
550   VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
551           << location << ")" << StackTraceIfVLOG10();
552 
553   return implementation_->UnifiedMemoryDeallocate(location);
554 }
555 
HostMemoryAllocate(uint64 size)556 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
557   void *buffer = implementation_->HostMemoryAllocate(size);
558   VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
559           << ") returns " << buffer << StackTraceIfVLOG10();
560   return buffer;
561 }
562 
HostMemoryDeallocate(void * location)563 void StreamExecutor::HostMemoryDeallocate(void *location) {
564   VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
565           << ")" << StackTraceIfVLOG10();
566 
567   return implementation_->HostMemoryDeallocate(location);
568 }
569 
HostMemoryRegister(void * location,uint64 size)570 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
571   VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
572           << ", size=" << size << ")" << StackTraceIfVLOG10();
573   if (location == nullptr || size == 0) {
574     LOG(WARNING) << "attempting to register null or zero-sized memory: "
575                  << location << "; size " << size;
576   }
577   return implementation_->HostMemoryRegister(location, size);
578 }
579 
HostMemoryUnregister(void * location)580 bool StreamExecutor::HostMemoryUnregister(void *location) {
581   VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
582           << ")" << StackTraceIfVLOG10();
583   return implementation_->HostMemoryUnregister(location);
584 }
585 
SynchronizeAllActivity()586 bool StreamExecutor::SynchronizeAllActivity() {
587   VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
588           << StackTraceIfVLOG10();
589   bool ok = implementation_->SynchronizeAllActivity();
590 
591   // This should all be quick and infallible work, so we can perform the
592   // synchronization even in the case of failure.
593   BlockOnThreadExecutor(background_threads_.get());
594 
595   return ok;
596 }
597 
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)598 port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
599                                                 uint64 size) {
600   VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
601           << ", size=" << size << ")" << StackTraceIfVLOG10();
602 
603   return implementation_->SynchronousMemZero(location, size);
604 }
605 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)606 port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location,
607                                                int value, uint64 size) {
608   VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
609           << ", value=" << value << ", size=" << size << ")"
610           << StackTraceIfVLOG10();
611 
612   return implementation_->SynchronousMemSet(location, value, size);
613 }
614 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)615 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
616                                        const void *host_src, uint64 size) {
617   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
618           << device_dst->opaque() << ", host_src=" << host_src
619           << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
620 
621   // Tracing overloaded methods is very difficult due to issues with type
622   // inference on template args. Since use of these overloaded methods is
623   // discouraged anyway, this isn't a huge deal.
624   port::Status status =
625       implementation_->SynchronousMemcpy(device_dst, host_src, size);
626   if (!status.ok()) {
627     LOG(ERROR) << "synchronous memcpy: " << status;
628   }
629   return status.ok();
630 }
631 
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)632 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
633                                        const DeviceMemoryBase &device_src,
634                                        uint64 size) {
635   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
636           << ", device_src=" << device_src.opaque() << ", size=" << size
637           << ") D2H" << StackTraceIfVLOG10();
638 
639   port::Status status =
640       implementation_->SynchronousMemcpy(host_dst, device_src, size);
641   if (!status.ok()) {
642     LOG(ERROR) << "synchronous memcpy: " << status;
643   }
644   return status.ok();
645 }
646 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)647 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
648                                        const DeviceMemoryBase &device_src,
649                                        uint64 size) {
650   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
651           << device_dst->opaque() << ", device_src=" << device_src.opaque()
652           << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
653 
654   port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
655       device_dst, device_src, size);
656   if (!status.ok()) {
657     LOG(ERROR) << "synchronous memcpy: " << status;
658   }
659   return status.ok();
660 }
661 
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64_t size,void * host_dst)662 port::Status StreamExecutor::SynchronousMemcpyD2H(
663     const DeviceMemoryBase &device_src, int64_t size, void *host_dst) {
664   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
665           << device_src.opaque() << ", size=" << size
666           << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
667 
668   port::Status result;
669   SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
670                host_dst);
671 
672   result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
673   if (!result.ok()) {
674     result = port::Status(
675         port::error::INTERNAL,
676         absl::StrFormat("failed to synchronously memcpy device-to-host: device "
677                         "%p to host %p size %d: %s",
678                         device_src.opaque(), host_dst, size,
679                         result.ToString()));
680   }
681 
682   return result;
683 }
684 
SynchronousMemcpyH2D(const void * host_src,int64_t size,DeviceMemoryBase * device_dst)685 port::Status StreamExecutor::SynchronousMemcpyH2D(
686     const void *host_src, int64_t size, DeviceMemoryBase *device_dst) {
687   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
688           << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
689           << StackTraceIfVLOG10();
690 
691   port::Status result;
692   SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
693                device_dst);
694 
695   result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
696   if (!result.ok()) {
697     result = port::Status(
698         port::error::INTERNAL,
699         absl::StrFormat("failed to synchronously memcpy host-to-device: host "
700                         "%p to device %p size %d: %s",
701                         host_src, device_dst->opaque(), size,
702                         result.ToString()));
703   }
704 
705   return result;
706 }
707 
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)708 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
709                             const DeviceMemoryBase &device_src, uint64 size) {
710   return implementation_->Memcpy(stream, host_dst, device_src, size);
711 }
712 
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)713 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
714                             const void *host_src, uint64 size) {
715   return implementation_->Memcpy(stream, device_dst, host_src, size);
716 }
717 
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)718 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
719                                           DeviceMemoryBase *device_dst,
720                                           const DeviceMemoryBase &device_src,
721                                           uint64 size) {
722   return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
723                                                size);
724 }
725 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)726 port::Status StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
727                                      uint64 size) {
728   return implementation_->MemZero(stream, location, size);
729 }
730 
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)731 port::Status StreamExecutor::Memset32(Stream *stream,
732                                       DeviceMemoryBase *location,
733                                       uint32 pattern, uint64 size) {
734   CHECK_EQ(0, size % 4)
735       << "need 32-bit multiple size to fill with 32-bit pattern";
736   return implementation_->Memset32(stream, location, pattern, size);
737 }
738 
HostCallback(Stream * stream,std::function<void ()> callback)739 bool StreamExecutor::HostCallback(Stream *stream,
740                                   std::function<void()> callback) {
741   return implementation_->HostCallback(stream, std::move(callback));
742 }
743 
HostCallback(Stream * stream,std::function<port::Status ()> callback)744 bool StreamExecutor::HostCallback(Stream *stream,
745                                   std::function<port::Status()> callback) {
746   return implementation_->HostCallback(stream, std::move(callback));
747 }
748 
AllocateEvent(Event * event)749 port::Status StreamExecutor::AllocateEvent(Event *event) {
750   return implementation_->AllocateEvent(event);
751 }
752 
DeallocateEvent(Event * event)753 port::Status StreamExecutor::DeallocateEvent(Event *event) {
754   return implementation_->DeallocateEvent(event);
755 }
756 
RecordEvent(Stream * stream,Event * event)757 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
758   return implementation_->RecordEvent(stream, event);
759 }
760 
WaitForEvent(Stream * stream,Event * event)761 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
762   return implementation_->WaitForEvent(stream, event);
763 }
764 
PollForEventStatus(Event * event)765 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
766   return implementation_->PollForEventStatus(event);
767 }
768 
AllocateStream(Stream * stream)769 bool StreamExecutor::AllocateStream(Stream *stream) {
770   live_stream_count_.fetch_add(1, std::memory_order_relaxed);
771   if (!implementation_->AllocateStream(stream)) {
772     auto count = live_stream_count_.fetch_sub(1);
773     CHECK_GE(count, 0) << "live stream count should not dip below zero";
774     LOG(INFO) << "failed to allocate stream; live stream count: " << count;
775     return false;
776   }
777 
778   return true;
779 }
780 
DeallocateStream(Stream * stream)781 void StreamExecutor::DeallocateStream(Stream *stream) {
782   implementation_->DeallocateStream(stream);
783   CHECK_GE(live_stream_count_.fetch_sub(1), 0)
784       << "live stream count should not dip below zero";
785 }
786 
CreateStreamDependency(Stream * dependent,Stream * other)787 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
788   return implementation_->CreateStreamDependency(dependent, other);
789 }
790 
AllocateTimer(Timer * timer)791 bool StreamExecutor::AllocateTimer(Timer *timer) {
792   return implementation_->AllocateTimer(timer);
793 }
794 
DeallocateTimer(Timer * timer)795 void StreamExecutor::DeallocateTimer(Timer *timer) {
796   return implementation_->DeallocateTimer(timer);
797 }
798 
StartTimer(Stream * stream,Timer * timer)799 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
800   return implementation_->StartTimer(stream, timer);
801 }
802 
StopTimer(Stream * stream,Timer * timer)803 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
804   return implementation_->StopTimer(stream, timer);
805 }
806 
CreateDeviceDescription() const807 std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
808     const {
809   auto desc_status = implementation_->CreateDeviceDescription();
810   return desc_status.ConsumeValueOrDie();
811 }
812 
DeviceMemoryUsage(int64 * free,int64 * total) const813 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
814   return implementation_->DeviceMemoryUsage(free, total);
815 }
816 
EnqueueOnBackgroundThread(std::function<void ()> task)817 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
818   background_threads_->Schedule(std::move(task));
819 }
820 
CreateAllocRecord(void * opaque,uint64 bytes)821 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
822   if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
823     absl::MutexLock lock(&mu_);
824     mem_allocs_[opaque] = AllocRecord{bytes, ""};
825     mem_alloc_bytes_ += bytes;
826   }
827 }
828 
EraseAllocRecord(void * opaque)829 void StreamExecutor::EraseAllocRecord(void *opaque) {
830   if (FLAGS_check_device_leaks && opaque != nullptr) {
831     absl::MutexLock lock(&mu_);
832     if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
833       LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
834     } else {
835       mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
836       mem_allocs_.erase(opaque);
837     }
838   }
839 }
840 
EnableTracing(bool enabled)841 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
842 
RegisterTraceListener(TraceListener * listener)843 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
844   {
845     absl::MutexLock lock(&mu_);
846     if (listeners_.find(listener) != listeners_.end()) {
847       LOG(INFO) << "Attempt to register already-registered listener, "
848                 << listener;
849     } else {
850       listeners_.insert(listener);
851     }
852   }
853 
854   implementation_->RegisterTraceListener(listener);
855 }
856 
UnregisterTraceListener(TraceListener * listener)857 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
858   {
859     absl::MutexLock lock(&mu_);
860     if (listeners_.find(listener) == listeners_.end()) {
861       LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
862       return false;
863     }
864     listeners_.erase(listener);
865   }
866 
867   implementation_->UnregisterTraceListener(listener);
868   return true;
869 }
870 
GetAllocatorStats()871 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
872   return implementation_->GetAllocatorStats();
873 }
874 
ClearAllocatorStats()875 bool StreamExecutor::ClearAllocatorStats() {
876   return implementation_->ClearAllocatorStats();
877 }
878 
879 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)880 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&...args) {
881   if (tracing_enabled_) {
882     {
883       // instance tracers held in a block to limit the lock lifetime.
884       absl::ReaderMutexLock lock(&mu_);
885       for (TraceListener *listener : listeners_) {
886         (listener->*trace_call)(std::forward<ArgsT>(args)...);
887       }
888     }
889   }
890 }
891 
implementation()892 internal::StreamExecutorInterface *StreamExecutor::implementation() {
893   return implementation_->GetUnderlyingExecutor();
894 }
895 
StreamExecutorMemoryAllocator(StreamExecutor * executor)896 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
897     StreamExecutor *executor)
898     : DeviceMemoryAllocator(executor->platform()) {
899   stream_executors_ = {executor};
900 }
901 
StreamExecutorMemoryAllocator(const Platform * platform,absl::Span<StreamExecutor * const> stream_executors)902 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
903     const Platform *platform,
904     absl::Span<StreamExecutor *const> stream_executors)
905     : DeviceMemoryAllocator(platform),
906       stream_executors_(stream_executors.begin(), stream_executors.end()) {}
907 
Allocate(int device_ordinal,uint64 size,bool retry_on_failure,int64_t memory_space)908 port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
909     int device_ordinal, uint64 size, bool retry_on_failure,
910     int64_t memory_space) {
911   TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
912                       GetStreamExecutor(device_ordinal));
913   DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
914   if (size > 0 && result == nullptr) {
915     return tensorflow::errors::ResourceExhausted(absl::StrFormat(
916         "Failed to allocate request for %s (%uB) on device ordinal %d",
917         tensorflow::strings::HumanReadableNumBytes(size), size,
918         device_ordinal));
919   }
920   VLOG(3) << absl::StreamFormat(
921       "Allocated %s (%uB) on device ordinal %d: %p",
922       tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
923       result.opaque());
924   return OwningDeviceMemory(result, device_ordinal, this);
925 }
926 
Deallocate(int device_ordinal,DeviceMemoryBase mem)927 port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
928                                                        DeviceMemoryBase mem) {
929   if (!mem.is_null()) {
930     TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
931                         GetStreamExecutor(device_ordinal));
932     VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
933                                   mem.opaque(), device_ordinal);
934     executor->Deallocate(&mem);
935   }
936   return port::Status::OK();
937 }
938 
939 port::StatusOr<StreamExecutor *>
GetStreamExecutor(int device_ordinal) const940 StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
941   if (device_ordinal < 0) {
942     return tensorflow::errors::InvalidArgument(absl::StrFormat(
943         "device ordinal value (%d) must be non-negative", device_ordinal));
944   }
945   for (StreamExecutor *se : stream_executors_) {
946     if (se->device_ordinal() == device_ordinal) {
947       return se;
948     }
949   }
950   return tensorflow::errors::NotFound(
951       absl::StrFormat("Device %s:%d present but not supported",
952                       platform()->Name(), device_ordinal));
953 }
954 
AllowsAsynchronousDeallocation() const955 bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
956   return false;
957 }
958 
GetStream(int device_ordinal)959 port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
960     int device_ordinal) {
961   CHECK(!AllowsAsynchronousDeallocation())
962       << "The logic below only works for synchronous allocators";
963   TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
964                       GetStreamExecutor(device_ordinal));
965   Stream *out = [&] {
966     absl::MutexLock lock(&mutex_);
967     if (!streams_.count(device_ordinal)) {
968       auto p = streams_.emplace(std::piecewise_construct,
969                                 std::forward_as_tuple(device_ordinal),
970                                 std::forward_as_tuple(executor));
971       p.first->second.Init();
972       return &p.first->second;
973     }
974     return &streams_.at(device_ordinal);
975   }();
976   return out;
977 }
978 
979 }  // namespace stream_executor
980