1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21
22 #include <atomic>
23 #include <memory>
24 #include <utility>
25
26 #include "absl/base/const_init.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/synchronization/notification.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/fft.h"
34 #include "tensorflow/stream_executor/lib/env.h"
35 #include "tensorflow/stream_executor/lib/error.h"
36 #include "tensorflow/stream_executor/lib/stacktrace.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/lib/threadpool.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/rng.h"
41 #include "tensorflow/stream_executor/stream.h"
42 #include "tensorflow/stream_executor/stream_executor_internal.h"
43
44 namespace {
45 bool FLAGS_check_device_leaks = false;
46 } // namespace
47
48 namespace stream_executor {
49 namespace {
50
StackTraceIfVLOG10()51 std::string StackTraceIfVLOG10() {
52 if (VLOG_IS_ON(10)) {
53 return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
54 } else {
55 return "";
56 }
57 }
58
59 // Make sure the executor is done with its work; we know (because this isn't
60 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)61 void BlockOnThreadExecutor(port::ThreadPool *executor) {
62 absl::Notification n;
63 executor->Schedule([&n]() { n.Notify(); });
64 n.WaitForNotification();
65 }
66
67 std::atomic_int_fast64_t correlation_id_generator(0);
68
69 } // namespace
70
71 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
72 typename... BeginArgsT>
73 class ScopedTracer {
74 public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)75 ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
76 CompleteCallT complete_call, const ReturnT *result,
77 BeginArgsT... begin_args)
78 : stream_exec_(stream_exec),
79 complete_call_(complete_call),
80 result_(result) {
81 if (stream_exec_->tracing_enabled_) {
82 correlation_id_ =
83 correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
84 Trace(begin_call, begin_args...);
85 }
86 }
87
~ScopedTracer()88 ~ScopedTracer() {
89 if (stream_exec_->tracing_enabled_) {
90 Trace(complete_call_, result_);
91 }
92 }
93
94 private:
95 template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)96 void Trace(CallbackT callback, TraceArgsT... args) {
97 {
98 // Instance tracers held in a block to limit the lock lifetime.
99 absl::ReaderMutexLock lock{&stream_exec_->mu_};
100 for (TraceListener *listener : stream_exec_->listeners_) {
101 (listener->*callback)(correlation_id_,
102 std::forward<TraceArgsT>(args)...);
103 }
104 }
105 }
106
107 StreamExecutor *stream_exec_;
108 CompleteCallT complete_call_;
109 const ReturnT *result_;
110 int64 correlation_id_;
111 };
112
113 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
114 typename... BeginArgsT>
115 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)116 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
117 CompleteCallT complete_call, ReturnT *result,
118 BeginArgsT... begin_args) {
119 return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
120 stream_exec, begin_call, complete_call, result,
121 std::forward<BeginArgsT>(begin_args)...);
122 }
123
124 #define SCOPED_TRACE(LOC, ...) \
125 auto tracer = \
126 MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);
127
128 /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
129
130 // Get per-device memory limit in bytes. Returns 0 if
131 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()132 static int64 GetMemoryLimitBytes() {
133 int64_t value;
134 SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
135 0, &value));
136 return value * (1ll << 20);
137 }
138
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation,int device_ordinal)139 StreamExecutor::StreamExecutor(
140 const Platform *platform,
141 std::unique_ptr<internal::StreamExecutorInterface> implementation,
142 int device_ordinal)
143 : platform_(platform),
144 implementation_(std::move(implementation)),
145 device_ordinal_(device_ordinal),
146 background_threads_(new port::ThreadPool(
147 port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
148 live_stream_count_(0),
149 tracing_enabled_(false),
150 mem_alloc_bytes_(0),
151 memory_limit_bytes_(GetMemoryLimitBytes()),
152 allocator_(this) {
153 std::string name = absl::AsciiStrToLower(platform_->Name());
154 if (name == "cuda") {
155 platform_kind_ = PlatformKind::kCuda;
156 } else if (name == "rocm") {
157 platform_kind_ = PlatformKind::kROCm;
158 } else if (name == "opencl") {
159 platform_kind_ = PlatformKind::kOpenCL;
160 } else if (name == "host") {
161 platform_kind_ = PlatformKind::kHost;
162 } else {
163 platform_kind_ = PlatformKind::kInvalid;
164 }
165 }
166
~StreamExecutor()167 StreamExecutor::~StreamExecutor() {
168 BlockOnThreadExecutor(background_threads_.get());
169
170 if (live_stream_count_.load() != 0) {
171 LOG(WARNING) << "Not all streams were deallocated at executor destruction "
172 << "time. This may lead to unexpected/bad behavior - "
173 << "especially if any stream is still active!";
174 }
175
176 if (FLAGS_check_device_leaks) {
177 for (const auto &it : mem_allocs_) {
178 LOG(INFO) << "Memory alloced at executor exit: addr: "
179 << absl::StrFormat("%p", it.first)
180 << ", bytes: " << it.second.bytes << ", trace: \n"
181 << it.second.stack_trace;
182 }
183 }
184 }
185
Init(DeviceOptions device_options)186 port::Status StreamExecutor::Init(DeviceOptions device_options) {
187 return implementation_->Init(device_ordinal_, std::move(device_options));
188 }
189
Init()190 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
191
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)192 port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
193 KernelBase *kernel) {
194 return implementation_->GetKernel(spec, kernel);
195 }
196
UnloadKernel(const KernelBase * kernel)197 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
198 implementation_->UnloadKernel(kernel);
199 }
200
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)201 port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
202 ModuleHandle *module_handle) {
203 return implementation_->LoadModule(spec, module_handle);
204 }
205
UnloadModule(ModuleHandle module_handle)206 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
207 return implementation_->UnloadModule(module_handle);
208 }
209
Deallocate(DeviceMemoryBase * mem)210 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
211 VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
212 << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
213
214 if (mem->opaque() != nullptr) {
215 EraseAllocRecord(mem->opaque());
216 }
217 implementation_->Deallocate(mem);
218 mem->Reset(nullptr, 0);
219 }
220
GetMemAllocs(std::map<void *,AllocRecord> * records_out)221 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
222 absl::ReaderMutexLock lock(&mu_);
223 *records_out = mem_allocs_;
224 }
225
CanEnablePeerAccessTo(StreamExecutor * other)226 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
227 return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
228 }
229
EnablePeerAccessTo(StreamExecutor * other)230 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
231 return implementation_->EnablePeerAccessTo(other->implementation_.get());
232 }
233
GetDeviceDescription() const234 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
235 absl::MutexLock lock(&mu_);
236 if (device_description_ != nullptr) {
237 return *device_description_;
238 }
239
240 device_description_ = CreateDeviceDescription();
241 return *device_description_;
242 }
243
GetDeviceLoad() const244 int64 StreamExecutor::GetDeviceLoad() const {
245 return implementation_->GetDeviceLoad();
246 }
247
PlatformDeviceCount() const248 int StreamExecutor::PlatformDeviceCount() const {
249 return implementation_->PlatformDeviceCount();
250 }
251
SupportsBlas() const252 bool StreamExecutor::SupportsBlas() const {
253 return implementation_->SupportsBlas();
254 }
255
SupportsRng() const256 bool StreamExecutor::SupportsRng() const {
257 return implementation_->SupportsRng();
258 }
259
SupportsDnn() const260 bool StreamExecutor::SupportsDnn() const {
261 return implementation_->SupportsDnn();
262 }
263
GetConvolveAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)264 bool StreamExecutor::GetConvolveAlgorithms(
265 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
266 dnn::DnnSupport *dnn_support = AsDnn();
267 if (!dnn_support) {
268 return false;
269 }
270 return dnn_support->GetConvolveAlgorithms(
271 GetDeviceDescription().cuda_compute_capability(), out_algorithms);
272 }
273
GetConvolveExecutionPlans(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> * out_exec_plans)274 bool StreamExecutor::GetConvolveExecutionPlans(
275 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
276 const dnn::BatchDescriptor &input_descriptor,
277 const dnn::FilterDescriptor &filter_descriptor,
278 const dnn::BatchDescriptor &output_descriptor,
279 const dnn::ConvolutionDescriptor &convolution_descriptor,
280 std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> *out_exec_plans) {
281 dnn::DnnSupport *dnn_support = AsDnn();
282 if (!dnn_support) {
283 return false;
284 }
285 return dnn_support->GetConvolveExecutionPlans(
286 kind, element_type, stream, input_descriptor, filter_descriptor,
287 output_descriptor, convolution_descriptor, out_exec_plans);
288 }
289
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)290 bool StreamExecutor::GetMIOpenConvolveAlgorithms(
291 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
292 const dnn::BatchDescriptor &input_descriptor, DeviceMemoryBase input_data,
293 const dnn::FilterDescriptor &filter_descriptor,
294 DeviceMemoryBase filter_data, const dnn::BatchDescriptor &output_descriptor,
295 DeviceMemoryBase output_data,
296 const dnn::ConvolutionDescriptor &convolution_descriptor,
297 ScratchAllocator *scratch_allocator,
298 std::vector<dnn::ProfileResult> *out_algorithms) {
299 dnn::DnnSupport *dnn_support = AsDnn();
300 if (!dnn_support) {
301 return false;
302 }
303 return dnn_support->GetMIOpenConvolveAlgorithms(
304 kind, element_type, stream, input_descriptor, input_data,
305 filter_descriptor, filter_data, output_descriptor, output_data,
306 convolution_descriptor, scratch_allocator, out_algorithms);
307 }
308
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)309 bool StreamExecutor::GetRnnAlgorithms(
310 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
311 dnn::DnnSupport *dnn_support = AsDnn();
312 if (!dnn_support) {
313 return false;
314 }
315 return dnn_support->GetRnnAlgorithms(out_algorithms);
316 }
317
GetConvolveBackwardDataAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)318 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
319 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
320 dnn::DnnSupport *dnn_support = AsDnn();
321 if (!dnn_support) {
322 return false;
323 }
324 return dnn_support->GetConvolveBackwardDataAlgorithms(
325 GetDeviceDescription().cuda_compute_capability(), out_algorithms);
326 }
327
GetConvolveBackwardFilterAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)328 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
329 std::vector<dnn::AlgorithmDesc> *out_algorithms) {
330 dnn::DnnSupport *dnn_support = AsDnn();
331 if (!dnn_support) {
332 return false;
333 }
334 return dnn_support->GetConvolveBackwardFilterAlgorithms(
335 GetDeviceDescription().cuda_compute_capability(), out_algorithms);
336 }
337
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)338 bool StreamExecutor::GetBlasGemmAlgorithms(
339 std::vector<blas::AlgorithmType> *out_algorithms) {
340 blas::BlasSupport *blas_support = AsBlas();
341 if (!blas_support) {
342 return false;
343 }
344 return blas_support->GetBlasGemmAlgorithms(out_algorithms);
345 }
346
347 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams & params)348 StreamExecutor::CreateBlasLtMatmulPlan(
349 const blas::BlasLtMatmulPlanParams ¶ms) {
350 blas::BlasSupport *blas_support = AsBlas();
351 if (!blas_support) {
352 return port::Status(port::error::UNKNOWN,
353 "Fail to find the blas implementation.");
354 }
355 return blas_support->CreateBlasLtMatmulPlan(params);
356 }
357
358 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan * plan,size_t max_workspace_size,int max_algorithm_count)359 StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
360 size_t max_workspace_size,
361 int max_algorithm_count) {
362 blas::BlasSupport *blas_support = AsBlas();
363 if (!blas_support) {
364 return port::Status(port::error::UNKNOWN,
365 "Fail to find the blas implementation.");
366 }
367 return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
368 max_algorithm_count);
369 }
370
371 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)372 StreamExecutor::createRnnDescriptor(
373 int num_layers, int hidden_size, int input_size, int cell_size,
374 int batch_size, dnn::RnnInputMode input_mode,
375 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
376 dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
377 float dropout, uint64 seed, ScratchAllocator *state_allocator,
378 bool use_padded_io) {
379 dnn::DnnSupport *dnn_support = AsDnn();
380 if (!dnn_support) {
381 return port::Status(port::error::UNKNOWN,
382 "Fail to find the dnn implementation.");
383 }
384 return dnn_support->createRnnDescriptor(
385 num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
386 direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
387 state_allocator, use_padded_io);
388 }
389
390 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)391 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
392 int batch_size, int data_size,
393 dnn::DataType data_type) {
394 dnn::DnnSupport *dnn_support = AsDnn();
395 if (!dnn_support) {
396 return port::Status(port::error::UNKNOWN,
397 "Fail to find the dnn implementation.");
398 }
399 return dnn_support->createRnnSequenceTensorDescriptor(
400 max_seq_length, batch_size, data_size, data_type);
401 }
402
403 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)404 StreamExecutor::createRnnSequenceTensorDescriptor(
405 int max_seq_length, int batch_size, int data_size,
406 const absl::Span<const int> &seq_lengths, bool time_major,
407 dnn::DataType data_type) {
408 dnn::DnnSupport *dnn_support = AsDnn();
409 if (!dnn_support) {
410 return port::Status(port::error::UNKNOWN,
411 "Fail to find the dnn implementation.");
412 }
413 return dnn_support->createRnnSequenceTensorDescriptor(
414 max_seq_length, batch_size, data_size, seq_lengths, time_major,
415 data_type);
416 }
417
418 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)419 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
420 int data_size,
421 dnn::DataType data_type) {
422 dnn::DnnSupport *dnn_support = AsDnn();
423 if (!dnn_support) {
424 return port::Status(port::error::UNKNOWN,
425 "Fail to find the dnn implementation.");
426 }
427 return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
428 data_size, data_type);
429 }
430
AsDnn()431 dnn::DnnSupport *StreamExecutor::AsDnn() {
432 absl::MutexLock lock(&mu_);
433 if (dnn_ != nullptr) {
434 return dnn_.get();
435 }
436
437 dnn_.reset(implementation_->CreateDnn());
438 return dnn_.get();
439 }
440
AsBlas()441 blas::BlasSupport *StreamExecutor::AsBlas() {
442 absl::MutexLock lock(&mu_);
443 if (blas_ != nullptr) {
444 return blas_.get();
445 }
446
447 blas_.reset(implementation_->CreateBlas());
448 return blas_.get();
449 }
450
AsFft()451 fft::FftSupport *StreamExecutor::AsFft() {
452 absl::MutexLock lock(&mu_);
453 if (fft_ != nullptr) {
454 return fft_.get();
455 }
456
457 fft_.reset(implementation_->CreateFft());
458 return fft_.get();
459 }
460
AsRng()461 rng::RngSupport *StreamExecutor::AsRng() {
462 absl::MutexLock lock(&mu_);
463 if (rng_ != nullptr) {
464 return rng_.get();
465 }
466
467 rng_.reset(implementation_->CreateRng());
468 return rng_.get();
469 }
470
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)471 port::Status StreamExecutor::Launch(Stream *stream,
472 const ThreadDim &thread_dims,
473 const BlockDim &block_dims,
474 const KernelBase &kernel,
475 const KernelArgsArrayBase &args) {
476 SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
477 kernel, args);
478
479 return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
480 }
481
BlockHostUntilDone(Stream * stream)482 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
483 port::Status result;
484 SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
485
486 result = implementation_->BlockHostUntilDone(stream);
487 return result;
488 }
489
GetStatus(Stream * stream)490 port::Status StreamExecutor::GetStatus(Stream *stream) {
491 return implementation_->GetStatus(stream);
492 }
493
Allocate(uint64 size,int64_t memory_space)494 DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64_t memory_space) {
495 if (memory_limit_bytes_ > 0 &&
496 static_cast<int64>(mem_alloc_bytes_ + size) > memory_limit_bytes_) {
497 LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
498 << device_ordinal_
499 << " within provided limit. [used=" << mem_alloc_bytes_
500 << ", limit=" << memory_limit_bytes_ << "]";
501 return DeviceMemoryBase();
502 }
503 DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
504 VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
505 << ", memory_space=" << memory_space << ") returns " << buf.opaque()
506 << StackTraceIfVLOG10();
507 CreateAllocRecord(buf.opaque(), size);
508
509 return buf;
510 }
511
GetUntypedSymbol(const std::string & symbol_name,ModuleHandle module_handle)512 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
513 const std::string &symbol_name, ModuleHandle module_handle) {
514 // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
515 // be nullptr/0 for consistency with DeviceMemory semantics.
516 void *opaque = nullptr;
517 size_t bytes = 0;
518 if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
519 return DeviceMemoryBase(opaque, bytes);
520 }
521
522 if (static_cast<bool>(module_handle)) {
523 return port::Status(
524 port::error::NOT_FOUND,
525 absl::StrCat("Check if module containing symbol ", symbol_name,
526 " is loaded (module_handle = ",
527 reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
528 } else {
529 return port::Status(
530 port::error::NOT_FOUND,
531 absl::StrCat("Check if kernel using the symbol is loaded: ",
532 symbol_name));
533 }
534 }
535
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)536 bool StreamExecutor::GetSymbol(const std::string &symbol_name,
537 ModuleHandle module_handle, void **mem,
538 size_t *bytes) {
539 return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
540 }
541
UnifiedMemoryAllocate(uint64 bytes)542 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
543 void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
544 VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
545 << ") returns " << buffer << StackTraceIfVLOG10();
546 return buffer;
547 }
548
UnifiedMemoryDeallocate(void * location)549 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
550 VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
551 << location << ")" << StackTraceIfVLOG10();
552
553 return implementation_->UnifiedMemoryDeallocate(location);
554 }
555
HostMemoryAllocate(uint64 size)556 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
557 void *buffer = implementation_->HostMemoryAllocate(size);
558 VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
559 << ") returns " << buffer << StackTraceIfVLOG10();
560 return buffer;
561 }
562
HostMemoryDeallocate(void * location)563 void StreamExecutor::HostMemoryDeallocate(void *location) {
564 VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
565 << ")" << StackTraceIfVLOG10();
566
567 return implementation_->HostMemoryDeallocate(location);
568 }
569
HostMemoryRegister(void * location,uint64 size)570 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
571 VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
572 << ", size=" << size << ")" << StackTraceIfVLOG10();
573 if (location == nullptr || size == 0) {
574 LOG(WARNING) << "attempting to register null or zero-sized memory: "
575 << location << "; size " << size;
576 }
577 return implementation_->HostMemoryRegister(location, size);
578 }
579
HostMemoryUnregister(void * location)580 bool StreamExecutor::HostMemoryUnregister(void *location) {
581 VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
582 << ")" << StackTraceIfVLOG10();
583 return implementation_->HostMemoryUnregister(location);
584 }
585
SynchronizeAllActivity()586 bool StreamExecutor::SynchronizeAllActivity() {
587 VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
588 << StackTraceIfVLOG10();
589 bool ok = implementation_->SynchronizeAllActivity();
590
591 // This should all be quick and infallible work, so we can perform the
592 // synchronization even in the case of failure.
593 BlockOnThreadExecutor(background_threads_.get());
594
595 return ok;
596 }
597
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)598 port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
599 uint64 size) {
600 VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
601 << ", size=" << size << ")" << StackTraceIfVLOG10();
602
603 return implementation_->SynchronousMemZero(location, size);
604 }
605
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)606 port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location,
607 int value, uint64 size) {
608 VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
609 << ", value=" << value << ", size=" << size << ")"
610 << StackTraceIfVLOG10();
611
612 return implementation_->SynchronousMemSet(location, value, size);
613 }
614
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)615 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
616 const void *host_src, uint64 size) {
617 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
618 << device_dst->opaque() << ", host_src=" << host_src
619 << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
620
621 // Tracing overloaded methods is very difficult due to issues with type
622 // inference on template args. Since use of these overloaded methods is
623 // discouraged anyway, this isn't a huge deal.
624 port::Status status =
625 implementation_->SynchronousMemcpy(device_dst, host_src, size);
626 if (!status.ok()) {
627 LOG(ERROR) << "synchronous memcpy: " << status;
628 }
629 return status.ok();
630 }
631
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)632 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
633 const DeviceMemoryBase &device_src,
634 uint64 size) {
635 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
636 << ", device_src=" << device_src.opaque() << ", size=" << size
637 << ") D2H" << StackTraceIfVLOG10();
638
639 port::Status status =
640 implementation_->SynchronousMemcpy(host_dst, device_src, size);
641 if (!status.ok()) {
642 LOG(ERROR) << "synchronous memcpy: " << status;
643 }
644 return status.ok();
645 }
646
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)647 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
648 const DeviceMemoryBase &device_src,
649 uint64 size) {
650 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
651 << device_dst->opaque() << ", device_src=" << device_src.opaque()
652 << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
653
654 port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
655 device_dst, device_src, size);
656 if (!status.ok()) {
657 LOG(ERROR) << "synchronous memcpy: " << status;
658 }
659 return status.ok();
660 }
661
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64_t size,void * host_dst)662 port::Status StreamExecutor::SynchronousMemcpyD2H(
663 const DeviceMemoryBase &device_src, int64_t size, void *host_dst) {
664 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
665 << device_src.opaque() << ", size=" << size
666 << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
667
668 port::Status result;
669 SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
670 host_dst);
671
672 result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
673 if (!result.ok()) {
674 result = port::Status(
675 port::error::INTERNAL,
676 absl::StrFormat("failed to synchronously memcpy device-to-host: device "
677 "%p to host %p size %d: %s",
678 device_src.opaque(), host_dst, size,
679 result.ToString()));
680 }
681
682 return result;
683 }
684
SynchronousMemcpyH2D(const void * host_src,int64_t size,DeviceMemoryBase * device_dst)685 port::Status StreamExecutor::SynchronousMemcpyH2D(
686 const void *host_src, int64_t size, DeviceMemoryBase *device_dst) {
687 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
688 << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
689 << StackTraceIfVLOG10();
690
691 port::Status result;
692 SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
693 device_dst);
694
695 result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
696 if (!result.ok()) {
697 result = port::Status(
698 port::error::INTERNAL,
699 absl::StrFormat("failed to synchronously memcpy host-to-device: host "
700 "%p to device %p size %d: %s",
701 host_src, device_dst->opaque(), size,
702 result.ToString()));
703 }
704
705 return result;
706 }
707
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)708 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
709 const DeviceMemoryBase &device_src, uint64 size) {
710 return implementation_->Memcpy(stream, host_dst, device_src, size);
711 }
712
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)713 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
714 const void *host_src, uint64 size) {
715 return implementation_->Memcpy(stream, device_dst, host_src, size);
716 }
717
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)718 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
719 DeviceMemoryBase *device_dst,
720 const DeviceMemoryBase &device_src,
721 uint64 size) {
722 return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
723 size);
724 }
725
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)726 port::Status StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
727 uint64 size) {
728 return implementation_->MemZero(stream, location, size);
729 }
730
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)731 port::Status StreamExecutor::Memset32(Stream *stream,
732 DeviceMemoryBase *location,
733 uint32 pattern, uint64 size) {
734 CHECK_EQ(0, size % 4)
735 << "need 32-bit multiple size to fill with 32-bit pattern";
736 return implementation_->Memset32(stream, location, pattern, size);
737 }
738
HostCallback(Stream * stream,std::function<void ()> callback)739 bool StreamExecutor::HostCallback(Stream *stream,
740 std::function<void()> callback) {
741 return implementation_->HostCallback(stream, std::move(callback));
742 }
743
HostCallback(Stream * stream,std::function<port::Status ()> callback)744 bool StreamExecutor::HostCallback(Stream *stream,
745 std::function<port::Status()> callback) {
746 return implementation_->HostCallback(stream, std::move(callback));
747 }
748
AllocateEvent(Event * event)749 port::Status StreamExecutor::AllocateEvent(Event *event) {
750 return implementation_->AllocateEvent(event);
751 }
752
DeallocateEvent(Event * event)753 port::Status StreamExecutor::DeallocateEvent(Event *event) {
754 return implementation_->DeallocateEvent(event);
755 }
756
RecordEvent(Stream * stream,Event * event)757 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
758 return implementation_->RecordEvent(stream, event);
759 }
760
WaitForEvent(Stream * stream,Event * event)761 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
762 return implementation_->WaitForEvent(stream, event);
763 }
764
PollForEventStatus(Event * event)765 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
766 return implementation_->PollForEventStatus(event);
767 }
768
AllocateStream(Stream * stream)769 bool StreamExecutor::AllocateStream(Stream *stream) {
770 live_stream_count_.fetch_add(1, std::memory_order_relaxed);
771 if (!implementation_->AllocateStream(stream)) {
772 auto count = live_stream_count_.fetch_sub(1);
773 CHECK_GE(count, 0) << "live stream count should not dip below zero";
774 LOG(INFO) << "failed to allocate stream; live stream count: " << count;
775 return false;
776 }
777
778 return true;
779 }
780
DeallocateStream(Stream * stream)781 void StreamExecutor::DeallocateStream(Stream *stream) {
782 implementation_->DeallocateStream(stream);
783 CHECK_GE(live_stream_count_.fetch_sub(1), 0)
784 << "live stream count should not dip below zero";
785 }
786
CreateStreamDependency(Stream * dependent,Stream * other)787 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
788 return implementation_->CreateStreamDependency(dependent, other);
789 }
790
AllocateTimer(Timer * timer)791 bool StreamExecutor::AllocateTimer(Timer *timer) {
792 return implementation_->AllocateTimer(timer);
793 }
794
DeallocateTimer(Timer * timer)795 void StreamExecutor::DeallocateTimer(Timer *timer) {
796 return implementation_->DeallocateTimer(timer);
797 }
798
StartTimer(Stream * stream,Timer * timer)799 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
800 return implementation_->StartTimer(stream, timer);
801 }
802
StopTimer(Stream * stream,Timer * timer)803 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
804 return implementation_->StopTimer(stream, timer);
805 }
806
CreateDeviceDescription() const807 std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
808 const {
809 auto desc_status = implementation_->CreateDeviceDescription();
810 return desc_status.ConsumeValueOrDie();
811 }
812
DeviceMemoryUsage(int64 * free,int64 * total) const813 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
814 return implementation_->DeviceMemoryUsage(free, total);
815 }
816
EnqueueOnBackgroundThread(std::function<void ()> task)817 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
818 background_threads_->Schedule(std::move(task));
819 }
820
CreateAllocRecord(void * opaque,uint64 bytes)821 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
822 if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
823 absl::MutexLock lock(&mu_);
824 mem_allocs_[opaque] = AllocRecord{bytes, ""};
825 mem_alloc_bytes_ += bytes;
826 }
827 }
828
EraseAllocRecord(void * opaque)829 void StreamExecutor::EraseAllocRecord(void *opaque) {
830 if (FLAGS_check_device_leaks && opaque != nullptr) {
831 absl::MutexLock lock(&mu_);
832 if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
833 LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
834 } else {
835 mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
836 mem_allocs_.erase(opaque);
837 }
838 }
839 }
840
EnableTracing(bool enabled)841 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
842
RegisterTraceListener(TraceListener * listener)843 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
844 {
845 absl::MutexLock lock(&mu_);
846 if (listeners_.find(listener) != listeners_.end()) {
847 LOG(INFO) << "Attempt to register already-registered listener, "
848 << listener;
849 } else {
850 listeners_.insert(listener);
851 }
852 }
853
854 implementation_->RegisterTraceListener(listener);
855 }
856
UnregisterTraceListener(TraceListener * listener)857 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
858 {
859 absl::MutexLock lock(&mu_);
860 if (listeners_.find(listener) == listeners_.end()) {
861 LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
862 return false;
863 }
864 listeners_.erase(listener);
865 }
866
867 implementation_->UnregisterTraceListener(listener);
868 return true;
869 }
870
GetAllocatorStats()871 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
872 return implementation_->GetAllocatorStats();
873 }
874
ClearAllocatorStats()875 bool StreamExecutor::ClearAllocatorStats() {
876 return implementation_->ClearAllocatorStats();
877 }
878
879 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)880 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&...args) {
881 if (tracing_enabled_) {
882 {
883 // instance tracers held in a block to limit the lock lifetime.
884 absl::ReaderMutexLock lock(&mu_);
885 for (TraceListener *listener : listeners_) {
886 (listener->*trace_call)(std::forward<ArgsT>(args)...);
887 }
888 }
889 }
890 }
891
implementation()892 internal::StreamExecutorInterface *StreamExecutor::implementation() {
893 return implementation_->GetUnderlyingExecutor();
894 }
895
StreamExecutorMemoryAllocator(StreamExecutor * executor)896 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
897 StreamExecutor *executor)
898 : DeviceMemoryAllocator(executor->platform()) {
899 stream_executors_ = {executor};
900 }
901
StreamExecutorMemoryAllocator(const Platform * platform,absl::Span<StreamExecutor * const> stream_executors)902 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
903 const Platform *platform,
904 absl::Span<StreamExecutor *const> stream_executors)
905 : DeviceMemoryAllocator(platform),
906 stream_executors_(stream_executors.begin(), stream_executors.end()) {}
907
Allocate(int device_ordinal,uint64 size,bool retry_on_failure,int64_t memory_space)908 port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
909 int device_ordinal, uint64 size, bool retry_on_failure,
910 int64_t memory_space) {
911 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
912 GetStreamExecutor(device_ordinal));
913 DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
914 if (size > 0 && result == nullptr) {
915 return tensorflow::errors::ResourceExhausted(absl::StrFormat(
916 "Failed to allocate request for %s (%uB) on device ordinal %d",
917 tensorflow::strings::HumanReadableNumBytes(size), size,
918 device_ordinal));
919 }
920 VLOG(3) << absl::StreamFormat(
921 "Allocated %s (%uB) on device ordinal %d: %p",
922 tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
923 result.opaque());
924 return OwningDeviceMemory(result, device_ordinal, this);
925 }
926
Deallocate(int device_ordinal,DeviceMemoryBase mem)927 port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
928 DeviceMemoryBase mem) {
929 if (!mem.is_null()) {
930 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
931 GetStreamExecutor(device_ordinal));
932 VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
933 mem.opaque(), device_ordinal);
934 executor->Deallocate(&mem);
935 }
936 return port::Status::OK();
937 }
938
939 port::StatusOr<StreamExecutor *>
GetStreamExecutor(int device_ordinal) const940 StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
941 if (device_ordinal < 0) {
942 return tensorflow::errors::InvalidArgument(absl::StrFormat(
943 "device ordinal value (%d) must be non-negative", device_ordinal));
944 }
945 for (StreamExecutor *se : stream_executors_) {
946 if (se->device_ordinal() == device_ordinal) {
947 return se;
948 }
949 }
950 return tensorflow::errors::NotFound(
951 absl::StrFormat("Device %s:%d present but not supported",
952 platform()->Name(), device_ordinal));
953 }
954
AllowsAsynchronousDeallocation() const955 bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
956 return false;
957 }
958
GetStream(int device_ordinal)959 port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
960 int device_ordinal) {
961 CHECK(!AllowsAsynchronousDeallocation())
962 << "The logic below only works for synchronous allocators";
963 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
964 GetStreamExecutor(device_ordinal));
965 Stream *out = [&] {
966 absl::MutexLock lock(&mutex_);
967 if (!streams_.count(device_ordinal)) {
968 auto p = streams_.emplace(std::piecewise_construct,
969 std::forward_as_tuple(device_ordinal),
970 std::forward_as_tuple(executor));
971 p.first->second.Init();
972 return &p.first->second;
973 }
974 return &streams_.at(device_ordinal);
975 }();
976 return out;
977 }
978
979 } // namespace stream_executor
980