1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file extends/implements core stream executor base classes in terms of
16 // the C API defined in stream_executor.h. A class "CSomething" represents a
17 // "Something" that can be manipulated via calls in the C interface and a C
18 // struct called "SP_Something".
19 //
20 // This file also contains stream_executor::Platform registration for pluggable
21 // device.
22 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
23
24 #include <string>
25
26 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/regexp.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/strcat.h"
33 #include "tensorflow/core/platform/stringpiece.h"
34 #include "tensorflow/stream_executor/executor_cache.h"
35 #include "tensorflow/stream_executor/multi_platform_manager.h"
36 #include "tensorflow/stream_executor/platform.h"
37 #include "tensorflow/stream_executor/stream.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
40 #include "tensorflow/stream_executor/timer.h"
41
42 using tensorflow::StatusFromTF_Status;
43
44 namespace stream_executor {
45 using tensorflow::StringPiece;
46 using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
47
48 namespace {
49
50 #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
51 do { \
52 if (STRUCT_OBJ.struct_size == 0) { \
53 return port::FailedPreconditionError( \
54 "struct_size field in " #STRUCT_NAME \
55 " must be set to " #SIZE_VALUE_NAME "."); \
56 } \
57 } while (0)
58
59 #define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \
60 do { \
61 if (STRUCT_OBJ.NAME == 0) { \
62 return port::FailedPreconditionError( \
63 "'" #NAME "' field in " #STRUCT_NAME " must be set."); \
64 } \
65 } while (0)
66
ValidateDeviceType(StringPiece type)67 port::Status ValidateDeviceType(StringPiece type) {
68 // Validate device type. Device type must start with a capital letter and
69 // consist of capital letters and underscores. Reasoning behind this decision:
70 // * At the minimum we want to disallow '/' and ':' since
71 // these characters are used in device spec, for e.g.
72 // /job:foo/replica:12/device:GPU:1.
73 // * Underscores seem useful, for e.g. XLA_GPU uses underscores.
74 // * Allowing lowercase might get confusing. For example, say someone
75 // registers a new type called "Gpu". It might be confusing for users that
76 // "Gpu" is not the same device type as "GPU".
77 // Note that lowercase "cpu" and "gpu" are currently supported only for
78 // legacy reasons:
79 // https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd
80 static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"};
81 bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx);
82 if (!matches) {
83 return port::FailedPreconditionError(
84 tensorflow::strings::StrCat("Device name/type '", type, "' must match ",
85 kTfDeviceTypeRegEx->pattern(), "."));
86 }
87 return port::Status::OK();
88 }
89
ValidateSPPlatform(const SP_Platform & platform)90 port::Status ValidateSPPlatform(const SP_Platform& platform) {
91 VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
92 VALIDATE_MEMBER(SP_Platform, platform, name);
93 VALIDATE_MEMBER(SP_Platform, platform, type);
94 TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name));
95 TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type));
96 // `visible_device_count` could be 0 at initialization time.
97 return port::Status::OK();
98 }
99
ValidateSPPlatformFns(const SP_PlatformFns & platform_fns)100 port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
101 VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns,
102 SP_PLATFORM_FNS_STRUCT_SIZE);
103 VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device);
104 VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device);
105 VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_stream_executor);
106 VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_stream_executor);
107 VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_timer_fns);
108 VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_timer_fns);
109 VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device_fns);
110 VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device_fns);
111 return port::Status::OK();
112 }
113
ValidateSPTimerFns(const SP_TimerFns & timer_fns)114 port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
115 VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE);
116 VALIDATE_MEMBER(SP_TimerFns, timer_fns, nanoseconds);
117 return port::Status::OK();
118 }
119
ValidateSPAllocatorStats(const SP_AllocatorStats & stats)120 port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
121 VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE);
122 // All other fields could theoretically be zero/null.
123 return port::Status::OK();
124 }
125
ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase & mem)126 port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
127 VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem,
128 SP_DEVICE_MEMORY_BASE_STRUCT_SIZE);
129 // All other fields could theoretically be zero/null.
130 return port::Status::OK();
131 }
132
ValidateSPDevice(const SP_Device & device)133 port::Status ValidateSPDevice(const SP_Device& device) {
134 VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE);
135 // All other fields could theoretically be zero/null.
136 return port::Status::OK();
137 }
138
ValidateSPDeviceFns(const SP_DeviceFns & device_fns)139 port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
140 VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE);
141 // All other fields could theoretically be zero/null.
142 return port::Status::OK();
143 }
144
ValidateSPStreamExecutor(const SP_StreamExecutor & se,const SP_Platform & platform)145 port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
146 const SP_Platform& platform) {
147 VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE);
148 VALIDATE_MEMBER(SP_StreamExecutor, se, allocate);
149 VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate);
150 VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats);
151 VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_allocate);
152 VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_deallocate);
153 if (platform.supports_unified_memory) {
154 VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_allocate);
155 VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_deallocate);
156 }
157 VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage);
158 VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream);
159 VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream);
160 VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream_dependency);
161 VALIDATE_MEMBER(SP_StreamExecutor, se, get_stream_status);
162 VALIDATE_MEMBER(SP_StreamExecutor, se, create_event);
163 VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_event);
164 VALIDATE_MEMBER(SP_StreamExecutor, se, get_event_status);
165 VALIDATE_MEMBER(SP_StreamExecutor, se, record_event);
166 VALIDATE_MEMBER(SP_StreamExecutor, se, wait_for_event);
167 VALIDATE_MEMBER(SP_StreamExecutor, se, create_timer);
168 VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_timer);
169 VALIDATE_MEMBER(SP_StreamExecutor, se, start_timer);
170 VALIDATE_MEMBER(SP_StreamExecutor, se, stop_timer);
171 VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_dtoh);
172 VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_htod);
173 VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_dtoh);
174 VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_htod);
175 VALIDATE_MEMBER(SP_StreamExecutor, se, block_host_for_event);
176 VALIDATE_MEMBER(SP_StreamExecutor, se, synchronize_all_activity);
177 VALIDATE_MEMBER(SP_StreamExecutor, se, host_callback);
178 return port::Status::OK();
179 }
180
ValidateSEPlatformRegistrationParams(const SE_PlatformRegistrationParams & params)181 port::Status ValidateSEPlatformRegistrationParams(
182 const SE_PlatformRegistrationParams& params) {
183 VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
184 SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
185 VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform);
186 VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform_fns);
187 return port::Status::OK();
188 }
189 #undef VALIDATE_MEMBER
190
191 // Converts SE_EventStatus to Event::Status.
SEEventStatusToEventStatus(SE_EventStatus s)192 Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
193 switch (s) {
194 case SE_EVENT_ERROR:
195 return Event::Status::kError;
196 case SE_EVENT_PENDING:
197 return Event::Status::kPending;
198 case SE_EVENT_COMPLETE:
199 return Event::Status::kComplete;
200 default:
201 return Event::Status::kUnknown;
202 }
203 }
204
205 // Converts DeviceMemoryBase to a C struct.
DeviceMemoryBaseToC(const DeviceMemoryBase * mem)206 SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
207 SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
208 // `opaque` field inside SP_DeviceMemoryBase is not const.
209 // Therefore, we need to cast away the constness before setting it.
210 device_memory_base.opaque = const_cast<void*>(mem->opaque());
211 device_memory_base.size = mem->size();
212 device_memory_base.payload = mem->payload();
213 return device_memory_base;
214 }
215
DeviceMemoryBaseFromC(const SP_DeviceMemoryBase & mem)216 DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
217 DeviceMemoryBase base(mem.opaque, mem.size);
218 base.SetPayload(mem.payload);
219 return base;
220 }
221
222 // Wrapper that allows passing std::function across C API.
223 struct HostCallbackContext {
224 std::function<port::Status()> callback;
225 };
226
227 // This wrapper allows calling `HostCallbackContext::callback` across C API.
228 // This function matches `SE_StatusCallbackFn` signature and will be passed as
229 // `callback_fn` to `host_callback` in `SP_StreamExecutor`.
HostCallbackTrampoline(void * ctx,TF_Status * status)230 void HostCallbackTrampoline(void* ctx, TF_Status* status) {
231 HostCallbackContext* host_ctx = static_cast<HostCallbackContext*>(ctx);
232 port::Status s = host_ctx->callback();
233 Set_TF_Status_from_Status(status, s);
234 delete host_ctx;
235 }
236
237 class CStreamExecutor : public internal::StreamExecutorInterface {
238 public:
CStreamExecutor(SP_Device device,SP_DeviceFns * device_fns,SP_StreamExecutor * stream_executor,SP_Platform * platform,SP_PlatformFns * platform_fns,SP_TimerFns * timer_fns,const std::string & name,int visible_device_count)239 explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns,
240 SP_StreamExecutor* stream_executor,
241 SP_Platform* platform, SP_PlatformFns* platform_fns,
242 SP_TimerFns* timer_fns, const std::string& name,
243 int visible_device_count)
244 : device_(std::move(device)),
245 device_fns_(device_fns),
246 stream_executor_(stream_executor),
247 platform_(platform),
248 platform_fns_(platform_fns),
249 timer_fns_(timer_fns),
250 platform_name_(name),
251 visible_device_count_(visible_device_count) {}
252
~CStreamExecutor()253 ~CStreamExecutor() override {
254 platform_fns_->destroy_device(platform_, &device_);
255 }
256
Init(int device_ordinal,DeviceOptions device_options)257 port::Status Init(int device_ordinal, DeviceOptions device_options) override {
258 return port::Status::OK();
259 }
260
Allocate(uint64 size,int64 memory_space)261 DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override {
262 SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
263 stream_executor_->allocate(&device_, size, memory_space, &mem);
264 port::Status status = ValidateSPDeviceMemoryBase(mem);
265 if (!status.ok()) {
266 LOG(ERROR) << status.error_message();
267 }
268 return DeviceMemoryBaseFromC(mem);
269 }
Allocate(uint64 size)270 DeviceMemoryBase Allocate(uint64 size) {
271 return Allocate(size, /*memory_space=*/0);
272 }
GetSubBuffer(DeviceMemoryBase * parent,uint64 offset,uint64 size)273 void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
274 uint64 size) override {
275 LOG(FATAL) << "GetSubBuffer is not supported by pluggable device.";
276 }
277
Deallocate(DeviceMemoryBase * mem)278 void Deallocate(DeviceMemoryBase* mem) override {
279 SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem);
280 stream_executor_->deallocate(&device_, &device_memory_base);
281 }
282
HostMemoryAllocate(uint64 size)283 void* HostMemoryAllocate(uint64 size) override {
284 return stream_executor_->host_memory_allocate(&device_, size);
285 }
286
HostMemoryDeallocate(void * mem)287 void HostMemoryDeallocate(void* mem) override {
288 stream_executor_->host_memory_deallocate(&device_, mem);
289 }
290
HostMemoryRegister(void * mem,uint64 size)291 bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
HostMemoryUnregister(void * mem)292 bool HostMemoryUnregister(void* mem) override { return false; }
293
UnifiedMemoryAllocate(uint64 size)294 void* UnifiedMemoryAllocate(uint64 size) override {
295 CHECK(stream_executor_->unified_memory_allocate);
296 return stream_executor_->unified_memory_allocate(&device_, size);
297 }
298
UnifiedMemoryDeallocate(void * mem)299 void UnifiedMemoryDeallocate(void* mem) override {
300 CHECK(stream_executor_->unified_memory_deallocate);
301 stream_executor_->unified_memory_deallocate(&device_, mem);
302 }
303
GetAllocatorStats()304 absl::optional<AllocatorStats> GetAllocatorStats() override {
305 SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE};
306 TF_Bool has_stats =
307 stream_executor_->get_allocator_stats(&device_, &c_stats);
308 if (!has_stats) {
309 return absl::nullopt;
310 }
311 port::Status status = ValidateSPAllocatorStats(c_stats);
312 if (!status.ok()) {
313 LOG(ERROR) << status.error_message();
314 return absl::nullopt;
315 }
316 ::stream_executor::AllocatorStats stats;
317 stats.num_allocs = c_stats.num_allocs;
318 stats.bytes_in_use = c_stats.bytes_in_use;
319 stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
320 stats.largest_alloc_size = c_stats.largest_alloc_size;
321 if (c_stats.has_bytes_limit) {
322 stats.bytes_limit = c_stats.bytes_limit;
323 }
324 stats.bytes_reserved = c_stats.bytes_reserved;
325 stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
326 if (c_stats.has_bytes_reservable_limit) {
327 stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
328 }
329 stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
330 return stats;
331 }
SynchronizeAllActivity()332 bool SynchronizeAllActivity() override {
333 OwnedTFStatus c_status(TF_NewStatus());
334 stream_executor_->synchronize_all_activity(&device_, c_status.get());
335 if (TF_GetCode(c_status.get()) != TF_OK) {
336 LOG(ERROR) << TF_Message(c_status.get());
337 return false;
338 }
339 return true;
340 }
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)341 port::Status SynchronousMemZero(DeviceMemoryBase* location,
342 uint64 size) override {
343 // TODO(annarev): figure out if we should support memzero/memset
344 // functionality by allocating on host and then copying to device.
345 return port::UnimplementedError(
346 "SynchronousMemZero is not supported by pluggable device.");
347 }
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)348 port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
349 uint64 size) override {
350 return port::UnimplementedError(
351 "SynchronousMemSet is not supported by pluggable device.");
352 }
SynchronousMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)353 port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
354 const void* host_src, uint64 size) override {
355 OwnedTFStatus c_status(TF_NewStatus());
356 SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst);
357 stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src,
358 size, c_status.get());
359 return StatusFromTF_Status(c_status.get());
360 }
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)361 port::Status SynchronousMemcpy(void* host_dst,
362 const DeviceMemoryBase& gpu_src,
363 uint64 size) override {
364 OwnedTFStatus c_status(TF_NewStatus());
365 SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src);
366 stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base,
367 size, c_status.get());
368 return StatusFromTF_Status(c_status.get());
369 }
SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)370 port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
371 const DeviceMemoryBase& gpu_src,
372 uint64 size) override {
373 OwnedTFStatus c_status(TF_NewStatus());
374 SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
375 SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
376 stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst,
377 &device_mem_src, size, c_status.get());
378 return StatusFromTF_Status(c_status.get());
379 }
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)380 port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
381 uint64 size) override {
382 return port::UnimplementedError(
383 "MemZero is not supported by pluggable device.");
384 }
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)385 port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
386 uint64 size) override {
387 return port::UnimplementedError(
388 "Memset is not supported by pluggable device.");
389 }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)390 port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
391 uint32 pattern, uint64 size) override {
392 return port::UnimplementedError(
393 "Memset32 is not supported by pluggable device.");
394 }
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)395 bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src,
396 uint64 size) override {
397 OwnedTFStatus c_status(TF_NewStatus());
398 SP_Stream stream_handle =
399 static_cast<CStream*>(stream->implementation())->Handle();
400 SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
401 stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst,
402 &device_mem_src, size, c_status.get());
403 if (TF_GetCode(c_status.get()) != TF_OK) {
404 LOG(ERROR) << TF_Message(c_status.get());
405 return false;
406 }
407 return true;
408 }
Memcpy(Stream * stream,DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)409 bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src,
410 uint64 size) override {
411 OwnedTFStatus c_status(TF_NewStatus());
412 SP_Stream stream_handle =
413 static_cast<CStream*>(stream->implementation())->Handle();
414 SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
415 stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst,
416 host_src, size, c_status.get());
417 if (TF_GetCode(c_status.get()) != TF_OK) {
418 LOG(ERROR) << TF_Message(c_status.get());
419 return false;
420 }
421 return true;
422 }
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)423 bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst,
424 const DeviceMemoryBase& gpu_src,
425 uint64 size) override {
426 OwnedTFStatus c_status(TF_NewStatus());
427 SP_Stream stream_handle =
428 static_cast<CStream*>(stream->implementation())->Handle();
429 SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
430 SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
431 stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst,
432 &device_mem_src, size, c_status.get());
433 if (TF_GetCode(c_status.get()) != TF_OK) {
434 LOG(ERROR) << TF_Message(c_status.get());
435 return false;
436 }
437 return true;
438 }
HostCallback(Stream * stream,std::function<port::Status ()> callback)439 bool HostCallback(Stream* stream,
440 std::function<port::Status()> callback) override {
441 SP_Stream stream_handle =
442 static_cast<CStream*>(stream->implementation())->Handle();
443 HostCallbackContext* ctx = new HostCallbackContext{callback};
444 return stream_executor_->host_callback(&device_, stream_handle,
445 &HostCallbackTrampoline, ctx);
446 }
AllocateEvent(Event * event)447 port::Status AllocateEvent(Event* event) override {
448 DCHECK(event != nullptr);
449 return static_cast<CEvent*>(event->implementation())->Create();
450 }
DeallocateEvent(Event * event)451 port::Status DeallocateEvent(Event* event) override {
452 static_cast<CEvent*>(event->implementation())->Destroy();
453 return port::Status::OK();
454 }
RecordEvent(Stream * stream,Event * event)455 port::Status RecordEvent(Stream* stream, Event* event) override {
456 SP_Stream stream_handle =
457 static_cast<CStream*>(stream->implementation())->Handle();
458 return static_cast<CEvent*>(event->implementation())->Record(stream_handle);
459 }
WaitForEvent(Stream * stream,Event * event)460 port::Status WaitForEvent(Stream* stream, Event* event) override {
461 SP_Stream stream_handle =
462 static_cast<CStream*>(stream->implementation())->Handle();
463 SP_Event event_handle =
464 static_cast<CEvent*>(event->implementation())->Handle();
465 OwnedTFStatus c_status(TF_NewStatus());
466 stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
467 c_status.get());
468 port::Status s = StatusFromTF_Status(c_status.get());
469 return s;
470 }
PollForEventStatus(Event * event)471 Event::Status PollForEventStatus(Event* event) override {
472 SP_Event event_handle =
473 static_cast<CEvent*>(event->implementation())->Handle();
474 SE_EventStatus event_status =
475 stream_executor_->get_event_status(&device_, event_handle);
476 return SEEventStatusToEventStatus(event_status);
477 }
AllocateStream(Stream * stream)478 bool AllocateStream(Stream* stream) override {
479 DCHECK(stream != nullptr);
480 port::Status status =
481 static_cast<CStream*>(stream->implementation())->Create();
482 // TODO(annarev): update AllocateStream to return status instead
483 // (similar to AllocateEvent).
484 return status.ok();
485 }
DeallocateStream(Stream * stream)486 void DeallocateStream(Stream* stream) override {
487 static_cast<CStream*>(stream->implementation())->Destroy();
488 }
CreateStreamDependency(Stream * dependent,Stream * other)489 bool CreateStreamDependency(Stream* dependent, Stream* other) override {
490 OwnedTFStatus c_status(TF_NewStatus());
491 SP_Stream dependent_handle =
492 static_cast<CStream*>(dependent->implementation())->Handle();
493 SP_Stream other_handle =
494 static_cast<CStream*>(other->implementation())->Handle();
495 stream_executor_->create_stream_dependency(&device_, dependent_handle,
496 other_handle, c_status.get());
497 if (TF_GetCode(c_status.get()) != TF_OK) {
498 LOG(ERROR) << TF_Message(c_status.get());
499 return false;
500 }
501 return true;
502 }
AllocateTimer(Timer * timer)503 bool AllocateTimer(Timer* timer) override {
504 port::Status status =
505 static_cast<CTimer*>(timer->implementation())->Create();
506 // TODO(annarev): change return value of AllocateTimer
507 // to status (similar to AllocateEvent).
508 return status.ok();
509 }
DeallocateTimer(Timer * timer)510 void DeallocateTimer(Timer* timer) override {
511 static_cast<CTimer*>(timer->implementation())->Destroy();
512 }
StartTimer(Stream * stream,Timer * timer)513 bool StartTimer(Stream* stream, Timer* timer) override {
514 OwnedTFStatus c_status(TF_NewStatus());
515 SP_Stream stream_handle =
516 static_cast<CStream*>(stream->implementation())->Handle();
517 SP_Timer timer_handle =
518 static_cast<CTimer*>(timer->implementation())->Handle();
519 stream_executor_->start_timer(&device_, stream_handle, timer_handle,
520 c_status.get());
521 if (TF_GetCode(c_status.get()) != TF_OK) {
522 LOG(ERROR) << TF_Message(c_status.get());
523 return false;
524 }
525 return true;
526 }
StopTimer(Stream * stream,Timer * timer)527 bool StopTimer(Stream* stream, Timer* timer) override {
528 OwnedTFStatus c_status(TF_NewStatus());
529 SP_Stream stream_handle =
530 static_cast<CStream*>(stream->implementation())->Handle();
531 SP_Timer timer_handle =
532 static_cast<CTimer*>(timer->implementation())->Handle();
533 stream_executor_->stop_timer(&device_, stream_handle, timer_handle,
534 c_status.get());
535 if (TF_GetCode(c_status.get()) != TF_OK) {
536 LOG(ERROR) << TF_Message(c_status.get());
537 return false;
538 }
539 return true;
540 }
BlockHostForEvent(Stream * stream,Event * event)541 port::Status BlockHostForEvent(Stream* stream, Event* event) {
542 OwnedTFStatus c_status(TF_NewStatus());
543 SP_Event event_handle =
544 static_cast<CEvent*>(event->implementation())->Handle();
545 stream_executor_->block_host_for_event(&device_, event_handle,
546 c_status.get());
547 return StatusFromTF_Status(c_status.get());
548 }
549
BlockHostUntilDone(Stream * stream)550 port::Status BlockHostUntilDone(Stream* stream) override {
551 OwnedTFStatus c_status(TF_NewStatus());
552 SP_Stream stream_handle =
553 static_cast<CStream*>(stream->implementation())->Handle();
554
555 // If `block_host_until_done` is set, use it.
556 if (stream_executor_->block_host_until_done != nullptr) {
557 stream_executor_->block_host_until_done(&device_, stream_handle,
558 c_status.get());
559 return StatusFromTF_Status(c_status.get());
560 }
561 // Create and record an event and then wait for it.
562 SP_Event event_handle;
563 stream_executor_->create_event(&device_, &event_handle, c_status.get());
564 TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
565 stream_executor_->record_event(&device_, stream_handle, event_handle,
566 c_status.get());
567 port::Status s = StatusFromTF_Status(c_status.get());
568 if (!s.ok()) {
569 stream_executor_->destroy_event(&device_, event_handle);
570 return s;
571 }
572 stream_executor_->block_host_for_event(&device_, event_handle,
573 c_status.get());
574 stream_executor_->destroy_event(&device_, event_handle);
575 return StatusFromTF_Status(c_status.get());
576 }
577
GetStatus(Stream * stream)578 port::Status GetStatus(Stream* stream) override {
579 OwnedTFStatus c_status(TF_NewStatus());
580 SP_Stream stream_handle =
581 static_cast<CStream*>(stream->implementation())->Handle();
582 stream_executor_->get_stream_status(&device_, stream_handle,
583 c_status.get());
584 return StatusFromTF_Status(c_status.get());
585 }
PlatformDeviceCount()586 int PlatformDeviceCount() override { return visible_device_count_; }
EnablePeerAccessTo(StreamExecutorInterface * other)587 port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
588 return port::UnimplementedError(
589 "EnablePeerAccessTo is not supported by pluggable device.");
590 }
CanEnablePeerAccessTo(StreamExecutorInterface * other)591 bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
592 return false;
593 }
594
DeviceMemoryUsage(int64 * free,int64 * total) const595 bool DeviceMemoryUsage(int64* free, int64* total) const override {
596 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
597 "64-bit int types should match in size");
598 return stream_executor_->device_memory_usage(
599 &device_, reinterpret_cast<int64_t*>(free),
600 reinterpret_cast<int64_t*>(total));
601 }
602
603 // Creates a new DeviceDescription object.
604 // Ownership is transferred to the caller.
CreateDeviceDescription() const605 port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
606 const override {
607 OwnedTFStatus c_status(TF_NewStatus());
608
609 internal::DeviceDescriptionBuilder builder;
610 if (device_.hardware_name != nullptr) {
611 builder.set_name(device_.hardware_name);
612 }
613 if (device_.device_vendor != nullptr) {
614 builder.set_device_vendor(device_.device_vendor);
615 }
616 if (device_.pci_bus_id != nullptr) {
617 builder.set_pci_bus_id(device_.pci_bus_id);
618 }
619
620 if (device_fns_->get_numa_node != nullptr) {
621 int32_t numa_node = device_fns_->get_numa_node(&device_);
622 if (numa_node >= 0) {
623 builder.set_numa_node(numa_node);
624 }
625 }
626
627 if (device_fns_->get_memory_bandwidth != nullptr) {
628 int64_t memory_bandwidth = device_fns_->get_memory_bandwidth(&device_);
629 if (memory_bandwidth >= 0) {
630 builder.set_memory_bandwidth(memory_bandwidth);
631 }
632 }
633 // TODO(annarev): Add gflops field in DeviceDescription and set it here.
634 // TODO(annarev): Perhaps add `supports_unified_memory` in
635 // DeviceDescription.
636 return builder.Build();
637 }
638
639 // Each call creates a new instance of the platform-specific implementation of
640 // the corresponding interface type.
CreateEventImplementation()641 std::unique_ptr<internal::EventInterface> CreateEventImplementation()
642 override {
643 return std::unique_ptr<internal::EventInterface>(
644 new CEvent(&device_, stream_executor_));
645 }
CreateKernelImplementation()646 std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
647 override {
648 LOG(FATAL)
649 << "CreateKernelImplementation is not supported by pluggable device.";
650 }
GetStreamImplementation()651 std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
652 override {
653 return std::unique_ptr<internal::StreamInterface>(
654 new CStream(&device_, stream_executor_));
655 }
GetTimerImplementation()656 std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
657 return std::unique_ptr<internal::TimerInterface>(
658 new CTimer(&device_, stream_executor_, timer_fns_));
659 }
660
661 private:
662 SP_Device device_;
663 SP_DeviceFns* device_fns_;
664 SP_StreamExecutor* stream_executor_;
665 SP_Platform* platform_;
666 SP_PlatformFns* platform_fns_;
667 SP_TimerFns* timer_fns_;
668 std::string platform_name_;
669 int visible_device_count_;
670 };
671 } // namespace
672
CPlatform(SP_Platform platform,void (* destroy_platform)(SP_Platform *),SP_PlatformFns platform_fns,void (* destroy_platform_fns)(SP_PlatformFns *),SP_DeviceFns device_fns,SP_StreamExecutor stream_executor,SP_TimerFns timer_fns)673 CPlatform::CPlatform(SP_Platform platform,
674 void (*destroy_platform)(SP_Platform*),
675 SP_PlatformFns platform_fns,
676 void (*destroy_platform_fns)(SP_PlatformFns*),
677 SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
678 SP_TimerFns timer_fns)
679 : platform_(std::move(platform)),
680 destroy_platform_(destroy_platform),
681 platform_fns_(std::move(platform_fns)),
682 destroy_platform_fns_(destroy_platform_fns),
683 device_fns_(std::move(device_fns)),
684 stream_executor_(std::move(stream_executor)),
685 timer_fns_(std::move(timer_fns)),
686 name_(platform.name) {}
687
~CPlatform()688 CPlatform::~CPlatform() {
689 executor_cache_.DestroyAllExecutors();
690 platform_fns_.destroy_device_fns(&platform_, &device_fns_);
691 platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
692 platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
693 destroy_platform_(&platform_);
694 destroy_platform_fns_(&platform_fns_);
695 }
696
697 port::StatusOr<std::unique_ptr<DeviceDescription>>
DescriptionForDevice(int ordinal) const698 CPlatform::DescriptionForDevice(int ordinal) const {
699 // TODO(annarev): see if we can get StreamExecutor instance
700 // and call GetDeviceDescription. executor_cache_.Get would need
701 // to be made const for it to work.
702 internal::DeviceDescriptionBuilder builder;
703 builder.set_name(name_);
704 return builder.Build();
705 }
ExecutorForDevice(int ordinal)706 port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDevice(int ordinal) {
707 stream_executor::StreamExecutorConfig config;
708 config.ordinal = ordinal;
709 return GetExecutor(config);
710 }
ExecutorForDeviceWithPluginConfig(int ordinal,const PluginConfig & plugin_config)711 port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDeviceWithPluginConfig(
712 int ordinal, const PluginConfig& plugin_config) {
713 StreamExecutorConfig config;
714 config.ordinal = ordinal;
715 config.plugin_config = plugin_config;
716 return GetExecutor(config);
717 }
GetExecutor(const StreamExecutorConfig & config)718 port::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
719 const StreamExecutorConfig& config) {
720 return executor_cache_.GetOrCreate(
721 config, [&]() { return GetUncachedExecutor(config); });
722 }
GetUncachedExecutor(const StreamExecutorConfig & config)723 port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
724 const StreamExecutorConfig& config) {
725 // Fill device creation params
726 SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE};
727 SP_Device device{SP_DEVICE_STRUCT_SIZE};
728 device_params.device = &device;
729 device_params.ext = nullptr;
730 device_params.ordinal = config.ordinal;
731 OwnedTFStatus c_status(TF_NewStatus());
732
733 // Create Device
734 platform_fns_.create_device(&platform_, &device_params, c_status.get());
735 TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
736 TF_RETURN_IF_ERROR(ValidateSPDevice(device));
737
738 auto executor = absl::make_unique<CStreamExecutor>(
739 std::move(device), &device_fns_, &stream_executor_, &platform_,
740 &platform_fns_, &timer_fns_, name_, platform_.visible_device_count);
741 auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
742 config.ordinal);
743 return result;
744 }
745
InitStreamExecutorPlugin(void * dso_handle)746 port::Status InitStreamExecutorPlugin(void* dso_handle) {
747 tensorflow::Env* env = tensorflow::Env::Default();
748
749 // Step 1: Load symbol for `TF_InitPlugin`
750 void* dso_symbol;
751 TF_RETURN_IF_ERROR(
752 env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol));
753
754 // Step 2: Call `TF_InitPlugin`
755 auto init_fn = reinterpret_cast<SEInitPluginFn>(dso_symbol);
756 return InitStreamExecutorPlugin(init_fn);
757 }
758
InitStreamExecutorPlugin(SEInitPluginFn init_fn)759 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
760 SE_PlatformRegistrationParams params{
761 SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
762 SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
763 SP_PlatformFns platform_fns{SP_PLATFORM_FNS_STRUCT_SIZE};
764 params.major_version = SE_MAJOR;
765 params.minor_version = SE_MINOR;
766 params.patch_version = SE_PATCH;
767 params.platform = &platform;
768 params.platform_fns = &platform_fns;
769
770 OwnedTFStatus c_status(TF_NewStatus());
771 init_fn(¶ms, c_status.get());
772 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
773 TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params));
774 TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
775 TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns));
776
777 // Fill SP_DeviceFns creation params
778 SE_CreateDeviceFnsParams device_fns_params{
779 SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE};
780 SP_DeviceFns device_fns{SP_DEVICE_FNS_STRUCT_SIZE};
781 device_fns_params.device_fns = &device_fns;
782
783 // Create StreamExecutor
784 platform_fns.create_device_fns(&platform, &device_fns_params, c_status.get());
785 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
786 TF_RETURN_IF_ERROR(ValidateSPDeviceFns(device_fns));
787
788 // Fill stream executor creation params
789 SE_CreateStreamExecutorParams se_params{
790 SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
791 SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE};
792 se_params.stream_executor = &se;
793
794 // Create StreamExecutor
795 platform_fns.create_stream_executor(&platform, &se_params, c_status.get());
796 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
797 TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se, platform));
798
799 SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE};
800 platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get());
801 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
802 TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
803
804 // Register new platform
805 std::string platform_name = std::string(platform.name);
806 std::unique_ptr<stream_executor::CPlatform> cplatform(
807 new stream_executor::CPlatform(
808 std::move(platform), params.destroy_platform, std::move(platform_fns),
809 params.destroy_platform_fns, std::move(device_fns), std::move(se),
810 std::move(timer_fns)));
811 SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
812 std::move(cplatform)));
813
814 // TODO(annarev): Add pluggable device registration here.
815 return port::Status::OK();
816 }
817 } // namespace stream_executor
818