1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21
22 #include <atomic>
23 #include <utility>
24
25 #include "tensorflow/stream_executor/blas.h"
26 #include "tensorflow/stream_executor/fft.h"
27 #include "tensorflow/stream_executor/lib/env.h"
28 #include "tensorflow/stream_executor/lib/error.h"
29 #include "tensorflow/stream_executor/lib/notification.h"
30 #include "tensorflow/stream_executor/lib/stacktrace.h"
31 #include "tensorflow/stream_executor/lib/str_util.h"
32 #include "tensorflow/stream_executor/lib/stringprintf.h"
33 #include "tensorflow/stream_executor/lib/threadpool.h"
34 #include "tensorflow/stream_executor/platform/port.h"
35 #include "tensorflow/stream_executor/rng.h"
36 #include "tensorflow/stream_executor/stream_executor_internal.h"
37
38 namespace {
39 bool FLAGS_check_device_leaks = false;
40 } // namespace
41
42 namespace perftools {
43 namespace gputools {
44 namespace {
45
StackTraceIfVLOG10()46 string StackTraceIfVLOG10() {
47 if (VLOG_IS_ON(10)) {
48 return port::StrCat(" ", port::CurrentStackTrace(), "\n");
49 } else {
50 return "";
51 }
52 }
53
54 // Make sure the executor is done with its work; we know (because this isn't
55 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)56 void BlockOnThreadExecutor(port::ThreadPool *executor) {
57 port::Notification n;
58 executor->Schedule([&n]() { n.Notify(); });
59 n.WaitForNotification();
60 }
61
StreamExecutorImplementationFromPlatformKind(PlatformKind platform_kind,const PluginConfig & plugin_config)62 internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind(
63 PlatformKind platform_kind, const PluginConfig &plugin_config) {
64 // Note: we use this factory-assignment-in-switch pattern instead of just
65 // invoking the callable in case linkage is messed up -- instead of invoking a
66 // nullptr std::function (due to failed registration) we give a nice
67 // LOG(FATAL) message.
68 internal::StreamExecutorFactory factory;
69 switch (platform_kind) {
70 case PlatformKind::kCuda:
71 factory = *internal::MakeCUDAExecutorImplementation();
72 break;
73 case PlatformKind::kOpenCL:
74 factory = *internal::MakeOpenCLExecutorImplementation();
75 break;
76 case PlatformKind::kHost:
77 factory = internal::MakeHostExecutorImplementation;
78 break;
79 default:
80 factory = nullptr;
81 }
82 if (factory == nullptr) {
83 LOG(FATAL)
84 << "cannot create StreamExecutor implementation for platform kind: "
85 << PlatformKindString(platform_kind);
86 }
87 return factory(plugin_config);
88 }
89
90 std::atomic_int_fast64_t correlation_id_generator(0);
91
92 } // namespace
93
94 template <typename BeginCallT, typename CompleteCallT,
95 typename ReturnT, typename... BeginArgsT>
96 class ScopedTracer {
97 public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)98 ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
99 CompleteCallT complete_call, const ReturnT *result,
100 BeginArgsT... begin_args)
101 : stream_exec_(stream_exec),
102 complete_call_(complete_call),
103 result_(result) {
104 if (stream_exec_->tracing_enabled_) {
105 correlation_id_ =
106 correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
107 Trace(begin_call, begin_args...);
108 }
109 }
110
~ScopedTracer()111 ~ScopedTracer() {
112 if (stream_exec_->tracing_enabled_) {
113 Trace(complete_call_, result_);
114 }
115 }
116
117 private:
118 template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)119 void Trace(CallbackT callback, TraceArgsT... args) {
120 {
121 // Instance tracers held in a block to limit the lock lifetime.
122 tf_shared_lock lock{stream_exec_->mu_};
123 for (TraceListener *listener : stream_exec_->listeners_) {
124 (listener->*callback)(correlation_id_,
125 std::forward<TraceArgsT>(args)...);
126 }
127 }
128 }
129
130 StreamExecutor *stream_exec_;
131 CompleteCallT complete_call_;
132 const ReturnT* result_;
133 int64 correlation_id_;
134 };
135
136 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
137 typename... BeginArgsT>
138 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)139 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
140 CompleteCallT complete_call, ReturnT *result,
141 BeginArgsT... begin_args) {
142 return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
143 stream_exec, begin_call, complete_call, result,
144 std::forward<BeginArgsT>(begin_args)...);
145 }
146
147 #define SCOPED_TRACE(LOC, ...) \
148 auto tracer = MakeScopedTracer(this, &LOC ## Begin, \
149 &LOC ## Complete, ## __VA_ARGS__);
150
151 /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED};
152
StreamExecutor(PlatformKind platform_kind,const PluginConfig & plugin_config)153 StreamExecutor::StreamExecutor(PlatformKind platform_kind,
154 const PluginConfig &plugin_config)
155 : platform_(nullptr),
156 implementation_(StreamExecutorImplementationFromPlatformKind(
157 platform_kind, plugin_config)),
158 platform_kind_(platform_kind),
159 device_ordinal_(-1),
160 background_threads_(new port::ThreadPool(
161 port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
162 live_stream_count_(0),
163 tracing_enabled_(false) {
164 CheckPlatformKindIsValid(platform_kind);
165 }
166
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation)167 StreamExecutor::StreamExecutor(
168 const Platform *platform,
169 std::unique_ptr<internal::StreamExecutorInterface> implementation)
170 : platform_(platform),
171 implementation_(std::move(implementation)),
172 device_ordinal_(-1),
173 background_threads_(new port::ThreadPool(
174 port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
175 live_stream_count_(0),
176 tracing_enabled_(false) {
177 if (port::Lowercase(platform_->Name()) == "cuda") {
178 platform_kind_ = PlatformKind::kCuda;
179 } else if (port::Lowercase(platform_->Name()) == "opencl") {
180 platform_kind_ = PlatformKind::kOpenCL;
181 } else if (port::Lowercase(platform_->Name()) == "host") {
182 platform_kind_ = PlatformKind::kHost;
183 }
184 }
185
~StreamExecutor()186 StreamExecutor::~StreamExecutor() {
187 BlockOnThreadExecutor(background_threads_.get());
188
189 if (live_stream_count_.load() != 0) {
190 LOG(WARNING) << "Not all streams were deallocated at executor destruction "
191 << "time. This may lead to unexpected/bad behavior - "
192 << "especially if any stream is still active!";
193 }
194
195 if (FLAGS_check_device_leaks) {
196 for (auto it : mem_allocs_) {
197 LOG(INFO) << "Memory alloced at executor exit: addr: "
198 << port::Printf("%p", it.first)
199 << ", bytes: " << it.second.bytes << ", trace: \n"
200 << it.second.stack_trace;
201 }
202 }
203 }
204
Init(int device_ordinal,DeviceOptions device_options)205 port::Status StreamExecutor::Init(int device_ordinal,
206 DeviceOptions device_options) {
207 device_ordinal_ = device_ordinal;
208 return implementation_->Init(device_ordinal, std::move(device_options));
209 }
210
Init()211 port::Status StreamExecutor::Init() {
212 return Init(0, DeviceOptions::Default());
213 }
214
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)215 bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
216 KernelBase *kernel) {
217 return implementation_->GetKernel(spec, kernel);
218 }
219
UnloadKernel(const KernelBase * kernel)220 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
221 implementation_->UnloadKernel(kernel);
222 }
223
Deallocate(DeviceMemoryBase * mem)224 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
225 VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
226 << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
227
228 if (mem->opaque() != nullptr) {
229 EraseAllocRecord(mem->opaque());
230 }
231 implementation_->Deallocate(mem);
232 mem->Reset(nullptr, 0);
233 }
234
GetMemAllocs(std::map<void *,AllocRecord> * records_out)235 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
236 tf_shared_lock lock{mu_};
237 *records_out = mem_allocs_;
238 }
239
CanEnablePeerAccessTo(StreamExecutor * other)240 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
241 return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
242 }
243
EnablePeerAccessTo(StreamExecutor * other)244 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
245 return implementation_->EnablePeerAccessTo(other->implementation_.get());
246 }
247
GetDeviceSharedMemoryConfig()248 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
249 return implementation_->GetDeviceSharedMemoryConfig();
250 }
251
SetDeviceSharedMemoryConfig(SharedMemoryConfig config)252 port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
253 SharedMemoryConfig config) {
254 if (config != SharedMemoryConfig::kDefault &&
255 config != SharedMemoryConfig::kFourByte &&
256 config != SharedMemoryConfig::kEightByte) {
257 string error_msg = port::Printf(
258 "Invalid shared memory config specified: %d", static_cast<int>(config));
259 LOG(ERROR) << error_msg;
260 return port::Status{port::error::INVALID_ARGUMENT, error_msg};
261 }
262 return implementation_->SetDeviceSharedMemoryConfig(config);
263 }
264
GetDeviceDescription() const265 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
266 mutex_lock lock{mu_};
267 if (device_description_ != nullptr) {
268 return *device_description_;
269 }
270
271 device_description_.reset(PopulateDeviceDescription());
272 return *device_description_;
273 }
274
GetDeviceLoad() const275 int64 StreamExecutor::GetDeviceLoad() const {
276 return implementation_->GetDeviceLoad();
277 }
278
PlatformDeviceCount() const279 int StreamExecutor::PlatformDeviceCount() const {
280 return implementation_->PlatformDeviceCount();
281 }
282
SupportsBlas() const283 bool StreamExecutor::SupportsBlas() const {
284 return implementation_->SupportsBlas();
285 }
286
SupportsRng() const287 bool StreamExecutor::SupportsRng() const {
288 return implementation_->SupportsRng();
289 }
290
SupportsDnn() const291 bool StreamExecutor::SupportsDnn() const {
292 return implementation_->SupportsDnn();
293 }
294
GetConvolveAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)295 bool StreamExecutor::GetConvolveAlgorithms(
296 bool with_winograd_nonfused,
297 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
298 dnn::DnnSupport *dnn_support = AsDnn();
299 if (!dnn_support) {
300 return false;
301 }
302 int cc_major, cc_minor;
303 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
304 return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
305 cc_minor, out_algorithms);
306 }
307
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)308 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
309 bool with_winograd_nonfused,
310 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
311 dnn::DnnSupport *dnn_support = AsDnn();
312 if (!dnn_support) {
313 return false;
314 }
315 int cc_major, cc_minor;
316 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
317 return dnn_support->GetConvolveBackwardDataAlgorithms(
318 with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
319 }
320
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)321 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
322 bool with_winograd_nonfused,
323 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
324 dnn::DnnSupport *dnn_support = AsDnn();
325 if (!dnn_support) {
326 return false;
327 }
328 int cc_major, cc_minor;
329 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
330 return dnn_support->GetConvolveBackwardFilterAlgorithms(
331 with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
332 }
333
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)334 bool StreamExecutor::GetBlasGemmAlgorithms(
335 std::vector<blas::AlgorithmType> *out_algorithms) {
336 blas::BlasSupport *blas_support = AsBlas();
337 if (!blas_support) {
338 return false;
339 }
340 return blas_support->GetBlasGemmAlgorithms(out_algorithms);
341 }
342
343 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,float dropout,uint64 seed,ScratchAllocator * state_allocator)344 StreamExecutor::createRnnDescriptor(
345 int num_layers, int hidden_size, int input_size,
346 dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
347 dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed,
348 ScratchAllocator *state_allocator) {
349 dnn::DnnSupport *dnn_support = AsDnn();
350 if (!dnn_support) {
351 return port::Status(port::error::UNKNOWN,
352 "Fail to find the dnn implementation.");
353 }
354 return dnn_support->createRnnDescriptor(
355 num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
356 data_type, dropout, seed, state_allocator);
357 }
358
359 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,dnn::DataType data_type)360 StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length,
361 int batch_size, int data_size,
362 dnn::DataType data_type) {
363 dnn::DnnSupport *dnn_support = AsDnn();
364 if (!dnn_support) {
365 return port::Status(port::error::UNKNOWN,
366 "Fail to find the dnn implementation.");
367 }
368 return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size,
369 data_size, data_type);
370 }
371
372 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)373 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
374 int data_size,
375 dnn::DataType data_type) {
376 dnn::DnnSupport *dnn_support = AsDnn();
377 if (!dnn_support) {
378 return port::Status(port::error::UNKNOWN,
379 "Fail to find the dnn implementation.");
380 }
381 return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
382 data_size, data_type);
383 }
384
AsDnn()385 dnn::DnnSupport *StreamExecutor::AsDnn() {
386 mutex_lock lock{mu_};
387 if (dnn_ != nullptr) {
388 return dnn_.get();
389 }
390
391 dnn_.reset(implementation_->CreateDnn());
392 return dnn_.get();
393 }
394
AsBlas()395 blas::BlasSupport *StreamExecutor::AsBlas() {
396 mutex_lock lock{mu_};
397 if (blas_ != nullptr) {
398 return blas_.get();
399 }
400
401 blas_.reset(implementation_->CreateBlas());
402 return blas_.get();
403 }
404
AsFft()405 fft::FftSupport *StreamExecutor::AsFft() {
406 mutex_lock lock{mu_};
407 if (fft_ != nullptr) {
408 return fft_.get();
409 }
410
411 fft_.reset(implementation_->CreateFft());
412 return fft_.get();
413 }
414
AsRng()415 rng::RngSupport *StreamExecutor::AsRng() {
416 mutex_lock lock{mu_};
417 if (rng_ != nullptr) {
418 return rng_.get();
419 }
420
421 rng_.reset(implementation_->CreateRng());
422 return rng_.get();
423 }
424
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)425 bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
426 const BlockDim &block_dims,
427 const KernelBase &kernel,
428 const KernelArgsArrayBase &args) {
429 SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
430 kernel, args);
431
432 return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
433 }
434
BlockHostUntilDone(Stream * stream)435 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
436 port::Status result;
437 SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
438
439 result = implementation_->BlockHostUntilDone(stream);
440 return result;
441 }
442
Allocate(uint64 size)443 void *StreamExecutor::Allocate(uint64 size) {
444 void *buf = implementation_->Allocate(size);
445 VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
446 << buf << StackTraceIfVLOG10();
447 CreateAllocRecord(buf, size);
448
449 return buf;
450 }
451
GetSymbol(const string & symbol_name,void ** mem,size_t * bytes)452 bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem,
453 size_t *bytes) {
454 return implementation_->GetSymbol(symbol_name, mem, bytes);
455 }
456
HostMemoryAllocate(uint64 size)457 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
458 void *buffer = implementation_->HostMemoryAllocate(size);
459 VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
460 << ") returns " << buffer << StackTraceIfVLOG10();
461 return buffer;
462 }
463
HostMemoryDeallocate(void * location)464 void StreamExecutor::HostMemoryDeallocate(void *location) {
465 VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
466 << ")" << StackTraceIfVLOG10();
467
468 return implementation_->HostMemoryDeallocate(location);
469 }
470
HostMemoryRegister(void * location,uint64 size)471 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
472 VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
473 << ", size=" << size << ")" << StackTraceIfVLOG10();
474 if (location == nullptr || size == 0) {
475 LOG(WARNING) << "attempting to register null or zero-sized memory: "
476 << location << "; size " << size;
477 }
478 return implementation_->HostMemoryRegister(location, size);
479 }
480
HostMemoryUnregister(void * location)481 bool StreamExecutor::HostMemoryUnregister(void *location) {
482 VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
483 << ")" << StackTraceIfVLOG10();
484 return implementation_->HostMemoryUnregister(location);
485 }
486
SynchronizeAllActivity()487 bool StreamExecutor::SynchronizeAllActivity() {
488 VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
489 << StackTraceIfVLOG10();
490 bool ok = implementation_->SynchronizeAllActivity();
491
492 // This should all be quick and infallible work, so we can perform the
493 // synchronization even in the case of failure.
494 BlockOnThreadExecutor(background_threads_.get());
495
496 return ok;
497 }
498
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)499 bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
500 uint64 size) {
501 VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
502 << ", size=" << size << ")" << StackTraceIfVLOG10();
503
504 return implementation_->SynchronousMemZero(location, size);
505 }
506
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)507 bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
508 uint64 size) {
509 VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
510 << ", value=" << value << ", size=" << size << ")"
511 << StackTraceIfVLOG10();
512
513 return implementation_->SynchronousMemSet(location, value, size);
514 }
515
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)516 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
517 const void *host_src, uint64 size) {
518 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
519 << device_dst->opaque() << ", host_src=" << host_src
520 << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
521
522 // Tracing overloaded methods is very difficult due to issues with type
523 // inference on template args. Since use of these overloaded methods is
524 // discouraged anyway, this isn't a huge deal.
525 port::Status status =
526 implementation_->SynchronousMemcpy(device_dst, host_src, size);
527 if (!status.ok()) {
528 LOG(ERROR) << "synchronous memcpy: " << status;
529 }
530 return status.ok();
531 }
532
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)533 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
534 const DeviceMemoryBase &device_src,
535 uint64 size) {
536 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
537 << ", device_src=" << device_src.opaque() << ", size=" << size
538 << ") D2H" << StackTraceIfVLOG10();
539
540 port::Status status =
541 implementation_->SynchronousMemcpy(host_dst, device_src, size);
542 if (!status.ok()) {
543 LOG(ERROR) << "synchronous memcpy: " << status;
544 }
545 return status.ok();
546 }
547
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)548 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
549 const DeviceMemoryBase &device_src,
550 uint64 size) {
551 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
552 << device_dst->opaque() << ", device_src=" << device_src.opaque()
553 << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
554
555 port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
556 device_dst, device_src, size);
557 if (!status.ok()) {
558 LOG(ERROR) << "synchronous memcpy: " << status;
559 }
560 return status.ok();
561 }
562
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64 size,void * host_dst)563 port::Status StreamExecutor::SynchronousMemcpyD2H(
564 const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
565 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
566 << device_src.opaque() << ", size=" << size
567 << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
568
569 port::Status result;
570 SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
571 host_dst);
572
573 result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
574 if (!result.ok()) {
575 result = port::Status{port::error::INTERNAL,
576 port::Printf("failed to synchronously memcpy "
577 "device-to-host: device %p to host %p "
578 "size %lld: %s",
579 device_src.opaque(), host_dst, size,
580 result.ToString().c_str())};
581 }
582
583 return result;
584 }
585
SynchronousMemcpyH2D(const void * host_src,int64 size,DeviceMemoryBase * device_dst)586 port::Status StreamExecutor::SynchronousMemcpyH2D(
587 const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
588 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
589 << ", size=" << size << ", device_dst" << device_dst->opaque() << ")"
590 << StackTraceIfVLOG10();
591
592 port::Status result;
593 SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
594 device_dst);
595
596 result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
597 if (!result.ok()) {
598 result = port::Status{
599 port::error::INTERNAL,
600 port::Printf("failed to synchronously memcpy host-to-device: host "
601 "%p to device %p size %lld: %s",
602 host_src, device_dst->opaque(), size,
603 result.ToString().c_str())};
604 }
605
606 return result;
607 }
608
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)609 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
610 const DeviceMemoryBase &device_src, uint64 size) {
611 return implementation_->Memcpy(stream, host_dst, device_src, size);
612 }
613
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)614 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
615 const void *host_src, uint64 size) {
616 return implementation_->Memcpy(stream, device_dst, host_src, size);
617 }
618
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)619 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
620 DeviceMemoryBase *device_dst,
621 const DeviceMemoryBase &device_src,
622 uint64 size) {
623 return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
624 size);
625 }
626
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)627 bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
628 uint64 size) {
629 return implementation_->MemZero(stream, location, size);
630 }
631
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)632 bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
633 uint32 pattern, uint64 size) {
634 CHECK_EQ(0, size % 4)
635 << "need 32-bit multiple size to fill with 32-bit pattern";
636 return implementation_->Memset32(stream, location, pattern, size);
637 }
638
HostCallback(Stream * stream,std::function<void ()> callback)639 bool StreamExecutor::HostCallback(Stream *stream,
640 std::function<void()> callback) {
641 return implementation_->HostCallback(stream, std::move(callback));
642 }
643
AllocateEvent(Event * event)644 port::Status StreamExecutor::AllocateEvent(Event *event) {
645 return implementation_->AllocateEvent(event);
646 }
647
DeallocateEvent(Event * event)648 port::Status StreamExecutor::DeallocateEvent(Event *event) {
649 return implementation_->DeallocateEvent(event);
650 }
651
RecordEvent(Stream * stream,Event * event)652 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
653 return implementation_->RecordEvent(stream, event);
654 }
655
WaitForEvent(Stream * stream,Event * event)656 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
657 return implementation_->WaitForEvent(stream, event);
658 }
659
PollForEventStatus(Event * event)660 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
661 return implementation_->PollForEventStatus(event);
662 }
663
AllocateStream(Stream * stream)664 bool StreamExecutor::AllocateStream(Stream *stream) {
665 live_stream_count_.fetch_add(1, std::memory_order_relaxed);
666 if (!implementation_->AllocateStream(stream)) {
667 auto count = live_stream_count_.fetch_sub(1);
668 CHECK_GE(count, 0) << "live stream count should not dip below zero";
669 LOG(INFO) << "failed to allocate stream; live stream count: " << count;
670 return false;
671 }
672
673 return true;
674 }
675
DeallocateStream(Stream * stream)676 void StreamExecutor::DeallocateStream(Stream *stream) {
677 implementation_->DeallocateStream(stream);
678 CHECK_GE(live_stream_count_.fetch_sub(1), 0)
679 << "live stream count should not dip below zero";
680 }
681
CreateStreamDependency(Stream * dependent,Stream * other)682 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
683 return implementation_->CreateStreamDependency(dependent, other);
684 }
685
AllocateTimer(Timer * timer)686 bool StreamExecutor::AllocateTimer(Timer *timer) {
687 return implementation_->AllocateTimer(timer);
688 }
689
DeallocateTimer(Timer * timer)690 void StreamExecutor::DeallocateTimer(Timer *timer) {
691 return implementation_->DeallocateTimer(timer);
692 }
693
StartTimer(Stream * stream,Timer * timer)694 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
695 return implementation_->StartTimer(stream, timer);
696 }
697
StopTimer(Stream * stream,Timer * timer)698 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
699 return implementation_->StopTimer(stream, timer);
700 }
701
PopulateDeviceDescription() const702 DeviceDescription *StreamExecutor::PopulateDeviceDescription() const {
703 return implementation_->PopulateDeviceDescription();
704 }
705
DeviceMemoryUsage(int64 * free,int64 * total) const706 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
707 return implementation_->DeviceMemoryUsage(free, total);
708 }
709
EnqueueOnBackgroundThread(std::function<void ()> task)710 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
711 background_threads_->Schedule(std::move(task));
712 }
713
CreateAllocRecord(void * opaque,uint64 bytes)714 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
715 if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
716 mutex_lock lock{mu_};
717 mem_allocs_[opaque] = AllocRecord{
718 bytes, ""};
719 }
720 }
721
EraseAllocRecord(void * opaque)722 void StreamExecutor::EraseAllocRecord(void *opaque) {
723 if (FLAGS_check_device_leaks && opaque != nullptr) {
724 mutex_lock lock{mu_};
725 if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
726 LOG(ERROR) << "Deallocating unknown pointer: "
727 << port::Printf("0x%p", opaque);
728 } else {
729 mem_allocs_.erase(opaque);
730 }
731 }
732 }
733
EnableTracing(bool enabled)734 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
735
RegisterTraceListener(TraceListener * listener)736 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
737 {
738 mutex_lock lock{mu_};
739 if (listeners_.find(listener) != listeners_.end()) {
740 LOG(INFO) << "Attempt to register already-registered listener, "
741 << listener;
742 } else {
743 listeners_.insert(listener);
744 }
745 }
746
747 implementation_->RegisterTraceListener(listener);
748 }
749
UnregisterTraceListener(TraceListener * listener)750 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
751 {
752 mutex_lock lock{mu_};
753 if (listeners_.find(listener) == listeners_.end()) {
754 LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
755 return false;
756 }
757 listeners_.erase(listener);
758 }
759
760 implementation_->UnregisterTraceListener(listener);
761 return true;
762 }
763
764 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)765 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
766 if (tracing_enabled_) {
767 {
768 // instance tracers held in a block to limit the lock lifetime.
769 tf_shared_lock lock{mu_};
770 for (TraceListener *listener : listeners_) {
771 (listener->*trace_call)(std::forward<ArgsT>(args)...);
772 }
773 }
774 }
775 }
776
implementation()777 internal::StreamExecutorInterface *StreamExecutor::implementation() {
778 return implementation_->GetUnderlyingExecutor();
779 }
780
781 } // namespace gputools
782 } // namespace perftools
783