1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21
22 #include <atomic>
23 #include <memory>
24 #include <utility>
25
26 #include "absl/base/const_init.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/synchronization/notification.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/fft.h"
34 #include "tensorflow/stream_executor/lib/env.h"
35 #include "tensorflow/stream_executor/lib/error.h"
36 #include "tensorflow/stream_executor/lib/stacktrace.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/lib/threadpool.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/rng.h"
41 #include "tensorflow/stream_executor/stream_executor_internal.h"
42
43 namespace {
44 bool FLAGS_check_device_leaks = false;
45 } // namespace
46
47 namespace stream_executor {
48 namespace {
49
StackTraceIfVLOG10()50 string StackTraceIfVLOG10() {
51 if (VLOG_IS_ON(10)) {
52 return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
53 } else {
54 return "";
55 }
56 }
57
58 // Make sure the executor is done with its work; we know (because this isn't
59 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)60 void BlockOnThreadExecutor(port::ThreadPool *executor) {
61 absl::Notification n;
62 executor->Schedule([&n]() { n.Notify(); });
63 n.WaitForNotification();
64 }
65
66 std::atomic_int_fast64_t correlation_id_generator(0);
67
68 } // namespace
69
70 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
71 typename... BeginArgsT>
72 class ScopedTracer {
73 public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)74 ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
75 CompleteCallT complete_call, const ReturnT *result,
76 BeginArgsT... begin_args)
77 : stream_exec_(stream_exec),
78 complete_call_(complete_call),
79 result_(result) {
80 if (stream_exec_->tracing_enabled_) {
81 correlation_id_ =
82 correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
83 Trace(begin_call, begin_args...);
84 }
85 }
86
~ScopedTracer()87 ~ScopedTracer() {
88 if (stream_exec_->tracing_enabled_) {
89 Trace(complete_call_, result_);
90 }
91 }
92
93 private:
94 template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)95 void Trace(CallbackT callback, TraceArgsT... args) {
96 {
97 // Instance tracers held in a block to limit the lock lifetime.
98 absl::ReaderMutexLock lock{&stream_exec_->mu_};
99 for (TraceListener *listener : stream_exec_->listeners_) {
100 (listener->*callback)(correlation_id_,
101 std::forward<TraceArgsT>(args)...);
102 }
103 }
104 }
105
106 StreamExecutor *stream_exec_;
107 CompleteCallT complete_call_;
108 const ReturnT *result_;
109 int64 correlation_id_;
110 };
111
112 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
113 typename... BeginArgsT>
114 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)115 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
116 CompleteCallT complete_call, ReturnT *result,
117 BeginArgsT... begin_args) {
118 return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
119 stream_exec, begin_call, complete_call, result,
120 std::forward<BeginArgsT>(begin_args)...);
121 }
122
123 #define SCOPED_TRACE(LOC, ...) \
124 auto tracer = \
125 MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);
126
127 /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
128
129 // Get per-device memory limit in bytes. Returns 0 if
130 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()131 static int64 GetMemoryLimitBytes() {
132 int64 value;
133 SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
134 0, &value));
135 return value * (1ll << 20);
136 }
137
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation,int device_ordinal)138 StreamExecutor::StreamExecutor(
139 const Platform *platform,
140 std::unique_ptr<internal::StreamExecutorInterface> implementation,
141 int device_ordinal)
142 : platform_(platform),
143 implementation_(std::move(implementation)),
144 device_ordinal_(device_ordinal),
145 background_threads_(new port::ThreadPool(
146 port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
147 live_stream_count_(0),
148 tracing_enabled_(false),
149 mem_alloc_bytes_(0),
150 memory_limit_bytes_(GetMemoryLimitBytes()),
151 allocator_(this) {
152 string name = absl::AsciiStrToLower(platform_->Name());
153 if (name == "cuda") {
154 platform_kind_ = PlatformKind::kCuda;
155 } else if (name == "rocm") {
156 platform_kind_ = PlatformKind::kROCm;
157 } else if (name == "opencl") {
158 platform_kind_ = PlatformKind::kOpenCL;
159 } else if (name == "host") {
160 platform_kind_ = PlatformKind::kHost;
161 } else {
162 platform_kind_ = PlatformKind::kInvalid;
163 }
164 }
165
~StreamExecutor()166 StreamExecutor::~StreamExecutor() {
167 BlockOnThreadExecutor(background_threads_.get());
168
169 if (live_stream_count_.load() != 0) {
170 LOG(WARNING) << "Not all streams were deallocated at executor destruction "
171 << "time. This may lead to unexpected/bad behavior - "
172 << "especially if any stream is still active!";
173 }
174
175 if (FLAGS_check_device_leaks) {
176 for (auto it : mem_allocs_) {
177 LOG(INFO) << "Memory alloced at executor exit: addr: "
178 << absl::StrFormat("%p", it.first)
179 << ", bytes: " << it.second.bytes << ", trace: \n"
180 << it.second.stack_trace;
181 }
182 }
183 }
184
Init(DeviceOptions device_options)185 port::Status StreamExecutor::Init(DeviceOptions device_options) {
186 return implementation_->Init(device_ordinal_, std::move(device_options));
187 }
188
Init()189 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
190
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)191 port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
192 KernelBase *kernel) {
193 return implementation_->GetKernel(spec, kernel);
194 }
195
UnloadKernel(const KernelBase * kernel)196 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
197 implementation_->UnloadKernel(kernel);
198 }
199
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)200 port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
201 ModuleHandle *module_handle) {
202 return implementation_->LoadModule(spec, module_handle);
203 }
204
UnloadModule(ModuleHandle module_handle)205 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
206 return implementation_->UnloadModule(module_handle);
207 }
208
Deallocate(DeviceMemoryBase * mem)209 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
210 VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
211 << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
212
213 if (mem->opaque() != nullptr) {
214 EraseAllocRecord(mem->opaque());
215 }
216 implementation_->Deallocate(mem);
217 mem->Reset(nullptr, 0);
218 }
219
GetMemAllocs(std::map<void *,AllocRecord> * records_out)220 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
221 absl::ReaderMutexLock lock(&mu_);
222 *records_out = mem_allocs_;
223 }
224
CanEnablePeerAccessTo(StreamExecutor * other)225 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
226 return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
227 }
228
EnablePeerAccessTo(StreamExecutor * other)229 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
230 return implementation_->EnablePeerAccessTo(other->implementation_.get());
231 }
232
GetDeviceSharedMemoryConfig()233 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
234 return implementation_->GetDeviceSharedMemoryConfig();
235 }
236
SetDeviceSharedMemoryConfig(SharedMemoryConfig config)237 port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
238 SharedMemoryConfig config) {
239 if (config != SharedMemoryConfig::kDefault &&
240 config != SharedMemoryConfig::kFourByte &&
241 config != SharedMemoryConfig::kEightByte) {
242 string error_msg = absl::StrFormat(
243 "Invalid shared memory config specified: %d", static_cast<int>(config));
244 LOG(ERROR) << error_msg;
245 return port::Status(port::error::INVALID_ARGUMENT, error_msg);
246 }
247 return implementation_->SetDeviceSharedMemoryConfig(config);
248 }
249
GetDeviceDescription() const250 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
251 absl::MutexLock lock(&mu_);
252 if (device_description_ != nullptr) {
253 return *device_description_;
254 }
255
256 device_description_ = CreateDeviceDescription();
257 return *device_description_;
258 }
259
GetDeviceLoad() const260 int64 StreamExecutor::GetDeviceLoad() const {
261 return implementation_->GetDeviceLoad();
262 }
263
PlatformDeviceCount() const264 int StreamExecutor::PlatformDeviceCount() const {
265 return implementation_->PlatformDeviceCount();
266 }
267
SupportsBlas() const268 bool StreamExecutor::SupportsBlas() const {
269 return implementation_->SupportsBlas();
270 }
271
SupportsRng() const272 bool StreamExecutor::SupportsRng() const {
273 return implementation_->SupportsRng();
274 }
275
SupportsDnn() const276 bool StreamExecutor::SupportsDnn() const {
277 return implementation_->SupportsDnn();
278 }
279
GetConvolveAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)280 bool StreamExecutor::GetConvolveAlgorithms(
281 bool with_winograd_nonfused,
282 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
283 dnn::DnnSupport *dnn_support = AsDnn();
284 if (!dnn_support) {
285 return false;
286 }
287 int cc_major, cc_minor;
288 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
289 return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
290 cc_minor, out_algorithms);
291 }
292
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,Stream * stream,dnn::DataType element_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,std::vector<dnn::ProfileResult> * out_algorithms)293 bool StreamExecutor::GetMIOpenConvolveAlgorithms(
294 dnn::ConvolutionKind kind, Stream *stream, dnn::DataType element_type,
295 const dnn::BatchDescriptor &input_descriptor,
296 const dnn::FilterDescriptor &filter_descriptor,
297 const dnn::ConvolutionDescriptor &convolution_descriptor,
298 const dnn::BatchDescriptor &output_descriptor,
299 std::vector<dnn::ProfileResult> *out_algorithms) {
300 dnn::DnnSupport *dnn_support = AsDnn();
301 if (!dnn_support) {
302 return false;
303 }
304 return dnn_support->GetMIOpenConvolveAlgorithms(
305 kind, stream, element_type, input_descriptor, filter_descriptor,
306 convolution_descriptor, output_descriptor, out_algorithms);
307 }
308
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)309 bool StreamExecutor::GetRnnAlgorithms(
310 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
311 dnn::DnnSupport *dnn_support = AsDnn();
312 if (!dnn_support) {
313 return false;
314 }
315 return dnn_support->GetRnnAlgorithms(out_algorithms);
316 }
317
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)318 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
319 bool with_winograd_nonfused,
320 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
321 dnn::DnnSupport *dnn_support = AsDnn();
322 if (!dnn_support) {
323 return false;
324 }
325 int cc_major, cc_minor;
326 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
327 return dnn_support->GetConvolveBackwardDataAlgorithms(
328 with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
329 }
330
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)331 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
332 bool with_winograd_nonfused,
333 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
334 dnn::DnnSupport *dnn_support = AsDnn();
335 if (!dnn_support) {
336 return false;
337 }
338 int cc_major, cc_minor;
339 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
340 return dnn_support->GetConvolveBackwardFilterAlgorithms(
341 with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
342 }
343
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)344 bool StreamExecutor::GetBlasGemmAlgorithms(
345 std::vector<blas::AlgorithmType> *out_algorithms) {
346 blas::BlasSupport *blas_support = AsBlas();
347 if (!blas_support) {
348 return false;
349 }
350 return blas_support->GetBlasGemmAlgorithms(out_algorithms);
351 }
352
353 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)354 StreamExecutor::createRnnDescriptor(
355 int num_layers, int hidden_size, int input_size, int cell_size,
356 int batch_size, dnn::RnnInputMode input_mode,
357 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
358 dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
359 float dropout, uint64 seed, ScratchAllocator *state_allocator,
360 bool use_padded_io) {
361 dnn::DnnSupport *dnn_support = AsDnn();
362 if (!dnn_support) {
363 return port::Status(port::error::UNKNOWN,
364 "Fail to find the dnn implementation.");
365 }
366 return dnn_support->createRnnDescriptor(
367 num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
368 direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
369 state_allocator, use_padded_io);
370 }
371
372 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)373 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
374 int batch_size, int data_size,
375 dnn::DataType data_type) {
376 dnn::DnnSupport *dnn_support = AsDnn();
377 if (!dnn_support) {
378 return port::Status(port::error::UNKNOWN,
379 "Fail to find the dnn implementation.");
380 }
381 return dnn_support->createRnnSequenceTensorDescriptor(
382 max_seq_length, batch_size, data_size, data_type);
383 }
384
385 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)386 StreamExecutor::createRnnSequenceTensorDescriptor(
387 int max_seq_length, int batch_size, int data_size,
388 const absl::Span<const int> &seq_lengths, bool time_major,
389 dnn::DataType data_type) {
390 dnn::DnnSupport *dnn_support = AsDnn();
391 if (!dnn_support) {
392 return port::Status(port::error::UNKNOWN,
393 "Fail to find the dnn implementation.");
394 }
395 return dnn_support->createRnnSequenceTensorDescriptor(
396 max_seq_length, batch_size, data_size, seq_lengths, time_major,
397 data_type);
398 }
399
400 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)401 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
402 int data_size,
403 dnn::DataType data_type) {
404 dnn::DnnSupport *dnn_support = AsDnn();
405 if (!dnn_support) {
406 return port::Status(port::error::UNKNOWN,
407 "Fail to find the dnn implementation.");
408 }
409 return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
410 data_size, data_type);
411 }
412
AsDnn()413 dnn::DnnSupport *StreamExecutor::AsDnn() {
414 absl::MutexLock lock(&mu_);
415 if (dnn_ != nullptr) {
416 return dnn_.get();
417 }
418
419 dnn_.reset(implementation_->CreateDnn());
420 return dnn_.get();
421 }
422
AsBlas()423 blas::BlasSupport *StreamExecutor::AsBlas() {
424 absl::MutexLock lock(&mu_);
425 if (blas_ != nullptr) {
426 return blas_.get();
427 }
428
429 blas_.reset(implementation_->CreateBlas());
430 return blas_.get();
431 }
432
AsFft()433 fft::FftSupport *StreamExecutor::AsFft() {
434 absl::MutexLock lock(&mu_);
435 if (fft_ != nullptr) {
436 return fft_.get();
437 }
438
439 fft_.reset(implementation_->CreateFft());
440 return fft_.get();
441 }
442
AsRng()443 rng::RngSupport *StreamExecutor::AsRng() {
444 absl::MutexLock lock(&mu_);
445 if (rng_ != nullptr) {
446 return rng_.get();
447 }
448
449 rng_.reset(implementation_->CreateRng());
450 return rng_.get();
451 }
452
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)453 port::Status StreamExecutor::Launch(Stream *stream,
454 const ThreadDim &thread_dims,
455 const BlockDim &block_dims,
456 const KernelBase &kernel,
457 const KernelArgsArrayBase &args) {
458 SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
459 kernel, args);
460
461 return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
462 }
463
BlockHostUntilDone(Stream * stream)464 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
465 port::Status result;
466 SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
467
468 result = implementation_->BlockHostUntilDone(stream);
469 return result;
470 }
471
GetStatus(Stream * stream)472 port::Status StreamExecutor::GetStatus(Stream *stream) {
473 return implementation_->GetStatus(stream);
474 }
475
Allocate(uint64 size,int64 memory_space)476 DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64 memory_space) {
477 if (memory_limit_bytes_ > 0 &&
478 mem_alloc_bytes_ + size > memory_limit_bytes_) {
479 LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
480 << device_ordinal_
481 << " within provided limit. [used=" << mem_alloc_bytes_
482 << ", limit=" << memory_limit_bytes_ << "]";
483 return DeviceMemoryBase();
484 }
485 DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
486 VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
487 << ", memory_space=" << memory_space << ") returns " << buf.opaque()
488 << StackTraceIfVLOG10();
489 CreateAllocRecord(buf.opaque(), size);
490
491 return buf;
492 }
493
GetUntypedSymbol(const string & symbol_name,ModuleHandle module_handle)494 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
495 const string &symbol_name, ModuleHandle module_handle) {
496 // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
497 // be nullptr/0 for consistency with DeviceMemory semantics.
498 void *opaque = nullptr;
499 size_t bytes = 0;
500 if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
501 return DeviceMemoryBase(opaque, bytes);
502 }
503
504 if (static_cast<bool>(module_handle)) {
505 return port::Status(
506 port::error::NOT_FOUND,
507 absl::StrCat("Check if module containing symbol ", symbol_name,
508 " is loaded (module_handle = ",
509 reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
510 } else {
511 return port::Status(
512 port::error::NOT_FOUND,
513 absl::StrCat("Check if kernel using the symbol is loaded: ",
514 symbol_name));
515 }
516 }
517
GetSymbol(const string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)518 bool StreamExecutor::GetSymbol(const string &symbol_name,
519 ModuleHandle module_handle, void **mem,
520 size_t *bytes) {
521 return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
522 }
523
UnifiedMemoryAllocate(uint64 bytes)524 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
525 void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
526 VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
527 << ") returns " << buffer << StackTraceIfVLOG10();
528 return buffer;
529 }
530
UnifiedMemoryDeallocate(void * location)531 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
532 VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
533 << location << ")" << StackTraceIfVLOG10();
534
535 return implementation_->UnifiedMemoryDeallocate(location);
536 }
537
HostMemoryAllocate(uint64 size)538 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
539 void *buffer = implementation_->HostMemoryAllocate(size);
540 VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
541 << ") returns " << buffer << StackTraceIfVLOG10();
542 return buffer;
543 }
544
HostMemoryDeallocate(void * location)545 void StreamExecutor::HostMemoryDeallocate(void *location) {
546 VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
547 << ")" << StackTraceIfVLOG10();
548
549 return implementation_->HostMemoryDeallocate(location);
550 }
551
HostMemoryRegister(void * location,uint64 size)552 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
553 VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
554 << ", size=" << size << ")" << StackTraceIfVLOG10();
555 if (location == nullptr || size == 0) {
556 LOG(WARNING) << "attempting to register null or zero-sized memory: "
557 << location << "; size " << size;
558 }
559 return implementation_->HostMemoryRegister(location, size);
560 }
561
HostMemoryUnregister(void * location)562 bool StreamExecutor::HostMemoryUnregister(void *location) {
563 VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
564 << ")" << StackTraceIfVLOG10();
565 return implementation_->HostMemoryUnregister(location);
566 }
567
SynchronizeAllActivity()568 bool StreamExecutor::SynchronizeAllActivity() {
569 VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
570 << StackTraceIfVLOG10();
571 bool ok = implementation_->SynchronizeAllActivity();
572
573 // This should all be quick and infallible work, so we can perform the
574 // synchronization even in the case of failure.
575 BlockOnThreadExecutor(background_threads_.get());
576
577 return ok;
578 }
579
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)580 port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
581 uint64 size) {
582 VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
583 << ", size=" << size << ")" << StackTraceIfVLOG10();
584
585 return implementation_->SynchronousMemZero(location, size);
586 }
587
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)588 port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location,
589 int value, uint64 size) {
590 VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
591 << ", value=" << value << ", size=" << size << ")"
592 << StackTraceIfVLOG10();
593
594 return implementation_->SynchronousMemSet(location, value, size);
595 }
596
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)597 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
598 const void *host_src, uint64 size) {
599 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
600 << device_dst->opaque() << ", host_src=" << host_src
601 << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
602
603 // Tracing overloaded methods is very difficult due to issues with type
604 // inference on template args. Since use of these overloaded methods is
605 // discouraged anyway, this isn't a huge deal.
606 port::Status status =
607 implementation_->SynchronousMemcpy(device_dst, host_src, size);
608 if (!status.ok()) {
609 LOG(ERROR) << "synchronous memcpy: " << status;
610 }
611 return status.ok();
612 }
613
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)614 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
615 const DeviceMemoryBase &device_src,
616 uint64 size) {
617 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
618 << ", device_src=" << device_src.opaque() << ", size=" << size
619 << ") D2H" << StackTraceIfVLOG10();
620
621 port::Status status =
622 implementation_->SynchronousMemcpy(host_dst, device_src, size);
623 if (!status.ok()) {
624 LOG(ERROR) << "synchronous memcpy: " << status;
625 }
626 return status.ok();
627 }
628
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)629 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
630 const DeviceMemoryBase &device_src,
631 uint64 size) {
632 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
633 << device_dst->opaque() << ", device_src=" << device_src.opaque()
634 << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
635
636 port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
637 device_dst, device_src, size);
638 if (!status.ok()) {
639 LOG(ERROR) << "synchronous memcpy: " << status;
640 }
641 return status.ok();
642 }
643
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64 size,void * host_dst)644 port::Status StreamExecutor::SynchronousMemcpyD2H(
645 const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
646 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
647 << device_src.opaque() << ", size=" << size
648 << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
649
650 port::Status result;
651 SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
652 host_dst);
653
654 result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
655 if (!result.ok()) {
656 result = port::Status(
657 port::error::INTERNAL,
658 absl::StrFormat("failed to synchronously memcpy device-to-host: device "
659 "%p to host %p size %d: %s",
660 device_src.opaque(), host_dst, size,
661 result.ToString()));
662 }
663
664 return result;
665 }
666
SynchronousMemcpyH2D(const void * host_src,int64 size,DeviceMemoryBase * device_dst)667 port::Status StreamExecutor::SynchronousMemcpyH2D(
668 const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
669 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
670 << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
671 << StackTraceIfVLOG10();
672
673 port::Status result;
674 SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
675 device_dst);
676
677 result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
678 if (!result.ok()) {
679 result = port::Status(
680 port::error::INTERNAL,
681 absl::StrFormat("failed to synchronously memcpy host-to-device: host "
682 "%p to device %p size %d: %s",
683 host_src, device_dst->opaque(), size,
684 result.ToString()));
685 }
686
687 return result;
688 }
689
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)690 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
691 const DeviceMemoryBase &device_src, uint64 size) {
692 return implementation_->Memcpy(stream, host_dst, device_src, size);
693 }
694
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)695 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
696 const void *host_src, uint64 size) {
697 return implementation_->Memcpy(stream, device_dst, host_src, size);
698 }
699
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)700 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
701 DeviceMemoryBase *device_dst,
702 const DeviceMemoryBase &device_src,
703 uint64 size) {
704 return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
705 size);
706 }
707
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)708 port::Status StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
709 uint64 size) {
710 return implementation_->MemZero(stream, location, size);
711 }
712
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)713 port::Status StreamExecutor::Memset32(Stream *stream,
714 DeviceMemoryBase *location,
715 uint32 pattern, uint64 size) {
716 CHECK_EQ(0, size % 4)
717 << "need 32-bit multiple size to fill with 32-bit pattern";
718 return implementation_->Memset32(stream, location, pattern, size);
719 }
720
HostCallback(Stream * stream,std::function<void ()> callback)721 bool StreamExecutor::HostCallback(Stream *stream,
722 std::function<void()> callback) {
723 return implementation_->HostCallback(stream, std::move(callback));
724 }
725
HostCallback(Stream * stream,std::function<port::Status ()> callback)726 bool StreamExecutor::HostCallback(Stream *stream,
727 std::function<port::Status()> callback) {
728 return implementation_->HostCallback(stream, std::move(callback));
729 }
730
AllocateEvent(Event * event)731 port::Status StreamExecutor::AllocateEvent(Event *event) {
732 return implementation_->AllocateEvent(event);
733 }
734
DeallocateEvent(Event * event)735 port::Status StreamExecutor::DeallocateEvent(Event *event) {
736 return implementation_->DeallocateEvent(event);
737 }
738
RecordEvent(Stream * stream,Event * event)739 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
740 return implementation_->RecordEvent(stream, event);
741 }
742
WaitForEvent(Stream * stream,Event * event)743 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
744 return implementation_->WaitForEvent(stream, event);
745 }
746
PollForEventStatus(Event * event)747 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
748 return implementation_->PollForEventStatus(event);
749 }
750
AllocateStream(Stream * stream)751 bool StreamExecutor::AllocateStream(Stream *stream) {
752 live_stream_count_.fetch_add(1, std::memory_order_relaxed);
753 if (!implementation_->AllocateStream(stream)) {
754 auto count = live_stream_count_.fetch_sub(1);
755 CHECK_GE(count, 0) << "live stream count should not dip below zero";
756 LOG(INFO) << "failed to allocate stream; live stream count: " << count;
757 return false;
758 }
759
760 return true;
761 }
762
DeallocateStream(Stream * stream)763 void StreamExecutor::DeallocateStream(Stream *stream) {
764 implementation_->DeallocateStream(stream);
765 CHECK_GE(live_stream_count_.fetch_sub(1), 0)
766 << "live stream count should not dip below zero";
767 }
768
CreateStreamDependency(Stream * dependent,Stream * other)769 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
770 return implementation_->CreateStreamDependency(dependent, other);
771 }
772
AllocateTimer(Timer * timer)773 bool StreamExecutor::AllocateTimer(Timer *timer) {
774 return implementation_->AllocateTimer(timer);
775 }
776
DeallocateTimer(Timer * timer)777 void StreamExecutor::DeallocateTimer(Timer *timer) {
778 return implementation_->DeallocateTimer(timer);
779 }
780
StartTimer(Stream * stream,Timer * timer)781 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
782 return implementation_->StartTimer(stream, timer);
783 }
784
StopTimer(Stream * stream,Timer * timer)785 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
786 return implementation_->StopTimer(stream, timer);
787 }
788
CreateDeviceDescription() const789 std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
790 const {
791 auto desc_status = implementation_->CreateDeviceDescription();
792 return desc_status.ConsumeValueOrDie();
793 }
794
DeviceMemoryUsage(int64 * free,int64 * total) const795 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
796 return implementation_->DeviceMemoryUsage(free, total);
797 }
798
EnqueueOnBackgroundThread(std::function<void ()> task)799 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
800 background_threads_->Schedule(std::move(task));
801 }
802
CreateAllocRecord(void * opaque,uint64 bytes)803 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
804 if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
805 absl::MutexLock lock(&mu_);
806 mem_allocs_[opaque] = AllocRecord{bytes, ""};
807 mem_alloc_bytes_ += bytes;
808 }
809 }
810
EraseAllocRecord(void * opaque)811 void StreamExecutor::EraseAllocRecord(void *opaque) {
812 if (FLAGS_check_device_leaks && opaque != nullptr) {
813 absl::MutexLock lock(&mu_);
814 if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
815 LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
816 } else {
817 mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
818 mem_allocs_.erase(opaque);
819 }
820 }
821 }
822
EnableTracing(bool enabled)823 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
824
RegisterTraceListener(TraceListener * listener)825 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
826 {
827 absl::MutexLock lock(&mu_);
828 if (listeners_.find(listener) != listeners_.end()) {
829 LOG(INFO) << "Attempt to register already-registered listener, "
830 << listener;
831 } else {
832 listeners_.insert(listener);
833 }
834 }
835
836 implementation_->RegisterTraceListener(listener);
837 }
838
UnregisterTraceListener(TraceListener * listener)839 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
840 {
841 absl::MutexLock lock(&mu_);
842 if (listeners_.find(listener) == listeners_.end()) {
843 LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
844 return false;
845 }
846 listeners_.erase(listener);
847 }
848
849 implementation_->UnregisterTraceListener(listener);
850 return true;
851 }
852
GetAllocatorStats()853 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
854 return implementation_->GetAllocatorStats();
855 }
856
857 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)858 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
859 if (tracing_enabled_) {
860 {
861 // instance tracers held in a block to limit the lock lifetime.
862 absl::ReaderMutexLock lock(&mu_);
863 for (TraceListener *listener : listeners_) {
864 (listener->*trace_call)(std::forward<ArgsT>(args)...);
865 }
866 }
867 }
868 }
869
implementation()870 internal::StreamExecutorInterface *StreamExecutor::implementation() {
871 return implementation_->GetUnderlyingExecutor();
872 }
873
StreamExecutorMemoryAllocator(StreamExecutor * executor)874 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
875 StreamExecutor *executor)
876 : DeviceMemoryAllocator(executor->platform()) {
877 stream_executors_ = {executor};
878 }
879
StreamExecutorMemoryAllocator(const Platform * platform,absl::Span<StreamExecutor * const> stream_executors)880 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
881 const Platform *platform,
882 absl::Span<StreamExecutor *const> stream_executors)
883 : DeviceMemoryAllocator(platform),
884 stream_executors_(stream_executors.begin(), stream_executors.end()) {}
885
Allocate(int device_ordinal,uint64 size,bool retry_on_failure,int64 memory_space)886 port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
887 int device_ordinal, uint64 size, bool retry_on_failure,
888 int64 memory_space) {
889 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
890 GetStreamExecutor(device_ordinal));
891 DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
892 if (size > 0 && result == nullptr) {
893 return tensorflow::errors::ResourceExhausted(absl::StrFormat(
894 "Failed to allocate request for %s (%uB) on device ordinal %d",
895 tensorflow::strings::HumanReadableNumBytes(size), size,
896 device_ordinal));
897 }
898 VLOG(3) << absl::StreamFormat(
899 "Allocated %s (%uB) on device ordinal %d: %p",
900 tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
901 result.opaque());
902 return OwningDeviceMemory(result, device_ordinal, this);
903 }
904
Deallocate(int device_ordinal,DeviceMemoryBase mem)905 port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
906 DeviceMemoryBase mem) {
907 if (!mem.is_null()) {
908 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
909 GetStreamExecutor(device_ordinal));
910 VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
911 mem.opaque(), device_ordinal);
912 executor->Deallocate(&mem);
913 }
914 return port::Status::OK();
915 }
916
917 port::StatusOr<StreamExecutor *>
GetStreamExecutor(int device_ordinal) const918 StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
919 if (device_ordinal < 0) {
920 return tensorflow::errors::InvalidArgument(absl::StrFormat(
921 "device ordinal value (%d) must be non-negative", device_ordinal));
922 }
923 for (StreamExecutor *se : stream_executors_) {
924 if (se->device_ordinal() == device_ordinal) {
925 return se;
926 }
927 }
928 return tensorflow::errors::NotFound(
929 absl::StrFormat("Device %s:%d present but not supported",
930 platform()->Name(), device_ordinal));
931 }
932
AllowsAsynchronousDeallocation() const933 bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
934 return false;
935 }
936
GetStream(int device_ordinal)937 port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
938 int device_ordinal) {
939 CHECK(!AllowsAsynchronousDeallocation())
940 << "The logic below only works for synchronous allocators";
941 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
942 GetStreamExecutor(device_ordinal));
943 Stream *out = [&] {
944 absl::MutexLock lock(&mutex_);
945 if (!streams_.count(device_ordinal)) {
946 auto p = streams_.emplace(std::piecewise_construct,
947 std::forward_as_tuple(device_ordinal),
948 std::forward_as_tuple(executor));
949 p.first->second.Init();
950 return &p.first->second;
951 }
952 return &streams_.at(device_ordinal);
953 }();
954 return out;
955 }
956
957 } // namespace stream_executor
958