• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation of HostExecutor class [of those methods not defined in the
17 // class declaration].
18 #include "tensorflow/stream_executor/host/host_gpu_executor.h"
19 
20 #include <stdint.h>
21 #include <string.h>
22 
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/synchronization/notification.h"
26 #include "tensorflow/core/platform/mem.h"
27 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
28 #include "tensorflow/stream_executor/host/host_platform_id.h"
29 #include "tensorflow/stream_executor/host/host_stream.h"
30 #include "tensorflow/stream_executor/host/host_timer.h"
31 #include "tensorflow/stream_executor/lib/statusor.h"
32 #include "tensorflow/stream_executor/plugin_registry.h"
33 #include "tensorflow/stream_executor/stream_executor_internal.h"
34 
35 namespace stream_executor {
36 namespace host {
37 
AsHostStream(Stream * stream)38 HostStream *AsHostStream(Stream *stream) {
39   DCHECK(stream != nullptr);
40   return dynamic_cast<HostStream *>(stream->implementation());
41 }
42 
HostExecutor(const PluginConfig & plugin_config)43 HostExecutor::HostExecutor(const PluginConfig &plugin_config)
44     : plugin_config_(plugin_config) {}
45 
~HostExecutor()46 HostExecutor::~HostExecutor() {}
47 
Init(int device_ordinal,DeviceOptions device_options)48 port::Status HostExecutor::Init(int device_ordinal,
49                                 DeviceOptions device_options) {
50   auto it =
51       device_options.non_portable_tags.find("host_thread_stack_size_in_bytes");
52   if (it != device_options.non_portable_tags.end()) {
53     if (!absl::SimpleAtoi(it->second, &thread_stack_size_in_bytes_)) {
54       return port::InvalidArgumentError(absl::StrCat(
55           "Unable to parse host_thread_stack_size_in_bytes as an integer: ",
56           it->second));
57     }
58   }
59   return port::Status::OK();
60 }
61 
DeviceMemoryUsage(int64 * free,int64 * total) const62 bool HostExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
63   tensorflow::port::MemoryInfo mem_info = tensorflow::port::GetMemoryInfo();
64   *free = (mem_info.free != INT64_MAX) ? mem_info.free : -1;
65   *total = (mem_info.total != INT64_MAX) ? mem_info.total : -1;
66   return true;
67 }
68 
Allocate(uint64 size,int64 memory_space)69 DeviceMemoryBase HostExecutor::Allocate(uint64 size, int64 memory_space) {
70   CHECK_EQ(memory_space, 0);
71   // Use a minimum alignment of 64 bytes to be friendly to AVX512 code.
72   // This should probably be kept in sync with
73   // tensorflow::Allocator::kAllocatorAlignment.
74   return DeviceMemoryBase(
75       tensorflow::port::AlignedMalloc(size, /*minimum_alignment=*/64), size);
76 }
77 
GetSubBuffer(DeviceMemoryBase * parent,uint64 offset_bytes,uint64 size_bytes)78 void *HostExecutor::GetSubBuffer(DeviceMemoryBase *parent, uint64 offset_bytes,
79                                  uint64 size_bytes) {
80   return reinterpret_cast<char *>(parent->opaque()) + offset_bytes;
81 }
82 
Deallocate(DeviceMemoryBase * mem)83 void HostExecutor::Deallocate(DeviceMemoryBase *mem) {
84   tensorflow::port::AlignedFree(mem->opaque());
85 }
86 
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)87 port::Status HostExecutor::SynchronousMemZero(DeviceMemoryBase *location,
88                                               uint64 size) {
89   memset(location->opaque(), 0, size);
90   return port::Status::OK();
91 }
92 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)93 port::Status HostExecutor::SynchronousMemSet(DeviceMemoryBase *location,
94                                              int value, uint64 size) {
95   memset(location->opaque(), value, size);
96   return port::Status::OK();
97 }
98 
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)99 bool HostExecutor::Memcpy(Stream *stream, void *host_dst,
100                           const DeviceMemoryBase &gpu_src, uint64 size) {
101   // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated
102   // with the HostExecutor.
103   void *src_mem = const_cast<void *>(gpu_src.opaque());
104   AsHostStream(stream)->EnqueueTask(
105       [host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); });
106   return true;
107 }
108 
Memcpy(Stream * stream,DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)109 bool HostExecutor::Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
110                           const void *host_src, uint64 size) {
111   void *dst_mem = gpu_dst->opaque();
112   // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated
113   // with the HostExecutor.
114   AsHostStream(stream)->EnqueueTask(
115       [dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); });
116   return true;
117 }
118 
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)119 bool HostExecutor::MemcpyDeviceToDevice(Stream *stream,
120                                         DeviceMemoryBase *gpu_dst,
121                                         const DeviceMemoryBase &gpu_src,
122                                         uint64 size) {
123   void *dst_mem = gpu_dst->opaque();
124   void *src_mem = const_cast<void *>(gpu_src.opaque());
125   // Enqueue this [asynchronous] "device-to-device" (i.e., host-to-host, given
126   // the nature of the HostExecutor) memcpy  on the stream (HostStream)
127   // associated with the HostExecutor.
128   AsHostStream(stream)->EnqueueTask(
129       [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); });
130   return true;
131 }
132 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)133 port::Status HostExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
134                                    uint64 size) {
135   void *gpu_mem = location->opaque();
136   // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
137   // with the HostExecutor.
138   AsHostStream(stream)->EnqueueTask(
139       [gpu_mem, size]() { memset(gpu_mem, 0, size); });
140   return port::Status::OK();
141 }
142 
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)143 port::Status HostExecutor::Memset(Stream *stream, DeviceMemoryBase *location,
144                                   uint8 pattern, uint64 size) {
145   void *gpu_mem = location->opaque();
146   // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
147   // with the HostExecutor.
148   AsHostStream(stream)->EnqueueTask(
149       [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); });
150   return port::Status::OK();
151 }
152 
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)153 port::Status HostExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
154                                     uint32 pattern, uint64 size) {
155   void *gpu_mem = location->opaque();
156   // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
157   // with the HostExecutor.
158   AsHostStream(stream)->EnqueueTask(
159       [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); });
160   return port::Status::OK();
161 }
162 
SynchronousMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)163 port::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
164                                              const void *host_src,
165                                              uint64 size) {
166   memcpy(gpu_dst->opaque(), host_src, size);
167   return port::Status::OK();
168 }
169 
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)170 port::Status HostExecutor::SynchronousMemcpy(void *host_dst,
171                                              const DeviceMemoryBase &gpu_src,
172                                              uint64 size) {
173   memcpy(host_dst, gpu_src.opaque(), size);
174   return port::Status::OK();
175 }
176 
SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)177 port::Status HostExecutor::SynchronousMemcpyDeviceToDevice(
178     DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64 size) {
179   memcpy(gpu_dst->opaque(), gpu_src.opaque(), size);
180   return port::Status::OK();
181 }
182 
HostCallback(Stream * stream,std::function<port::Status ()> callback)183 bool HostExecutor::HostCallback(Stream *stream,
184                                 std::function<port::Status()> callback) {
185   AsHostStream(stream)->EnqueueTask([callback]() {
186     port::Status s = callback();
187     if (!s.ok()) {
188       LOG(WARNING) << "Host callback failed: " << s;
189     }
190   });
191   return true;
192 }
193 
AllocateStream(Stream * stream)194 bool HostExecutor::AllocateStream(Stream *stream) { return true; }
195 
DeallocateStream(Stream * stream)196 void HostExecutor::DeallocateStream(Stream *stream) {}
197 
CreateStreamDependency(Stream * dependent,Stream * other)198 bool HostExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
199   auto event = std::make_shared<absl::Notification>();
200   AsHostStream(other)->EnqueueTask([event]() { event->Notify(); });
201   AsHostStream(dependent)->EnqueueTask(
202       [event]() { event->WaitForNotification(); });
203   return true;
204 }
205 
206 class HostEvent : public internal::EventInterface {
207  public:
HostEvent()208   HostEvent() : notification_(std::make_shared<absl::Notification>()) {}
209 
notification()210   std::shared_ptr<absl::Notification> &notification() { return notification_; }
211 
212  private:
213   // We use a std::shared_ptr here because the client may delete the HostEvent
214   // object while there are still RecordEvent and WaitForEvent callbacks pending
215   // on a stream.
216   std::shared_ptr<absl::Notification> notification_;
217 };
218 
219 std::unique_ptr<internal::EventInterface>
CreateEventImplementation()220 HostExecutor::CreateEventImplementation() {
221   return std::unique_ptr<internal::EventInterface>(new HostEvent());
222 }
223 
AsHostEvent(Event * event)224 static HostEvent *AsHostEvent(Event *event) {
225   DCHECK(event != nullptr);
226   return static_cast<HostEvent *>(event->implementation());
227 }
228 
AllocateEvent(Event *)229 port::Status HostExecutor::AllocateEvent(Event * /*event*/) {
230   return port::Status::OK();
231 }
232 
DeallocateEvent(Event *)233 port::Status HostExecutor::DeallocateEvent(Event * /*event*/) {
234   return port::Status::OK();
235 }
236 
RecordEvent(Stream * stream,Event * event)237 port::Status HostExecutor::RecordEvent(Stream *stream, Event *event) {
238   std::shared_ptr<absl::Notification> notification =
239       AsHostEvent(event)->notification();
240   AsHostStream(stream)->EnqueueTask([notification]() {
241     CHECK(!notification->HasBeenNotified());
242     notification->Notify();
243   });
244   return port::Status::OK();
245 }
246 
WaitForEvent(Stream * stream,Event * event)247 port::Status HostExecutor::WaitForEvent(Stream *stream, Event *event) {
248   std::shared_ptr<absl::Notification> notification =
249       AsHostEvent(event)->notification();
250   AsHostStream(stream)->EnqueueTask(
251       [notification]() { notification->WaitForNotification(); });
252   return port::Status::OK();
253 }
254 
PollForEventStatus(Event * event)255 Event::Status HostExecutor::PollForEventStatus(Event *event) {
256   absl::Notification &notification = *AsHostEvent(event)->notification();
257   return notification.HasBeenNotified() ? Event::Status::kComplete
258                                         : Event::Status::kPending;
259 }
260 
StartTimer(Stream * stream,Timer * timer)261 bool HostExecutor::StartTimer(Stream *stream, Timer *timer) {
262   dynamic_cast<HostTimer *>(timer->implementation())->Start(stream);
263   return true;
264 }
265 
StopTimer(Stream * stream,Timer * timer)266 bool HostExecutor::StopTimer(Stream *stream, Timer *timer) {
267   dynamic_cast<HostTimer *>(timer->implementation())->Stop(stream);
268   return true;
269 }
270 
BlockHostUntilDone(Stream * stream)271 port::Status HostExecutor::BlockHostUntilDone(Stream *stream) {
272   AsHostStream(stream)->BlockUntilDone();
273   return port::Status::OK();
274 }
275 
276 port::StatusOr<std::unique_ptr<DeviceDescription>>
CreateDeviceDescription(int device_ordinal)277 HostExecutor::CreateDeviceDescription(int device_ordinal) {
278   internal::DeviceDescriptionBuilder builder;
279 
280   builder.set_device_address_bits(64);
281 
282   // TODO(rspringer): How to report a value that's based in reality but that
283   // doesn't result in thrashing or other badness? 4GiB chosen arbitrarily.
284   builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024);
285 
286   float cycle_counter_frequency = static_cast<float>(
287       tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency());
288   builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9);
289 
290   builder.set_name("Host");
291   builder.set_platform_version("Default Version");
292 
293   return builder.Build();
294 }
295 
SupportsBlas() const296 bool HostExecutor::SupportsBlas() const {
297   return PluginRegistry::Instance()
298       ->GetFactory<PluginRegistry::BlasFactory>(kHostPlatformId,
299                                                 plugin_config_.blas())
300       .ok();
301 }
302 
CreateBlas()303 blas::BlasSupport *HostExecutor::CreateBlas() {
304   PluginRegistry *registry = PluginRegistry::Instance();
305   port::StatusOr<PluginRegistry::BlasFactory> status =
306       registry->GetFactory<PluginRegistry::BlasFactory>(kHostPlatformId,
307                                                         plugin_config_.blas());
308   if (!status.ok()) {
309     LOG(ERROR) << "Unable to retrieve BLAS factory: "
310                << status.status().error_message();
311     return nullptr;
312   }
313 
314   return status.ValueOrDie()(this);
315 }
316 
SupportsFft() const317 bool HostExecutor::SupportsFft() const {
318   return PluginRegistry::Instance()
319       ->GetFactory<PluginRegistry::FftFactory>(kHostPlatformId,
320                                                plugin_config_.fft())
321       .ok();
322 }
323 
CreateFft()324 fft::FftSupport *HostExecutor::CreateFft() {
325   PluginRegistry *registry = PluginRegistry::Instance();
326   port::StatusOr<PluginRegistry::FftFactory> status =
327       registry->GetFactory<PluginRegistry::FftFactory>(kHostPlatformId,
328                                                        plugin_config_.fft());
329   if (!status.ok()) {
330     LOG(ERROR) << "Unable to retrieve FFT factory: "
331                << status.status().error_message();
332     return nullptr;
333   }
334 
335   return status.ValueOrDie()(this);
336 }
337 
SupportsRng() const338 bool HostExecutor::SupportsRng() const {
339   return PluginRegistry::Instance()
340       ->GetFactory<PluginRegistry::RngFactory>(kHostPlatformId,
341                                                plugin_config_.rng())
342       .ok();
343 }
344 
CreateRng()345 rng::RngSupport *HostExecutor::CreateRng() {
346   PluginRegistry *registry = PluginRegistry::Instance();
347   port::StatusOr<PluginRegistry::RngFactory> status =
348       registry->GetFactory<PluginRegistry::RngFactory>(kHostPlatformId,
349                                                        plugin_config_.rng());
350   if (!status.ok()) {
351     LOG(ERROR) << "Unable to retrieve RNG factory: "
352                << status.status().error_message();
353     return nullptr;
354   }
355 
356   return status.ValueOrDie()(this);
357 }
358 
359 std::unique_ptr<internal::StreamInterface>
GetStreamImplementation()360 HostExecutor::GetStreamImplementation() {
361   return std::unique_ptr<internal::StreamInterface>(
362       new HostStream(thread_stack_size_in_bytes_));
363 }
364 
365 }  // namespace host
366 }  // namespace stream_executor
367