1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21
22 #include <atomic>
23 #include <memory>
24 #include <utility>
25
26 #include "absl/base/const_init.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/synchronization/notification.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/fft.h"
34 #include "tensorflow/stream_executor/lib/env.h"
35 #include "tensorflow/stream_executor/lib/error.h"
36 #include "tensorflow/stream_executor/lib/stacktrace.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/lib/threadpool.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/rng.h"
41 #include "tensorflow/stream_executor/stream.h"
42 #include "tensorflow/stream_executor/stream_executor_internal.h"
43
44 namespace {
45 bool FLAGS_check_device_leaks = false;
46 } // namespace
47
48 namespace stream_executor {
49 namespace {
50
StackTraceIfVLOG10()51 std::string StackTraceIfVLOG10() {
52 if (VLOG_IS_ON(10)) {
53 return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
54 } else {
55 return "";
56 }
57 }
58
59 // Make sure the executor is done with its work; we know (because this isn't
60 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)61 void BlockOnThreadExecutor(port::ThreadPool *executor) {
62 absl::Notification n;
63 executor->Schedule([&n]() { n.Notify(); });
64 n.WaitForNotification();
65 }
66
67 std::atomic_int_fast64_t correlation_id_generator(0);
68
69 } // namespace
70
71 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
72 typename... BeginArgsT>
73 class ScopedTracer {
74 public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)75 ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
76 CompleteCallT complete_call, const ReturnT *result,
77 BeginArgsT... begin_args)
78 : stream_exec_(stream_exec),
79 complete_call_(complete_call),
80 result_(result) {
81 if (stream_exec_->tracing_enabled_) {
82 correlation_id_ =
83 correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
84 Trace(begin_call, begin_args...);
85 }
86 }
87
~ScopedTracer()88 ~ScopedTracer() {
89 if (stream_exec_->tracing_enabled_) {
90 Trace(complete_call_, result_);
91 }
92 }
93
94 private:
95 template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)96 void Trace(CallbackT callback, TraceArgsT... args) {
97 {
98 // Instance tracers held in a block to limit the lock lifetime.
99 absl::ReaderMutexLock lock{&stream_exec_->mu_};
100 for (TraceListener *listener : stream_exec_->listeners_) {
101 (listener->*callback)(correlation_id_,
102 std::forward<TraceArgsT>(args)...);
103 }
104 }
105 }
106
107 StreamExecutor *stream_exec_;
108 CompleteCallT complete_call_;
109 const ReturnT *result_;
110 int64 correlation_id_;
111 };
112
113 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
114 typename... BeginArgsT>
115 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)116 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
117 CompleteCallT complete_call, ReturnT *result,
118 BeginArgsT... begin_args) {
119 return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
120 stream_exec, begin_call, complete_call, result,
121 std::forward<BeginArgsT>(begin_args)...);
122 }
123
124 #define SCOPED_TRACE(LOC, ...) \
125 auto tracer = \
126 MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);
127
128 /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
129
130 // Get per-device memory limit in bytes. Returns 0 if
131 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()132 static int64 GetMemoryLimitBytes() {
133 int64 value;
134 SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
135 0, &value));
136 return value * (1ll << 20);
137 }
138
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation,int device_ordinal)139 StreamExecutor::StreamExecutor(
140 const Platform *platform,
141 std::unique_ptr<internal::StreamExecutorInterface> implementation,
142 int device_ordinal)
143 : platform_(platform),
144 implementation_(std::move(implementation)),
145 device_ordinal_(device_ordinal),
146 background_threads_(new port::ThreadPool(
147 port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
148 live_stream_count_(0),
149 tracing_enabled_(false),
150 mem_alloc_bytes_(0),
151 memory_limit_bytes_(GetMemoryLimitBytes()),
152 allocator_(this) {
153 std::string name = absl::AsciiStrToLower(platform_->Name());
154 if (name == "cuda") {
155 platform_kind_ = PlatformKind::kCuda;
156 } else if (name == "rocm") {
157 platform_kind_ = PlatformKind::kROCm;
158 } else if (name == "opencl") {
159 platform_kind_ = PlatformKind::kOpenCL;
160 } else if (name == "host") {
161 platform_kind_ = PlatformKind::kHost;
162 } else {
163 platform_kind_ = PlatformKind::kInvalid;
164 }
165 }
166
~StreamExecutor()167 StreamExecutor::~StreamExecutor() {
168 BlockOnThreadExecutor(background_threads_.get());
169
170 if (live_stream_count_.load() != 0) {
171 LOG(WARNING) << "Not all streams were deallocated at executor destruction "
172 << "time. This may lead to unexpected/bad behavior - "
173 << "especially if any stream is still active!";
174 }
175
176 if (FLAGS_check_device_leaks) {
177 for (const auto &it : mem_allocs_) {
178 LOG(INFO) << "Memory alloced at executor exit: addr: "
179 << absl::StrFormat("%p", it.first)
180 << ", bytes: " << it.second.bytes << ", trace: \n"
181 << it.second.stack_trace;
182 }
183 }
184 }
185
Init(DeviceOptions device_options)186 port::Status StreamExecutor::Init(DeviceOptions device_options) {
187 return implementation_->Init(device_ordinal_, std::move(device_options));
188 }
189
Init()190 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
191
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)192 port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
193 KernelBase *kernel) {
194 return implementation_->GetKernel(spec, kernel);
195 }
196
UnloadKernel(const KernelBase * kernel)197 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
198 implementation_->UnloadKernel(kernel);
199 }
200
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)201 port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
202 ModuleHandle *module_handle) {
203 return implementation_->LoadModule(spec, module_handle);
204 }
205
UnloadModule(ModuleHandle module_handle)206 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
207 return implementation_->UnloadModule(module_handle);
208 }
209
Deallocate(DeviceMemoryBase * mem)210 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
211 VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
212 << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
213
214 if (mem->opaque() != nullptr) {
215 EraseAllocRecord(mem->opaque());
216 }
217 implementation_->Deallocate(mem);
218 mem->Reset(nullptr, 0);
219 }
220
GetMemAllocs(std::map<void *,AllocRecord> * records_out)221 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
222 absl::ReaderMutexLock lock(&mu_);
223 *records_out = mem_allocs_;
224 }
225
CanEnablePeerAccessTo(StreamExecutor * other)226 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
227 return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
228 }
229
EnablePeerAccessTo(StreamExecutor * other)230 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
231 return implementation_->EnablePeerAccessTo(other->implementation_.get());
232 }
233
GetDeviceDescription() const234 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
235 absl::MutexLock lock(&mu_);
236 if (device_description_ != nullptr) {
237 return *device_description_;
238 }
239
240 device_description_ = CreateDeviceDescription();
241 return *device_description_;
242 }
243
GetDeviceLoad() const244 int64 StreamExecutor::GetDeviceLoad() const {
245 return implementation_->GetDeviceLoad();
246 }
247
PlatformDeviceCount() const248 int StreamExecutor::PlatformDeviceCount() const {
249 return implementation_->PlatformDeviceCount();
250 }
251
SupportsBlas() const252 bool StreamExecutor::SupportsBlas() const {
253 return implementation_->SupportsBlas();
254 }
255
SupportsRng() const256 bool StreamExecutor::SupportsRng() const {
257 return implementation_->SupportsRng();
258 }
259
SupportsDnn() const260 bool StreamExecutor::SupportsDnn() const {
261 return implementation_->SupportsDnn();
262 }
263
GetConvolveAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)264 bool StreamExecutor::GetConvolveAlgorithms(
265 bool with_winograd_nonfused,
266 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
267 dnn::DnnSupport *dnn_support = AsDnn();
268 if (!dnn_support) {
269 return false;
270 }
271 int cc_major, cc_minor;
272 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
273 return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
274 cc_minor, out_algorithms);
275 }
276
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)277 bool StreamExecutor::GetMIOpenConvolveAlgorithms(
278 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
279 const dnn::BatchDescriptor &input_descriptor, DeviceMemoryBase input_data,
280 const dnn::FilterDescriptor &filter_descriptor,
281 DeviceMemoryBase filter_data, const dnn::BatchDescriptor &output_descriptor,
282 DeviceMemoryBase output_data,
283 const dnn::ConvolutionDescriptor &convolution_descriptor,
284 ScratchAllocator *scratch_allocator,
285 std::vector<dnn::ProfileResult> *out_algorithms) {
286 dnn::DnnSupport *dnn_support = AsDnn();
287 if (!dnn_support) {
288 return false;
289 }
290 return dnn_support->GetMIOpenConvolveAlgorithms(
291 kind, element_type, stream, input_descriptor, input_data,
292 filter_descriptor, filter_data, output_descriptor, output_data,
293 convolution_descriptor, scratch_allocator, out_algorithms);
294 }
295
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)296 bool StreamExecutor::GetRnnAlgorithms(
297 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
298 dnn::DnnSupport *dnn_support = AsDnn();
299 if (!dnn_support) {
300 return false;
301 }
302 return dnn_support->GetRnnAlgorithms(out_algorithms);
303 }
304
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)305 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
306 bool with_winograd_nonfused,
307 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
308 dnn::DnnSupport *dnn_support = AsDnn();
309 if (!dnn_support) {
310 return false;
311 }
312 int cc_major, cc_minor;
313 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
314 return dnn_support->GetConvolveBackwardDataAlgorithms(
315 with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
316 }
317
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)318 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
319 bool with_winograd_nonfused,
320 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
321 dnn::DnnSupport *dnn_support = AsDnn();
322 if (!dnn_support) {
323 return false;
324 }
325 int cc_major, cc_minor;
326 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
327 return dnn_support->GetConvolveBackwardFilterAlgorithms(
328 with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
329 }
330
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)331 bool StreamExecutor::GetBlasGemmAlgorithms(
332 std::vector<blas::AlgorithmType> *out_algorithms) {
333 blas::BlasSupport *blas_support = AsBlas();
334 if (!blas_support) {
335 return false;
336 }
337 return blas_support->GetBlasGemmAlgorithms(out_algorithms);
338 }
339
340 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams & params)341 StreamExecutor::CreateBlasLtMatmulPlan(
342 const blas::BlasLtMatmulPlanParams ¶ms) {
343 blas::BlasSupport *blas_support = AsBlas();
344 if (!blas_support) {
345 return port::Status(port::error::UNKNOWN,
346 "Fail to find the blas implementation.");
347 }
348 return blas_support->CreateBlasLtMatmulPlan(params);
349 }
350
351 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan * plan,size_t max_workspace_size,int max_algorithm_count)352 StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
353 size_t max_workspace_size,
354 int max_algorithm_count) {
355 blas::BlasSupport *blas_support = AsBlas();
356 if (!blas_support) {
357 return port::Status(port::error::UNKNOWN,
358 "Fail to find the blas implementation.");
359 }
360 return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
361 max_algorithm_count);
362 }
363
364 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)365 StreamExecutor::createRnnDescriptor(
366 int num_layers, int hidden_size, int input_size, int cell_size,
367 int batch_size, dnn::RnnInputMode input_mode,
368 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
369 dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
370 float dropout, uint64 seed, ScratchAllocator *state_allocator,
371 bool use_padded_io) {
372 dnn::DnnSupport *dnn_support = AsDnn();
373 if (!dnn_support) {
374 return port::Status(port::error::UNKNOWN,
375 "Fail to find the dnn implementation.");
376 }
377 return dnn_support->createRnnDescriptor(
378 num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
379 direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
380 state_allocator, use_padded_io);
381 }
382
383 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)384 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
385 int batch_size, int data_size,
386 dnn::DataType data_type) {
387 dnn::DnnSupport *dnn_support = AsDnn();
388 if (!dnn_support) {
389 return port::Status(port::error::UNKNOWN,
390 "Fail to find the dnn implementation.");
391 }
392 return dnn_support->createRnnSequenceTensorDescriptor(
393 max_seq_length, batch_size, data_size, data_type);
394 }
395
396 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)397 StreamExecutor::createRnnSequenceTensorDescriptor(
398 int max_seq_length, int batch_size, int data_size,
399 const absl::Span<const int> &seq_lengths, bool time_major,
400 dnn::DataType data_type) {
401 dnn::DnnSupport *dnn_support = AsDnn();
402 if (!dnn_support) {
403 return port::Status(port::error::UNKNOWN,
404 "Fail to find the dnn implementation.");
405 }
406 return dnn_support->createRnnSequenceTensorDescriptor(
407 max_seq_length, batch_size, data_size, seq_lengths, time_major,
408 data_type);
409 }
410
411 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)412 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
413 int data_size,
414 dnn::DataType data_type) {
415 dnn::DnnSupport *dnn_support = AsDnn();
416 if (!dnn_support) {
417 return port::Status(port::error::UNKNOWN,
418 "Fail to find the dnn implementation.");
419 }
420 return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
421 data_size, data_type);
422 }
423
AsDnn()424 dnn::DnnSupport *StreamExecutor::AsDnn() {
425 absl::MutexLock lock(&mu_);
426 if (dnn_ != nullptr) {
427 return dnn_.get();
428 }
429
430 dnn_.reset(implementation_->CreateDnn());
431 return dnn_.get();
432 }
433
AsBlas()434 blas::BlasSupport *StreamExecutor::AsBlas() {
435 absl::MutexLock lock(&mu_);
436 if (blas_ != nullptr) {
437 return blas_.get();
438 }
439
440 blas_.reset(implementation_->CreateBlas());
441 return blas_.get();
442 }
443
AsFft()444 fft::FftSupport *StreamExecutor::AsFft() {
445 absl::MutexLock lock(&mu_);
446 if (fft_ != nullptr) {
447 return fft_.get();
448 }
449
450 fft_.reset(implementation_->CreateFft());
451 return fft_.get();
452 }
453
AsRng()454 rng::RngSupport *StreamExecutor::AsRng() {
455 absl::MutexLock lock(&mu_);
456 if (rng_ != nullptr) {
457 return rng_.get();
458 }
459
460 rng_.reset(implementation_->CreateRng());
461 return rng_.get();
462 }
463
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)464 port::Status StreamExecutor::Launch(Stream *stream,
465 const ThreadDim &thread_dims,
466 const BlockDim &block_dims,
467 const KernelBase &kernel,
468 const KernelArgsArrayBase &args) {
469 SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
470 kernel, args);
471
472 return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
473 }
474
BlockHostUntilDone(Stream * stream)475 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
476 port::Status result;
477 SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
478
479 result = implementation_->BlockHostUntilDone(stream);
480 return result;
481 }
482
GetStatus(Stream * stream)483 port::Status StreamExecutor::GetStatus(Stream *stream) {
484 return implementation_->GetStatus(stream);
485 }
486
Allocate(uint64 size,int64 memory_space)487 DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64 memory_space) {
488 if (memory_limit_bytes_ > 0 &&
489 static_cast<int64>(mem_alloc_bytes_ + size) > memory_limit_bytes_) {
490 LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
491 << device_ordinal_
492 << " within provided limit. [used=" << mem_alloc_bytes_
493 << ", limit=" << memory_limit_bytes_ << "]";
494 return DeviceMemoryBase();
495 }
496 DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
497 VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
498 << ", memory_space=" << memory_space << ") returns " << buf.opaque()
499 << StackTraceIfVLOG10();
500 CreateAllocRecord(buf.opaque(), size);
501
502 return buf;
503 }
504
GetUntypedSymbol(const std::string & symbol_name,ModuleHandle module_handle)505 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
506 const std::string &symbol_name, ModuleHandle module_handle) {
507 // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
508 // be nullptr/0 for consistency with DeviceMemory semantics.
509 void *opaque = nullptr;
510 size_t bytes = 0;
511 if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
512 return DeviceMemoryBase(opaque, bytes);
513 }
514
515 if (static_cast<bool>(module_handle)) {
516 return port::Status(
517 port::error::NOT_FOUND,
518 absl::StrCat("Check if module containing symbol ", symbol_name,
519 " is loaded (module_handle = ",
520 reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
521 } else {
522 return port::Status(
523 port::error::NOT_FOUND,
524 absl::StrCat("Check if kernel using the symbol is loaded: ",
525 symbol_name));
526 }
527 }
528
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)529 bool StreamExecutor::GetSymbol(const std::string &symbol_name,
530 ModuleHandle module_handle, void **mem,
531 size_t *bytes) {
532 return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
533 }
534
UnifiedMemoryAllocate(uint64 bytes)535 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
536 void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
537 VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
538 << ") returns " << buffer << StackTraceIfVLOG10();
539 return buffer;
540 }
541
UnifiedMemoryDeallocate(void * location)542 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
543 VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
544 << location << ")" << StackTraceIfVLOG10();
545
546 return implementation_->UnifiedMemoryDeallocate(location);
547 }
548
HostMemoryAllocate(uint64 size)549 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
550 void *buffer = implementation_->HostMemoryAllocate(size);
551 VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
552 << ") returns " << buffer << StackTraceIfVLOG10();
553 return buffer;
554 }
555
HostMemoryDeallocate(void * location)556 void StreamExecutor::HostMemoryDeallocate(void *location) {
557 VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
558 << ")" << StackTraceIfVLOG10();
559
560 return implementation_->HostMemoryDeallocate(location);
561 }
562
HostMemoryRegister(void * location,uint64 size)563 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
564 VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
565 << ", size=" << size << ")" << StackTraceIfVLOG10();
566 if (location == nullptr || size == 0) {
567 LOG(WARNING) << "attempting to register null or zero-sized memory: "
568 << location << "; size " << size;
569 }
570 return implementation_->HostMemoryRegister(location, size);
571 }
572
HostMemoryUnregister(void * location)573 bool StreamExecutor::HostMemoryUnregister(void *location) {
574 VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
575 << ")" << StackTraceIfVLOG10();
576 return implementation_->HostMemoryUnregister(location);
577 }
578
SynchronizeAllActivity()579 bool StreamExecutor::SynchronizeAllActivity() {
580 VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
581 << StackTraceIfVLOG10();
582 bool ok = implementation_->SynchronizeAllActivity();
583
584 // This should all be quick and infallible work, so we can perform the
585 // synchronization even in the case of failure.
586 BlockOnThreadExecutor(background_threads_.get());
587
588 return ok;
589 }
590
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)591 port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
592 uint64 size) {
593 VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
594 << ", size=" << size << ")" << StackTraceIfVLOG10();
595
596 return implementation_->SynchronousMemZero(location, size);
597 }
598
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)599 port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location,
600 int value, uint64 size) {
601 VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
602 << ", value=" << value << ", size=" << size << ")"
603 << StackTraceIfVLOG10();
604
605 return implementation_->SynchronousMemSet(location, value, size);
606 }
607
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)608 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
609 const void *host_src, uint64 size) {
610 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
611 << device_dst->opaque() << ", host_src=" << host_src
612 << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
613
614 // Tracing overloaded methods is very difficult due to issues with type
615 // inference on template args. Since use of these overloaded methods is
616 // discouraged anyway, this isn't a huge deal.
617 port::Status status =
618 implementation_->SynchronousMemcpy(device_dst, host_src, size);
619 if (!status.ok()) {
620 LOG(ERROR) << "synchronous memcpy: " << status;
621 }
622 return status.ok();
623 }
624
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)625 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
626 const DeviceMemoryBase &device_src,
627 uint64 size) {
628 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
629 << ", device_src=" << device_src.opaque() << ", size=" << size
630 << ") D2H" << StackTraceIfVLOG10();
631
632 port::Status status =
633 implementation_->SynchronousMemcpy(host_dst, device_src, size);
634 if (!status.ok()) {
635 LOG(ERROR) << "synchronous memcpy: " << status;
636 }
637 return status.ok();
638 }
639
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)640 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
641 const DeviceMemoryBase &device_src,
642 uint64 size) {
643 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
644 << device_dst->opaque() << ", device_src=" << device_src.opaque()
645 << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
646
647 port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
648 device_dst, device_src, size);
649 if (!status.ok()) {
650 LOG(ERROR) << "synchronous memcpy: " << status;
651 }
652 return status.ok();
653 }
654
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64 size,void * host_dst)655 port::Status StreamExecutor::SynchronousMemcpyD2H(
656 const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
657 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
658 << device_src.opaque() << ", size=" << size
659 << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
660
661 port::Status result;
662 SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
663 host_dst);
664
665 result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
666 if (!result.ok()) {
667 result = port::Status(
668 port::error::INTERNAL,
669 absl::StrFormat("failed to synchronously memcpy device-to-host: device "
670 "%p to host %p size %d: %s",
671 device_src.opaque(), host_dst, size,
672 result.ToString()));
673 }
674
675 return result;
676 }
677
SynchronousMemcpyH2D(const void * host_src,int64 size,DeviceMemoryBase * device_dst)678 port::Status StreamExecutor::SynchronousMemcpyH2D(
679 const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
680 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
681 << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
682 << StackTraceIfVLOG10();
683
684 port::Status result;
685 SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
686 device_dst);
687
688 result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
689 if (!result.ok()) {
690 result = port::Status(
691 port::error::INTERNAL,
692 absl::StrFormat("failed to synchronously memcpy host-to-device: host "
693 "%p to device %p size %d: %s",
694 host_src, device_dst->opaque(), size,
695 result.ToString()));
696 }
697
698 return result;
699 }
700
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)701 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
702 const DeviceMemoryBase &device_src, uint64 size) {
703 return implementation_->Memcpy(stream, host_dst, device_src, size);
704 }
705
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)706 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
707 const void *host_src, uint64 size) {
708 return implementation_->Memcpy(stream, device_dst, host_src, size);
709 }
710
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)711 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
712 DeviceMemoryBase *device_dst,
713 const DeviceMemoryBase &device_src,
714 uint64 size) {
715 return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
716 size);
717 }
718
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)719 port::Status StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
720 uint64 size) {
721 return implementation_->MemZero(stream, location, size);
722 }
723
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)724 port::Status StreamExecutor::Memset32(Stream *stream,
725 DeviceMemoryBase *location,
726 uint32 pattern, uint64 size) {
727 CHECK_EQ(0, size % 4)
728 << "need 32-bit multiple size to fill with 32-bit pattern";
729 return implementation_->Memset32(stream, location, pattern, size);
730 }
731
HostCallback(Stream * stream,std::function<void ()> callback)732 bool StreamExecutor::HostCallback(Stream *stream,
733 std::function<void()> callback) {
734 return implementation_->HostCallback(stream, std::move(callback));
735 }
736
HostCallback(Stream * stream,std::function<port::Status ()> callback)737 bool StreamExecutor::HostCallback(Stream *stream,
738 std::function<port::Status()> callback) {
739 return implementation_->HostCallback(stream, std::move(callback));
740 }
741
AllocateEvent(Event * event)742 port::Status StreamExecutor::AllocateEvent(Event *event) {
743 return implementation_->AllocateEvent(event);
744 }
745
DeallocateEvent(Event * event)746 port::Status StreamExecutor::DeallocateEvent(Event *event) {
747 return implementation_->DeallocateEvent(event);
748 }
749
RecordEvent(Stream * stream,Event * event)750 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
751 return implementation_->RecordEvent(stream, event);
752 }
753
WaitForEvent(Stream * stream,Event * event)754 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
755 return implementation_->WaitForEvent(stream, event);
756 }
757
PollForEventStatus(Event * event)758 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
759 return implementation_->PollForEventStatus(event);
760 }
761
AllocateStream(Stream * stream)762 bool StreamExecutor::AllocateStream(Stream *stream) {
763 live_stream_count_.fetch_add(1, std::memory_order_relaxed);
764 if (!implementation_->AllocateStream(stream)) {
765 auto count = live_stream_count_.fetch_sub(1);
766 CHECK_GE(count, 0) << "live stream count should not dip below zero";
767 LOG(INFO) << "failed to allocate stream; live stream count: " << count;
768 return false;
769 }
770
771 return true;
772 }
773
DeallocateStream(Stream * stream)774 void StreamExecutor::DeallocateStream(Stream *stream) {
775 implementation_->DeallocateStream(stream);
776 CHECK_GE(live_stream_count_.fetch_sub(1), 0)
777 << "live stream count should not dip below zero";
778 }
779
CreateStreamDependency(Stream * dependent,Stream * other)780 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
781 return implementation_->CreateStreamDependency(dependent, other);
782 }
783
AllocateTimer(Timer * timer)784 bool StreamExecutor::AllocateTimer(Timer *timer) {
785 return implementation_->AllocateTimer(timer);
786 }
787
DeallocateTimer(Timer * timer)788 void StreamExecutor::DeallocateTimer(Timer *timer) {
789 return implementation_->DeallocateTimer(timer);
790 }
791
StartTimer(Stream * stream,Timer * timer)792 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
793 return implementation_->StartTimer(stream, timer);
794 }
795
StopTimer(Stream * stream,Timer * timer)796 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
797 return implementation_->StopTimer(stream, timer);
798 }
799
CreateDeviceDescription() const800 std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
801 const {
802 auto desc_status = implementation_->CreateDeviceDescription();
803 return desc_status.ConsumeValueOrDie();
804 }
805
DeviceMemoryUsage(int64 * free,int64 * total) const806 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
807 return implementation_->DeviceMemoryUsage(free, total);
808 }
809
EnqueueOnBackgroundThread(std::function<void ()> task)810 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
811 background_threads_->Schedule(std::move(task));
812 }
813
CreateAllocRecord(void * opaque,uint64 bytes)814 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
815 if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
816 absl::MutexLock lock(&mu_);
817 mem_allocs_[opaque] = AllocRecord{bytes, ""};
818 mem_alloc_bytes_ += bytes;
819 }
820 }
821
EraseAllocRecord(void * opaque)822 void StreamExecutor::EraseAllocRecord(void *opaque) {
823 if (FLAGS_check_device_leaks && opaque != nullptr) {
824 absl::MutexLock lock(&mu_);
825 if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
826 LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
827 } else {
828 mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
829 mem_allocs_.erase(opaque);
830 }
831 }
832 }
833
EnableTracing(bool enabled)834 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
835
RegisterTraceListener(TraceListener * listener)836 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
837 {
838 absl::MutexLock lock(&mu_);
839 if (listeners_.find(listener) != listeners_.end()) {
840 LOG(INFO) << "Attempt to register already-registered listener, "
841 << listener;
842 } else {
843 listeners_.insert(listener);
844 }
845 }
846
847 implementation_->RegisterTraceListener(listener);
848 }
849
UnregisterTraceListener(TraceListener * listener)850 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
851 {
852 absl::MutexLock lock(&mu_);
853 if (listeners_.find(listener) == listeners_.end()) {
854 LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
855 return false;
856 }
857 listeners_.erase(listener);
858 }
859
860 implementation_->UnregisterTraceListener(listener);
861 return true;
862 }
863
GetAllocatorStats()864 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
865 return implementation_->GetAllocatorStats();
866 }
867
868 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)869 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&...args) {
870 if (tracing_enabled_) {
871 {
872 // instance tracers held in a block to limit the lock lifetime.
873 absl::ReaderMutexLock lock(&mu_);
874 for (TraceListener *listener : listeners_) {
875 (listener->*trace_call)(std::forward<ArgsT>(args)...);
876 }
877 }
878 }
879 }
880
implementation()881 internal::StreamExecutorInterface *StreamExecutor::implementation() {
882 return implementation_->GetUnderlyingExecutor();
883 }
884
StreamExecutorMemoryAllocator(StreamExecutor * executor)885 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
886 StreamExecutor *executor)
887 : DeviceMemoryAllocator(executor->platform()) {
888 stream_executors_ = {executor};
889 }
890
StreamExecutorMemoryAllocator(const Platform * platform,absl::Span<StreamExecutor * const> stream_executors)891 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
892 const Platform *platform,
893 absl::Span<StreamExecutor *const> stream_executors)
894 : DeviceMemoryAllocator(platform),
895 stream_executors_(stream_executors.begin(), stream_executors.end()) {}
896
Allocate(int device_ordinal,uint64 size,bool retry_on_failure,int64 memory_space)897 port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
898 int device_ordinal, uint64 size, bool retry_on_failure,
899 int64 memory_space) {
900 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
901 GetStreamExecutor(device_ordinal));
902 DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
903 if (size > 0 && result == nullptr) {
904 return tensorflow::errors::ResourceExhausted(absl::StrFormat(
905 "Failed to allocate request for %s (%uB) on device ordinal %d",
906 tensorflow::strings::HumanReadableNumBytes(size), size,
907 device_ordinal));
908 }
909 VLOG(3) << absl::StreamFormat(
910 "Allocated %s (%uB) on device ordinal %d: %p",
911 tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
912 result.opaque());
913 return OwningDeviceMemory(result, device_ordinal, this);
914 }
915
Deallocate(int device_ordinal,DeviceMemoryBase mem)916 port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
917 DeviceMemoryBase mem) {
918 if (!mem.is_null()) {
919 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
920 GetStreamExecutor(device_ordinal));
921 VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
922 mem.opaque(), device_ordinal);
923 executor->Deallocate(&mem);
924 }
925 return port::Status::OK();
926 }
927
928 port::StatusOr<StreamExecutor *>
GetStreamExecutor(int device_ordinal) const929 StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
930 if (device_ordinal < 0) {
931 return tensorflow::errors::InvalidArgument(absl::StrFormat(
932 "device ordinal value (%d) must be non-negative", device_ordinal));
933 }
934 for (StreamExecutor *se : stream_executors_) {
935 if (se->device_ordinal() == device_ordinal) {
936 return se;
937 }
938 }
939 return tensorflow::errors::NotFound(
940 absl::StrFormat("Device %s:%d present but not supported",
941 platform()->Name(), device_ordinal));
942 }
943
AllowsAsynchronousDeallocation() const944 bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
945 return false;
946 }
947
GetStream(int device_ordinal)948 port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
949 int device_ordinal) {
950 CHECK(!AllowsAsynchronousDeallocation())
951 << "The logic below only works for synchronous allocators";
952 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
953 GetStreamExecutor(device_ordinal));
954 Stream *out = [&] {
955 absl::MutexLock lock(&mutex_);
956 if (!streams_.count(device_ordinal)) {
957 auto p = streams_.emplace(std::piecewise_construct,
958 std::forward_as_tuple(device_ordinal),
959 std::forward_as_tuple(executor));
960 p.first->second.Init();
961 return &p.first->second;
962 }
963 return &streams_.at(device_ordinal);
964 }();
965 return out;
966 }
967
968 } // namespace stream_executor
969