1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation of HostExecutor class [of those methods not defined in the
17 // class declaration].
18 #include "tensorflow/stream_executor/host/host_gpu_executor.h"
19
20 #include <stdint.h>
21 #include <string.h>
22
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/synchronization/notification.h"
26 #include "tensorflow/core/platform/mem.h"
27 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
28 #include "tensorflow/stream_executor/host/host_platform_id.h"
29 #include "tensorflow/stream_executor/host/host_stream.h"
30 #include "tensorflow/stream_executor/host/host_timer.h"
31 #include "tensorflow/stream_executor/lib/statusor.h"
32 #include "tensorflow/stream_executor/plugin_registry.h"
33 #include "tensorflow/stream_executor/stream_executor_internal.h"
34
35 namespace stream_executor {
36 namespace host {
37
AsHostStream(Stream * stream)38 HostStream *AsHostStream(Stream *stream) {
39 DCHECK(stream != nullptr);
40 return dynamic_cast<HostStream *>(stream->implementation());
41 }
42
HostExecutor(const PluginConfig & plugin_config)43 HostExecutor::HostExecutor(const PluginConfig &plugin_config)
44 : plugin_config_(plugin_config) {}
45
~HostExecutor()46 HostExecutor::~HostExecutor() {}
47
Init(int device_ordinal,DeviceOptions device_options)48 port::Status HostExecutor::Init(int device_ordinal,
49 DeviceOptions device_options) {
50 auto it =
51 device_options.non_portable_tags.find("host_thread_stack_size_in_bytes");
52 if (it != device_options.non_portable_tags.end()) {
53 if (!absl::SimpleAtoi(it->second, &thread_stack_size_in_bytes_)) {
54 return port::InvalidArgumentError(absl::StrCat(
55 "Unable to parse host_thread_stack_size_in_bytes as an integer: ",
56 it->second));
57 }
58 }
59 return port::Status::OK();
60 }
61
DeviceMemoryUsage(int64 * free,int64 * total) const62 bool HostExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
63 tensorflow::port::MemoryInfo mem_info = tensorflow::port::GetMemoryInfo();
64 *free = (mem_info.free != INT64_MAX) ? mem_info.free : -1;
65 *total = (mem_info.total != INT64_MAX) ? mem_info.total : -1;
66 return true;
67 }
68
Allocate(uint64 size,int64 memory_space)69 DeviceMemoryBase HostExecutor::Allocate(uint64 size, int64 memory_space) {
70 CHECK_EQ(memory_space, 0);
71 // Use a minimum alignment of 64 bytes to be friendly to AVX512 code.
72 // This should probably be kept in sync with
73 // tensorflow::Allocator::kAllocatorAlignment.
74 return DeviceMemoryBase(
75 tensorflow::port::AlignedMalloc(size, /*minimum_alignment=*/64), size);
76 }
77
GetSubBuffer(DeviceMemoryBase * parent,uint64 offset_bytes,uint64 size_bytes)78 void *HostExecutor::GetSubBuffer(DeviceMemoryBase *parent, uint64 offset_bytes,
79 uint64 size_bytes) {
80 return reinterpret_cast<char *>(parent->opaque()) + offset_bytes;
81 }
82
Deallocate(DeviceMemoryBase * mem)83 void HostExecutor::Deallocate(DeviceMemoryBase *mem) {
84 tensorflow::port::AlignedFree(mem->opaque());
85 }
86
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)87 port::Status HostExecutor::SynchronousMemZero(DeviceMemoryBase *location,
88 uint64 size) {
89 memset(location->opaque(), 0, size);
90 return port::Status::OK();
91 }
92
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)93 port::Status HostExecutor::SynchronousMemSet(DeviceMemoryBase *location,
94 int value, uint64 size) {
95 memset(location->opaque(), value, size);
96 return port::Status::OK();
97 }
98
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)99 bool HostExecutor::Memcpy(Stream *stream, void *host_dst,
100 const DeviceMemoryBase &gpu_src, uint64 size) {
101 // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated
102 // with the HostExecutor.
103 void *src_mem = const_cast<void *>(gpu_src.opaque());
104 AsHostStream(stream)->EnqueueTask(
105 [host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); });
106 return true;
107 }
108
Memcpy(Stream * stream,DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)109 bool HostExecutor::Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
110 const void *host_src, uint64 size) {
111 void *dst_mem = gpu_dst->opaque();
112 // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated
113 // with the HostExecutor.
114 AsHostStream(stream)->EnqueueTask(
115 [dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); });
116 return true;
117 }
118
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)119 bool HostExecutor::MemcpyDeviceToDevice(Stream *stream,
120 DeviceMemoryBase *gpu_dst,
121 const DeviceMemoryBase &gpu_src,
122 uint64 size) {
123 void *dst_mem = gpu_dst->opaque();
124 void *src_mem = const_cast<void *>(gpu_src.opaque());
125 // Enqueue this [asynchronous] "device-to-device" (i.e., host-to-host, given
126 // the nature of the HostExecutor) memcpy on the stream (HostStream)
127 // associated with the HostExecutor.
128 AsHostStream(stream)->EnqueueTask(
129 [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); });
130 return true;
131 }
132
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)133 port::Status HostExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
134 uint64 size) {
135 void *gpu_mem = location->opaque();
136 // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
137 // with the HostExecutor.
138 AsHostStream(stream)->EnqueueTask(
139 [gpu_mem, size]() { memset(gpu_mem, 0, size); });
140 return port::Status::OK();
141 }
142
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)143 port::Status HostExecutor::Memset(Stream *stream, DeviceMemoryBase *location,
144 uint8 pattern, uint64 size) {
145 void *gpu_mem = location->opaque();
146 // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
147 // with the HostExecutor.
148 AsHostStream(stream)->EnqueueTask(
149 [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); });
150 return port::Status::OK();
151 }
152
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)153 port::Status HostExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
154 uint32 pattern, uint64 size) {
155 void *gpu_mem = location->opaque();
156 // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
157 // with the HostExecutor.
158 AsHostStream(stream)->EnqueueTask(
159 [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); });
160 return port::Status::OK();
161 }
162
SynchronousMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)163 port::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
164 const void *host_src,
165 uint64 size) {
166 memcpy(gpu_dst->opaque(), host_src, size);
167 return port::Status::OK();
168 }
169
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)170 port::Status HostExecutor::SynchronousMemcpy(void *host_dst,
171 const DeviceMemoryBase &gpu_src,
172 uint64 size) {
173 memcpy(host_dst, gpu_src.opaque(), size);
174 return port::Status::OK();
175 }
176
SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)177 port::Status HostExecutor::SynchronousMemcpyDeviceToDevice(
178 DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64 size) {
179 memcpy(gpu_dst->opaque(), gpu_src.opaque(), size);
180 return port::Status::OK();
181 }
182
HostCallback(Stream * stream,std::function<port::Status ()> callback)183 bool HostExecutor::HostCallback(Stream *stream,
184 std::function<port::Status()> callback) {
185 AsHostStream(stream)->EnqueueTask([callback]() {
186 port::Status s = callback();
187 if (!s.ok()) {
188 LOG(WARNING) << "Host callback failed: " << s;
189 }
190 });
191 return true;
192 }
193
AllocateStream(Stream * stream)194 bool HostExecutor::AllocateStream(Stream *stream) { return true; }
195
DeallocateStream(Stream * stream)196 void HostExecutor::DeallocateStream(Stream *stream) {}
197
CreateStreamDependency(Stream * dependent,Stream * other)198 bool HostExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
199 auto event = std::make_shared<absl::Notification>();
200 AsHostStream(other)->EnqueueTask([event]() { event->Notify(); });
201 AsHostStream(dependent)->EnqueueTask(
202 [event]() { event->WaitForNotification(); });
203 return true;
204 }
205
206 class HostEvent : public internal::EventInterface {
207 public:
HostEvent()208 HostEvent() : notification_(std::make_shared<absl::Notification>()) {}
209
notification()210 std::shared_ptr<absl::Notification> ¬ification() { return notification_; }
211
212 private:
213 // We use a std::shared_ptr here because the client may delete the HostEvent
214 // object while there are still RecordEvent and WaitForEvent callbacks pending
215 // on a stream.
216 std::shared_ptr<absl::Notification> notification_;
217 };
218
219 std::unique_ptr<internal::EventInterface>
CreateEventImplementation()220 HostExecutor::CreateEventImplementation() {
221 return std::unique_ptr<internal::EventInterface>(new HostEvent());
222 }
223
AsHostEvent(Event * event)224 static HostEvent *AsHostEvent(Event *event) {
225 DCHECK(event != nullptr);
226 return static_cast<HostEvent *>(event->implementation());
227 }
228
AllocateEvent(Event *)229 port::Status HostExecutor::AllocateEvent(Event * /*event*/) {
230 return port::Status::OK();
231 }
232
DeallocateEvent(Event *)233 port::Status HostExecutor::DeallocateEvent(Event * /*event*/) {
234 return port::Status::OK();
235 }
236
RecordEvent(Stream * stream,Event * event)237 port::Status HostExecutor::RecordEvent(Stream *stream, Event *event) {
238 std::shared_ptr<absl::Notification> notification =
239 AsHostEvent(event)->notification();
240 AsHostStream(stream)->EnqueueTask([notification]() {
241 CHECK(!notification->HasBeenNotified());
242 notification->Notify();
243 });
244 return port::Status::OK();
245 }
246
WaitForEvent(Stream * stream,Event * event)247 port::Status HostExecutor::WaitForEvent(Stream *stream, Event *event) {
248 std::shared_ptr<absl::Notification> notification =
249 AsHostEvent(event)->notification();
250 AsHostStream(stream)->EnqueueTask(
251 [notification]() { notification->WaitForNotification(); });
252 return port::Status::OK();
253 }
254
PollForEventStatus(Event * event)255 Event::Status HostExecutor::PollForEventStatus(Event *event) {
256 absl::Notification ¬ification = *AsHostEvent(event)->notification();
257 return notification.HasBeenNotified() ? Event::Status::kComplete
258 : Event::Status::kPending;
259 }
260
StartTimer(Stream * stream,Timer * timer)261 bool HostExecutor::StartTimer(Stream *stream, Timer *timer) {
262 dynamic_cast<HostTimer *>(timer->implementation())->Start(stream);
263 return true;
264 }
265
StopTimer(Stream * stream,Timer * timer)266 bool HostExecutor::StopTimer(Stream *stream, Timer *timer) {
267 dynamic_cast<HostTimer *>(timer->implementation())->Stop(stream);
268 return true;
269 }
270
BlockHostUntilDone(Stream * stream)271 port::Status HostExecutor::BlockHostUntilDone(Stream *stream) {
272 AsHostStream(stream)->BlockUntilDone();
273 return port::Status::OK();
274 }
275
276 port::StatusOr<std::unique_ptr<DeviceDescription>>
CreateDeviceDescription(int device_ordinal)277 HostExecutor::CreateDeviceDescription(int device_ordinal) {
278 internal::DeviceDescriptionBuilder builder;
279
280 builder.set_device_address_bits(64);
281
282 // TODO(rspringer): How to report a value that's based in reality but that
283 // doesn't result in thrashing or other badness? 4GiB chosen arbitrarily.
284 builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024);
285
286 float cycle_counter_frequency = static_cast<float>(
287 tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency());
288 builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9);
289
290 builder.set_name("Host");
291 builder.set_platform_version("Default Version");
292
293 return builder.Build();
294 }
295
SupportsBlas() const296 bool HostExecutor::SupportsBlas() const {
297 return PluginRegistry::Instance()
298 ->GetFactory<PluginRegistry::BlasFactory>(kHostPlatformId,
299 plugin_config_.blas())
300 .ok();
301 }
302
CreateBlas()303 blas::BlasSupport *HostExecutor::CreateBlas() {
304 PluginRegistry *registry = PluginRegistry::Instance();
305 port::StatusOr<PluginRegistry::BlasFactory> status =
306 registry->GetFactory<PluginRegistry::BlasFactory>(kHostPlatformId,
307 plugin_config_.blas());
308 if (!status.ok()) {
309 LOG(ERROR) << "Unable to retrieve BLAS factory: "
310 << status.status().error_message();
311 return nullptr;
312 }
313
314 return status.ValueOrDie()(this);
315 }
316
SupportsFft() const317 bool HostExecutor::SupportsFft() const {
318 return PluginRegistry::Instance()
319 ->GetFactory<PluginRegistry::FftFactory>(kHostPlatformId,
320 plugin_config_.fft())
321 .ok();
322 }
323
CreateFft()324 fft::FftSupport *HostExecutor::CreateFft() {
325 PluginRegistry *registry = PluginRegistry::Instance();
326 port::StatusOr<PluginRegistry::FftFactory> status =
327 registry->GetFactory<PluginRegistry::FftFactory>(kHostPlatformId,
328 plugin_config_.fft());
329 if (!status.ok()) {
330 LOG(ERROR) << "Unable to retrieve FFT factory: "
331 << status.status().error_message();
332 return nullptr;
333 }
334
335 return status.ValueOrDie()(this);
336 }
337
SupportsRng() const338 bool HostExecutor::SupportsRng() const {
339 return PluginRegistry::Instance()
340 ->GetFactory<PluginRegistry::RngFactory>(kHostPlatformId,
341 plugin_config_.rng())
342 .ok();
343 }
344
CreateRng()345 rng::RngSupport *HostExecutor::CreateRng() {
346 PluginRegistry *registry = PluginRegistry::Instance();
347 port::StatusOr<PluginRegistry::RngFactory> status =
348 registry->GetFactory<PluginRegistry::RngFactory>(kHostPlatformId,
349 plugin_config_.rng());
350 if (!status.ok()) {
351 LOG(ERROR) << "Unable to retrieve RNG factory: "
352 << status.status().error_message();
353 return nullptr;
354 }
355
356 return status.ValueOrDie()(this);
357 }
358
359 std::unique_ptr<internal::StreamInterface>
GetStreamImplementation()360 HostExecutor::GetStreamImplementation() {
361 return std::unique_ptr<internal::StreamInterface>(
362 new HostStream(thread_stack_size_in_bytes_));
363 }
364
365 } // namespace host
366 } // namespace stream_executor
367