1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Classes and utilities that work with StreamExecutor C API for internal use. 16 // This includes functions used for device registration and interfaces needed 17 // for testing. 18 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 19 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 20 21 #include "tensorflow/c/experimental/stream_executor/stream_executor.h" 22 #include "tensorflow/c/tf_status_helper.h" 23 #include "tensorflow/stream_executor/executor_cache.h" 24 #include "tensorflow/stream_executor/lib/status.h" 25 #include "tensorflow/stream_executor/platform.h" 26 27 namespace stream_executor { 28 29 // Plugin initialization function that a device plugin 30 // must define. 31 typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, 32 TF_Status* const); 33 34 // Registers StreamExecutor platform. `device_type` and `platform_name` are 35 // output parameters. 36 port::Status InitStreamExecutorPlugin(void* dso_handle, 37 std::string* device_type, 38 std::string* platform_name); 39 40 // Allow registering a StreamExecutor plugin using a function (used for 41 // testing). 42 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, 43 std::string* device_type, 44 std::string* platform_name); 45 46 struct TFStatusDeleter { operatorTFStatusDeleter47 void operator()(TF_Status* s) const { TF_DeleteStatus(s); } 48 }; 49 50 // This file implements core stream executor base classes in terms of 51 // the C API defined in stream_executor.h. A class "CSomething" represents a 52 // "Something" that can be manipulated via calls in the C interface. 53 class CPlatform : public Platform { 54 public: 55 explicit CPlatform(SP_Platform platform, 56 void (*destroy_platform)(SP_Platform*), 57 SP_PlatformFns platform_fns, 58 void (*destroy_platform_fns)(SP_PlatformFns*), 59 SP_DeviceFns device_fns, SP_StreamExecutor stream_executor, 60 SP_TimerFns timer_fns); 61 ~CPlatform() override; 62 id()63 Id id() const override { return const_cast<int*>(&plugin_id_value_); } Name()64 const std::string& Name() const override { return name_; } VisibleDeviceCount()65 int VisibleDeviceCount() const override { 66 int visible_device_count = 0; 67 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 68 platform_fns_.get_device_count(&platform_, &visible_device_count, 69 c_status.get()); 70 if (TF_GetCode(c_status.get()) != TF_OK) { 71 LOG(ERROR) << TF_Message(c_status.get()); 72 return 0; 73 } 74 return visible_device_count; 75 } UseBfcAllocator()76 bool UseBfcAllocator() const { return platform_.use_bfc_allocator; } 77 port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice( 78 int ordinal) const override; 79 port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override; 80 port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig( 81 int ordinal, const PluginConfig& plugin_config) override; 82 port::StatusOr<StreamExecutor*> GetExecutor( 83 const StreamExecutorConfig& config) override; 84 port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor( 85 const StreamExecutorConfig& config) override; 86 87 // Trace listener is not supported RegisterTraceListener(std::unique_ptr<TraceListener> listener)88 void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override { 89 LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device"; 90 } UnregisterTraceListener(TraceListener * listener)91 void UnregisterTraceListener(TraceListener* listener) override {} 92 DestroyAllExecutors()93 void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); } 94 95 private: 96 SP_Platform platform_; 97 void (*destroy_platform_)(SP_Platform*); 98 SP_PlatformFns platform_fns_; 99 void (*destroy_platform_fns_)(SP_PlatformFns*); 100 SP_DeviceFns device_fns_; 101 SP_StreamExecutor stream_executor_; 102 SP_TimerFns timer_fns_; 103 const std::string name_; 104 int plugin_id_value_; 105 stream_executor::ExecutorCache executor_cache_; 106 }; 107 108 class CStream : public internal::StreamInterface { 109 public: CStream(SP_Device * device,SP_StreamExecutor * stream_executor)110 CStream(SP_Device* device, SP_StreamExecutor* stream_executor) 111 : device_(device), 112 stream_executor_(stream_executor), 113 stream_handle_(nullptr) {} ~CStream()114 ~CStream() override { Destroy(); } 115 Create()116 port::Status Create() { 117 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 118 stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); 119 port::Status s = tensorflow::StatusFromTF_Status(c_status.get()); 120 return s; 121 } 122 Destroy()123 void Destroy() { 124 if (stream_handle_ != nullptr) { 125 stream_executor_->destroy_stream(device_, stream_handle_); 126 stream_handle_ = nullptr; 127 } 128 } 129 Handle()130 SP_Stream Handle() { return stream_handle_; } 131 132 private: 133 SP_Device* device_; 134 SP_StreamExecutor* stream_executor_; 135 SP_Stream stream_handle_; 136 }; 137 138 class CEvent : public internal::EventInterface { 139 public: CEvent(SP_Device * device,SP_StreamExecutor * stream_executor)140 CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) 141 : device_(device), 142 stream_executor_(stream_executor), 143 event_handle_(nullptr) {} ~CEvent()144 ~CEvent() override { Destroy(); } 145 Create()146 port::Status Create() { 147 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 148 stream_executor_->create_event(device_, &event_handle_, c_status.get()); 149 return tensorflow::StatusFromTF_Status(c_status.get()); 150 } 151 Record(SP_Stream stream_handle)152 port::Status Record(SP_Stream stream_handle) { 153 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 154 stream_executor_->record_event(device_, stream_handle, event_handle_, 155 c_status.get()); 156 return tensorflow::StatusFromTF_Status(c_status.get()); 157 } 158 Destroy()159 void Destroy() { 160 if (event_handle_ != nullptr) { 161 stream_executor_->destroy_event(device_, event_handle_); 162 event_handle_ = nullptr; 163 } 164 } 165 Handle()166 SP_Event Handle() { return event_handle_; } 167 168 private: 169 SP_Device* device_; 170 SP_StreamExecutor* stream_executor_; 171 SP_Event event_handle_; 172 }; 173 174 class CTimer : public internal::TimerInterface { 175 public: CTimer(SP_Device * device,SP_StreamExecutor * stream_executor,SP_TimerFns * timer_fns)176 CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, 177 SP_TimerFns* timer_fns) 178 : device_(device), 179 stream_executor_(stream_executor), 180 timer_handle_(nullptr), 181 timer_fns_(timer_fns) {} ~CTimer()182 ~CTimer() override { Destroy(); } 183 Create()184 port::Status Create() { 185 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 186 stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); 187 return tensorflow::StatusFromTF_Status(c_status.get()); 188 } 189 Destroy()190 void Destroy() { 191 if (timer_handle_ != nullptr) { 192 stream_executor_->destroy_timer(device_, timer_handle_); 193 timer_handle_ = nullptr; 194 } 195 } 196 Handle()197 SP_Timer Handle() { return timer_handle_; } 198 Microseconds()199 uint64 Microseconds() const override { 200 return timer_fns_->nanoseconds(timer_handle_) / 1000; 201 } 202 Nanoseconds()203 uint64 Nanoseconds() const override { 204 return timer_fns_->nanoseconds(timer_handle_); 205 } 206 207 private: 208 SP_Device* device_; 209 SP_StreamExecutor* stream_executor_; 210 SP_Timer timer_handle_; 211 SP_TimerFns* timer_fns_; 212 }; 213 214 } // namespace stream_executor 215 #endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 216