• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file extends/implements core stream executor base classes in terms of
16 // the C API defined in stream_executor.h. A class "CSomething" represents a
17 // "Something" that can be manipulated via calls in the C interface and a C
18 // struct called "SP_Something".
19 //
20 // This file also contains stream_executor::Platform registration for pluggable
21 // device.
22 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
23 
24 #include <string>
25 
26 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/regexp.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/strcat.h"
33 #include "tensorflow/core/platform/stringpiece.h"
34 #include "tensorflow/stream_executor/executor_cache.h"
35 #include "tensorflow/stream_executor/multi_platform_manager.h"
36 #include "tensorflow/stream_executor/platform.h"
37 #include "tensorflow/stream_executor/stream.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
40 #include "tensorflow/stream_executor/timer.h"
41 
42 using tensorflow::StatusFromTF_Status;
43 
44 namespace stream_executor {
45 using tensorflow::StringPiece;
46 using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
47 
48 namespace {
49 
50 #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
51   do {                                                                 \
52     if (STRUCT_OBJ.struct_size == 0) {                                 \
53       return port::FailedPreconditionError(                            \
54           "struct_size field in " #STRUCT_NAME                         \
55           " must be set to " #SIZE_VALUE_NAME ".");                    \
56     }                                                                  \
57   } while (0)
58 
59 #define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME)           \
60   do {                                                           \
61     if (STRUCT_OBJ.NAME == 0) {                                  \
62       return port::FailedPreconditionError(                      \
63           "'" #NAME "' field in " #STRUCT_NAME " must be set."); \
64     }                                                            \
65   } while (0)
66 
ValidateDeviceType(StringPiece type)67 port::Status ValidateDeviceType(StringPiece type) {
68   // Validate device type. Device type must start with a capital letter and
69   // consist of capital letters and underscores. Reasoning behind this decision:
70   // * At the minimum we want to disallow '/' and ':' since
71   //   these characters are used in device spec, for e.g.
72   //   /job:foo/replica:12/device:GPU:1.
73   // * Underscores seem useful, for e.g. XLA_GPU uses underscores.
74   // * Allowing lowercase might get confusing. For example, say someone
75   //   registers a new type called "Gpu". It might be confusing for users that
76   //   "Gpu" is not the same device type as "GPU".
77   //   Note that lowercase "cpu" and "gpu" are currently supported only for
78   //   legacy reasons:
79   //   https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd
80   static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"};
81   bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx);
82   if (!matches) {
83     return port::FailedPreconditionError(
84         tensorflow::strings::StrCat("Device name/type '", type, "' must match ",
85                                     kTfDeviceTypeRegEx->pattern(), "."));
86   }
87   return port::Status::OK();
88 }
89 
ValidateSPPlatform(const SP_Platform & platform)90 port::Status ValidateSPPlatform(const SP_Platform& platform) {
91   VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
92   VALIDATE_MEMBER(SP_Platform, platform, name);
93   VALIDATE_MEMBER(SP_Platform, platform, type);
94   TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name));
95   TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type));
96   // `visible_device_count` could be 0 at initialization time.
97   return port::Status::OK();
98 }
99 
ValidateSPPlatformFns(const SP_PlatformFns & platform_fns)100 port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
101   VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns,
102                        SP_PLATFORM_FNS_STRUCT_SIZE);
103   VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device);
104   VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device);
105   VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_stream_executor);
106   VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_stream_executor);
107   VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_timer_fns);
108   VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_timer_fns);
109   VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device_fns);
110   VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device_fns);
111   return port::Status::OK();
112 }
113 
ValidateSPTimerFns(const SP_TimerFns & timer_fns)114 port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
115   VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE);
116   VALIDATE_MEMBER(SP_TimerFns, timer_fns, nanoseconds);
117   return port::Status::OK();
118 }
119 
ValidateSPAllocatorStats(const SP_AllocatorStats & stats)120 port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
121   VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE);
122   // All other fields could theoretically be zero/null.
123   return port::Status::OK();
124 }
125 
ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase & mem)126 port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
127   VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem,
128                        SP_DEVICE_MEMORY_BASE_STRUCT_SIZE);
129   // All other fields could theoretically be zero/null.
130   return port::Status::OK();
131 }
132 
ValidateSPDevice(const SP_Device & device)133 port::Status ValidateSPDevice(const SP_Device& device) {
134   VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE);
135   // All other fields could theoretically be zero/null.
136   return port::Status::OK();
137 }
138 
ValidateSPDeviceFns(const SP_DeviceFns & device_fns)139 port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
140   VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE);
141   // All other fields could theoretically be zero/null.
142   return port::Status::OK();
143 }
144 
ValidateSPStreamExecutor(const SP_StreamExecutor & se,const SP_Platform & platform)145 port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
146                                       const SP_Platform& platform) {
147   VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE);
148   VALIDATE_MEMBER(SP_StreamExecutor, se, allocate);
149   VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate);
150   VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats);
151   VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_allocate);
152   VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_deallocate);
153   if (platform.supports_unified_memory) {
154     VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_allocate);
155     VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_deallocate);
156   }
157   VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage);
158   VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream);
159   VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream);
160   VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream_dependency);
161   VALIDATE_MEMBER(SP_StreamExecutor, se, get_stream_status);
162   VALIDATE_MEMBER(SP_StreamExecutor, se, create_event);
163   VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_event);
164   VALIDATE_MEMBER(SP_StreamExecutor, se, get_event_status);
165   VALIDATE_MEMBER(SP_StreamExecutor, se, record_event);
166   VALIDATE_MEMBER(SP_StreamExecutor, se, wait_for_event);
167   VALIDATE_MEMBER(SP_StreamExecutor, se, create_timer);
168   VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_timer);
169   VALIDATE_MEMBER(SP_StreamExecutor, se, start_timer);
170   VALIDATE_MEMBER(SP_StreamExecutor, se, stop_timer);
171   VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_dtoh);
172   VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_htod);
173   VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_dtoh);
174   VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_htod);
175   VALIDATE_MEMBER(SP_StreamExecutor, se, block_host_for_event);
176   VALIDATE_MEMBER(SP_StreamExecutor, se, synchronize_all_activity);
177   VALIDATE_MEMBER(SP_StreamExecutor, se, host_callback);
178   return port::Status::OK();
179 }
180 
ValidateSEPlatformRegistrationParams(const SE_PlatformRegistrationParams & params)181 port::Status ValidateSEPlatformRegistrationParams(
182     const SE_PlatformRegistrationParams& params) {
183   VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
184                        SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
185   VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform);
186   VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform_fns);
187   return port::Status::OK();
188 }
189 #undef VALIDATE_MEMBER
190 
191 // Converts SE_EventStatus to Event::Status.
SEEventStatusToEventStatus(SE_EventStatus s)192 Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
193   switch (s) {
194     case SE_EVENT_ERROR:
195       return Event::Status::kError;
196     case SE_EVENT_PENDING:
197       return Event::Status::kPending;
198     case SE_EVENT_COMPLETE:
199       return Event::Status::kComplete;
200     default:
201       return Event::Status::kUnknown;
202   }
203 }
204 
205 // Converts DeviceMemoryBase to a C struct.
DeviceMemoryBaseToC(const DeviceMemoryBase * mem)206 SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
207   SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
208   // `opaque` field inside SP_DeviceMemoryBase is not const.
209   // Therefore, we need to cast away the constness before setting it.
210   device_memory_base.opaque = const_cast<void*>(mem->opaque());
211   device_memory_base.size = mem->size();
212   device_memory_base.payload = mem->payload();
213   return device_memory_base;
214 }
215 
DeviceMemoryBaseFromC(const SP_DeviceMemoryBase & mem)216 DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
217   DeviceMemoryBase base(mem.opaque, mem.size);
218   base.SetPayload(mem.payload);
219   return base;
220 }
221 
222 // Wrapper that allows passing std::function across C API.
223 struct HostCallbackContext {
224   std::function<port::Status()> callback;
225 };
226 
227 // This wrapper allows calling `HostCallbackContext::callback` across C API.
228 // This function matches `SE_StatusCallbackFn` signature and will be passed as
229 // `callback_fn` to `host_callback` in `SP_StreamExecutor`.
HostCallbackTrampoline(void * ctx,TF_Status * status)230 void HostCallbackTrampoline(void* ctx, TF_Status* status) {
231   HostCallbackContext* host_ctx = static_cast<HostCallbackContext*>(ctx);
232   port::Status s = host_ctx->callback();
233   Set_TF_Status_from_Status(status, s);
234   delete host_ctx;
235 }
236 
237 class CStreamExecutor : public internal::StreamExecutorInterface {
238  public:
CStreamExecutor(SP_Device device,SP_DeviceFns * device_fns,SP_StreamExecutor * stream_executor,SP_Platform * platform,SP_PlatformFns * platform_fns,SP_TimerFns * timer_fns,const std::string & name,int visible_device_count)239   explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns,
240                            SP_StreamExecutor* stream_executor,
241                            SP_Platform* platform, SP_PlatformFns* platform_fns,
242                            SP_TimerFns* timer_fns, const std::string& name,
243                            int visible_device_count)
244       : device_(std::move(device)),
245         device_fns_(device_fns),
246         stream_executor_(stream_executor),
247         platform_(platform),
248         platform_fns_(platform_fns),
249         timer_fns_(timer_fns),
250         platform_name_(name),
251         visible_device_count_(visible_device_count) {}
252 
~CStreamExecutor()253   ~CStreamExecutor() override {
254     platform_fns_->destroy_device(platform_, &device_);
255   }
256 
Init(int device_ordinal,DeviceOptions device_options)257   port::Status Init(int device_ordinal, DeviceOptions device_options) override {
258     return port::Status::OK();
259   }
260 
Allocate(uint64 size,int64 memory_space)261   DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override {
262     SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
263     stream_executor_->allocate(&device_, size, memory_space, &mem);
264     port::Status status = ValidateSPDeviceMemoryBase(mem);
265     if (!status.ok()) {
266       LOG(ERROR) << status.error_message();
267     }
268     return DeviceMemoryBaseFromC(mem);
269   }
Allocate(uint64 size)270   DeviceMemoryBase Allocate(uint64 size) {
271     return Allocate(size, /*memory_space=*/0);
272   }
GetSubBuffer(DeviceMemoryBase * parent,uint64 offset,uint64 size)273   void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
274                      uint64 size) override {
275     LOG(FATAL) << "GetSubBuffer is not supported by pluggable device.";
276   }
277 
Deallocate(DeviceMemoryBase * mem)278   void Deallocate(DeviceMemoryBase* mem) override {
279     SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem);
280     stream_executor_->deallocate(&device_, &device_memory_base);
281   }
282 
HostMemoryAllocate(uint64 size)283   void* HostMemoryAllocate(uint64 size) override {
284     return stream_executor_->host_memory_allocate(&device_, size);
285   }
286 
HostMemoryDeallocate(void * mem)287   void HostMemoryDeallocate(void* mem) override {
288     stream_executor_->host_memory_deallocate(&device_, mem);
289   }
290 
HostMemoryRegister(void * mem,uint64 size)291   bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
HostMemoryUnregister(void * mem)292   bool HostMemoryUnregister(void* mem) override { return false; }
293 
UnifiedMemoryAllocate(uint64 size)294   void* UnifiedMemoryAllocate(uint64 size) override {
295     CHECK(stream_executor_->unified_memory_allocate);
296     return stream_executor_->unified_memory_allocate(&device_, size);
297   }
298 
UnifiedMemoryDeallocate(void * mem)299   void UnifiedMemoryDeallocate(void* mem) override {
300     CHECK(stream_executor_->unified_memory_deallocate);
301     stream_executor_->unified_memory_deallocate(&device_, mem);
302   }
303 
GetAllocatorStats()304   absl::optional<AllocatorStats> GetAllocatorStats() override {
305     SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE};
306     TF_Bool has_stats =
307         stream_executor_->get_allocator_stats(&device_, &c_stats);
308     if (!has_stats) {
309       return absl::nullopt;
310     }
311     port::Status status = ValidateSPAllocatorStats(c_stats);
312     if (!status.ok()) {
313       LOG(ERROR) << status.error_message();
314       return absl::nullopt;
315     }
316     ::stream_executor::AllocatorStats stats;
317     stats.num_allocs = c_stats.num_allocs;
318     stats.bytes_in_use = c_stats.bytes_in_use;
319     stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
320     stats.largest_alloc_size = c_stats.largest_alloc_size;
321     if (c_stats.has_bytes_limit) {
322       stats.bytes_limit = c_stats.bytes_limit;
323     }
324     stats.bytes_reserved = c_stats.bytes_reserved;
325     stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
326     if (c_stats.has_bytes_reservable_limit) {
327       stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
328     }
329     stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
330     return stats;
331   }
SynchronizeAllActivity()332   bool SynchronizeAllActivity() override {
333     OwnedTFStatus c_status(TF_NewStatus());
334     stream_executor_->synchronize_all_activity(&device_, c_status.get());
335     if (TF_GetCode(c_status.get()) != TF_OK) {
336       LOG(ERROR) << TF_Message(c_status.get());
337       return false;
338     }
339     return true;
340   }
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)341   port::Status SynchronousMemZero(DeviceMemoryBase* location,
342                                   uint64 size) override {
343     // TODO(annarev): figure out if we should support memzero/memset
344     // functionality by allocating on host and then copying to device.
345     return port::UnimplementedError(
346         "SynchronousMemZero is not supported by pluggable device.");
347   }
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)348   port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
349                                  uint64 size) override {
350     return port::UnimplementedError(
351         "SynchronousMemSet is not supported by pluggable device.");
352   }
SynchronousMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)353   port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
354                                  const void* host_src, uint64 size) override {
355     OwnedTFStatus c_status(TF_NewStatus());
356     SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst);
357     stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src,
358                                        size, c_status.get());
359     return StatusFromTF_Status(c_status.get());
360   }
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)361   port::Status SynchronousMemcpy(void* host_dst,
362                                  const DeviceMemoryBase& gpu_src,
363                                  uint64 size) override {
364     OwnedTFStatus c_status(TF_NewStatus());
365     SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src);
366     stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base,
367                                        size, c_status.get());
368     return StatusFromTF_Status(c_status.get());
369   }
SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)370   port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
371                                                const DeviceMemoryBase& gpu_src,
372                                                uint64 size) override {
373     OwnedTFStatus c_status(TF_NewStatus());
374     SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
375     SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
376     stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst,
377                                        &device_mem_src, size, c_status.get());
378     return StatusFromTF_Status(c_status.get());
379   }
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)380   port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
381                        uint64 size) override {
382     return port::UnimplementedError(
383         "MemZero is not supported by pluggable device.");
384   }
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)385   port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
386                       uint64 size) override {
387     return port::UnimplementedError(
388         "Memset is not supported by pluggable device.");
389   }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)390   port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
391                         uint32 pattern, uint64 size) override {
392     return port::UnimplementedError(
393         "Memset32 is not supported by pluggable device.");
394   }
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)395   bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src,
396               uint64 size) override {
397     OwnedTFStatus c_status(TF_NewStatus());
398     SP_Stream stream_handle =
399         static_cast<CStream*>(stream->implementation())->Handle();
400     SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
401     stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst,
402                                   &device_mem_src, size, c_status.get());
403     if (TF_GetCode(c_status.get()) != TF_OK) {
404       LOG(ERROR) << TF_Message(c_status.get());
405       return false;
406     }
407     return true;
408   }
Memcpy(Stream * stream,DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)409   bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src,
410               uint64 size) override {
411     OwnedTFStatus c_status(TF_NewStatus());
412     SP_Stream stream_handle =
413         static_cast<CStream*>(stream->implementation())->Handle();
414     SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
415     stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst,
416                                   host_src, size, c_status.get());
417     if (TF_GetCode(c_status.get()) != TF_OK) {
418       LOG(ERROR) << TF_Message(c_status.get());
419       return false;
420     }
421     return true;
422   }
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)423   bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst,
424                             const DeviceMemoryBase& gpu_src,
425                             uint64 size) override {
426     OwnedTFStatus c_status(TF_NewStatus());
427     SP_Stream stream_handle =
428         static_cast<CStream*>(stream->implementation())->Handle();
429     SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
430     SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
431     stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst,
432                                   &device_mem_src, size, c_status.get());
433     if (TF_GetCode(c_status.get()) != TF_OK) {
434       LOG(ERROR) << TF_Message(c_status.get());
435       return false;
436     }
437     return true;
438   }
HostCallback(Stream * stream,std::function<port::Status ()> callback)439   bool HostCallback(Stream* stream,
440                     std::function<port::Status()> callback) override {
441     SP_Stream stream_handle =
442         static_cast<CStream*>(stream->implementation())->Handle();
443     HostCallbackContext* ctx = new HostCallbackContext{callback};
444     return stream_executor_->host_callback(&device_, stream_handle,
445                                            &HostCallbackTrampoline, ctx);
446   }
AllocateEvent(Event * event)447   port::Status AllocateEvent(Event* event) override {
448     DCHECK(event != nullptr);
449     return static_cast<CEvent*>(event->implementation())->Create();
450   }
DeallocateEvent(Event * event)451   port::Status DeallocateEvent(Event* event) override {
452     static_cast<CEvent*>(event->implementation())->Destroy();
453     return port::Status::OK();
454   }
RecordEvent(Stream * stream,Event * event)455   port::Status RecordEvent(Stream* stream, Event* event) override {
456     SP_Stream stream_handle =
457         static_cast<CStream*>(stream->implementation())->Handle();
458     return static_cast<CEvent*>(event->implementation())->Record(stream_handle);
459   }
WaitForEvent(Stream * stream,Event * event)460   port::Status WaitForEvent(Stream* stream, Event* event) override {
461     SP_Stream stream_handle =
462         static_cast<CStream*>(stream->implementation())->Handle();
463     SP_Event event_handle =
464         static_cast<CEvent*>(event->implementation())->Handle();
465     OwnedTFStatus c_status(TF_NewStatus());
466     stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
467                                      c_status.get());
468     port::Status s = StatusFromTF_Status(c_status.get());
469     return s;
470   }
PollForEventStatus(Event * event)471   Event::Status PollForEventStatus(Event* event) override {
472     SP_Event event_handle =
473         static_cast<CEvent*>(event->implementation())->Handle();
474     SE_EventStatus event_status =
475         stream_executor_->get_event_status(&device_, event_handle);
476     return SEEventStatusToEventStatus(event_status);
477   }
AllocateStream(Stream * stream)478   bool AllocateStream(Stream* stream) override {
479     DCHECK(stream != nullptr);
480     port::Status status =
481         static_cast<CStream*>(stream->implementation())->Create();
482     // TODO(annarev): update AllocateStream to return status instead
483     // (similar to AllocateEvent).
484     return status.ok();
485   }
DeallocateStream(Stream * stream)486   void DeallocateStream(Stream* stream) override {
487     static_cast<CStream*>(stream->implementation())->Destroy();
488   }
CreateStreamDependency(Stream * dependent,Stream * other)489   bool CreateStreamDependency(Stream* dependent, Stream* other) override {
490     OwnedTFStatus c_status(TF_NewStatus());
491     SP_Stream dependent_handle =
492         static_cast<CStream*>(dependent->implementation())->Handle();
493     SP_Stream other_handle =
494         static_cast<CStream*>(other->implementation())->Handle();
495     stream_executor_->create_stream_dependency(&device_, dependent_handle,
496                                                other_handle, c_status.get());
497     if (TF_GetCode(c_status.get()) != TF_OK) {
498       LOG(ERROR) << TF_Message(c_status.get());
499       return false;
500     }
501     return true;
502   }
AllocateTimer(Timer * timer)503   bool AllocateTimer(Timer* timer) override {
504     port::Status status =
505         static_cast<CTimer*>(timer->implementation())->Create();
506     // TODO(annarev): change return value of AllocateTimer
507     // to status (similar to AllocateEvent).
508     return status.ok();
509   }
DeallocateTimer(Timer * timer)510   void DeallocateTimer(Timer* timer) override {
511     static_cast<CTimer*>(timer->implementation())->Destroy();
512   }
StartTimer(Stream * stream,Timer * timer)513   bool StartTimer(Stream* stream, Timer* timer) override {
514     OwnedTFStatus c_status(TF_NewStatus());
515     SP_Stream stream_handle =
516         static_cast<CStream*>(stream->implementation())->Handle();
517     SP_Timer timer_handle =
518         static_cast<CTimer*>(timer->implementation())->Handle();
519     stream_executor_->start_timer(&device_, stream_handle, timer_handle,
520                                   c_status.get());
521     if (TF_GetCode(c_status.get()) != TF_OK) {
522       LOG(ERROR) << TF_Message(c_status.get());
523       return false;
524     }
525     return true;
526   }
StopTimer(Stream * stream,Timer * timer)527   bool StopTimer(Stream* stream, Timer* timer) override {
528     OwnedTFStatus c_status(TF_NewStatus());
529     SP_Stream stream_handle =
530         static_cast<CStream*>(stream->implementation())->Handle();
531     SP_Timer timer_handle =
532         static_cast<CTimer*>(timer->implementation())->Handle();
533     stream_executor_->stop_timer(&device_, stream_handle, timer_handle,
534                                  c_status.get());
535     if (TF_GetCode(c_status.get()) != TF_OK) {
536       LOG(ERROR) << TF_Message(c_status.get());
537       return false;
538     }
539     return true;
540   }
BlockHostForEvent(Stream * stream,Event * event)541   port::Status BlockHostForEvent(Stream* stream, Event* event) {
542     OwnedTFStatus c_status(TF_NewStatus());
543     SP_Event event_handle =
544         static_cast<CEvent*>(event->implementation())->Handle();
545     stream_executor_->block_host_for_event(&device_, event_handle,
546                                            c_status.get());
547     return StatusFromTF_Status(c_status.get());
548   }
549 
BlockHostUntilDone(Stream * stream)550   port::Status BlockHostUntilDone(Stream* stream) override {
551     OwnedTFStatus c_status(TF_NewStatus());
552     SP_Stream stream_handle =
553         static_cast<CStream*>(stream->implementation())->Handle();
554 
555     // If `block_host_until_done` is set, use it.
556     if (stream_executor_->block_host_until_done != nullptr) {
557       stream_executor_->block_host_until_done(&device_, stream_handle,
558                                               c_status.get());
559       return StatusFromTF_Status(c_status.get());
560     }
561     // Create and record an event and then wait for it.
562     SP_Event event_handle;
563     stream_executor_->create_event(&device_, &event_handle, c_status.get());
564     TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
565     stream_executor_->record_event(&device_, stream_handle, event_handle,
566                                    c_status.get());
567     port::Status s = StatusFromTF_Status(c_status.get());
568     if (!s.ok()) {
569       stream_executor_->destroy_event(&device_, event_handle);
570       return s;
571     }
572     stream_executor_->block_host_for_event(&device_, event_handle,
573                                            c_status.get());
574     stream_executor_->destroy_event(&device_, event_handle);
575     return StatusFromTF_Status(c_status.get());
576   }
577 
GetStatus(Stream * stream)578   port::Status GetStatus(Stream* stream) override {
579     OwnedTFStatus c_status(TF_NewStatus());
580     SP_Stream stream_handle =
581         static_cast<CStream*>(stream->implementation())->Handle();
582     stream_executor_->get_stream_status(&device_, stream_handle,
583                                         c_status.get());
584     return StatusFromTF_Status(c_status.get());
585   }
PlatformDeviceCount()586   int PlatformDeviceCount() override { return visible_device_count_; }
EnablePeerAccessTo(StreamExecutorInterface * other)587   port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
588     return port::UnimplementedError(
589         "EnablePeerAccessTo is not supported by pluggable device.");
590   }
CanEnablePeerAccessTo(StreamExecutorInterface * other)591   bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
592     return false;
593   }
594 
DeviceMemoryUsage(int64 * free,int64 * total) const595   bool DeviceMemoryUsage(int64* free, int64* total) const override {
596     static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
597                   "64-bit int types should match in size");
598     return stream_executor_->device_memory_usage(
599         &device_, reinterpret_cast<int64_t*>(free),
600         reinterpret_cast<int64_t*>(total));
601   }
602 
603   // Creates a new DeviceDescription object.
604   // Ownership is transferred to the caller.
CreateDeviceDescription() const605   port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
606       const override {
607     OwnedTFStatus c_status(TF_NewStatus());
608 
609     internal::DeviceDescriptionBuilder builder;
610     if (device_.hardware_name != nullptr) {
611       builder.set_name(device_.hardware_name);
612     }
613     if (device_.device_vendor != nullptr) {
614       builder.set_device_vendor(device_.device_vendor);
615     }
616     if (device_.pci_bus_id != nullptr) {
617       builder.set_pci_bus_id(device_.pci_bus_id);
618     }
619 
620     if (device_fns_->get_numa_node != nullptr) {
621       int32_t numa_node = device_fns_->get_numa_node(&device_);
622       if (numa_node >= 0) {
623         builder.set_numa_node(numa_node);
624       }
625     }
626 
627     if (device_fns_->get_memory_bandwidth != nullptr) {
628       int64_t memory_bandwidth = device_fns_->get_memory_bandwidth(&device_);
629       if (memory_bandwidth >= 0) {
630         builder.set_memory_bandwidth(memory_bandwidth);
631       }
632     }
633     // TODO(annarev): Add gflops field in DeviceDescription and set it here.
634     // TODO(annarev): Perhaps add `supports_unified_memory` in
635     // DeviceDescription.
636     return builder.Build();
637   }
638 
639   // Each call creates a new instance of the platform-specific implementation of
640   // the corresponding interface type.
CreateEventImplementation()641   std::unique_ptr<internal::EventInterface> CreateEventImplementation()
642       override {
643     return std::unique_ptr<internal::EventInterface>(
644         new CEvent(&device_, stream_executor_));
645   }
CreateKernelImplementation()646   std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
647       override {
648     LOG(FATAL)
649         << "CreateKernelImplementation is not supported by pluggable device.";
650   }
GetStreamImplementation()651   std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
652       override {
653     return std::unique_ptr<internal::StreamInterface>(
654         new CStream(&device_, stream_executor_));
655   }
GetTimerImplementation()656   std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
657     return std::unique_ptr<internal::TimerInterface>(
658         new CTimer(&device_, stream_executor_, timer_fns_));
659   }
660 
661  private:
662   SP_Device device_;
663   SP_DeviceFns* device_fns_;
664   SP_StreamExecutor* stream_executor_;
665   SP_Platform* platform_;
666   SP_PlatformFns* platform_fns_;
667   SP_TimerFns* timer_fns_;
668   std::string platform_name_;
669   int visible_device_count_;
670 };
671 }  // namespace
672 
CPlatform(SP_Platform platform,void (* destroy_platform)(SP_Platform *),SP_PlatformFns platform_fns,void (* destroy_platform_fns)(SP_PlatformFns *),SP_DeviceFns device_fns,SP_StreamExecutor stream_executor,SP_TimerFns timer_fns)673 CPlatform::CPlatform(SP_Platform platform,
674                      void (*destroy_platform)(SP_Platform*),
675                      SP_PlatformFns platform_fns,
676                      void (*destroy_platform_fns)(SP_PlatformFns*),
677                      SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
678                      SP_TimerFns timer_fns)
679     : platform_(std::move(platform)),
680       destroy_platform_(destroy_platform),
681       platform_fns_(std::move(platform_fns)),
682       destroy_platform_fns_(destroy_platform_fns),
683       device_fns_(std::move(device_fns)),
684       stream_executor_(std::move(stream_executor)),
685       timer_fns_(std::move(timer_fns)),
686       name_(platform.name) {}
687 
~CPlatform()688 CPlatform::~CPlatform() {
689   executor_cache_.DestroyAllExecutors();
690   platform_fns_.destroy_device_fns(&platform_, &device_fns_);
691   platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
692   platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
693   destroy_platform_(&platform_);
694   destroy_platform_fns_(&platform_fns_);
695 }
696 
697 port::StatusOr<std::unique_ptr<DeviceDescription>>
DescriptionForDevice(int ordinal) const698 CPlatform::DescriptionForDevice(int ordinal) const {
699   // TODO(annarev): see if we can get StreamExecutor instance
700   // and call GetDeviceDescription. executor_cache_.Get would need
701   // to be made const for it to work.
702   internal::DeviceDescriptionBuilder builder;
703   builder.set_name(name_);
704   return builder.Build();
705 }
ExecutorForDevice(int ordinal)706 port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDevice(int ordinal) {
707   stream_executor::StreamExecutorConfig config;
708   config.ordinal = ordinal;
709   return GetExecutor(config);
710 }
ExecutorForDeviceWithPluginConfig(int ordinal,const PluginConfig & plugin_config)711 port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDeviceWithPluginConfig(
712     int ordinal, const PluginConfig& plugin_config) {
713   StreamExecutorConfig config;
714   config.ordinal = ordinal;
715   config.plugin_config = plugin_config;
716   return GetExecutor(config);
717 }
GetExecutor(const StreamExecutorConfig & config)718 port::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
719     const StreamExecutorConfig& config) {
720   return executor_cache_.GetOrCreate(
721       config, [&]() { return GetUncachedExecutor(config); });
722 }
GetUncachedExecutor(const StreamExecutorConfig & config)723 port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
724     const StreamExecutorConfig& config) {
725   // Fill device creation params
726   SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE};
727   SP_Device device{SP_DEVICE_STRUCT_SIZE};
728   device_params.device = &device;
729   device_params.ext = nullptr;
730   device_params.ordinal = config.ordinal;
731   OwnedTFStatus c_status(TF_NewStatus());
732 
733   // Create Device
734   platform_fns_.create_device(&platform_, &device_params, c_status.get());
735   TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
736   TF_RETURN_IF_ERROR(ValidateSPDevice(device));
737 
738   auto executor = absl::make_unique<CStreamExecutor>(
739       std::move(device), &device_fns_, &stream_executor_, &platform_,
740       &platform_fns_, &timer_fns_, name_, platform_.visible_device_count);
741   auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
742                                                   config.ordinal);
743   return result;
744 }
745 
InitStreamExecutorPlugin(void * dso_handle)746 port::Status InitStreamExecutorPlugin(void* dso_handle) {
747   tensorflow::Env* env = tensorflow::Env::Default();
748 
749   // Step 1: Load symbol for `TF_InitPlugin`
750   void* dso_symbol;
751   TF_RETURN_IF_ERROR(
752       env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol));
753 
754   // Step 2: Call `TF_InitPlugin`
755   auto init_fn = reinterpret_cast<SEInitPluginFn>(dso_symbol);
756   return InitStreamExecutorPlugin(init_fn);
757 }
758 
InitStreamExecutorPlugin(SEInitPluginFn init_fn)759 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
760   SE_PlatformRegistrationParams params{
761       SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
762   SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
763   SP_PlatformFns platform_fns{SP_PLATFORM_FNS_STRUCT_SIZE};
764   params.major_version = SE_MAJOR;
765   params.minor_version = SE_MINOR;
766   params.patch_version = SE_PATCH;
767   params.platform = &platform;
768   params.platform_fns = &platform_fns;
769 
770   OwnedTFStatus c_status(TF_NewStatus());
771   init_fn(&params, c_status.get());
772   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
773   TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params));
774   TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
775   TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns));
776 
777   // Fill SP_DeviceFns creation params
778   SE_CreateDeviceFnsParams device_fns_params{
779       SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE};
780   SP_DeviceFns device_fns{SP_DEVICE_FNS_STRUCT_SIZE};
781   device_fns_params.device_fns = &device_fns;
782 
783   // Create StreamExecutor
784   platform_fns.create_device_fns(&platform, &device_fns_params, c_status.get());
785   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
786   TF_RETURN_IF_ERROR(ValidateSPDeviceFns(device_fns));
787 
788   // Fill stream executor creation params
789   SE_CreateStreamExecutorParams se_params{
790       SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
791   SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE};
792   se_params.stream_executor = &se;
793 
794   // Create StreamExecutor
795   platform_fns.create_stream_executor(&platform, &se_params, c_status.get());
796   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
797   TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se, platform));
798 
799   SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE};
800   platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get());
801   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
802   TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
803 
804   // Register new platform
805   std::string platform_name = std::string(platform.name);
806   std::unique_ptr<stream_executor::CPlatform> cplatform(
807       new stream_executor::CPlatform(
808           std::move(platform), params.destroy_platform, std::move(platform_fns),
809           params.destroy_platform_fns, std::move(device_fns), std::move(se),
810           std::move(timer_fns)));
811   SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
812       std::move(cplatform)));
813 
814   // TODO(annarev): Add pluggable device registration here.
815   return port::Status::OK();
816 }
817 }  // namespace stream_executor
818