1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21
22 #include <atomic>
23 #include <utility>
24
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/core/util/env_var.h"
27 #include "tensorflow/stream_executor/blas.h"
28 #include "tensorflow/stream_executor/fft.h"
29 #include "tensorflow/stream_executor/lib/env.h"
30 #include "tensorflow/stream_executor/lib/error.h"
31 #include "tensorflow/stream_executor/lib/notification.h"
32 #include "tensorflow/stream_executor/lib/stacktrace.h"
33 #include "tensorflow/stream_executor/lib/str_util.h"
34 #include "tensorflow/stream_executor/lib/stringprintf.h"
35 #include "tensorflow/stream_executor/lib/threadpool.h"
36 #include "tensorflow/stream_executor/platform/port.h"
37 #include "tensorflow/stream_executor/rng.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39
40 namespace {
41 bool FLAGS_check_device_leaks = false;
42 } // namespace
43
44 namespace stream_executor {
45 namespace {
46
StackTraceIfVLOG10()47 string StackTraceIfVLOG10() {
48 if (VLOG_IS_ON(10)) {
49 return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
50 } else {
51 return "";
52 }
53 }
54
55 // Make sure the executor is done with its work; we know (because this isn't
56 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)57 void BlockOnThreadExecutor(port::ThreadPool *executor) {
58 port::Notification n;
59 executor->Schedule([&n]() { n.Notify(); });
60 n.WaitForNotification();
61 }
62
StreamExecutorImplementationFromPlatformKind(PlatformKind platform_kind,const PluginConfig & plugin_config)63 internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind(
64 PlatformKind platform_kind, const PluginConfig &plugin_config) {
65 // Note: we use this factory-assignment-in-switch pattern instead of just
66 // invoking the callable in case linkage is messed up -- instead of invoking a
67 // nullptr std::function (due to failed registration) we give a nice
68 // LOG(FATAL) message.
69 internal::StreamExecutorFactory factory;
70 switch (platform_kind) {
71 case PlatformKind::kCuda:
72 factory = *internal::MakeCUDAExecutorImplementation();
73 break;
74 case PlatformKind::kROCm:
75 factory = *internal::MakeROCMExecutorImplementation();
76 break;
77 case PlatformKind::kOpenCL:
78 factory = *internal::MakeOpenCLExecutorImplementation();
79 break;
80 case PlatformKind::kHost:
81 factory = internal::MakeHostExecutorImplementation;
82 break;
83 default:
84 factory = nullptr;
85 }
86 if (factory == nullptr) {
87 LOG(FATAL)
88 << "cannot create StreamExecutor implementation for platform kind: "
89 << PlatformKindString(platform_kind);
90 }
91 return factory(plugin_config);
92 }
93
94 std::atomic_int_fast64_t correlation_id_generator(0);
95
96 } // namespace
97
98 template <typename BeginCallT, typename CompleteCallT,
99 typename ReturnT, typename... BeginArgsT>
100 class ScopedTracer {
101 public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)102 ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
103 CompleteCallT complete_call, const ReturnT *result,
104 BeginArgsT... begin_args)
105 : stream_exec_(stream_exec),
106 complete_call_(complete_call),
107 result_(result) {
108 if (stream_exec_->tracing_enabled_) {
109 correlation_id_ =
110 correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
111 Trace(begin_call, begin_args...);
112 }
113 }
114
~ScopedTracer()115 ~ScopedTracer() {
116 if (stream_exec_->tracing_enabled_) {
117 Trace(complete_call_, result_);
118 }
119 }
120
121 private:
122 template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)123 void Trace(CallbackT callback, TraceArgsT... args) {
124 {
125 // Instance tracers held in a block to limit the lock lifetime.
126 tf_shared_lock lock{stream_exec_->mu_};
127 for (TraceListener *listener : stream_exec_->listeners_) {
128 (listener->*callback)(correlation_id_,
129 std::forward<TraceArgsT>(args)...);
130 }
131 }
132 }
133
134 StreamExecutor *stream_exec_;
135 CompleteCallT complete_call_;
136 const ReturnT* result_;
137 int64 correlation_id_;
138 };
139
140 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
141 typename... BeginArgsT>
142 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)143 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
144 CompleteCallT complete_call, ReturnT *result,
145 BeginArgsT... begin_args) {
146 return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
147 stream_exec, begin_call, complete_call, result,
148 std::forward<BeginArgsT>(begin_args)...);
149 }
150
151 #define SCOPED_TRACE(LOC, ...) \
152 auto tracer = MakeScopedTracer(this, &LOC ## Begin, \
153 &LOC ## Complete, ## __VA_ARGS__);
154
155 /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED};
156
StreamExecutor(PlatformKind platform_kind,const PluginConfig & plugin_config)157 StreamExecutor::StreamExecutor(PlatformKind platform_kind,
158 const PluginConfig &plugin_config)
159 : platform_(nullptr),
160 implementation_(StreamExecutorImplementationFromPlatformKind(
161 platform_kind, plugin_config)),
162 platform_kind_(platform_kind),
163 device_ordinal_(-1),
164 background_threads_(new port::ThreadPool(
165 port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
166 live_stream_count_(0),
167 tracing_enabled_(false) {
168 CheckPlatformKindIsValid(platform_kind);
169 }
170
171 // Get per-device memory limit in bytes. Returns 0 if
172 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()173 static int64 GetMemoryLimitBytes() {
174 int64 value;
175 SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
176 0, &value));
177 return value * (1ll << 20);
178 }
179
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation)180 StreamExecutor::StreamExecutor(
181 const Platform *platform,
182 std::unique_ptr<internal::StreamExecutorInterface> implementation)
183 : platform_(platform),
184 implementation_(std::move(implementation)),
185 device_ordinal_(-1),
186 background_threads_(new port::ThreadPool(
187 port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
188 live_stream_count_(0),
189 tracing_enabled_(false),
190 mem_alloc_bytes_(0),
191 memory_limit_bytes_(GetMemoryLimitBytes()) {
192 if (port::Lowercase(platform_->Name()) == "cuda") {
193 platform_kind_ = PlatformKind::kCuda;
194 } else if (port::Lowercase(platform_->Name()) == "rocm") {
195 platform_kind_ = PlatformKind::kROCm;
196 } else if (port::Lowercase(platform_->Name()) == "opencl") {
197 platform_kind_ = PlatformKind::kOpenCL;
198 } else if (port::Lowercase(platform_->Name()) == "host") {
199 platform_kind_ = PlatformKind::kHost;
200 } else {
201 platform_kind_ = PlatformKind::kInvalid;
202 }
203 }
204
~StreamExecutor()205 StreamExecutor::~StreamExecutor() {
206 BlockOnThreadExecutor(background_threads_.get());
207
208 if (live_stream_count_.load() != 0) {
209 LOG(WARNING) << "Not all streams were deallocated at executor destruction "
210 << "time. This may lead to unexpected/bad behavior - "
211 << "especially if any stream is still active!";
212 }
213
214 if (FLAGS_check_device_leaks) {
215 for (auto it : mem_allocs_) {
216 LOG(INFO) << "Memory alloced at executor exit: addr: "
217 << port::Printf("%p", it.first)
218 << ", bytes: " << it.second.bytes << ", trace: \n"
219 << it.second.stack_trace;
220 }
221 }
222 }
223
Init(int device_ordinal,DeviceOptions device_options)224 port::Status StreamExecutor::Init(int device_ordinal,
225 DeviceOptions device_options) {
226 device_ordinal_ = device_ordinal;
227 return implementation_->Init(device_ordinal, std::move(device_options));
228 }
229
Init()230 port::Status StreamExecutor::Init() {
231 return Init(0, DeviceOptions::Default());
232 }
233
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)234 bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
235 KernelBase *kernel) {
236 return implementation_->GetKernel(spec, kernel);
237 }
238
UnloadKernel(const KernelBase * kernel)239 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
240 implementation_->UnloadKernel(kernel);
241 }
242
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)243 bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
244 ModuleHandle *module_handle) {
245 return implementation_->LoadModule(spec, module_handle);
246 }
247
UnloadModule(ModuleHandle module_handle)248 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
249 return implementation_->UnloadModule(module_handle);
250 }
251
Deallocate(DeviceMemoryBase * mem)252 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
253 VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
254 << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
255
256 if (mem->opaque() != nullptr) {
257 EraseAllocRecord(mem->opaque());
258 }
259 implementation_->Deallocate(mem);
260 mem->Reset(nullptr, 0);
261 }
262
GetMemAllocs(std::map<void *,AllocRecord> * records_out)263 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
264 tf_shared_lock lock(mu_);
265 *records_out = mem_allocs_;
266 }
267
CanEnablePeerAccessTo(StreamExecutor * other)268 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
269 return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
270 }
271
EnablePeerAccessTo(StreamExecutor * other)272 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
273 return implementation_->EnablePeerAccessTo(other->implementation_.get());
274 }
275
GetDeviceSharedMemoryConfig()276 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
277 return implementation_->GetDeviceSharedMemoryConfig();
278 }
279
SetDeviceSharedMemoryConfig(SharedMemoryConfig config)280 port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
281 SharedMemoryConfig config) {
282 if (config != SharedMemoryConfig::kDefault &&
283 config != SharedMemoryConfig::kFourByte &&
284 config != SharedMemoryConfig::kEightByte) {
285 string error_msg = port::Printf(
286 "Invalid shared memory config specified: %d", static_cast<int>(config));
287 LOG(ERROR) << error_msg;
288 return port::Status(port::error::INVALID_ARGUMENT, error_msg);
289 }
290 return implementation_->SetDeviceSharedMemoryConfig(config);
291 }
292
GetDeviceDescription() const293 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
294 mutex_lock lock(mu_);
295 if (device_description_ != nullptr) {
296 return *device_description_;
297 }
298
299 device_description_.reset(PopulateDeviceDescription());
300 return *device_description_;
301 }
302
GetDeviceLoad() const303 int64 StreamExecutor::GetDeviceLoad() const {
304 return implementation_->GetDeviceLoad();
305 }
306
PlatformDeviceCount() const307 int StreamExecutor::PlatformDeviceCount() const {
308 return implementation_->PlatformDeviceCount();
309 }
310
SupportsBlas() const311 bool StreamExecutor::SupportsBlas() const {
312 return implementation_->SupportsBlas();
313 }
314
SupportsRng() const315 bool StreamExecutor::SupportsRng() const {
316 return implementation_->SupportsRng();
317 }
318
SupportsDnn() const319 bool StreamExecutor::SupportsDnn() const {
320 return implementation_->SupportsDnn();
321 }
322
GetConvolveAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)323 bool StreamExecutor::GetConvolveAlgorithms(
324 bool with_winograd_nonfused,
325 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
326 dnn::DnnSupport *dnn_support = AsDnn();
327 if (!dnn_support) {
328 return false;
329 }
330 int cc_major, cc_minor;
331 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
332 return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
333 cc_minor, out_algorithms);
334 }
335
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)336 bool StreamExecutor::GetRnnAlgorithms(
337 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
338 dnn::DnnSupport *dnn_support = AsDnn();
339 if (!dnn_support) {
340 return false;
341 }
342 return dnn_support->GetRnnAlgorithms(out_algorithms);
343 }
344
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)345 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
346 bool with_winograd_nonfused,
347 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
348 dnn::DnnSupport *dnn_support = AsDnn();
349 if (!dnn_support) {
350 return false;
351 }
352 int cc_major, cc_minor;
353 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
354 return dnn_support->GetConvolveBackwardDataAlgorithms(
355 with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
356 }
357
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)358 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
359 bool with_winograd_nonfused,
360 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
361 dnn::DnnSupport *dnn_support = AsDnn();
362 if (!dnn_support) {
363 return false;
364 }
365 int cc_major, cc_minor;
366 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
367 return dnn_support->GetConvolveBackwardFilterAlgorithms(
368 with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
369 }
370
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)371 bool StreamExecutor::GetBlasGemmAlgorithms(
372 std::vector<blas::AlgorithmType> *out_algorithms) {
373 blas::BlasSupport *blas_support = AsBlas();
374 if (!blas_support) {
375 return false;
376 }
377 return blas_support->GetBlasGemmAlgorithms(out_algorithms);
378 }
379
380 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator)381 StreamExecutor::createRnnDescriptor(
382 int num_layers, int hidden_size, int input_size, int batch_size,
383 dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
384 dnn::RnnMode rnn_mode, dnn::DataType data_type,
385 const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
386 ScratchAllocator *state_allocator) {
387 dnn::DnnSupport *dnn_support = AsDnn();
388 if (!dnn_support) {
389 return port::Status(port::error::UNKNOWN,
390 "Fail to find the dnn implementation.");
391 }
392 return dnn_support->createRnnDescriptor(
393 num_layers, hidden_size, input_size, batch_size, input_mode,
394 direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
395 state_allocator);
396 }
397
398 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)399 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
400 int batch_size, int data_size,
401 dnn::DataType data_type) {
402 dnn::DnnSupport *dnn_support = AsDnn();
403 if (!dnn_support) {
404 return port::Status(port::error::UNKNOWN,
405 "Fail to find the dnn implementation.");
406 }
407 return dnn_support->createRnnSequenceTensorDescriptor(
408 max_seq_length, batch_size, data_size, data_type);
409 }
410
411 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)412 StreamExecutor::createRnnSequenceTensorDescriptor(
413 int max_seq_length, int batch_size, int data_size,
414 const absl::Span<const int> &seq_lengths, bool time_major,
415 dnn::DataType data_type) {
416 dnn::DnnSupport *dnn_support = AsDnn();
417 if (!dnn_support) {
418 return port::Status(port::error::UNKNOWN,
419 "Fail to find the dnn implementation.");
420 }
421 return dnn_support->createRnnSequenceTensorDescriptor(
422 max_seq_length, batch_size, data_size, seq_lengths, time_major,
423 data_type);
424 }
425
426 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)427 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
428 int data_size,
429 dnn::DataType data_type) {
430 dnn::DnnSupport *dnn_support = AsDnn();
431 if (!dnn_support) {
432 return port::Status(port::error::UNKNOWN,
433 "Fail to find the dnn implementation.");
434 }
435 return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
436 data_size, data_type);
437 }
438
AsDnn()439 dnn::DnnSupport *StreamExecutor::AsDnn() {
440 mutex_lock lock(mu_);
441 if (dnn_ != nullptr) {
442 return dnn_.get();
443 }
444
445 dnn_.reset(implementation_->CreateDnn());
446 return dnn_.get();
447 }
448
AsBlas()449 blas::BlasSupport *StreamExecutor::AsBlas() {
450 mutex_lock lock(mu_);
451 if (blas_ != nullptr) {
452 return blas_.get();
453 }
454
455 blas_.reset(implementation_->CreateBlas());
456 return blas_.get();
457 }
458
AsFft()459 fft::FftSupport *StreamExecutor::AsFft() {
460 mutex_lock lock(mu_);
461 if (fft_ != nullptr) {
462 return fft_.get();
463 }
464
465 fft_.reset(implementation_->CreateFft());
466 return fft_.get();
467 }
468
AsRng()469 rng::RngSupport *StreamExecutor::AsRng() {
470 mutex_lock lock(mu_);
471 if (rng_ != nullptr) {
472 return rng_.get();
473 }
474
475 rng_.reset(implementation_->CreateRng());
476 return rng_.get();
477 }
478
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)479 bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
480 const BlockDim &block_dims,
481 const KernelBase &kernel,
482 const KernelArgsArrayBase &args) {
483 SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
484 kernel, args);
485
486 return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
487 }
488
BlockHostUntilDone(Stream * stream)489 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
490 port::Status result;
491 SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
492
493 result = implementation_->BlockHostUntilDone(stream);
494 return result;
495 }
496
GetStatus(Stream * stream)497 port::Status StreamExecutor::GetStatus(Stream *stream) {
498 return implementation_->GetStatus(stream);
499 }
500
Allocate(uint64 size)501 void *StreamExecutor::Allocate(uint64 size) {
502 if (memory_limit_bytes_ > 0 &&
503 mem_alloc_bytes_ + size > memory_limit_bytes_) {
504 LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
505 << device_ordinal_
506 << " within provided limit. [used=" << mem_alloc_bytes_
507 << ", limit=" << memory_limit_bytes_ << "]";
508 return nullptr;
509 }
510 void *buf = implementation_->Allocate(size);
511 VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
512 << buf << StackTraceIfVLOG10();
513 CreateAllocRecord(buf, size);
514
515 return buf;
516 }
517
GetUntypedSymbol(const string & symbol_name,ModuleHandle module_handle)518 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
519 const string &symbol_name, ModuleHandle module_handle) {
520 // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
521 // be nullptr/0 for consistency with DeviceMemory semantics.
522 void *opaque = nullptr;
523 size_t bytes = 0;
524 if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
525 return DeviceMemoryBase(opaque, bytes);
526 }
527
528 if (static_cast<bool>(module_handle)) {
529 return port::Status(
530 port::error::NOT_FOUND,
531 absl::StrCat("Check if module containing symbol ", symbol_name,
532 " is loaded (module_handle = ",
533 reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
534 } else {
535 return port::Status(
536 port::error::NOT_FOUND,
537 absl::StrCat("Check if kernel using the symbol is loaded: ",
538 symbol_name));
539 }
540 }
541
GetSymbol(const string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)542 bool StreamExecutor::GetSymbol(const string &symbol_name,
543 ModuleHandle module_handle, void **mem,
544 size_t *bytes) {
545 return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
546 }
547
UnifiedMemoryAllocate(uint64 bytes)548 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
549 void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
550 VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
551 << ") returns " << buffer << StackTraceIfVLOG10();
552 return buffer;
553 }
554
UnifiedMemoryDeallocate(void * location)555 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
556 VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
557 << location << ")" << StackTraceIfVLOG10();
558
559 return implementation_->UnifiedMemoryDeallocate(location);
560 }
561
HostMemoryAllocate(uint64 size)562 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
563 void *buffer = implementation_->HostMemoryAllocate(size);
564 VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
565 << ") returns " << buffer << StackTraceIfVLOG10();
566 return buffer;
567 }
568
HostMemoryDeallocate(void * location)569 void StreamExecutor::HostMemoryDeallocate(void *location) {
570 VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
571 << ")" << StackTraceIfVLOG10();
572
573 return implementation_->HostMemoryDeallocate(location);
574 }
575
HostMemoryRegister(void * location,uint64 size)576 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
577 VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
578 << ", size=" << size << ")" << StackTraceIfVLOG10();
579 if (location == nullptr || size == 0) {
580 LOG(WARNING) << "attempting to register null or zero-sized memory: "
581 << location << "; size " << size;
582 }
583 return implementation_->HostMemoryRegister(location, size);
584 }
585
HostMemoryUnregister(void * location)586 bool StreamExecutor::HostMemoryUnregister(void *location) {
587 VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
588 << ")" << StackTraceIfVLOG10();
589 return implementation_->HostMemoryUnregister(location);
590 }
591
SynchronizeAllActivity()592 bool StreamExecutor::SynchronizeAllActivity() {
593 VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
594 << StackTraceIfVLOG10();
595 bool ok = implementation_->SynchronizeAllActivity();
596
597 // This should all be quick and infallible work, so we can perform the
598 // synchronization even in the case of failure.
599 BlockOnThreadExecutor(background_threads_.get());
600
601 return ok;
602 }
603
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)604 bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
605 uint64 size) {
606 VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
607 << ", size=" << size << ")" << StackTraceIfVLOG10();
608
609 return implementation_->SynchronousMemZero(location, size);
610 }
611
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)612 bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
613 uint64 size) {
614 VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
615 << ", value=" << value << ", size=" << size << ")"
616 << StackTraceIfVLOG10();
617
618 return implementation_->SynchronousMemSet(location, value, size);
619 }
620
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)621 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
622 const void *host_src, uint64 size) {
623 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
624 << device_dst->opaque() << ", host_src=" << host_src
625 << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
626
627 // Tracing overloaded methods is very difficult due to issues with type
628 // inference on template args. Since use of these overloaded methods is
629 // discouraged anyway, this isn't a huge deal.
630 port::Status status =
631 implementation_->SynchronousMemcpy(device_dst, host_src, size);
632 if (!status.ok()) {
633 LOG(ERROR) << "synchronous memcpy: " << status;
634 }
635 return status.ok();
636 }
637
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)638 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
639 const DeviceMemoryBase &device_src,
640 uint64 size) {
641 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
642 << ", device_src=" << device_src.opaque() << ", size=" << size
643 << ") D2H" << StackTraceIfVLOG10();
644
645 port::Status status =
646 implementation_->SynchronousMemcpy(host_dst, device_src, size);
647 if (!status.ok()) {
648 LOG(ERROR) << "synchronous memcpy: " << status;
649 }
650 return status.ok();
651 }
652
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)653 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
654 const DeviceMemoryBase &device_src,
655 uint64 size) {
656 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
657 << device_dst->opaque() << ", device_src=" << device_src.opaque()
658 << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
659
660 port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
661 device_dst, device_src, size);
662 if (!status.ok()) {
663 LOG(ERROR) << "synchronous memcpy: " << status;
664 }
665 return status.ok();
666 }
667
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64 size,void * host_dst)668 port::Status StreamExecutor::SynchronousMemcpyD2H(
669 const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
670 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
671 << device_src.opaque() << ", size=" << size
672 << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
673
674 port::Status result;
675 SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
676 host_dst);
677
678 result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
679 if (!result.ok()) {
680 result = port::Status(port::error::INTERNAL,
681 port::Printf("failed to synchronously memcpy "
682 "device-to-host: device %p to host %p "
683 "size %lld: %s",
684 device_src.opaque(), host_dst, size,
685 result.ToString().c_str()));
686 }
687
688 return result;
689 }
690
SynchronousMemcpyH2D(const void * host_src,int64 size,DeviceMemoryBase * device_dst)691 port::Status StreamExecutor::SynchronousMemcpyH2D(
692 const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
693 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
694 << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
695 << StackTraceIfVLOG10();
696
697 port::Status result;
698 SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
699 device_dst);
700
701 result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
702 if (!result.ok()) {
703 result = port::Status(
704 port::error::INTERNAL,
705 port::Printf("failed to synchronously memcpy host-to-device: host "
706 "%p to device %p size %lld: %s",
707 host_src, device_dst->opaque(), size,
708 result.ToString().c_str()));
709 }
710
711 return result;
712 }
713
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)714 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
715 const DeviceMemoryBase &device_src, uint64 size) {
716 return implementation_->Memcpy(stream, host_dst, device_src, size);
717 }
718
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)719 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
720 const void *host_src, uint64 size) {
721 return implementation_->Memcpy(stream, device_dst, host_src, size);
722 }
723
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)724 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
725 DeviceMemoryBase *device_dst,
726 const DeviceMemoryBase &device_src,
727 uint64 size) {
728 return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
729 size);
730 }
731
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)732 bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
733 uint64 size) {
734 return implementation_->MemZero(stream, location, size);
735 }
736
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)737 bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
738 uint32 pattern, uint64 size) {
739 CHECK_EQ(0, size % 4)
740 << "need 32-bit multiple size to fill with 32-bit pattern";
741 return implementation_->Memset32(stream, location, pattern, size);
742 }
743
HostCallback(Stream * stream,std::function<void ()> callback)744 bool StreamExecutor::HostCallback(Stream *stream,
745 std::function<void()> callback) {
746 return implementation_->HostCallback(stream, std::move(callback));
747 }
748
HostCallback(Stream * stream,std::function<port::Status ()> callback)749 bool StreamExecutor::HostCallback(Stream *stream,
750 std::function<port::Status()> callback) {
751 return implementation_->HostCallback(stream, std::move(callback));
752 }
753
AllocateEvent(Event * event)754 port::Status StreamExecutor::AllocateEvent(Event *event) {
755 return implementation_->AllocateEvent(event);
756 }
757
DeallocateEvent(Event * event)758 port::Status StreamExecutor::DeallocateEvent(Event *event) {
759 return implementation_->DeallocateEvent(event);
760 }
761
RecordEvent(Stream * stream,Event * event)762 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
763 return implementation_->RecordEvent(stream, event);
764 }
765
WaitForEvent(Stream * stream,Event * event)766 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
767 return implementation_->WaitForEvent(stream, event);
768 }
769
PollForEventStatus(Event * event)770 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
771 return implementation_->PollForEventStatus(event);
772 }
773
AllocateStream(Stream * stream)774 bool StreamExecutor::AllocateStream(Stream *stream) {
775 live_stream_count_.fetch_add(1, std::memory_order_relaxed);
776 if (!implementation_->AllocateStream(stream)) {
777 auto count = live_stream_count_.fetch_sub(1);
778 CHECK_GE(count, 0) << "live stream count should not dip below zero";
779 LOG(INFO) << "failed to allocate stream; live stream count: " << count;
780 return false;
781 }
782
783 return true;
784 }
785
DeallocateStream(Stream * stream)786 void StreamExecutor::DeallocateStream(Stream *stream) {
787 implementation_->DeallocateStream(stream);
788 CHECK_GE(live_stream_count_.fetch_sub(1), 0)
789 << "live stream count should not dip below zero";
790 }
791
CreateStreamDependency(Stream * dependent,Stream * other)792 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
793 return implementation_->CreateStreamDependency(dependent, other);
794 }
795
AllocateTimer(Timer * timer)796 bool StreamExecutor::AllocateTimer(Timer *timer) {
797 return implementation_->AllocateTimer(timer);
798 }
799
DeallocateTimer(Timer * timer)800 void StreamExecutor::DeallocateTimer(Timer *timer) {
801 return implementation_->DeallocateTimer(timer);
802 }
803
StartTimer(Stream * stream,Timer * timer)804 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
805 return implementation_->StartTimer(stream, timer);
806 }
807
StopTimer(Stream * stream,Timer * timer)808 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
809 return implementation_->StopTimer(stream, timer);
810 }
811
PopulateDeviceDescription() const812 DeviceDescription *StreamExecutor::PopulateDeviceDescription() const {
813 return implementation_->PopulateDeviceDescription();
814 }
815
DeviceMemoryUsage(int64 * free,int64 * total) const816 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
817 return implementation_->DeviceMemoryUsage(free, total);
818 }
819
EnqueueOnBackgroundThread(std::function<void ()> task)820 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
821 background_threads_->Schedule(std::move(task));
822 }
823
CreateAllocRecord(void * opaque,uint64 bytes)824 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
825 if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
826 mutex_lock lock(mu_);
827 mem_allocs_[opaque] = AllocRecord{
828 bytes, ""};
829 mem_alloc_bytes_ += bytes;
830 }
831 }
832
EraseAllocRecord(void * opaque)833 void StreamExecutor::EraseAllocRecord(void *opaque) {
834 if (FLAGS_check_device_leaks && opaque != nullptr) {
835 mutex_lock lock(mu_);
836 if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
837 LOG(ERROR) << "Deallocating unknown pointer: "
838 << port::Printf("0x%p", opaque);
839 } else {
840 mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
841 mem_allocs_.erase(opaque);
842 }
843 }
844 }
845
EnableTracing(bool enabled)846 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
847
RegisterTraceListener(TraceListener * listener)848 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
849 {
850 mutex_lock lock(mu_);
851 if (listeners_.find(listener) != listeners_.end()) {
852 LOG(INFO) << "Attempt to register already-registered listener, "
853 << listener;
854 } else {
855 listeners_.insert(listener);
856 }
857 }
858
859 implementation_->RegisterTraceListener(listener);
860 }
861
UnregisterTraceListener(TraceListener * listener)862 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
863 {
864 mutex_lock lock(mu_);
865 if (listeners_.find(listener) == listeners_.end()) {
866 LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
867 return false;
868 }
869 listeners_.erase(listener);
870 }
871
872 implementation_->UnregisterTraceListener(listener);
873 return true;
874 }
875
GetAllocatorStats()876 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
877 return implementation_->GetAllocatorStats();
878 }
879
880 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)881 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
882 if (tracing_enabled_) {
883 {
884 // instance tracers held in a block to limit the lock lifetime.
885 tf_shared_lock lock(mu_);
886 for (TraceListener *listener : listeners_) {
887 (listener->*trace_call)(std::forward<ArgsT>(args)...);
888 }
889 }
890 }
891 }
892
implementation()893 internal::StreamExecutorInterface *StreamExecutor::implementation() {
894 return implementation_->GetUnderlyingExecutor();
895 }
896
897 } // namespace stream_executor
898